Example: Sparse deconvolution
Deconvolution of a spike signal with a comparison of two penalty functions. The algorithm is based on quadratic MM and uses a fast solver for banded systems.
Reference: Penalty and Shrinkage Functions for Sparse Signal Processing Ivan Selesnick, NYU-Poly, selesi@poly.edu, 2012
http://eeweb.poly.edu/iselesni/lecture_notes/
Contents
Start
close all clear randn('state',0); % set state so that example can be reproduced printme = @(txt) print('-dpdf', sprintf('figures/Example_%s',txt));
Create spike signal
N = 100; % N : length of signal s = zeros(N,1); k = [20 45 70]; a = [2 -1 1]; % a : spike amplitudes s(k) = a; % s : spike signal figure(1) clf subplot(2,1,1) stem(s, 'marker', 'none') box off title('Sparse signal'); ylim1 = [-1.5 2.5]; ylim(ylim1) printme('original')
Create observed signal
The simulated observed signal is obtained by convolving the signal with a 4-point impulse response and adding noise.
L = 4; % L : length of impulse response h = ones(L,1)/L; % h : impulse response M = N + L - 1; % M : length of observed signal sigma = 0.03; % sigma : standard deviation of AWGN w = sigma * randn(M,1); % w : additive zero-mean Gaussian noise (AWGN) y = conv(h,s) + w; % y : observed data figure(2) clf subplot(2,1,1) plot(y) box off xlim([0 M]) title('Observed signal'); printme('observed')
Create convolution matrix H
Create the convolution matrix H using Matlab sparse matrix functions 'sparse' and 'spdiags'. By making H a sparse matrix: - less memory is used, - multiplying vectors by H is faster, - fast algorithms for solving banded systems will be used.
H = sparse(M,N); e = ones(N,1); for i = 0:L-1 H = H + spdiags(h(i+1)*e, -i, M, N); % H : convolution matrix (sparse) end issparse(H) % confirm that H is a sparse matrix
ans = 1
Verify that H*s is the same as conv(h,s)
err = H*s - conv(h,s);
max_err = max(abs(err));
fprintf('Maximum error = %g\n', max_err)
Maximum error = 0
Display structure of convolution matrix. Note that the matrix is banded (sparse).
figure(1) clf spy(H(1:24,1:21))
Least square solution
Find the least square solution to the deconvolution problem.
lambda = 0.4; % lambda : regularization parameter x_L2 = (H'*H + lambda*speye(N)) \ (H' * y); % x_L2 : least square solution rmse_L2 = sqrt(mean((x_L2 - s).^2)); fprintf('Least square solution: RMSE = %f\n', rmse_L2); figure(2) clf subplot(2,1,1) stem(x_L2, 'marker', 'none') box off title('Deconvolution (least square solution)'); axnote(sprintf('RMSE = %.4f', rmse_L2)) printme('L2')
Least square solution: RMSE = 0.192453
Verify banded system solver
Verify that MATLAB calls LAPACKS's fast algorithm for solving banded systems
spparms('spumoni', 3); % Set sparse monitor flag to obtain diagnostic output x_L2 = (H'*H + lambda*speye(N)) \ (H' * y); % x_L2 : least square solution % Diagnostic output: % % sp\: bandwidth = 3+1+3. % sp\: is A diagonal? no. % sp\: is band density (1.00) > bandden (0.50) to try banded solver? yes. % sp\: is LAPACK's banded solver successful? yes. spparms('spumoni', 0); % Reset sparse monitor flag to no diagnostic output
sp\: bandwidth = 3+1+3. sp\: is A diagonal? no. sp\: is band density (1.00) > bandden (0.50) to try banded solver? yes. sp\: is LAPACK's banded solver successful? yes.
Sparse deconvolution - L1 norm penalty
The penalty function is phi(x) = lam abs(x)
T = 3 * sigma * sqrt(sum(abs(h).^2)); % T : threshold using '3-sigma' rule lam = T; Nit = 50; phi_L1 = @(x) lam * abs(x); wfun_L1 = @(x) abs(x)/lam; dphi_L1 = @(x) lam * sign(x); [x1, cost1] = deconv_MM(y, phi_L1, wfun_L1, H, Nit); rmse1 = sqrt(mean((x1 - s).^2)); fprintf('L1 norm solution: RMSE = %f\n', rmse1);
L1 norm solution: RMSE = 0.034072
Display cost function history
figure(1) clf plot(1:Nit, cost1, '.-') title('Cost function history'); xlabel('Iteration') xlim([0 Nit]) box off printme('CostFunction_L1')
The L1 solution is quite similar to the original signal (much more so than the least square solution).
figure(2) clf subplot(2,1,1) stem(x1, 'marker', 'none') box off ylim(ylim1); title('Deconvolution (L1 norm penalty)'); axnote(sprintf('RMSE = %.4f', rmse1)) printme('L1')
Verify optimality conditions
v1 = H'*(y-H*x1); t = [linspace(-2, -eps, 100) linspace(eps, 2, 100)]; figure(1) clf plot(x1, v1, '.') line(t, dphi_L1(t), 'linestyle', ':') box off ylim([-1 1]*lam*1.5) xlabel('x') ylabel('H^T(y - Hx)') title('Optimality scatter plot - L1 penalty'); printme('scatter_L1')
Sparse deconvolution - nn garrote penalty
The penalty function corresponds to the nn garrote shrinkage function.
dphi = @(x) 0.5 * (- abs(x) + sqrt( x.^2 + 4*T^2 ) ) .* sign(x); phi = @(x) T^2*asinh(abs(x)/(2*T)) +0.25 * (abs(x).*sqrt(4*T^2+x.^2)-x.^2) ; wfun = @(x) 1/(2*T^2) * abs(x) .* ( sqrt(x.^2 + 4*T^2) + abs(x) ); % x / dphi(x) [x2, cost2] = deconv_MM(y, phi, wfun, H, Nit); rmse2 = sqrt(mean((x2 - s).^2)); fprintf('nn-garrote solution: RMSE = %f\n', rmse2);
nn-garrote solution: RMSE = 0.004564
Display cost function history
figure(1) clf plot(1:Nit, cost2, '.-') title('Cost function history'); xlabel('Iteration') xlim([0 Nit]) box off printme('CostFunction_garrote')
figure(2) clf subplot(2,1,1) stem(x2, 'marker', 'none') box off ylim(ylim1); title('Deconvolution (nn-garrote)'); axnote(sprintf('RMSE = %.4f', rmse2)) printme('garrote')
Verify (local) optimality conditions
v2 = H'*(y-H*x2); figure(1) clf plot(x2, v2, '.') line(t, dphi(t), 'linestyle', ':') box off ylim([-1 1]*T*1.5) xlim(t([1 end])) xlabel('x') ylabel('H^T(y - Hx)') title('(local) optimality scatter plot - nn garrote penalty'); printme('scatter_garrote')
Comparison
The nn garrote penalty is more accurate than the L1 norm penalty. The result obtained using the L1 norm penalty is attenuated compared with the true signal.
n = 1:N; k = abs(s) > 1e-5; k1 = abs(x1) > 1e-5; k2 = abs(x2) > 1e-5; figure(1) clf subplot(3,1,[1 2]) plot(n(k), s(k), 'o') hold on plot(n(k2), x2(k2),'+', n(k1), x1(k1),'x') line([0 N], [0 0]) hold off xlim([0 N]) legend('true', 'nn garrote', 'L1', 'location', 'se') box off xlabel('n') ylabel('x(n)') printme('compare')