samer@0
|
1 % mscaler - Dynamic additive and multiplicative normalisation in Matlab
|
samer@0
|
2 %
|
samer@0
|
3 % mscaler ::
|
samer@0
|
4 % model(real) ~'model for observations',
|
samer@0
|
5 % options {
|
samer@0
|
6 % scale :: nonneg /1 ~'initial scale factor';
|
samer@0
|
7 % offset :: real /0 ~'initial offset';
|
samer@0
|
8 % scale_rate :: nonneg/0.02 ~'scale adaptation rate';
|
samer@0
|
9 % offset_rate :: nonneg/1e-7 ~'offset adaptation rate';
|
samer@0
|
10 % batch :: natural/4 ~'batch size for updates'
|
samer@0
|
11 % }
|
samer@0
|
12 % -> arrow({[[N]]},{[[N]]},mscaler_state).
|
samer@0
|
13
|
samer@0
|
14 function o=mscaler(model,varargin)
|
samer@37
|
15 opts=options('scale',1,'offset',0, ...
|
samer@0
|
16 'scale_rate',0.02,'offset_rate',1e-7,'batch',4, ...
|
samer@0
|
17 'nargout', 1, ...
|
samer@0
|
18 varargin{:});
|
samer@0
|
19
|
samer@0
|
20 rates=[opts.scale_rate;opts.offset_rate];
|
samer@0
|
21 batch=opts.batch;
|
samer@0
|
22 score=scorefn(model);
|
samer@0
|
23 z3 = [0;0;0];
|
samer@0
|
24
|
samer@0
|
25 if (opts.batch==1)
|
samer@0
|
26 o=loop(@update,@(s)[opts.scale;opts.offset]);
|
samer@0
|
27 else
|
samer@0
|
28 o=loop(@update_batched,@(s){z3,[opts.scale;opts.offset]});
|
samer@0
|
29 end
|
samer@0
|
30
|
samer@0
|
31 function ss=stats(y,phi),
|
samer@0
|
32 n=size(y,1);
|
samer@0
|
33 ss=[n;sum(y.*phi)-n;sum(phi)];
|
samer@0
|
34 end
|
samer@0
|
35
|
samer@0
|
36 function [y,phi]=infer(params,x)
|
samer@0
|
37 y = (x-params(2))/params(1);
|
samer@0
|
38 phi = score(y);
|
samer@0
|
39 end
|
samer@0
|
40
|
samer@0
|
41 function params=updparams(params,stat1,stat2)
|
samer@0
|
42 params = [ params(1)*exp(rates(1)*stat1);
|
samer@0
|
43 params(2)+rates(2)*params(1)*stat2 ];
|
samer@0
|
44 end
|
samer@0
|
45
|
samer@0
|
46 function [y,state]=update(x,state)
|
samer@0
|
47 [y,phi] = infer(state,x);
|
samer@0
|
48 state = updparams(state, mean(y(:).*phi(:))-1, mean(phi(:)));
|
samer@0
|
49 end
|
samer@0
|
50
|
samer@0
|
51 function [y,state]=update_batched(x,state)
|
samer@0
|
52 params=state{2};
|
samer@0
|
53 [y,phi] = infer(params,x);
|
samer@0
|
54 ss = state{1} + stats(y(:),phi(:));
|
samer@0
|
55 if (ss(1)<batch) state{1}=ss;
|
samer@0
|
56 else state = {z3, updparams(params,ss(2)/ss(1),ss(3)/ss(1))};
|
samer@0
|
57 end
|
samer@0
|
58 end
|
samer@0
|
59 end
|