LCOV - code coverage report
Current view: top level - media/webrtc/trunk/webrtc/modules/audio_processing/level_controller - signal_classifier.cc (source / functions) Hit Total Coverage
Test: output.info Lines: 0 88 0.0 %
Date: 2017-07-14 16:53:18 Functions: 0 9 0.0 %
Legend: Lines: hit not hit

          Line data    Source code
       1             : /*
       2             :  *  Copyright (c) 2016 The WebRTC project authors. All Rights Reserved.
       3             :  *
       4             :  *  Use of this source code is governed by a BSD-style license
       5             :  *  that can be found in the LICENSE file in the root of the source
       6             :  *  tree. An additional intellectual property rights grant can be found
       7             :  *  in the file PATENTS.  All contributing project authors may
       8             :  *  be found in the AUTHORS file in the root of the source tree.
       9             :  */
      10             : 
      11             : #include "webrtc/modules/audio_processing/level_controller/signal_classifier.h"
      12             : 
      13             : #include <algorithm>
      14             : #include <numeric>
      15             : #include <vector>
      16             : 
      17             : #include "webrtc/base/array_view.h"
      18             : #include "webrtc/base/constructormagic.h"
      19             : #include "webrtc/modules/audio_processing/audio_buffer.h"
      20             : #include "webrtc/modules/audio_processing/level_controller/down_sampler.h"
      21             : #include "webrtc/modules/audio_processing/level_controller/noise_spectrum_estimator.h"
      22             : #include "webrtc/modules/audio_processing/logging/apm_data_dumper.h"
      23             : 
      24             : namespace webrtc {
      25             : namespace {
      26             : 
      27           0 : void RemoveDcLevel(rtc::ArrayView<float> x) {
      28           0 :   RTC_DCHECK_LT(0, x.size());
      29           0 :   float mean = std::accumulate(x.data(), x.data() + x.size(), 0.f);
      30           0 :   mean /= x.size();
      31             : 
      32           0 :   for (float& v : x) {
      33           0 :     v -= mean;
      34             :   }
      35           0 : }
      36             : 
      37           0 : void PowerSpectrum(const OouraFft* ooura_fft,
      38             :                    rtc::ArrayView<const float> x,
      39             :                    rtc::ArrayView<float> spectrum) {
      40           0 :   RTC_DCHECK_EQ(65, spectrum.size());
      41           0 :   RTC_DCHECK_EQ(128, x.size());
      42             :   float X[128];
      43           0 :   std::copy(x.data(), x.data() + x.size(), X);
      44           0 :   ooura_fft->Fft(X);
      45             : 
      46           0 :   float* X_p = X;
      47           0 :   RTC_DCHECK_EQ(X_p, &X[0]);
      48           0 :   spectrum[0] = (*X_p) * (*X_p);
      49           0 :   ++X_p;
      50           0 :   RTC_DCHECK_EQ(X_p, &X[1]);
      51           0 :   spectrum[64] = (*X_p) * (*X_p);
      52           0 :   for (int k = 1; k < 64; ++k) {
      53           0 :     ++X_p;
      54           0 :     RTC_DCHECK_EQ(X_p, &X[2 * k]);
      55           0 :     spectrum[k] = (*X_p) * (*X_p);
      56           0 :     ++X_p;
      57           0 :     RTC_DCHECK_EQ(X_p, &X[2 * k + 1]);
      58           0 :     spectrum[k] += (*X_p) * (*X_p);
      59             :   }
      60           0 : }
      61             : 
      62           0 : webrtc::SignalClassifier::SignalType ClassifySignal(
      63             :     rtc::ArrayView<const float> signal_spectrum,
      64             :     rtc::ArrayView<const float> noise_spectrum,
      65             :     ApmDataDumper* data_dumper) {
      66           0 :   int num_stationary_bands = 0;
      67           0 :   int num_highly_nonstationary_bands = 0;
      68             : 
      69             :   // Detect stationary and highly nonstationary bands.
      70           0 :   for (size_t k = 1; k < 40; k++) {
      71           0 :     if (signal_spectrum[k] < 3 * noise_spectrum[k] &&
      72           0 :         signal_spectrum[k] * 3 > noise_spectrum[k]) {
      73           0 :       ++num_stationary_bands;
      74           0 :     } else if (signal_spectrum[k] > 9 * noise_spectrum[k]) {
      75           0 :       ++num_highly_nonstationary_bands;
      76             :     }
      77             :   }
      78             : 
      79           0 :   data_dumper->DumpRaw("lc_num_stationary_bands", 1, &num_stationary_bands);
      80             :   data_dumper->DumpRaw("lc_num_highly_nonstationary_bands", 1,
      81           0 :                        &num_highly_nonstationary_bands);
      82             : 
      83             :   // Use the detected number of bands to classify the overall signal
      84             :   // stationarity.
      85           0 :   if (num_stationary_bands > 15) {
      86           0 :     return SignalClassifier::SignalType::kStationary;
      87           0 :   } else if (num_highly_nonstationary_bands > 15) {
      88           0 :     return SignalClassifier::SignalType::kHighlyNonStationary;
      89             :   } else {
      90           0 :     return SignalClassifier::SignalType::kNonStationary;
      91             :   }
      92             : }
      93             : 
      94             : }  // namespace
      95             : 
      96           0 : SignalClassifier::FrameExtender::FrameExtender(size_t frame_size,
      97           0 :                                                size_t extended_frame_size)
      98           0 :     : x_old_(extended_frame_size - frame_size, 0.f) {}
      99             : 
     100             : SignalClassifier::FrameExtender::~FrameExtender() = default;
     101             : 
     102           0 : void SignalClassifier::FrameExtender::ExtendFrame(
     103             :     rtc::ArrayView<const float> x,
     104             :     rtc::ArrayView<float> x_extended) {
     105           0 :   RTC_DCHECK_EQ(x_old_.size() + x.size(), x_extended.size());
     106           0 :   std::copy(x_old_.data(), x_old_.data() + x_old_.size(), x_extended.data());
     107           0 :   std::copy(x.data(), x.data() + x.size(), x_extended.data() + x_old_.size());
     108           0 :   std::copy(x_extended.data() + x_extended.size() - x_old_.size(),
     109           0 :             x_extended.data() + x_extended.size(), x_old_.data());
     110           0 : }
     111             : 
     112           0 : SignalClassifier::SignalClassifier(ApmDataDumper* data_dumper)
     113             :     : data_dumper_(data_dumper),
     114           0 :       down_sampler_(data_dumper_),
     115           0 :       noise_spectrum_estimator_(data_dumper_) {
     116           0 :   Initialize(AudioProcessing::kSampleRate48kHz);
     117           0 : }
     118           0 : SignalClassifier::~SignalClassifier() {}
     119             : 
     120           0 : void SignalClassifier::Initialize(int sample_rate_hz) {
     121           0 :   down_sampler_.Initialize(sample_rate_hz);
     122           0 :   noise_spectrum_estimator_.Initialize();
     123           0 :   frame_extender_.reset(new FrameExtender(80, 128));
     124           0 :   sample_rate_hz_ = sample_rate_hz;
     125           0 :   initialization_frames_left_ = 2;
     126           0 :   consistent_classification_counter_ = 3;
     127           0 :   last_signal_type_ = SignalClassifier::SignalType::kNonStationary;
     128           0 : }
     129             : 
     130           0 : void SignalClassifier::Analyze(const AudioBuffer& audio,
     131             :                                SignalType* signal_type) {
     132           0 :   RTC_DCHECK_EQ(audio.num_frames(), sample_rate_hz_ / 100);
     133             : 
     134             :   // Compute the signal power spectrum.
     135             :   float downsampled_frame[80];
     136           0 :   down_sampler_.DownSample(rtc::ArrayView<const float>(
     137           0 :                                audio.channels_const_f()[0], audio.num_frames()),
     138           0 :                            downsampled_frame);
     139             :   float extended_frame[128];
     140           0 :   frame_extender_->ExtendFrame(downsampled_frame, extended_frame);
     141           0 :   RemoveDcLevel(extended_frame);
     142             :   float signal_spectrum[65];
     143           0 :   PowerSpectrum(&ooura_fft_, extended_frame, signal_spectrum);
     144             : 
     145             :   // Classify the signal based on the estimate of the noise spectrum and the
     146             :   // signal spectrum estimate.
     147           0 :   *signal_type = ClassifySignal(signal_spectrum,
     148             :                                 noise_spectrum_estimator_.GetNoiseSpectrum(),
     149           0 :                                 data_dumper_);
     150             : 
     151             :   // Update the noise spectrum based on the signal spectrum.
     152           0 :   noise_spectrum_estimator_.Update(signal_spectrum,
     153           0 :                                    initialization_frames_left_ > 0);
     154             : 
     155             :   // Update the number of frames until a reliable signal spectrum is achieved.
     156           0 :   initialization_frames_left_ = std::max(0, initialization_frames_left_ - 1);
     157             : 
     158           0 :   if (last_signal_type_ == *signal_type) {
     159           0 :     consistent_classification_counter_ =
     160           0 :         std::max(0, consistent_classification_counter_ - 1);
     161             :   } else {
     162           0 :     last_signal_type_ = *signal_type;
     163           0 :     consistent_classification_counter_ = 3;
     164             :   }
     165             : 
     166           0 :   if (consistent_classification_counter_ > 0) {
     167           0 :     *signal_type = SignalClassifier::SignalType::kNonStationary;
     168             :   }
     169           0 : }
     170             : 
     171             : }  // namespace webrtc

Generated by: LCOV version 1.13