%% Computing redundant finite difference and its adjoint

function [y] = finDiff_2D(in, dir, p)
    in = squeeze(in);

    % Compute wavelet coefficients from the image
    if dir == 1 
        y = forward(in, p);
    % Compute image from the wavelet coefficients
    elseif dir == -1
        y = backward(in, p);
    else
        error('Direction can only assume +1 or -1 values');
    end
return


function [y] = forward(x, p)
    
    wlets = genHaar(p.N);
    [nFE, nPE] = size(x);
    nBands = size(wlets,1);
    
    fWavelets = zeros(nBands, nFE, nPE);
    fWavelets(:, 1:2,1:2) = wlets;
    fWavelets = circshift(fWavelets, [0 floor(nFE/2)-1 floor(nPE/2)-1]);
    
    for i = 1:nBands
        fWavelets(i,:,:) = fftshift(fftn(ifftshift(fWavelets(i,:,:))));
%         figure; imagesc(real(squeeze(fWavelets(i,:,:))));
    end
    
    y = zeros(nBands, nFE, nPE);    
    x = fftshift(fftn(ifftshift(x)));

    for i = 1:nBands
        y(i,:,:) = x .* squeeze(fWavelets(i,:,:));        
        y(i,:,:) = fftshift(ifftn(ifftshift(y(i,:,:))));
    end
    y = permute(y, [2,3,1]);
    y = y(:);
return


function [y] = backward(x, p)

    wlets = genHaar(p.N);
    wlets = flipWavelet(wlets);
    nBands = size(wlets,1);
    nFE = p.n(1);
    nPE = p.n(2);
    x = reshape(x, nFE, nPE, nBands);
    x = permute(x, [3, 1, 2]);
    
    fWavelets = zeros(nBands, nFE, nPE);
    fWavelets(:, 1:2,1:2) = wlets;
    fWavelets = circshift(fWavelets, [0 floor(nFE/2) floor(nPE/2)]);
    for i = 1:nBands
        fWavelets(i,:,:) = fftshift(fftn(ifftshift(fWavelets(i,:,:))));
    end
    
    y = zeros(nFE, nPE);    
    for i = 1:nBands
       y = y + ( squeeze(fWavelets(i,:,:)) .* squeeze(fftshift(fftn(ifftshift(x(i,:,:))))) );
    end
    
    y = fftshift(ifftn(ifftshift(y)));    
return 


% Generate kernels for Haar
function [y] = genHaar(N)

if N == 2
    y = zeros(2,2,2);   
    y(1, :) = [-1  1  0  0]*1;
    y(2, :) = [-1  0  1  0]*1;

elseif N == 4
    y = zeros(4,2,2);
    y(1, :) = [-1  1  0  0]*1;
    y(2, :) = [-1  0  1  0]*1;
    y(3, :) = [-1  0  0  1]*sqrt(2);
    y(4, :) = [ 0 -1  1  0]*sqrt(2);   
end
    y = y./2;
return

% Flip the 2x2x2 kernel across all three axis
function [y] = flipWavelet(x)
    y = flip(flip(x,2),3);
return