wolffd@0: function [W visB hidB] = gen_training_krbm(conf,W,mW,train_file,train_label) wolffd@0: %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% wolffd@0: % Training Knowledge Based RBM for generative classification % wolffd@0: % conf: training setting % wolffd@0: % W: weights of connections % wolffd@0: % mW: mask of connections % wolffd@0: % -*-sontran2012-*- % wolffd@0: %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% wolffd@0: %% load data wolffd@0: vars = whos('-file', train_file); wolffd@0: A = load(train_file,vars(1).name); wolffd@0: data = A.(vars(1).name); wolffd@0: vars = whos('-file', train_label); wolffd@0: A = load(train_label,vars(1).name); wolffd@0: label = A.(vars(1).name); wolffd@0: assert(~isempty(data),'[KRBM-GEN] Data is empty'); wolffd@0: assert(size(data,1) == size(label,1),'[KRBM-GEN] Number of data and label mismatch'); wolffd@0: Classes = unique(label)'; wolffd@0: lNum = size(Classes,2); wolffd@0: data = [data discrete2softmax(label,Classes)] wolffd@0: %% initialization wolffd@0: visNum = size(data,2); wolffd@0: hidNum = conf.hidNum; wolffd@0: sNum = conf.sNum; wolffd@0: lr = conf.params(1); wolffd@0: N = 10; % Number of epoch training with lr_1 wolffd@0: W = [W;0.1*randn(visNum - size(W,1),size(W,2))]; wolffd@0: W = [W 0.1*randn(size(W,1),hidNum-size(W,2))]; wolffd@0: wolffd@0: DW = zeros(size(W)); wolffd@0: visB = zeros(1,visNum); wolffd@0: DVB = zeros(1,visNum); wolffd@0: hidB = zeros(1,hidNum); wolffd@0: DHB = zeros(1,hidNum); wolffd@0: visP = zeros(sNum,visNum); wolffd@0: visN = zeros(sNum,visNum); wolffd@0: visNs = zeros(sNum,visNum); wolffd@0: hidP = zeros(sNum,hidNum); wolffd@0: hidPs = zeros(sNum,hidNum); wolffd@0: hidN = zeros(sNum,hidNum); wolffd@0: hidNs = zeros(sNum,hidNum); wolffd@0: %% Reconstruction error & evaluation error & early stopping wolffd@0: mse = 0; wolffd@0: omse = 0; wolffd@0: inc_count = 0; wolffd@0: MAX_INC = 3; % If the error increase MAX_INC times continuously, then stop training wolffd@0: %% Average best settings wolffd@0: n_best = 1; wolffd@0: aW = size(W); wolffd@0: aVB = size(visB); wolffd@0: aHB = size(hidB); wolffd@0: %% ==================== Start training =========================== %% wolffd@0: for i=1:conf.eNum wolffd@0: if i== N+1 wolffd@0: lr = conf.params(2); wolffd@0: end wolffd@0: omse = mse; wolffd@0: mse = 0; wolffd@0: for j=1:conf.bNum wolffd@0: visP = data((j-1)*conf.sNum+1:j*conf.sNum,:); wolffd@0: %up wolffd@0: hidP = logistic(visP*W + repmat(hidB,sNum,1)); wolffd@0: hidPs = 1*(hidP >rand(sNum,hidNum)); wolffd@0: hidNs = hidPs; wolffd@0: for k=1:conf.gNum wolffd@0: % down wolffd@0: visN = hidNs*W' + repmat(visB,sNum,1); wolffd@0: visN(:,1:visNum-lNum) = logistic(visN(:,1:visNum-lNum)); wolffd@0: visN(:,visNum-lNum+1:visNum) = softmax_activation(visN(:,visNum-lNum+1:visNum)); wolffd@0: visNs = [1*(visN(:,1:visNum-lNum)>rand(sNum,visNum-lNum)) visN(:,visNum-lNum+1:visNum)]; wolffd@0: if j==5 && k==1, observe_reconstruction(visN(:,1:visNum-lNum),sNum,i,28,28); end wolffd@0: % up wolffd@0: hidN = logistic(visNs*W + repmat(hidB,sNum,1)); wolffd@0: hidNs = 1*(hidN>rand(sNum,hidNum)); wolffd@0: end wolffd@0: % Compute MSE for reconstruction wolffd@0: rdiff = (visP - visN); wolffd@0: mse = mse + sum(sum(rdiff.*rdiff))/(sNum*visNum); wolffd@0: % Update W,visB,hidB wolffd@0: diff = (visP'*hidP - visNs'*hidN)/sNum; wolffd@0: DW = lr*(diff - conf.params(4)*W) + conf.params(3)*DW; wolffd@0: W = W + DW; wolffd@0: % W = W.*mW; wolffd@0: DVB = lr*sum(visP - visN,1)/sNum + conf.params(3)*DVB; wolffd@0: visB = visB + DVB; wolffd@0: DHB = lr*sum(hidP - hidN,1)/sNum + conf.params(3)*DHB; wolffd@0: hidB = hidB + DHB; wolffd@0: end wolffd@0: if mse > omse wolffd@0: inc_count = inc_count + 1 wolffd@0: else wolffd@0: inc_count = 0; wolffd@0: end wolffd@0: if inc_count> MAX_INC, break; end; wolffd@0: fprintf('Epoch %d : MSE = %f\n',i,mse); wolffd@0: end wolffd@0: end