Skip to content

Commit 5975d9c

Browse files
committed
add mvdr implement and optimize cgmm-training process
1 parent 43568b4 commit 5975d9c

File tree

6 files changed

+157
-45
lines changed

6 files changed

+157
-45
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
*.pyc
22
*.log
33
*.scp
4+
*.npy
5+
*.wav
46
data/*.pyc
57
__pycache__/
68
6ch/

apply_mvdr.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
#!/usr/bin/env python
2+
# coding=utf-8
3+
# wujian@17.10.27
4+
5+
import argparse
6+
import numpy as np
7+
import beamformer
8+
import utils
9+
from utils import MultiChannelWrapper
10+
11+
def main(args):
12+
"""
13+
M: num_chanels, T: num_frames
14+
apply_mvdr inputs:
15+
steer_vector: 1 x M
16+
sigma_noise[f]: M x M
17+
spectrum_onbin[f]: T x M
18+
return 1 x T
19+
"""
20+
sigma_noisy = np.load(args.sigma_noisy)
21+
sigma_noise = np.load(args.sigma_noise)
22+
sigma_clean = sigma_noisy - sigma_noise
23+
24+
wrapper = MultiChannelWrapper(args.descriptor)
25+
(time_steps, num_bins), spectrums = wrapper.spectrums()
26+
specs_noisy = np.transpose(spectrums, (2, 1, 0))
27+
specs_enhan = np.zeros([num_bins, time_steps]).astype(np.complex)
28+
for f in range(num_bins):
29+
steer_vector = beamformer.main_egvec(sigma_clean[f])
30+
specs_enhan[f] = beamformer.apply_mvdr(steer_vector, sigma_noise[f], specs_noisy[f])
31+
utils.reconstruct_wave(np.transpose(specs_enhan), args.save_dir, filter_coeff=args.filter_coeff)
32+
33+
if __name__ == '__main__':
34+
parser = argparse.ArgumentParser(description="Apply CGMM-MVDR beamformer on multiple channel")
35+
parser.add_argument('descriptor', type=str,
36+
help="""descriptor of multiple channel location""")
37+
parser.add_argument('sigma_noisy', type=str,
38+
help="""sigma of noisy(noise + clean) part estimated by CGMM""")
39+
parser.add_argument('sigma_noise', type=str,
40+
help="""sigma of noise part estimated by CGMM""")
41+
parser.add_argument('-s', '--save',
42+
dest='save_dir', type=str, default='default.wav',
43+
help="""path to save the enhanced wave""")
44+
parser.add_argument('-c', '--filter_coeff',
45+
dest='filter_coeff', type=float, default='0.97',
46+
help="""filter coefficient to apply when reconstruct wave""")
47+
args = parser.parse_args()
48+
main(args)
49+

beamformer.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
#!/usr/bin/env python
2+
# coding=utf-8
3+
# wujian@17.10.26
4+
5+
import numpy as np
6+
7+
def main_egvec(mat):
8+
"""
9+
return the eigen vector as a estimate of steer vector,
10+
which has maximum eigen value
11+
"""
12+
assert mat.ndim == 2, "Input must be 2-dim matrix/ndarray"
13+
eigen_val, eigen_vec = np.linalg.eig(mat)
14+
max_index = np.argsort(eigen_val)[-1]
15+
return eigen_vec[max_index]
16+
17+
def apply_mvdr(steer_vector, sigma_noise, spectrum_onbin):
18+
"""
19+
inputs:
20+
steer_vector: M x 1 => d
21+
sigma_noise: M x M => \phi_v
22+
spectrum_onbin: T x M => y
23+
w = \phi_v^{-1} * d / (d^H * \phi_v^{-1} * d) => M x 1
24+
s = w^H * y^T => 1 x T
25+
"""
26+
# T x M => M x T
27+
y = np.matrix(spectrum_onbin).T
28+
# 1 x M => M x 1
29+
d = np.matrix(steer_vector).T
30+
phi_inv = np.matrix(sigma_noise).I
31+
# M x 1
32+
w = phi_inv * d / (d.H * phi_inv * d)
33+
s = w.H * y
34+
return s

cgmm.py

Lines changed: 54 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# wujian@17.10.25
44

55
import math
6+
import os
67
import numpy as np
78

89
LOG_PI = math.log(math.pi)
@@ -45,76 +46,90 @@ def __init__(self, num_bins, time_steps, num_channels):
4546
# lambda, phi, R for noisy/noise part
4647
self.lambda_ = np.random.rand(num_bins, time_steps).astype(np.complex)
4748
self.phi = np.ones([num_bins, time_steps]).astype(np.complex)
48-
# type matrix
49-
self.R = [np.matrix(np.eye(num_channels, num_channels).astype(np.complex)) \
50-
for i in range(num_bins)]
51-
49+
50+
def init_sigma(self, sigma):
51+
"""
52+
Inputs: sigma is a np.matrix list
53+
Keeps \sigma^{-1} and det(\sigma), \sigma equals \mean(y^H * y)
54+
"""
55+
assert type(sigma) == list
56+
self.sigma_inv = [mat.I for mat in sigma]
57+
self.sigma_det = [np.linalg.det(mat) for mat in sigma]
58+
5259
def check_inputs(self, inputs):
5360
num_bins, time_steps, num_channels = inputs.shape
5461
assert num_bins == self.num_bins and time_steps == self.time_steps \
55-
and num_channels == self.dim, 'inputs dim does not match CGMM config'
62+
and num_channels == self.dim, 'Inputs dim does not match CGMM config'
5663

5764
def log_likelihood(self, spectrums):
5865
self.check_inputs(spectrums)
5966
posteriors = 0.0
6067
for f in range(self.num_bins):
61-
sigma_inv = self.R[f].I
62-
sigma_det = np.linalg.det(self.R[f])
6368
for t in range(self.time_steps):
6469
posteriors += self.lambda_[f, t] * gmm_posterior(spectrums[f, t], \
65-
self.phi[f, t], sigma_inv, sigma_det)
70+
self.phi[f, t], self.sigma_inv[f], self.sigma_det[f])
6671
return posteriors
6772

6873
def accu_stats(self, spectrums):
6974
self.check_inputs(spectrums)
7075
stats = np.zeros([self.num_bins, self.time_steps]).astype(np.complex)
7176
for f in range(self.num_bins):
72-
sigma_inv = self.R[f].I
73-
sigma_det = np.linalg.det(self.R[f])
7477
for t in range(self.time_steps):
7578
stats[f, t] = gmm_posterior(spectrums[f, t], self.phi[f, t], \
76-
sigma_inv, sigma_det)
79+
self.sigma_inv[f], self.sigma_det[f])
7780
return stats
7881

7982
def update_lambda(self, spectrums, stats):
8083
print('update lambda...')
84+
assert stats.shape == self.lambda_.shape
8185
for f in range(self.num_bins):
82-
sigma_inv = self.R[f].I
83-
sigma_det = np.linalg.det(self.R[f])
8486
for t in range(self.time_steps):
8587
self.lambda_[f, t] = gmm_posterior(spectrums[f, t], self.phi[f, t], \
86-
sigma_inv, sigma_det) / stats[f, t]
88+
self.sigma_inv[f], self.sigma_det[f])
89+
self.lambda_ = self.lambda_ / stats
8790

88-
def update_phi(self, spectrums):
91+
def update_phi(self, covar):
8992
print('update phi...')
9093
for f in range(self.num_bins):
91-
inv_R = self.R[f].I
9294
for t in range(self.time_steps):
93-
y = np.matrix(spectrums[f, t])
94-
self.phi[f, t] = np.trace(y.H * y * inv_R) / self.dim
95+
self.phi[f, t] = np.trace(covar[f * self.time_steps + t] * self.sigma_inv[f])
96+
self.phi = self.phi / self.dim
9597

96-
def update_R(self, spectrums):
98+
def update_R(self, covar):
9799
print('update R...')
98100
for f in range(self.num_bins):
99101
sum_lambda = self.lambda_[f].sum()
100-
self.R[f] = 0
102+
R = np.matrix(np.zeros([self.dim, self.dim]).astype(np.complex))
101103
for t in range(self.time_steps):
102-
y = np.matrix(spectrums[f, t])
103-
self.R[f] += self.lambda_[f, t] * y.H * y / self.phi[f, t]
104-
self.R[f] = self.R[f] / sum_lambda
104+
R += self.lambda_[f, t] * covar[f * self.time_steps + t] / self.phi[f, t]
105+
R = R / sum_lambda
106+
self.sigma_inv[f] = R.I
107+
self.sigma_det[f] = np.linalg.det(R)
105108

106-
def update_parameters(self, spectrums, stats):
109+
def update_parameters(self, spectrums, covar, stats):
107110
self.check_inputs(spectrums)
111+
assert len(covar) == self.num_bins * self.time_steps and type(covar) == list
108112
self.update_lambda(spectrums, stats)
109-
self.update_phi(spectrums)
110-
self.update_R(spectrums)
113+
self.update_phi(covar)
114+
self.update_R(covar)
111115

112116
class CGMMTrainer(object):
113117
def __init__(self, num_bins, time_steps, num_channels):
114118
self.noise_part = CGMM(num_bins, time_steps, num_channels)
115119
self.noisy_part = CGMM(num_bins, time_steps, num_channels)
116120
self.num_samples = num_bins * time_steps
117-
121+
122+
def init_sigma(self, spectrums):
123+
# precompute the covariance matrix of each channel
124+
print("initialize sigma...")
125+
num_bins, time_steps, num_channels = spectrums.shape
126+
self.covar = [y.H * y for y in [np.matrix(spectrums[f, t]) \
127+
for f in range(num_bins) for t in range(time_steps)]]
128+
self.noise_part.init_sigma([np.matrix(np.eye(num_channels, \
129+
num_channels).astype(np.complex)) for f in range(num_bins)])
130+
self.noisy_part.init_sigma([sum(self.covar[f * time_steps: \
131+
(f + 1) * time_steps]) / time_steps for f in range(num_bins)])
132+
118133
def log_likelihood(self, spectrums):
119134
return (self.noise_part.log_likelihood(spectrums) + \
120135
self.noisy_part.log_likelihood(spectrums)) / self.num_samples
@@ -125,14 +140,22 @@ def accu_stats(self, spectrums):
125140
self.noise_part.accu_stats(spectrums)
126141

127142
def update_parameters(self, spectrums, stats):
128-
self.noise_part.update_parameters(spectrums, stats)
129-
self.noisy_part.update_parameters(spectrums, stats)
130-
143+
self.noise_part.update_parameters(spectrums, self.covar, stats)
144+
self.noisy_part.update_parameters(spectrums, self.covar, stats)
145+
146+
def save_param(self, dest):
147+
sigma_ny = [mat.I for mat in self.noisy_part.sigma_inv]
148+
sigma_ne = [mat.I for mat in self.noise_part.sigma_inv]
149+
if not os.path.exists(dest):
150+
os.mkdir(dest)
151+
np.save(os.path.join(dest, 'sigma_noisy'), sigma_ny)
152+
np.save(os.path.join(dest, 'sigma_noise'), sigma_ne)
153+
131154
def train(self, spectrums, iters=30):
155+
self.init_sigma(spectrums)
132156
print('Likelihood: ({0.real:.5f}, {0.imag:.5f}i)'.format(self.log_likelihood(spectrums)))
133157
for it in range(1, iters + 1):
134158
stats = self.accu_stats(spectrums)
135159
self.update_parameters(spectrums, stats)
136160
print('epoch {0:2d}: Likelihood = ({1.real:.5f}, {1.imag:.5f}i)'.format(it, \
137161
self.log_likelihood(spectrums)))
138-

train_cgmm.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,27 +3,31 @@
33
# wujian@17.10.26
44

55
import argparse
6+
import time
67
import numpy as np
78

89
from utils import MultiChannelWrapper
910
from cgmm import CGMMTrainer
1011

1112
def train(args):
1213
wrapper = MultiChannelWrapper(args.descriptor)
13-
(time_steps, num_bins), spectrums = wrapper.spectrums(transpose=True)
14-
num_bins, time_steps, num_channels = np.array(spectrums).shape
15-
trainer = CGMMTrainer(num_bins, time_steps, num_channels)
16-
trainer.train(spectrums, iters=args.iters)
14+
(time_steps, num_bins), spectrums = wrapper.spectrums()
15+
trainer = CGMMTrainer(num_bins, time_steps, len(spectrums))
16+
start_time = time.time()
17+
trainer.train(np.transpose(spectrums), iters=args.iters)
18+
finish_time = time.time()
19+
print('Total raining time: {:.3f}s'.format(finish_time - start_time))
20+
trainer.save_param(args.save_dir)
1721

1822
if __name__ == '__main__':
1923
parser = argparse.ArgumentParser(description="Training CGMM on multiple channel")
2024
parser.add_argument('descriptor', type=str,
21-
help="""descriptor of multiple channel location, format:
22-
/path/to/channel1
23-
...
24-
/path/to/channeln""")
25+
help="""descriptor of multiple channel location""")
2526
parser.add_argument('-i', '--iters',
2627
dest='iters', type=int, default='10',
2728
help="""number of iterations to train""")
29+
parser.add_argument('-s', '--save',
30+
dest='save_dir', type=str, default='',
31+
help="""directory to save sigma of CGMM""")
2832
args = parser.parse_args()
2933
train(args)

utils.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#!/usr/bin/env python
1+
#!/isr/bin/env python
22
# coding=utf-8
33
# wujian@17.10.24
44

@@ -33,7 +33,7 @@ def pre_emphase(signal, filter_coeff=0.97):
3333
signal[0] -= filter_coeff * signal[0]
3434
return signal
3535

36-
def compute_spectrum(wave_wrapper, window_type='hamming'):
36+
def compute_spectrum(wave_wrapper, transpose=False, window_type='hamming'):
3737
"""
3838
Compute the DFT of each frames in the wrapper:
3939
1. default apply hamming-window on each frame
@@ -53,7 +53,7 @@ def compute_spectrum(wave_wrapper, window_type='hamming'):
5353
for index in range(num_frames):
5454
feature_in[: frame_size] = frames[index] * window
5555
spectrum[index] = np.fft.rfft(feature_in)
56-
return spectrum
56+
return spectrum if not transpose else np.transpose(spectrum)
5757

5858
def plot_spectrum(spectrum, frame_duration, title="samples.wav"):
5959
"""
@@ -82,7 +82,7 @@ def write_wave(samples, frame_rate, dest):
8282
dest_wave = wave.open(dest, "wb")
8383
# 1 channel; int16 default
8484
dest_wave.setparams((1, 2, frame_rate, samples.size, 'NONE', 'not compressed'))
85-
dest_wave.writeframes(samples.astype(np.int16))
85+
dest_wave.writeframes(samples.astype(np.int16).tostring())
8686
print("1 channels; 2 bytes per sample; {num_samples} samples; " \
8787
"{frame_rate} samples per sec. OUT[{path}]".format(path=dest, \
8888
num_samples=samples.size, frame_rate=frame_rate))
@@ -166,9 +166,9 @@ def subframes(self, normalize=True):
166166
return shape_per_item, frames
167167

168168
def spectrums(self, transpose=False):
169-
spects = [compute_spectrum(wrapper) for wrapper in self.wrappers]
169+
spects = [compute_spectrum(wrapper, transpose) for wrapper in self.wrappers]
170170
shape_per_item = check_status(spects)
171-
return shape_per_item, (spects if not transpose else np.transpose(spects))
171+
return shape_per_item, spects
172172

173173
def __str__(self):
174174
return '\n'.join([str(wrapper) for wrapper in self.wrappers])

0 commit comments

Comments
 (0)