view toolboxes/RBM/training_rbm.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
line wrap: on
line source
function [W visB hidB] = training_rbm(conf,W,data_file)
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Training RBM                                                       %  
% conf: training setting                                             %
% W: weights of connections                                          %
% -*-sontran2012-*-                                                  %
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%% load data
vars = whos('-file', data_file);
A = load(data_file,vars(1).name);
data = A.(vars(1).name);
assert(~isempty(data),'[KBRBM] Data is empty'); 
%% initialization
visNum  = size(data,2);
hidNum  = conf.hidNum;
sNum  = conf.sNum;
lr    = conf.params(1);
N     = 10;                                                                     % Number of epoch training with lr_1                     
W     = 0.1*randn(visNum - size(W,1),size(W,2));
W     = 0.1*randn(size(W,1),hidNum-size(W,2));

DW    = zeros(size(W));
visB  = zeros(1,visNum);
DVB   = zeros(1,visNum);
hidB  = zeros(1,hidNum);
DHB   = zeros(1,hidNum);
visP  = zeros(sNum,visNum);
visN  = zeros(sNum,visNum);
visNs = zeros(sNum,visNum);
hidP  = zeros(sNum,hidNum);
hidPs = zeros(sNum,hidNum);
hidN  = zeros(sNum,hidNum);
hidNs = zeros(sNum,hidNum);
%% Reconstruction error & evaluation error & early stopping
mse    = 0;
omse   = 0;
inc_count = 0;
MAX_INC = 3000;                                                                % If the error increase MAX_INC times continuously, then stop training
%% Average best settings
n_best  = 1;
aW  = size(W);
aVB = size(visB);
aHB = size(hidB);
%% Plotting
h = plot(nan);
%% ==================== Start training =========================== %%
for i=1:conf.eNum
    if i== N+1
        lr = conf.params(2);
    end
    omse = mse;
    mse = 0;
    for j=1:conf.bNum
       visP = data((j-1)*conf.sNum+1:j*conf.sNum,:);
       %up
       hidP = logistic(visP*W + repmat(hidB,sNum,1));
       hidPs =  1*(hidP >rand(sNum,hidNum));
       hidNs = hidPs;
       for k=1:conf.gNum
           % down
           visN  = logistic(hidNs*W' + repmat(visB,sNum,1));
           visNs = 1*(visN>rand(sNum,visNum));
           if j==5 && k==1, observe_reconstruction(visN,sNum,i,28,28); end
           % up
           hidN  = logistic(visNs*W + repmat(hidB,sNum,1));
           hidNs = 1*(hidN>rand(sNum,hidNum));
       end
       % Compute MSE for reconstruction
       rdiff = (visP - visN);
       mse = mse + sum(sum(rdiff.*rdiff))/(sNum*visNum);
       % Update W,visB,hidB
       diff = (visP'*hidP - visNs'*hidN)/sNum;
       DW  = lr*(diff - conf.params(4)*W) +  conf.params(3)*DW;
       W   = W + DW;
%        W = W.*mW;
       DVB  = lr*sum(visP - visN,1)/sNum + conf.params(3)*DVB;
       visB = visB + DVB;
       DHB  = lr*sum(hidP - hidN,1)/sNum + conf.params(3)*DHB;
       hidB = hidB + DHB;
    end
    %% 
    mse_plot(i) = mse;
    axis([0 (conf.eNum+1) 0 2]);
    set(h,'YData',mse_plot);
    drawnow;
%    plot(mse_plot,'XDataSource','real(mse_plot)','YDataSource','imag(mse_plot)')
%     linkdata on;
    
    if mse > omse
        inc_count = inc_count + 1
    else
        inc_count = 0;
    end
    if inc_count> MAX_INC, break; end;
    fprintf('Epoch %d  : MSE = %f\n',i,mse);
end
end