Line data Source code
1 : /*
2 : * Copyright (c) 2016 The WebRTC project authors. All Rights Reserved.
3 : *
4 : * Use of this source code is governed by a BSD-style license
5 : * that can be found in the LICENSE file in the root of the source
6 : * tree. An additional intellectual property rights grant can be found
7 : * in the file PATENTS. All contributing project authors may
8 : * be found in the AUTHORS file in the root of the source tree.
9 : */
10 :
11 : #include "webrtc/modules/audio_processing/level_controller/signal_classifier.h"
12 :
13 : #include <algorithm>
14 : #include <numeric>
15 : #include <vector>
16 :
17 : #include "webrtc/base/array_view.h"
18 : #include "webrtc/base/constructormagic.h"
19 : #include "webrtc/modules/audio_processing/audio_buffer.h"
20 : #include "webrtc/modules/audio_processing/level_controller/down_sampler.h"
21 : #include "webrtc/modules/audio_processing/level_controller/noise_spectrum_estimator.h"
22 : #include "webrtc/modules/audio_processing/logging/apm_data_dumper.h"
23 :
24 : namespace webrtc {
25 : namespace {
26 :
27 0 : void RemoveDcLevel(rtc::ArrayView<float> x) {
28 0 : RTC_DCHECK_LT(0, x.size());
29 0 : float mean = std::accumulate(x.data(), x.data() + x.size(), 0.f);
30 0 : mean /= x.size();
31 :
32 0 : for (float& v : x) {
33 0 : v -= mean;
34 : }
35 0 : }
36 :
37 0 : void PowerSpectrum(const OouraFft* ooura_fft,
38 : rtc::ArrayView<const float> x,
39 : rtc::ArrayView<float> spectrum) {
40 0 : RTC_DCHECK_EQ(65, spectrum.size());
41 0 : RTC_DCHECK_EQ(128, x.size());
42 : float X[128];
43 0 : std::copy(x.data(), x.data() + x.size(), X);
44 0 : ooura_fft->Fft(X);
45 :
46 0 : float* X_p = X;
47 0 : RTC_DCHECK_EQ(X_p, &X[0]);
48 0 : spectrum[0] = (*X_p) * (*X_p);
49 0 : ++X_p;
50 0 : RTC_DCHECK_EQ(X_p, &X[1]);
51 0 : spectrum[64] = (*X_p) * (*X_p);
52 0 : for (int k = 1; k < 64; ++k) {
53 0 : ++X_p;
54 0 : RTC_DCHECK_EQ(X_p, &X[2 * k]);
55 0 : spectrum[k] = (*X_p) * (*X_p);
56 0 : ++X_p;
57 0 : RTC_DCHECK_EQ(X_p, &X[2 * k + 1]);
58 0 : spectrum[k] += (*X_p) * (*X_p);
59 : }
60 0 : }
61 :
62 0 : webrtc::SignalClassifier::SignalType ClassifySignal(
63 : rtc::ArrayView<const float> signal_spectrum,
64 : rtc::ArrayView<const float> noise_spectrum,
65 : ApmDataDumper* data_dumper) {
66 0 : int num_stationary_bands = 0;
67 0 : int num_highly_nonstationary_bands = 0;
68 :
69 : // Detect stationary and highly nonstationary bands.
70 0 : for (size_t k = 1; k < 40; k++) {
71 0 : if (signal_spectrum[k] < 3 * noise_spectrum[k] &&
72 0 : signal_spectrum[k] * 3 > noise_spectrum[k]) {
73 0 : ++num_stationary_bands;
74 0 : } else if (signal_spectrum[k] > 9 * noise_spectrum[k]) {
75 0 : ++num_highly_nonstationary_bands;
76 : }
77 : }
78 :
79 0 : data_dumper->DumpRaw("lc_num_stationary_bands", 1, &num_stationary_bands);
80 : data_dumper->DumpRaw("lc_num_highly_nonstationary_bands", 1,
81 0 : &num_highly_nonstationary_bands);
82 :
83 : // Use the detected number of bands to classify the overall signal
84 : // stationarity.
85 0 : if (num_stationary_bands > 15) {
86 0 : return SignalClassifier::SignalType::kStationary;
87 0 : } else if (num_highly_nonstationary_bands > 15) {
88 0 : return SignalClassifier::SignalType::kHighlyNonStationary;
89 : } else {
90 0 : return SignalClassifier::SignalType::kNonStationary;
91 : }
92 : }
93 :
94 : } // namespace
95 :
96 0 : SignalClassifier::FrameExtender::FrameExtender(size_t frame_size,
97 0 : size_t extended_frame_size)
98 0 : : x_old_(extended_frame_size - frame_size, 0.f) {}
99 :
100 : SignalClassifier::FrameExtender::~FrameExtender() = default;
101 :
102 0 : void SignalClassifier::FrameExtender::ExtendFrame(
103 : rtc::ArrayView<const float> x,
104 : rtc::ArrayView<float> x_extended) {
105 0 : RTC_DCHECK_EQ(x_old_.size() + x.size(), x_extended.size());
106 0 : std::copy(x_old_.data(), x_old_.data() + x_old_.size(), x_extended.data());
107 0 : std::copy(x.data(), x.data() + x.size(), x_extended.data() + x_old_.size());
108 0 : std::copy(x_extended.data() + x_extended.size() - x_old_.size(),
109 0 : x_extended.data() + x_extended.size(), x_old_.data());
110 0 : }
111 :
112 0 : SignalClassifier::SignalClassifier(ApmDataDumper* data_dumper)
113 : : data_dumper_(data_dumper),
114 0 : down_sampler_(data_dumper_),
115 0 : noise_spectrum_estimator_(data_dumper_) {
116 0 : Initialize(AudioProcessing::kSampleRate48kHz);
117 0 : }
118 0 : SignalClassifier::~SignalClassifier() {}
119 :
120 0 : void SignalClassifier::Initialize(int sample_rate_hz) {
121 0 : down_sampler_.Initialize(sample_rate_hz);
122 0 : noise_spectrum_estimator_.Initialize();
123 0 : frame_extender_.reset(new FrameExtender(80, 128));
124 0 : sample_rate_hz_ = sample_rate_hz;
125 0 : initialization_frames_left_ = 2;
126 0 : consistent_classification_counter_ = 3;
127 0 : last_signal_type_ = SignalClassifier::SignalType::kNonStationary;
128 0 : }
129 :
130 0 : void SignalClassifier::Analyze(const AudioBuffer& audio,
131 : SignalType* signal_type) {
132 0 : RTC_DCHECK_EQ(audio.num_frames(), sample_rate_hz_ / 100);
133 :
134 : // Compute the signal power spectrum.
135 : float downsampled_frame[80];
136 0 : down_sampler_.DownSample(rtc::ArrayView<const float>(
137 0 : audio.channels_const_f()[0], audio.num_frames()),
138 0 : downsampled_frame);
139 : float extended_frame[128];
140 0 : frame_extender_->ExtendFrame(downsampled_frame, extended_frame);
141 0 : RemoveDcLevel(extended_frame);
142 : float signal_spectrum[65];
143 0 : PowerSpectrum(&ooura_fft_, extended_frame, signal_spectrum);
144 :
145 : // Classify the signal based on the estimate of the noise spectrum and the
146 : // signal spectrum estimate.
147 0 : *signal_type = ClassifySignal(signal_spectrum,
148 : noise_spectrum_estimator_.GetNoiseSpectrum(),
149 0 : data_dumper_);
150 :
151 : // Update the noise spectrum based on the signal spectrum.
152 0 : noise_spectrum_estimator_.Update(signal_spectrum,
153 0 : initialization_frames_left_ > 0);
154 :
155 : // Update the number of frames until a reliable signal spectrum is achieved.
156 0 : initialization_frames_left_ = std::max(0, initialization_frames_left_ - 1);
157 :
158 0 : if (last_signal_type_ == *signal_type) {
159 0 : consistent_classification_counter_ =
160 0 : std::max(0, consistent_classification_counter_ - 1);
161 : } else {
162 0 : last_signal_type_ = *signal_type;
163 0 : consistent_classification_counter_ = 3;
164 : }
165 :
166 0 : if (consistent_classification_counter_ > 0) {
167 0 : *signal_type = SignalClassifier::SignalType::kNonStationary;
168 : }
169 0 : }
170 :
171 : } // namespace webrtc
|