Line data Source code
1 : /*
2 : * Copyright (c) 2013 The WebRTC project authors. All Rights Reserved.
3 : *
4 : * Use of this source code is governed by a BSD-style license
5 : * that can be found in the LICENSE file in the root of the source
6 : * tree. An additional intellectual property rights grant can be found
7 : * in the file PATENTS. All contributing project authors may
8 : * be found in the AUTHORS file in the root of the source tree.
9 : */
10 :
11 : #include "webrtc/modules/audio_processing/transient/transient_suppressor.h"
12 :
13 : #include <math.h>
14 : #include <string.h>
15 : #include <cmath>
16 : #include <complex>
17 : #include <deque>
18 : #include <set>
19 :
20 : #include "webrtc/base/checks.h"
21 : #include "webrtc/common_audio/fft4g.h"
22 : #include "webrtc/common_audio/include/audio_util.h"
23 : #include "webrtc/common_audio/signal_processing/include/signal_processing_library.h"
24 : #include "webrtc/modules/audio_processing/transient/common.h"
25 : #include "webrtc/modules/audio_processing/transient/transient_detector.h"
26 : #include "webrtc/modules/audio_processing/ns/windows_private.h"
27 : #include "webrtc/system_wrappers/include/logging.h"
28 : #include "webrtc/typedefs.h"
29 :
30 : namespace webrtc {
31 :
32 : static const float kMeanIIRCoefficient = 0.5f;
33 : static const float kVoiceThreshold = 0.02f;
34 :
35 : // TODO(aluebs): Check if these values work also for 48kHz.
36 : static const size_t kMinVoiceBin = 3;
37 : static const size_t kMaxVoiceBin = 60;
38 :
39 : namespace {
40 :
41 0 : float ComplexMagnitude(float a, float b) {
42 0 : return std::abs(a) + std::abs(b);
43 : }
44 :
45 : } // namespace
46 :
47 0 : TransientSuppressor::TransientSuppressor()
48 : : data_length_(0),
49 : detection_length_(0),
50 : analysis_length_(0),
51 : buffer_delay_(0),
52 : complex_analysis_length_(0),
53 : num_channels_(0),
54 : window_(NULL),
55 : detector_smoothed_(0.f),
56 : keypress_counter_(0),
57 : chunks_since_keypress_(0),
58 : detection_enabled_(false),
59 : suppression_enabled_(false),
60 : use_hard_restoration_(false),
61 : chunks_since_voice_change_(0),
62 : seed_(182),
63 0 : using_reference_(false) {
64 0 : }
65 :
66 0 : TransientSuppressor::~TransientSuppressor() {}
67 :
68 0 : int TransientSuppressor::Initialize(int sample_rate_hz,
69 : int detection_rate_hz,
70 : int num_channels) {
71 0 : switch (sample_rate_hz) {
72 : case ts::kSampleRate8kHz:
73 0 : analysis_length_ = 128u;
74 0 : window_ = kBlocks80w128;
75 0 : break;
76 : case ts::kSampleRate16kHz:
77 0 : analysis_length_ = 256u;
78 0 : window_ = kBlocks160w256;
79 0 : break;
80 : case ts::kSampleRate32kHz:
81 0 : analysis_length_ = 512u;
82 0 : window_ = kBlocks320w512;
83 0 : break;
84 : case ts::kSampleRate48kHz:
85 0 : analysis_length_ = 1024u;
86 0 : window_ = kBlocks480w1024;
87 0 : break;
88 : default:
89 0 : return -1;
90 : }
91 0 : if (detection_rate_hz != ts::kSampleRate8kHz &&
92 0 : detection_rate_hz != ts::kSampleRate16kHz &&
93 0 : detection_rate_hz != ts::kSampleRate32kHz &&
94 : detection_rate_hz != ts::kSampleRate48kHz) {
95 0 : return -1;
96 : }
97 0 : if (num_channels <= 0) {
98 0 : return -1;
99 : }
100 :
101 0 : detector_.reset(new TransientDetector(detection_rate_hz));
102 0 : data_length_ = sample_rate_hz * ts::kChunkSizeMs / 1000;
103 0 : if (data_length_ > analysis_length_) {
104 0 : RTC_NOTREACHED();
105 0 : return -1;
106 : }
107 0 : buffer_delay_ = analysis_length_ - data_length_;
108 :
109 0 : complex_analysis_length_ = analysis_length_ / 2 + 1;
110 0 : RTC_DCHECK_GE(complex_analysis_length_, kMaxVoiceBin);
111 0 : num_channels_ = num_channels;
112 0 : in_buffer_.reset(new float[analysis_length_ * num_channels_]);
113 0 : memset(in_buffer_.get(),
114 : 0,
115 0 : analysis_length_ * num_channels_ * sizeof(in_buffer_[0]));
116 0 : detection_length_ = detection_rate_hz * ts::kChunkSizeMs / 1000;
117 0 : detection_buffer_.reset(new float[detection_length_]);
118 0 : memset(detection_buffer_.get(),
119 : 0,
120 0 : detection_length_ * sizeof(detection_buffer_[0]));
121 0 : out_buffer_.reset(new float[analysis_length_ * num_channels_]);
122 0 : memset(out_buffer_.get(),
123 : 0,
124 0 : analysis_length_ * num_channels_ * sizeof(out_buffer_[0]));
125 : // ip[0] must be zero to trigger initialization using rdft().
126 0 : size_t ip_length = 2 + sqrtf(analysis_length_);
127 0 : ip_.reset(new size_t[ip_length]());
128 0 : memset(ip_.get(), 0, ip_length * sizeof(ip_[0]));
129 0 : wfft_.reset(new float[complex_analysis_length_ - 1]);
130 0 : memset(wfft_.get(), 0, (complex_analysis_length_ - 1) * sizeof(wfft_[0]));
131 0 : spectral_mean_.reset(new float[complex_analysis_length_ * num_channels_]);
132 0 : memset(spectral_mean_.get(),
133 : 0,
134 0 : complex_analysis_length_ * num_channels_ * sizeof(spectral_mean_[0]));
135 0 : fft_buffer_.reset(new float[analysis_length_ + 2]);
136 0 : memset(fft_buffer_.get(), 0, (analysis_length_ + 2) * sizeof(fft_buffer_[0]));
137 0 : magnitudes_.reset(new float[complex_analysis_length_]);
138 0 : memset(magnitudes_.get(),
139 : 0,
140 0 : complex_analysis_length_ * sizeof(magnitudes_[0]));
141 0 : mean_factor_.reset(new float[complex_analysis_length_]);
142 :
143 : static const float kFactorHeight = 10.f;
144 : static const float kLowSlope = 1.f;
145 : static const float kHighSlope = 0.3f;
146 0 : for (size_t i = 0; i < complex_analysis_length_; ++i) {
147 0 : mean_factor_[i] =
148 0 : kFactorHeight /
149 0 : (1.f + exp(kLowSlope * static_cast<int>(i - kMinVoiceBin))) +
150 0 : kFactorHeight /
151 0 : (1.f + exp(kHighSlope * static_cast<int>(kMaxVoiceBin - i)));
152 : }
153 0 : detector_smoothed_ = 0.f;
154 0 : keypress_counter_ = 0;
155 0 : chunks_since_keypress_ = 0;
156 0 : detection_enabled_ = false;
157 0 : suppression_enabled_ = false;
158 0 : use_hard_restoration_ = false;
159 0 : chunks_since_voice_change_ = 0;
160 0 : seed_ = 182;
161 0 : using_reference_ = false;
162 0 : return 0;
163 : }
164 :
165 0 : int TransientSuppressor::Suppress(float* data,
166 : size_t data_length,
167 : int num_channels,
168 : const float* detection_data,
169 : size_t detection_length,
170 : const float* reference_data,
171 : size_t reference_length,
172 : float voice_probability,
173 : bool key_pressed) {
174 0 : if (!data || data_length != data_length_ || num_channels != num_channels_ ||
175 0 : detection_length != detection_length_ || voice_probability < 0 ||
176 : voice_probability > 1) {
177 0 : return -1;
178 : }
179 :
180 0 : UpdateKeypress(key_pressed);
181 0 : UpdateBuffers(data);
182 :
183 0 : int result = 0;
184 0 : if (detection_enabled_) {
185 0 : UpdateRestoration(voice_probability);
186 :
187 0 : if (!detection_data) {
188 : // Use the input data of the first channel if special detection data is
189 : // not supplied.
190 0 : detection_data = &in_buffer_[buffer_delay_];
191 : }
192 :
193 0 : float detector_result = detector_->Detect(
194 0 : detection_data, detection_length, reference_data, reference_length);
195 0 : if (detector_result < 0) {
196 0 : return -1;
197 : }
198 :
199 0 : using_reference_ = detector_->using_reference();
200 :
201 : // |detector_smoothed_| follows the |detector_result| when this last one is
202 : // increasing, but has an exponential decaying tail to be able to suppress
203 : // the ringing of keyclicks.
204 0 : float smooth_factor = using_reference_ ? 0.6 : 0.1;
205 0 : detector_smoothed_ = detector_result >= detector_smoothed_
206 0 : ? detector_result
207 0 : : smooth_factor * detector_smoothed_ +
208 0 : (1 - smooth_factor) * detector_result;
209 :
210 0 : for (int i = 0; i < num_channels_; ++i) {
211 0 : Suppress(&in_buffer_[i * analysis_length_],
212 0 : &spectral_mean_[i * complex_analysis_length_],
213 0 : &out_buffer_[i * analysis_length_]);
214 : }
215 : }
216 :
217 : // If the suppression isn't enabled, we use the in buffer to delay the signal
218 : // appropriately. This also gives time for the out buffer to be refreshed with
219 : // new data between detection and suppression getting enabled.
220 0 : for (int i = 0; i < num_channels_; ++i) {
221 0 : memcpy(&data[i * data_length_],
222 0 : suppression_enabled_ ? &out_buffer_[i * analysis_length_]
223 0 : : &in_buffer_[i * analysis_length_],
224 0 : data_length_ * sizeof(*data));
225 : }
226 0 : return result;
227 : }
228 :
229 : // This should only be called when detection is enabled. UpdateBuffers() must
230 : // have been called. At return, |out_buffer_| will be filled with the
231 : // processed output.
232 0 : void TransientSuppressor::Suppress(float* in_ptr,
233 : float* spectral_mean,
234 : float* out_ptr) {
235 : // Go to frequency domain.
236 0 : for (size_t i = 0; i < analysis_length_; ++i) {
237 : // TODO(aluebs): Rename windows
238 0 : fft_buffer_[i] = in_ptr[i] * window_[i];
239 : }
240 :
241 0 : WebRtc_rdft(analysis_length_, 1, fft_buffer_.get(), ip_.get(), wfft_.get());
242 :
243 : // Since WebRtc_rdft puts R[n/2] in fft_buffer_[1], we move it to the end
244 : // for convenience.
245 0 : fft_buffer_[analysis_length_] = fft_buffer_[1];
246 0 : fft_buffer_[analysis_length_ + 1] = 0.f;
247 0 : fft_buffer_[1] = 0.f;
248 :
249 0 : for (size_t i = 0; i < complex_analysis_length_; ++i) {
250 0 : magnitudes_[i] = ComplexMagnitude(fft_buffer_[i * 2],
251 0 : fft_buffer_[i * 2 + 1]);
252 : }
253 : // Restore audio if necessary.
254 0 : if (suppression_enabled_) {
255 0 : if (use_hard_restoration_) {
256 0 : HardRestoration(spectral_mean);
257 : } else {
258 0 : SoftRestoration(spectral_mean);
259 : }
260 : }
261 :
262 : // Update the spectral mean.
263 0 : for (size_t i = 0; i < complex_analysis_length_; ++i) {
264 0 : spectral_mean[i] = (1 - kMeanIIRCoefficient) * spectral_mean[i] +
265 0 : kMeanIIRCoefficient * magnitudes_[i];
266 : }
267 :
268 : // Back to time domain.
269 : // Put R[n/2] back in fft_buffer_[1].
270 0 : fft_buffer_[1] = fft_buffer_[analysis_length_];
271 :
272 0 : WebRtc_rdft(analysis_length_,
273 : -1,
274 : fft_buffer_.get(),
275 : ip_.get(),
276 0 : wfft_.get());
277 0 : const float fft_scaling = 2.f / analysis_length_;
278 :
279 0 : for (size_t i = 0; i < analysis_length_; ++i) {
280 0 : out_ptr[i] += fft_buffer_[i] * window_[i] * fft_scaling;
281 : }
282 0 : }
283 :
284 0 : void TransientSuppressor::UpdateKeypress(bool key_pressed) {
285 0 : const int kKeypressPenalty = 1000 / ts::kChunkSizeMs;
286 0 : const int kIsTypingThreshold = 1000 / ts::kChunkSizeMs;
287 0 : const int kChunksUntilNotTyping = 4000 / ts::kChunkSizeMs; // 4 seconds.
288 :
289 0 : if (key_pressed) {
290 0 : keypress_counter_ += kKeypressPenalty;
291 0 : chunks_since_keypress_ = 0;
292 0 : detection_enabled_ = true;
293 : }
294 0 : keypress_counter_ = std::max(0, keypress_counter_ - 1);
295 :
296 0 : if (keypress_counter_ > kIsTypingThreshold) {
297 0 : if (!suppression_enabled_) {
298 0 : LOG(LS_INFO) << "[ts] Transient suppression is now enabled.";
299 : }
300 0 : suppression_enabled_ = true;
301 0 : keypress_counter_ = 0;
302 : }
303 :
304 0 : if (detection_enabled_ &&
305 0 : ++chunks_since_keypress_ > kChunksUntilNotTyping) {
306 0 : if (suppression_enabled_) {
307 0 : LOG(LS_INFO) << "[ts] Transient suppression is now disabled.";
308 : }
309 0 : detection_enabled_ = false;
310 0 : suppression_enabled_ = false;
311 0 : keypress_counter_ = 0;
312 : }
313 0 : }
314 :
315 0 : void TransientSuppressor::UpdateRestoration(float voice_probability) {
316 0 : const int kHardRestorationOffsetDelay = 3;
317 0 : const int kHardRestorationOnsetDelay = 80;
318 :
319 0 : bool not_voiced = voice_probability < kVoiceThreshold;
320 :
321 0 : if (not_voiced == use_hard_restoration_) {
322 0 : chunks_since_voice_change_ = 0;
323 : } else {
324 0 : ++chunks_since_voice_change_;
325 :
326 0 : if ((use_hard_restoration_ &&
327 0 : chunks_since_voice_change_ > kHardRestorationOffsetDelay) ||
328 0 : (!use_hard_restoration_ &&
329 0 : chunks_since_voice_change_ > kHardRestorationOnsetDelay)) {
330 0 : use_hard_restoration_ = not_voiced;
331 0 : chunks_since_voice_change_ = 0;
332 : }
333 : }
334 0 : }
335 :
336 : // Shift buffers to make way for new data. Must be called after
337 : // |detection_enabled_| is updated by UpdateKeypress().
338 0 : void TransientSuppressor::UpdateBuffers(float* data) {
339 : // TODO(aluebs): Change to ring buffer.
340 0 : memmove(in_buffer_.get(),
341 0 : &in_buffer_[data_length_],
342 0 : (buffer_delay_ + (num_channels_ - 1) * analysis_length_) *
343 0 : sizeof(in_buffer_[0]));
344 : // Copy new chunk to buffer.
345 0 : for (int i = 0; i < num_channels_; ++i) {
346 0 : memcpy(&in_buffer_[buffer_delay_ + i * analysis_length_],
347 0 : &data[i * data_length_],
348 0 : data_length_ * sizeof(*data));
349 : }
350 0 : if (detection_enabled_) {
351 : // Shift previous chunk in out buffer.
352 0 : memmove(out_buffer_.get(),
353 0 : &out_buffer_[data_length_],
354 0 : (buffer_delay_ + (num_channels_ - 1) * analysis_length_) *
355 0 : sizeof(out_buffer_[0]));
356 : // Initialize new chunk in out buffer.
357 0 : for (int i = 0; i < num_channels_; ++i) {
358 0 : memset(&out_buffer_[buffer_delay_ + i * analysis_length_],
359 : 0,
360 0 : data_length_ * sizeof(out_buffer_[0]));
361 : }
362 : }
363 0 : }
364 :
365 : // Restores the unvoiced signal if a click is present.
366 : // Attenuates by a certain factor every peak in the |fft_buffer_| that exceeds
367 : // the spectral mean. The attenuation depends on |detector_smoothed_|.
368 : // If a restoration takes place, the |magnitudes_| are updated to the new value.
369 0 : void TransientSuppressor::HardRestoration(float* spectral_mean) {
370 : const float detector_result =
371 0 : 1.f - pow(1.f - detector_smoothed_, using_reference_ ? 200.f : 50.f);
372 : // To restore, we get the peaks in the spectrum. If higher than the previous
373 : // spectral mean we adjust them.
374 0 : for (size_t i = 0; i < complex_analysis_length_; ++i) {
375 0 : if (magnitudes_[i] > spectral_mean[i] && magnitudes_[i] > 0) {
376 : // RandU() generates values on [0, int16::max()]
377 0 : const float phase = 2 * ts::kPi * WebRtcSpl_RandU(&seed_) /
378 0 : std::numeric_limits<int16_t>::max();
379 0 : const float scaled_mean = detector_result * spectral_mean[i];
380 :
381 0 : fft_buffer_[i * 2] = (1 - detector_result) * fft_buffer_[i * 2] +
382 0 : scaled_mean * cosf(phase);
383 0 : fft_buffer_[i * 2 + 1] = (1 - detector_result) * fft_buffer_[i * 2 + 1] +
384 0 : scaled_mean * sinf(phase);
385 0 : magnitudes_[i] = magnitudes_[i] -
386 0 : detector_result * (magnitudes_[i] - spectral_mean[i]);
387 : }
388 : }
389 0 : }
390 :
391 : // Restores the voiced signal if a click is present.
392 : // Attenuates by a certain factor every peak in the |fft_buffer_| that exceeds
393 : // the spectral mean and that is lower than some function of the current block
394 : // frequency mean. The attenuation depends on |detector_smoothed_|.
395 : // If a restoration takes place, the |magnitudes_| are updated to the new value.
396 0 : void TransientSuppressor::SoftRestoration(float* spectral_mean) {
397 : // Get the spectral magnitude mean of the current block.
398 0 : float block_frequency_mean = 0;
399 0 : for (size_t i = kMinVoiceBin; i < kMaxVoiceBin; ++i) {
400 0 : block_frequency_mean += magnitudes_[i];
401 : }
402 0 : block_frequency_mean /= (kMaxVoiceBin - kMinVoiceBin);
403 :
404 : // To restore, we get the peaks in the spectrum. If higher than the
405 : // previous spectral mean and lower than a factor of the block mean
406 : // we adjust them. The factor is a double sigmoid that has a minimum in the
407 : // voice frequency range (300Hz - 3kHz).
408 0 : for (size_t i = 0; i < complex_analysis_length_; ++i) {
409 0 : if (magnitudes_[i] > spectral_mean[i] && magnitudes_[i] > 0 &&
410 0 : (using_reference_ ||
411 0 : magnitudes_[i] < block_frequency_mean * mean_factor_[i])) {
412 : const float new_magnitude =
413 0 : magnitudes_[i] -
414 0 : detector_smoothed_ * (magnitudes_[i] - spectral_mean[i]);
415 0 : const float magnitude_ratio = new_magnitude / magnitudes_[i];
416 :
417 0 : fft_buffer_[i * 2] *= magnitude_ratio;
418 0 : fft_buffer_[i * 2 + 1] *= magnitude_ratio;
419 0 : magnitudes_[i] = new_magnitude;
420 : }
421 : }
422 0 : }
423 :
424 : } // namespace webrtc
|