diff toolboxes/RBM/gen_training_rbm.m @ 0:e9a9cd732c1e tip

first hg version after svn
author wolffd
date Tue, 10 Feb 2015 15:05:51 +0000
parents
children
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/toolboxes/RBM/gen_training_rbm.m	Tue Feb 10 15:05:51 2015 +0000
@@ -0,0 +1,97 @@
+function [W visB hidB] = gen_training_krbm(conf,W,mW,train_file,train_label)
+%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
+% Training Knowledge Based RBM for generative classification         %  
+% conf: training setting                                             %
+% W: weights of connections                                          %
+% mW: mask of connections                                            %
+% -*-sontran2012-*-                                                  %
+%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
+%% load data
+vars = whos('-file', train_file);
+A = load(train_file,vars(1).name);
+data = A.(vars(1).name);
+vars = whos('-file', train_label);
+A = load(train_label,vars(1).name);
+label = A.(vars(1).name);
+assert(~isempty(data),'[KRBM-GEN] Data is empty'); 
+assert(size(data,1) == size(label,1),'[KRBM-GEN] Number of data and label mismatch'); 
+Classes = unique(label)';
+lNum = size(Classes,2);
+data = [data discrete2softmax(label,Classes)]                              
+%% initialization
+visNum  = size(data,2);
+hidNum  = conf.hidNum;
+sNum  = conf.sNum;
+lr    = conf.params(1);
+N     = 10;                                                                     % Number of epoch training with lr_1                     
+W     = [W;0.1*randn(visNum - size(W,1),size(W,2))];
+W     = [W 0.1*randn(size(W,1),hidNum-size(W,2))];
+
+DW    = zeros(size(W));
+visB  = zeros(1,visNum);
+DVB   = zeros(1,visNum);
+hidB  = zeros(1,hidNum);
+DHB   = zeros(1,hidNum);
+visP  = zeros(sNum,visNum);
+visN  = zeros(sNum,visNum);
+visNs = zeros(sNum,visNum);
+hidP  = zeros(sNum,hidNum);
+hidPs = zeros(sNum,hidNum);
+hidN  = zeros(sNum,hidNum);
+hidNs = zeros(sNum,hidNum);
+%% Reconstruction error & evaluation error & early stopping
+mse    = 0;
+omse   = 0;
+inc_count = 0;
+MAX_INC = 3;                                                                % If the error increase MAX_INC times continuously, then stop training
+%% Average best settings
+n_best  = 1;
+aW  = size(W);
+aVB = size(visB);
+aHB = size(hidB);
+%% ==================== Start training =========================== %%
+for i=1:conf.eNum
+    if i== N+1
+        lr = conf.params(2);
+    end
+    omse = mse;
+    mse = 0;
+    for j=1:conf.bNum
+       visP = data((j-1)*conf.sNum+1:j*conf.sNum,:);
+       %up
+       hidP = logistic(visP*W + repmat(hidB,sNum,1));
+       hidPs =  1*(hidP >rand(sNum,hidNum));
+       hidNs = hidPs;
+       for k=1:conf.gNum
+           % down
+           visN  = hidNs*W' + repmat(visB,sNum,1);
+           visN(:,1:visNum-lNum) = logistic(visN(:,1:visNum-lNum));
+           visN(:,visNum-lNum+1:visNum) = softmax_activation(visN(:,visNum-lNum+1:visNum));
+           visNs = [1*(visN(:,1:visNum-lNum)>rand(sNum,visNum-lNum)) visN(:,visNum-lNum+1:visNum)];
+           if j==5 && k==1, observe_reconstruction(visN(:,1:visNum-lNum),sNum,i,28,28); end
+           % up
+           hidN  = logistic(visNs*W + repmat(hidB,sNum,1));
+           hidNs = 1*(hidN>rand(sNum,hidNum));
+       end
+       % Compute MSE for reconstruction
+       rdiff = (visP - visN);
+       mse = mse + sum(sum(rdiff.*rdiff))/(sNum*visNum);
+       % Update W,visB,hidB
+       diff = (visP'*hidP - visNs'*hidN)/sNum;
+       DW  = lr*(diff - conf.params(4)*W) +  conf.params(3)*DW;
+       W   = W + DW;
+%        W = W.*mW;
+       DVB  = lr*sum(visP - visN,1)/sNum + conf.params(3)*DVB;
+       visB = visB + DVB;
+       DHB  = lr*sum(hidP - hidN,1)/sNum + conf.params(3)*DHB;
+       hidB = hidB + DHB;
+    end
+    if mse > omse
+        inc_count = inc_count + 1
+    else
+        inc_count = 0;
+    end
+    if inc_count> MAX_INC, break; end;
+    fprintf('Epoch %d  : MSE = %f\n',i,mse);
+end
+end
\ No newline at end of file