LCOV - code coverage report
Current view: top level - media/webrtc/trunk/webrtc/modules/audio_processing/transient - transient_suppressor.cc (source / functions) Hit Total Coverage
Test: output.info Lines: 0 215 0.0 %
Date: 2017-07-14 16:53:18 Functions: 0 11 0.0 %
Legend: Lines: hit not hit

          Line data    Source code
       1             : /*
       2             :  *  Copyright (c) 2013 The WebRTC project authors. All Rights Reserved.
       3             :  *
       4             :  *  Use of this source code is governed by a BSD-style license
       5             :  *  that can be found in the LICENSE file in the root of the source
       6             :  *  tree. An additional intellectual property rights grant can be found
       7             :  *  in the file PATENTS.  All contributing project authors may
       8             :  *  be found in the AUTHORS file in the root of the source tree.
       9             :  */
      10             : 
      11             : #include "webrtc/modules/audio_processing/transient/transient_suppressor.h"
      12             : 
      13             : #include <math.h>
      14             : #include <string.h>
      15             : #include <cmath>
      16             : #include <complex>
      17             : #include <deque>
      18             : #include <set>
      19             : 
      20             : #include "webrtc/base/checks.h"
      21             : #include "webrtc/common_audio/fft4g.h"
      22             : #include "webrtc/common_audio/include/audio_util.h"
      23             : #include "webrtc/common_audio/signal_processing/include/signal_processing_library.h"
      24             : #include "webrtc/modules/audio_processing/transient/common.h"
      25             : #include "webrtc/modules/audio_processing/transient/transient_detector.h"
      26             : #include "webrtc/modules/audio_processing/ns/windows_private.h"
      27             : #include "webrtc/system_wrappers/include/logging.h"
      28             : #include "webrtc/typedefs.h"
      29             : 
      30             : namespace webrtc {
      31             : 
      32             : static const float kMeanIIRCoefficient = 0.5f;
      33             : static const float kVoiceThreshold = 0.02f;
      34             : 
      35             : // TODO(aluebs): Check if these values work also for 48kHz.
      36             : static const size_t kMinVoiceBin = 3;
      37             : static const size_t kMaxVoiceBin = 60;
      38             : 
      39             : namespace {
      40             : 
      41           0 : float ComplexMagnitude(float a, float b) {
      42           0 :   return std::abs(a) + std::abs(b);
      43             : }
      44             : 
      45             : }  // namespace
      46             : 
      47           0 : TransientSuppressor::TransientSuppressor()
      48             :     : data_length_(0),
      49             :       detection_length_(0),
      50             :       analysis_length_(0),
      51             :       buffer_delay_(0),
      52             :       complex_analysis_length_(0),
      53             :       num_channels_(0),
      54             :       window_(NULL),
      55             :       detector_smoothed_(0.f),
      56             :       keypress_counter_(0),
      57             :       chunks_since_keypress_(0),
      58             :       detection_enabled_(false),
      59             :       suppression_enabled_(false),
      60             :       use_hard_restoration_(false),
      61             :       chunks_since_voice_change_(0),
      62             :       seed_(182),
      63           0 :       using_reference_(false) {
      64           0 : }
      65             : 
      66           0 : TransientSuppressor::~TransientSuppressor() {}
      67             : 
      68           0 : int TransientSuppressor::Initialize(int sample_rate_hz,
      69             :                                     int detection_rate_hz,
      70             :                                     int num_channels) {
      71           0 :   switch (sample_rate_hz) {
      72             :     case ts::kSampleRate8kHz:
      73           0 :       analysis_length_ = 128u;
      74           0 :       window_ = kBlocks80w128;
      75           0 :       break;
      76             :     case ts::kSampleRate16kHz:
      77           0 :       analysis_length_ = 256u;
      78           0 :       window_ = kBlocks160w256;
      79           0 :       break;
      80             :     case ts::kSampleRate32kHz:
      81           0 :       analysis_length_ = 512u;
      82           0 :       window_ = kBlocks320w512;
      83           0 :       break;
      84             :     case ts::kSampleRate48kHz:
      85           0 :       analysis_length_ = 1024u;
      86           0 :       window_ = kBlocks480w1024;
      87           0 :       break;
      88             :     default:
      89           0 :       return -1;
      90             :   }
      91           0 :   if (detection_rate_hz != ts::kSampleRate8kHz &&
      92           0 :       detection_rate_hz != ts::kSampleRate16kHz &&
      93           0 :       detection_rate_hz != ts::kSampleRate32kHz &&
      94             :       detection_rate_hz != ts::kSampleRate48kHz) {
      95           0 :     return -1;
      96             :   }
      97           0 :   if (num_channels <= 0) {
      98           0 :     return -1;
      99             :   }
     100             : 
     101           0 :   detector_.reset(new TransientDetector(detection_rate_hz));
     102           0 :   data_length_ = sample_rate_hz * ts::kChunkSizeMs / 1000;
     103           0 :   if (data_length_ > analysis_length_) {
     104           0 :     RTC_NOTREACHED();
     105           0 :     return -1;
     106             :   }
     107           0 :   buffer_delay_ = analysis_length_ - data_length_;
     108             : 
     109           0 :   complex_analysis_length_ = analysis_length_ / 2 + 1;
     110           0 :   RTC_DCHECK_GE(complex_analysis_length_, kMaxVoiceBin);
     111           0 :   num_channels_ = num_channels;
     112           0 :   in_buffer_.reset(new float[analysis_length_ * num_channels_]);
     113           0 :   memset(in_buffer_.get(),
     114             :          0,
     115           0 :          analysis_length_ * num_channels_ * sizeof(in_buffer_[0]));
     116           0 :   detection_length_ = detection_rate_hz * ts::kChunkSizeMs / 1000;
     117           0 :   detection_buffer_.reset(new float[detection_length_]);
     118           0 :   memset(detection_buffer_.get(),
     119             :          0,
     120           0 :          detection_length_ * sizeof(detection_buffer_[0]));
     121           0 :   out_buffer_.reset(new float[analysis_length_ * num_channels_]);
     122           0 :   memset(out_buffer_.get(),
     123             :          0,
     124           0 :          analysis_length_ * num_channels_ * sizeof(out_buffer_[0]));
     125             :   // ip[0] must be zero to trigger initialization using rdft().
     126           0 :   size_t ip_length = 2 + sqrtf(analysis_length_);
     127           0 :   ip_.reset(new size_t[ip_length]());
     128           0 :   memset(ip_.get(), 0, ip_length * sizeof(ip_[0]));
     129           0 :   wfft_.reset(new float[complex_analysis_length_ - 1]);
     130           0 :   memset(wfft_.get(), 0, (complex_analysis_length_ - 1) * sizeof(wfft_[0]));
     131           0 :   spectral_mean_.reset(new float[complex_analysis_length_ * num_channels_]);
     132           0 :   memset(spectral_mean_.get(),
     133             :          0,
     134           0 :          complex_analysis_length_ * num_channels_ * sizeof(spectral_mean_[0]));
     135           0 :   fft_buffer_.reset(new float[analysis_length_ + 2]);
     136           0 :   memset(fft_buffer_.get(), 0, (analysis_length_ + 2) * sizeof(fft_buffer_[0]));
     137           0 :   magnitudes_.reset(new float[complex_analysis_length_]);
     138           0 :   memset(magnitudes_.get(),
     139             :          0,
     140           0 :          complex_analysis_length_ * sizeof(magnitudes_[0]));
     141           0 :   mean_factor_.reset(new float[complex_analysis_length_]);
     142             : 
     143             :   static const float kFactorHeight = 10.f;
     144             :   static const float kLowSlope = 1.f;
     145             :   static const float kHighSlope = 0.3f;
     146           0 :   for (size_t i = 0; i < complex_analysis_length_; ++i) {
     147           0 :     mean_factor_[i] =
     148           0 :         kFactorHeight /
     149           0 :             (1.f + exp(kLowSlope * static_cast<int>(i - kMinVoiceBin))) +
     150           0 :         kFactorHeight /
     151           0 :             (1.f + exp(kHighSlope * static_cast<int>(kMaxVoiceBin - i)));
     152             :   }
     153           0 :   detector_smoothed_ = 0.f;
     154           0 :   keypress_counter_ = 0;
     155           0 :   chunks_since_keypress_ = 0;
     156           0 :   detection_enabled_ = false;
     157           0 :   suppression_enabled_ = false;
     158           0 :   use_hard_restoration_ = false;
     159           0 :   chunks_since_voice_change_ = 0;
     160           0 :   seed_ = 182;
     161           0 :   using_reference_ = false;
     162           0 :   return 0;
     163             : }
     164             : 
     165           0 : int TransientSuppressor::Suppress(float* data,
     166             :                                   size_t data_length,
     167             :                                   int num_channels,
     168             :                                   const float* detection_data,
     169             :                                   size_t detection_length,
     170             :                                   const float* reference_data,
     171             :                                   size_t reference_length,
     172             :                                   float voice_probability,
     173             :                                   bool key_pressed) {
     174           0 :   if (!data || data_length != data_length_ || num_channels != num_channels_ ||
     175           0 :       detection_length != detection_length_ || voice_probability < 0 ||
     176             :       voice_probability > 1) {
     177           0 :     return -1;
     178             :   }
     179             : 
     180           0 :   UpdateKeypress(key_pressed);
     181           0 :   UpdateBuffers(data);
     182             : 
     183           0 :   int result = 0;
     184           0 :   if (detection_enabled_) {
     185           0 :     UpdateRestoration(voice_probability);
     186             : 
     187           0 :     if (!detection_data) {
     188             :       // Use the input data  of the first channel if special detection data is
     189             :       // not supplied.
     190           0 :       detection_data = &in_buffer_[buffer_delay_];
     191             :     }
     192             : 
     193           0 :     float detector_result = detector_->Detect(
     194           0 :         detection_data, detection_length, reference_data, reference_length);
     195           0 :     if (detector_result < 0) {
     196           0 :       return -1;
     197             :     }
     198             : 
     199           0 :     using_reference_ = detector_->using_reference();
     200             : 
     201             :     // |detector_smoothed_| follows the |detector_result| when this last one is
     202             :     // increasing, but has an exponential decaying tail to be able to suppress
     203             :     // the ringing of keyclicks.
     204           0 :     float smooth_factor = using_reference_ ? 0.6 : 0.1;
     205           0 :     detector_smoothed_ = detector_result >= detector_smoothed_
     206           0 :                              ? detector_result
     207           0 :                              : smooth_factor * detector_smoothed_ +
     208           0 :                                    (1 - smooth_factor) * detector_result;
     209             : 
     210           0 :     for (int i = 0; i < num_channels_; ++i) {
     211           0 :       Suppress(&in_buffer_[i * analysis_length_],
     212           0 :                &spectral_mean_[i * complex_analysis_length_],
     213           0 :                &out_buffer_[i * analysis_length_]);
     214             :     }
     215             :   }
     216             : 
     217             :   // If the suppression isn't enabled, we use the in buffer to delay the signal
     218             :   // appropriately. This also gives time for the out buffer to be refreshed with
     219             :   // new data between detection and suppression getting enabled.
     220           0 :   for (int i = 0; i < num_channels_; ++i) {
     221           0 :     memcpy(&data[i * data_length_],
     222           0 :            suppression_enabled_ ? &out_buffer_[i * analysis_length_]
     223           0 :                                 : &in_buffer_[i * analysis_length_],
     224           0 :            data_length_ * sizeof(*data));
     225             :   }
     226           0 :   return result;
     227             : }
     228             : 
     229             : // This should only be called when detection is enabled. UpdateBuffers() must
     230             : // have been called. At return, |out_buffer_| will be filled with the
     231             : // processed output.
     232           0 : void TransientSuppressor::Suppress(float* in_ptr,
     233             :                                    float* spectral_mean,
     234             :                                    float* out_ptr) {
     235             :   // Go to frequency domain.
     236           0 :   for (size_t i = 0; i < analysis_length_; ++i) {
     237             :     // TODO(aluebs): Rename windows
     238           0 :     fft_buffer_[i] = in_ptr[i] * window_[i];
     239             :   }
     240             : 
     241           0 :   WebRtc_rdft(analysis_length_, 1, fft_buffer_.get(), ip_.get(), wfft_.get());
     242             : 
     243             :   // Since WebRtc_rdft puts R[n/2] in fft_buffer_[1], we move it to the end
     244             :   // for convenience.
     245           0 :   fft_buffer_[analysis_length_] = fft_buffer_[1];
     246           0 :   fft_buffer_[analysis_length_ + 1] = 0.f;
     247           0 :   fft_buffer_[1] = 0.f;
     248             : 
     249           0 :   for (size_t i = 0; i < complex_analysis_length_; ++i) {
     250           0 :     magnitudes_[i] = ComplexMagnitude(fft_buffer_[i * 2],
     251           0 :                                       fft_buffer_[i * 2 + 1]);
     252             :   }
     253             :   // Restore audio if necessary.
     254           0 :   if (suppression_enabled_) {
     255           0 :     if (use_hard_restoration_) {
     256           0 :       HardRestoration(spectral_mean);
     257             :     } else {
     258           0 :       SoftRestoration(spectral_mean);
     259             :     }
     260             :   }
     261             : 
     262             :   // Update the spectral mean.
     263           0 :   for (size_t i = 0; i < complex_analysis_length_; ++i) {
     264           0 :     spectral_mean[i] = (1 - kMeanIIRCoefficient) * spectral_mean[i] +
     265           0 :                        kMeanIIRCoefficient * magnitudes_[i];
     266             :   }
     267             : 
     268             :   // Back to time domain.
     269             :   // Put R[n/2] back in fft_buffer_[1].
     270           0 :   fft_buffer_[1] = fft_buffer_[analysis_length_];
     271             : 
     272           0 :   WebRtc_rdft(analysis_length_,
     273             :               -1,
     274             :               fft_buffer_.get(),
     275             :               ip_.get(),
     276           0 :               wfft_.get());
     277           0 :   const float fft_scaling = 2.f / analysis_length_;
     278             : 
     279           0 :   for (size_t i = 0; i < analysis_length_; ++i) {
     280           0 :     out_ptr[i] += fft_buffer_[i] * window_[i] * fft_scaling;
     281             :   }
     282           0 : }
     283             : 
     284           0 : void TransientSuppressor::UpdateKeypress(bool key_pressed) {
     285           0 :   const int kKeypressPenalty = 1000 / ts::kChunkSizeMs;
     286           0 :   const int kIsTypingThreshold = 1000 / ts::kChunkSizeMs;
     287           0 :   const int kChunksUntilNotTyping = 4000 / ts::kChunkSizeMs;  // 4 seconds.
     288             : 
     289           0 :   if (key_pressed) {
     290           0 :     keypress_counter_ += kKeypressPenalty;
     291           0 :     chunks_since_keypress_ = 0;
     292           0 :     detection_enabled_ = true;
     293             :   }
     294           0 :   keypress_counter_ = std::max(0, keypress_counter_ - 1);
     295             : 
     296           0 :   if (keypress_counter_ > kIsTypingThreshold) {
     297           0 :     if (!suppression_enabled_) {
     298           0 :       LOG(LS_INFO) << "[ts] Transient suppression is now enabled.";
     299             :     }
     300           0 :     suppression_enabled_ = true;
     301           0 :     keypress_counter_ = 0;
     302             :   }
     303             : 
     304           0 :   if (detection_enabled_ &&
     305           0 :       ++chunks_since_keypress_ > kChunksUntilNotTyping) {
     306           0 :     if (suppression_enabled_) {
     307           0 :       LOG(LS_INFO) << "[ts] Transient suppression is now disabled.";
     308             :     }
     309           0 :     detection_enabled_ = false;
     310           0 :     suppression_enabled_ = false;
     311           0 :     keypress_counter_ = 0;
     312             :   }
     313           0 : }
     314             : 
     315           0 : void TransientSuppressor::UpdateRestoration(float voice_probability) {
     316           0 :   const int kHardRestorationOffsetDelay = 3;
     317           0 :   const int kHardRestorationOnsetDelay = 80;
     318             : 
     319           0 :   bool not_voiced = voice_probability < kVoiceThreshold;
     320             : 
     321           0 :   if (not_voiced == use_hard_restoration_) {
     322           0 :     chunks_since_voice_change_ = 0;
     323             :   } else {
     324           0 :     ++chunks_since_voice_change_;
     325             : 
     326           0 :     if ((use_hard_restoration_ &&
     327           0 :          chunks_since_voice_change_ > kHardRestorationOffsetDelay) ||
     328           0 :         (!use_hard_restoration_ &&
     329           0 :          chunks_since_voice_change_ > kHardRestorationOnsetDelay)) {
     330           0 :       use_hard_restoration_ = not_voiced;
     331           0 :       chunks_since_voice_change_ = 0;
     332             :     }
     333             :   }
     334           0 : }
     335             : 
     336             : // Shift buffers to make way for new data. Must be called after
     337             : // |detection_enabled_| is updated by UpdateKeypress().
     338           0 : void TransientSuppressor::UpdateBuffers(float* data) {
     339             :   // TODO(aluebs): Change to ring buffer.
     340           0 :   memmove(in_buffer_.get(),
     341           0 :           &in_buffer_[data_length_],
     342           0 :           (buffer_delay_ + (num_channels_ - 1) * analysis_length_) *
     343           0 :               sizeof(in_buffer_[0]));
     344             :   // Copy new chunk to buffer.
     345           0 :   for (int i = 0; i < num_channels_; ++i) {
     346           0 :     memcpy(&in_buffer_[buffer_delay_ + i * analysis_length_],
     347           0 :            &data[i * data_length_],
     348           0 :            data_length_ * sizeof(*data));
     349             :   }
     350           0 :   if (detection_enabled_) {
     351             :     // Shift previous chunk in out buffer.
     352           0 :     memmove(out_buffer_.get(),
     353           0 :             &out_buffer_[data_length_],
     354           0 :             (buffer_delay_ + (num_channels_ - 1) * analysis_length_) *
     355           0 :                 sizeof(out_buffer_[0]));
     356             :     // Initialize new chunk in out buffer.
     357           0 :     for (int i = 0; i < num_channels_; ++i) {
     358           0 :       memset(&out_buffer_[buffer_delay_ + i * analysis_length_],
     359             :              0,
     360           0 :              data_length_ * sizeof(out_buffer_[0]));
     361             :     }
     362             :   }
     363           0 : }
     364             : 
     365             : // Restores the unvoiced signal if a click is present.
     366             : // Attenuates by a certain factor every peak in the |fft_buffer_| that exceeds
     367             : // the spectral mean. The attenuation depends on |detector_smoothed_|.
     368             : // If a restoration takes place, the |magnitudes_| are updated to the new value.
     369           0 : void TransientSuppressor::HardRestoration(float* spectral_mean) {
     370             :   const float detector_result =
     371           0 :       1.f - pow(1.f - detector_smoothed_, using_reference_ ? 200.f : 50.f);
     372             :   // To restore, we get the peaks in the spectrum. If higher than the previous
     373             :   // spectral mean we adjust them.
     374           0 :   for (size_t i = 0; i < complex_analysis_length_; ++i) {
     375           0 :     if (magnitudes_[i] > spectral_mean[i] && magnitudes_[i] > 0) {
     376             :       // RandU() generates values on [0, int16::max()]
     377           0 :       const float phase = 2 * ts::kPi * WebRtcSpl_RandU(&seed_) /
     378           0 :           std::numeric_limits<int16_t>::max();
     379           0 :       const float scaled_mean = detector_result * spectral_mean[i];
     380             : 
     381           0 :       fft_buffer_[i * 2] = (1 - detector_result) * fft_buffer_[i * 2] +
     382           0 :                            scaled_mean * cosf(phase);
     383           0 :       fft_buffer_[i * 2 + 1] = (1 - detector_result) * fft_buffer_[i * 2 + 1] +
     384           0 :                                scaled_mean * sinf(phase);
     385           0 :       magnitudes_[i] = magnitudes_[i] -
     386           0 :                        detector_result * (magnitudes_[i] - spectral_mean[i]);
     387             :     }
     388             :   }
     389           0 : }
     390             : 
     391             : // Restores the voiced signal if a click is present.
     392             : // Attenuates by a certain factor every peak in the |fft_buffer_| that exceeds
     393             : // the spectral mean and that is lower than some function of the current block
     394             : // frequency mean. The attenuation depends on |detector_smoothed_|.
     395             : // If a restoration takes place, the |magnitudes_| are updated to the new value.
     396           0 : void TransientSuppressor::SoftRestoration(float* spectral_mean) {
     397             :   // Get the spectral magnitude mean of the current block.
     398           0 :   float block_frequency_mean = 0;
     399           0 :   for (size_t i = kMinVoiceBin; i < kMaxVoiceBin; ++i) {
     400           0 :     block_frequency_mean += magnitudes_[i];
     401             :   }
     402           0 :   block_frequency_mean /= (kMaxVoiceBin - kMinVoiceBin);
     403             : 
     404             :   // To restore, we get the peaks in the spectrum. If higher than the
     405             :   // previous spectral mean and lower than a factor of the block mean
     406             :   // we adjust them. The factor is a double sigmoid that has a minimum in the
     407             :   // voice frequency range (300Hz - 3kHz).
     408           0 :   for (size_t i = 0; i < complex_analysis_length_; ++i) {
     409           0 :     if (magnitudes_[i] > spectral_mean[i] && magnitudes_[i] > 0 &&
     410           0 :         (using_reference_ ||
     411           0 :          magnitudes_[i] < block_frequency_mean * mean_factor_[i])) {
     412             :       const float new_magnitude =
     413           0 :           magnitudes_[i] -
     414           0 :           detector_smoothed_ * (magnitudes_[i] - spectral_mean[i]);
     415           0 :       const float magnitude_ratio = new_magnitude / magnitudes_[i];
     416             : 
     417           0 :       fft_buffer_[i * 2] *= magnitude_ratio;
     418           0 :       fft_buffer_[i * 2 + 1] *= magnitude_ratio;
     419           0 :       magnitudes_[i] = new_magnitude;
     420             :     }
     421             :   }
     422           0 : }
     423             : 
     424             : }  // namespace webrtc

Generated by: LCOV version 1.13