Mercurial > hg > camir-aes2014
comparison toolboxes/RBM/training_rbm.m @ 0:e9a9cd732c1e tip
first hg version after svn
author | wolffd |
---|---|
date | Tue, 10 Feb 2015 15:05:51 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:e9a9cd732c1e |
---|---|
1 function [W visB hidB] = training_rbm(conf,W,data_file) | |
2 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% | |
3 % Training RBM % | |
4 % conf: training setting % | |
5 % W: weights of connections % | |
6 % -*-sontran2012-*- % | |
7 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% | |
8 %% load data | |
9 vars = whos('-file', data_file); | |
10 A = load(data_file,vars(1).name); | |
11 data = A.(vars(1).name); | |
12 assert(~isempty(data),'[KBRBM] Data is empty'); | |
13 %% initialization | |
14 visNum = size(data,2); | |
15 hidNum = conf.hidNum; | |
16 sNum = conf.sNum; | |
17 lr = conf.params(1); | |
18 N = 10; % Number of epoch training with lr_1 | |
19 W = 0.1*randn(visNum - size(W,1),size(W,2)); | |
20 W = 0.1*randn(size(W,1),hidNum-size(W,2)); | |
21 | |
22 DW = zeros(size(W)); | |
23 visB = zeros(1,visNum); | |
24 DVB = zeros(1,visNum); | |
25 hidB = zeros(1,hidNum); | |
26 DHB = zeros(1,hidNum); | |
27 visP = zeros(sNum,visNum); | |
28 visN = zeros(sNum,visNum); | |
29 visNs = zeros(sNum,visNum); | |
30 hidP = zeros(sNum,hidNum); | |
31 hidPs = zeros(sNum,hidNum); | |
32 hidN = zeros(sNum,hidNum); | |
33 hidNs = zeros(sNum,hidNum); | |
34 %% Reconstruction error & evaluation error & early stopping | |
35 mse = 0; | |
36 omse = 0; | |
37 inc_count = 0; | |
38 MAX_INC = 3000; % If the error increase MAX_INC times continuously, then stop training | |
39 %% Average best settings | |
40 n_best = 1; | |
41 aW = size(W); | |
42 aVB = size(visB); | |
43 aHB = size(hidB); | |
44 %% Plotting | |
45 h = plot(nan); | |
46 %% ==================== Start training =========================== %% | |
47 for i=1:conf.eNum | |
48 if i== N+1 | |
49 lr = conf.params(2); | |
50 end | |
51 omse = mse; | |
52 mse = 0; | |
53 for j=1:conf.bNum | |
54 visP = data((j-1)*conf.sNum+1:j*conf.sNum,:); | |
55 %up | |
56 hidP = logistic(visP*W + repmat(hidB,sNum,1)); | |
57 hidPs = 1*(hidP >rand(sNum,hidNum)); | |
58 hidNs = hidPs; | |
59 for k=1:conf.gNum | |
60 % down | |
61 visN = logistic(hidNs*W' + repmat(visB,sNum,1)); | |
62 visNs = 1*(visN>rand(sNum,visNum)); | |
63 if j==5 && k==1, observe_reconstruction(visN,sNum,i,28,28); end | |
64 % up | |
65 hidN = logistic(visNs*W + repmat(hidB,sNum,1)); | |
66 hidNs = 1*(hidN>rand(sNum,hidNum)); | |
67 end | |
68 % Compute MSE for reconstruction | |
69 rdiff = (visP - visN); | |
70 mse = mse + sum(sum(rdiff.*rdiff))/(sNum*visNum); | |
71 % Update W,visB,hidB | |
72 diff = (visP'*hidP - visNs'*hidN)/sNum; | |
73 DW = lr*(diff - conf.params(4)*W) + conf.params(3)*DW; | |
74 W = W + DW; | |
75 % W = W.*mW; | |
76 DVB = lr*sum(visP - visN,1)/sNum + conf.params(3)*DVB; | |
77 visB = visB + DVB; | |
78 DHB = lr*sum(hidP - hidN,1)/sNum + conf.params(3)*DHB; | |
79 hidB = hidB + DHB; | |
80 end | |
81 %% | |
82 mse_plot(i) = mse; | |
83 axis([0 (conf.eNum+1) 0 2]); | |
84 set(h,'YData',mse_plot); | |
85 drawnow; | |
86 % plot(mse_plot,'XDataSource','real(mse_plot)','YDataSource','imag(mse_plot)') | |
87 % linkdata on; | |
88 | |
89 if mse > omse | |
90 inc_count = inc_count + 1 | |
91 else | |
92 inc_count = 0; | |
93 end | |
94 if inc_count> MAX_INC, break; end; | |
95 fprintf('Epoch %d : MSE = %f\n',i,mse); | |
96 end | |
97 end |