Line data Source code
1 : /*
2 : * Copyright (c) 2014 The WebRTC project authors. All Rights Reserved.
3 : *
4 : * Use of this source code is governed by a BSD-style license
5 : * that can be found in the LICENSE file in the root of the source
6 : * tree. An additional intellectual property rights grant can be found
7 : * in the file PATENTS. All contributing project authors may
8 : * be found in the AUTHORS file in the root of the source tree.
9 : */
10 :
11 : #ifndef WEBRTC_MODULES_AUDIO_PROCESSING_BEAMFORMER_NONLINEAR_BEAMFORMER_H_
12 : #define WEBRTC_MODULES_AUDIO_PROCESSING_BEAMFORMER_NONLINEAR_BEAMFORMER_H_
13 :
14 : // MSVC++ requires this to be set before any other includes to get M_PI.
15 : #ifndef _USE_MATH_DEFINES
16 : #define _USE_MATH_DEFINES
17 : #endif
18 :
19 : #include <math.h>
20 :
21 : #include <memory>
22 : #include <vector>
23 :
24 : #include "webrtc/common_audio/lapped_transform.h"
25 : #include "webrtc/common_audio/channel_buffer.h"
26 : #include "webrtc/modules/audio_processing/beamformer/array_util.h"
27 : #include "webrtc/modules/audio_processing/beamformer/complex_matrix.h"
28 :
29 : namespace webrtc {
30 :
31 0 : class PostFilterTransform : public LappedTransform::Callback {
32 : public:
33 : PostFilterTransform(size_t num_channels,
34 : size_t chunk_length,
35 : float* window,
36 : size_t fft_size);
37 :
38 : void ProcessChunk(float* const* data, float* final_mask);
39 :
40 : protected:
41 : void ProcessAudioBlock(const complex<float>* const* input,
42 : size_t num_input_channels,
43 : size_t num_freq_bins,
44 : size_t num_output_channels,
45 : complex<float>* const* output) override;
46 :
47 : private:
48 : LappedTransform transform_;
49 : const size_t num_freq_bins_;
50 : float* final_mask_;
51 : };
52 :
53 : // Enhances sound sources coming directly in front of a uniform linear array
54 : // and suppresses sound sources coming from all other directions. Operates on
55 : // multichannel signals and produces single-channel output.
56 : //
57 : // The implemented nonlinear postfilter algorithm taken from "A Robust Nonlinear
58 : // Beamforming Postprocessor" by Bastiaan Kleijn.
59 0 : class NonlinearBeamformer : public LappedTransform::Callback {
60 : public:
61 : static const float kHalfBeamWidthRadians;
62 :
63 : explicit NonlinearBeamformer(
64 : const std::vector<Point>& array_geometry,
65 : size_t num_postfilter_channels = 1u,
66 : SphericalPointf target_direction =
67 : SphericalPointf(static_cast<float>(M_PI) / 2.f, 0.f, 1.f));
68 : ~NonlinearBeamformer() override;
69 :
70 : // Sample rate corresponds to the lower band.
71 : // Needs to be called before the NonlinearBeamformer can be used.
72 : virtual void Initialize(int chunk_size_ms, int sample_rate_hz);
73 :
74 : // Analyzes one time-domain chunk of audio. The audio is expected to be split
75 : // into frequency bands inside the ChannelBuffer. The number of frames and
76 : // channels must correspond to the constructor parameters.
77 : virtual void AnalyzeChunk(const ChannelBuffer<float>& data);
78 :
79 : // Applies the postfilter mask to one chunk of audio. The audio is expected to
80 : // be split into frequency bands inside the ChannelBuffer. The number of
81 : // frames and channels must correspond to the constructor parameters.
82 : virtual void PostFilter(ChannelBuffer<float>* data);
83 :
84 : virtual void AimAt(const SphericalPointf& target_direction);
85 :
86 : virtual bool IsInBeam(const SphericalPointf& spherical_point);
87 :
88 : // After processing each block |is_target_present_| is set to true if the
89 : // target signal es present and to false otherwise. This methods can be called
90 : // to know if the data is target signal or interference and process it
91 : // accordingly.
92 : virtual bool is_target_present();
93 :
94 : protected:
95 : // Process one frequency-domain block of audio. This is where the fun
96 : // happens. Implements LappedTransform::Callback.
97 : void ProcessAudioBlock(const complex<float>* const* input,
98 : size_t num_input_channels,
99 : size_t num_freq_bins,
100 : size_t num_output_channels,
101 : complex<float>* const* output) override;
102 :
103 : private:
104 : FRIEND_TEST_ALL_PREFIXES(NonlinearBeamformerTest,
105 : InterfAnglesTakeAmbiguityIntoAccount);
106 :
107 : typedef Matrix<float> MatrixF;
108 : typedef ComplexMatrix<float> ComplexMatrixF;
109 : typedef complex<float> complex_f;
110 :
111 : void InitLowFrequencyCorrectionRanges();
112 : void InitHighFrequencyCorrectionRanges();
113 : void InitInterfAngles();
114 : void InitDelaySumMasks();
115 : void InitTargetCovMats();
116 : void InitDiffuseCovMats();
117 : void InitInterfCovMats();
118 : void NormalizeCovMats();
119 :
120 : // Calculates postfilter masks that minimize the mean squared error of our
121 : // estimation of the desired signal.
122 : float CalculatePostfilterMask(const ComplexMatrixF& interf_cov_mat,
123 : float rpsiw,
124 : float ratio_rxiw_rxim,
125 : float rmxi_r);
126 :
127 : // Prevents the postfilter masks from degenerating too quickly (a cause of
128 : // musical noise).
129 : void ApplyMaskTimeSmoothing();
130 : void ApplyMaskFrequencySmoothing();
131 :
132 : // The postfilter masks are unreliable at low frequencies. Calculates a better
133 : // mask by averaging mid-low frequency values.
134 : void ApplyLowFrequencyCorrection();
135 :
136 : // Postfilter masks are also unreliable at high frequencies. Average mid-high
137 : // frequency masks to calculate a single mask per block which can be applied
138 : // in the time-domain. Further, we average these block-masks over a chunk,
139 : // resulting in one postfilter mask per audio chunk. This allows us to skip
140 : // both transforming and blocking the high-frequency signal.
141 : void ApplyHighFrequencyCorrection();
142 :
143 : // Compute the means needed for the above frequency correction.
144 : float MaskRangeMean(size_t start_bin, size_t end_bin);
145 :
146 : // Applies post-filter mask to |input| and store in |output|.
147 : void ApplyPostFilter(const complex_f* input, complex_f* output);
148 :
149 : void EstimateTargetPresence();
150 :
151 : static const size_t kFftSize = 256;
152 : static const size_t kNumFreqBins = kFftSize / 2 + 1;
153 :
154 : // Deals with the fft transform and blocking.
155 : size_t chunk_length_;
156 : std::unique_ptr<LappedTransform> process_transform_;
157 : std::unique_ptr<PostFilterTransform> postfilter_transform_;
158 : float window_[kFftSize];
159 :
160 : // Parameters exposed to the user.
161 : const size_t num_input_channels_;
162 : const size_t num_postfilter_channels_;
163 : int sample_rate_hz_;
164 :
165 : const std::vector<Point> array_geometry_;
166 : // The normal direction of the array if it has one and it is in the xy-plane.
167 : const rtc::Optional<Point> array_normal_;
168 :
169 : // Minimum spacing between microphone pairs.
170 : const float min_mic_spacing_;
171 :
172 : // Calculated based on user-input and constants in the .cc file.
173 : size_t low_mean_start_bin_;
174 : size_t low_mean_end_bin_;
175 : size_t high_mean_start_bin_;
176 : size_t high_mean_end_bin_;
177 :
178 : // Quickly varying mask updated every block.
179 : float new_mask_[kNumFreqBins];
180 : // Time smoothed mask.
181 : float time_smooth_mask_[kNumFreqBins];
182 : // Time and frequency smoothed mask.
183 : float final_mask_[kNumFreqBins];
184 :
185 : float target_angle_radians_;
186 : // Angles of the interferer scenarios.
187 : std::vector<float> interf_angles_radians_;
188 : // The angle between the target and the interferer scenarios.
189 : const float away_radians_;
190 :
191 : // Array of length |kNumFreqBins|, Matrix of size |1| x |num_channels_|.
192 : ComplexMatrixF delay_sum_masks_[kNumFreqBins];
193 :
194 : // Arrays of length |kNumFreqBins|, Matrix of size |num_input_channels_| x
195 : // |num_input_channels_|.
196 : ComplexMatrixF target_cov_mats_[kNumFreqBins];
197 : ComplexMatrixF uniform_cov_mat_[kNumFreqBins];
198 : // Array of length |kNumFreqBins|, Matrix of size |num_input_channels_| x
199 : // |num_input_channels_|. The vector has a size equal to the number of
200 : // interferer scenarios.
201 : std::vector<std::unique_ptr<ComplexMatrixF>> interf_cov_mats_[kNumFreqBins];
202 :
203 : // Of length |kNumFreqBins|.
204 : float wave_numbers_[kNumFreqBins];
205 :
206 : // Preallocated for ProcessAudioBlock()
207 : // Of length |kNumFreqBins|.
208 : float rxiws_[kNumFreqBins];
209 : // The vector has a size equal to the number of interferer scenarios.
210 : std::vector<float> rpsiws_[kNumFreqBins];
211 :
212 : // The microphone normalization factor.
213 : ComplexMatrixF eig_m_;
214 :
215 : // For processing the high-frequency input signal.
216 : float high_pass_postfilter_mask_;
217 : float old_high_pass_mask_;
218 :
219 : // True when the target signal is present.
220 : bool is_target_present_;
221 : // Number of blocks after which the data is considered interference if the
222 : // mask does not pass |kMaskSignalThreshold|.
223 : size_t hold_target_blocks_;
224 : // Number of blocks since the last mask that passed |kMaskSignalThreshold|.
225 : size_t interference_blocks_count_;
226 : };
227 :
228 : } // namespace webrtc
229 :
230 : #endif // WEBRTC_MODULES_AUDIO_PROCESSING_BEAMFORMER_NONLINEAR_BEAMFORMER_H_
|