wolffd@0: function demprior(action); wolffd@0: %DEMPRIOR Demonstrate sampling from a multi-parameter Gaussian prior. wolffd@0: % wolffd@0: % Description wolffd@0: % This function plots the functions represented by a multi-layer wolffd@0: % perceptron network when the weights are set to values drawn from a wolffd@0: % Gaussian prior distribution. The parameters AW1, AB1 AW2 and AB2 wolffd@0: % control the inverse variances of the first-layer weights, the hidden wolffd@0: % unit biases, the second-layer weights and the output unit biases wolffd@0: % respectively. Their values can be adjusted on a logarithmic scale wolffd@0: % using the sliders, or by typing values into the text boxes and wolffd@0: % pressing the return key. wolffd@0: % wolffd@0: % See also wolffd@0: % MLP wolffd@0: % wolffd@0: wolffd@0: % Copyright (c) Ian T Nabney (1996-2001) wolffd@0: wolffd@0: if nargin<1, wolffd@0: action='initialize'; wolffd@0: end; wolffd@0: wolffd@0: if strcmp(action,'initialize') wolffd@0: wolffd@0: aw1 = 0.01; wolffd@0: ab1 = 0.1; wolffd@0: aw2 = 1.0; wolffd@0: ab2 = 1.0; wolffd@0: wolffd@0: % Create FIGURE wolffd@0: fig=figure( ... wolffd@0: 'Name','Sampling from a Gaussian prior', ... wolffd@0: 'Position', [50 50 480 380], ... wolffd@0: 'NumberTitle','off', ... wolffd@0: 'Color', [0.8 0.8 0.8], ... wolffd@0: 'Visible','on'); wolffd@0: wolffd@0: % The TITLE BAR frame wolffd@0: uicontrol(fig, ... wolffd@0: 'Style','frame', ... wolffd@0: 'Units','normalized', ... wolffd@0: 'HorizontalAlignment', 'center', ... wolffd@0: 'Position', [0.5 0.82 0.45 0.1], ... wolffd@0: 'BackgroundColor',[0.60 0.60 0.60]); wolffd@0: wolffd@0: % The TITLE BAR text wolffd@0: uicontrol(fig, ... wolffd@0: 'Style', 'text', ... wolffd@0: 'Units', 'normalized', ... wolffd@0: 'BackgroundColor', [0.6 0.6 0.6], ... wolffd@0: 'Position', [0.54 0.85 0.40 0.05], ... wolffd@0: 'HorizontalAlignment', 'left', ... wolffd@0: 'String', 'Sampling from a Gaussian prior'); wolffd@0: wolffd@0: % Frames to enclose sliders wolffd@0: uicontrol(fig, ... wolffd@0: 'Style', 'frame', ... wolffd@0: 'Units', 'normalized', ... wolffd@0: 'BackgroundColor', [0.6 0.6 0.6], ... wolffd@0: 'Position', [0.05 0.08 0.35 0.18]); wolffd@0: wolffd@0: uicontrol(fig, ... wolffd@0: 'Style', 'frame', ... wolffd@0: 'Units', 'normalized', ... wolffd@0: 'BackgroundColor', [0.6 0.6 0.6], ... wolffd@0: 'Position', [0.05 0.3 0.35 0.18]); wolffd@0: wolffd@0: uicontrol(fig, ... wolffd@0: 'Style', 'frame', ... wolffd@0: 'Units', 'normalized', ... wolffd@0: 'BackgroundColor', [0.6 0.6 0.6], ... wolffd@0: 'Position', [0.05 0.52 0.35 0.18]); wolffd@0: wolffd@0: uicontrol(fig, ... wolffd@0: 'Style', 'frame', ... wolffd@0: 'Units', 'normalized', ... wolffd@0: 'BackgroundColor', [0.6 0.6 0.6], ... wolffd@0: 'Position', [0.05 0.74 0.35 0.18]); wolffd@0: wolffd@0: % Frame text wolffd@0: uicontrol(fig, ... wolffd@0: 'Style', 'text', ... wolffd@0: 'Units', 'normalized', ... wolffd@0: 'HorizontalAlignment', 'left', ... wolffd@0: 'BackgroundColor', [0.6 0.6 0.6], ... wolffd@0: 'Position', [0.07 0.17 0.06 0.07], ... wolffd@0: 'String', 'aw1'); wolffd@0: wolffd@0: % Frame text wolffd@0: uicontrol(fig, ... wolffd@0: 'Style', 'text', ... wolffd@0: 'Units', 'normalized', ... wolffd@0: 'HorizontalAlignment', 'left', ... wolffd@0: 'BackgroundColor', [0.6 0.6 0.6], ... wolffd@0: 'Position', [0.07 0.39 0.06 0.07], ... wolffd@0: 'String', 'ab1'); wolffd@0: wolffd@0: % Frame text wolffd@0: uicontrol(fig, ... wolffd@0: 'Style', 'text', ... wolffd@0: 'Units', 'normalized', ... wolffd@0: 'HorizontalAlignment', 'left', ... wolffd@0: 'BackgroundColor', [0.6 0.6 0.6], ... wolffd@0: 'Position', [0.07 0.61 0.06 0.07], ... wolffd@0: 'String', 'aw2'); wolffd@0: wolffd@0: % Frame text wolffd@0: uicontrol(fig, ... wolffd@0: 'Style', 'text', ... wolffd@0: 'Units', 'normalized', ... wolffd@0: 'HorizontalAlignment', 'left', ... wolffd@0: 'BackgroundColor', [0.6 0.6 0.6], ... wolffd@0: 'Position', [0.07 0.83 0.06 0.07], ... wolffd@0: 'String', 'ab2'); wolffd@0: wolffd@0: % Slider wolffd@0: minval = -5; maxval = 5; wolffd@0: aw1slide = uicontrol(fig, ... wolffd@0: 'Style', 'slider', ... wolffd@0: 'Units', 'normalized', ... wolffd@0: 'Value', log10(aw1), ... wolffd@0: 'BackgroundColor', [0.8 0.8 0.8], ... wolffd@0: 'Position', [0.07 0.1 0.31 0.05], ... wolffd@0: 'Min', minval, 'Max', maxval, ... wolffd@0: 'Callback', 'demprior update'); wolffd@0: wolffd@0: % Slider wolffd@0: ab1slide = uicontrol(fig, ... wolffd@0: 'Style', 'slider', ... wolffd@0: 'Units', 'normalized', ... wolffd@0: 'Value', log10(ab1), ... wolffd@0: 'BackgroundColor', [0.8 0.8 0.8], ... wolffd@0: 'Position', [0.07 0.32 0.31 0.05], ... wolffd@0: 'Min', minval, 'Max', maxval, ... wolffd@0: 'Callback', 'demprior update'); wolffd@0: wolffd@0: % Slider wolffd@0: aw2slide = uicontrol(fig, ... wolffd@0: 'Style', 'slider', ... wolffd@0: 'Units', 'normalized', ... wolffd@0: 'Value', log10(aw2), ... wolffd@0: 'BackgroundColor', [0.8 0.8 0.8], ... wolffd@0: 'Position', [0.07 0.54 0.31 0.05], ... wolffd@0: 'Min', minval, 'Max', maxval, ... wolffd@0: 'Callback', 'demprior update'); wolffd@0: wolffd@0: % Slider wolffd@0: ab2slide = uicontrol(fig, ... wolffd@0: 'Style', 'slider', ... wolffd@0: 'Units', 'normalized', ... wolffd@0: 'Value', log10(ab2), ... wolffd@0: 'BackgroundColor', [0.8 0.8 0.8], ... wolffd@0: 'Position', [0.07 0.76 0.31 0.05], ... wolffd@0: 'Min', minval, 'Max', maxval, ... wolffd@0: 'Callback', 'demprior update'); wolffd@0: wolffd@0: % The graph box wolffd@0: haxes = axes('Position', [0.5 0.28 0.45 0.45], ... wolffd@0: 'Units', 'normalized', ... wolffd@0: 'Visible', 'on'); wolffd@0: wolffd@0: % Text display of hyper-parameter values wolffd@0: wolffd@0: format = '%8f'; wolffd@0: wolffd@0: aw1val = uicontrol(fig, ... wolffd@0: 'Style', 'edit', ... wolffd@0: 'Units', 'normalized', ... wolffd@0: 'Position', [0.15 0.17 0.23 0.07], ... wolffd@0: 'String', sprintf(format, aw1), ... wolffd@0: 'Callback', 'demprior newval'); wolffd@0: wolffd@0: ab1val = uicontrol(fig, ... wolffd@0: 'Style', 'edit', ... wolffd@0: 'Units', 'normalized', ... wolffd@0: 'Position', [0.15 0.39 0.23 0.07], ... wolffd@0: 'String', sprintf(format, ab1), ... wolffd@0: 'Callback', 'demprior newval'); wolffd@0: wolffd@0: aw2val = uicontrol(fig, ... wolffd@0: 'Style', 'edit', ... wolffd@0: 'Units', 'normalized', ... wolffd@0: 'Position', [0.15 0.61 0.23 0.07], ... wolffd@0: 'String', sprintf(format, aw2), ... wolffd@0: 'Callback', 'demprior newval'); wolffd@0: wolffd@0: ab2val = uicontrol(fig, ... wolffd@0: 'Style', 'edit', ... wolffd@0: 'Units', 'normalized', ... wolffd@0: 'Position', [0.15 0.83 0.23 0.07], ... wolffd@0: 'String', sprintf(format, ab2), ... wolffd@0: 'Callback', 'demprior newval'); wolffd@0: wolffd@0: % The SAMPLE button wolffd@0: uicontrol(fig, ... wolffd@0: 'Style','push', ... wolffd@0: 'Units','normalized', ... wolffd@0: 'BackgroundColor', [0.6 0.6 0.6], ... wolffd@0: 'Position',[0.5 0.08 0.13 0.1], ... wolffd@0: 'String','Sample', ... wolffd@0: 'Callback','demprior replot'); wolffd@0: wolffd@0: % The CLOSE button wolffd@0: uicontrol(fig, ... wolffd@0: 'Style','push', ... wolffd@0: 'Units','normalized', ... wolffd@0: 'BackgroundColor', [0.6 0.6 0.6], ... wolffd@0: 'Position',[0.82 0.08 0.13 0.1], ... wolffd@0: 'String','Close', ... wolffd@0: 'Callback','close(gcf)'); wolffd@0: wolffd@0: % The HELP button wolffd@0: uicontrol(fig, ... wolffd@0: 'Style','push', ... wolffd@0: 'Units','normalized', ... wolffd@0: 'BackgroundColor', [0.6 0.6 0.6], ... wolffd@0: 'Position',[0.66 0.08 0.13 0.1], ... wolffd@0: 'String','Help', ... wolffd@0: 'Callback','demprior help'); wolffd@0: wolffd@0: % Save handles to objects wolffd@0: wolffd@0: hndlList=[fig aw1slide ab1slide aw2slide ab2slide aw1val ab1val aw2val ... wolffd@0: ab2val haxes]; wolffd@0: set(fig, 'UserData', hndlList); wolffd@0: wolffd@0: demprior('replot') wolffd@0: wolffd@0: wolffd@0: elseif strcmp(action, 'update'), wolffd@0: wolffd@0: % Update when a slider is moved. wolffd@0: wolffd@0: hndlList = get(gcf, 'UserData'); wolffd@0: aw1slide = hndlList(2); wolffd@0: ab1slide = hndlList(3); wolffd@0: aw2slide = hndlList(4); wolffd@0: ab2slide = hndlList(5); wolffd@0: aw1val = hndlList(6); wolffd@0: ab1val = hndlList(7); wolffd@0: aw2val = hndlList(8); wolffd@0: ab2val = hndlList(9); wolffd@0: haxes = hndlList(10); wolffd@0: wolffd@0: aw1 = 10^get(aw1slide, 'Value'); wolffd@0: ab1 = 10^get(ab1slide, 'Value'); wolffd@0: aw2 = 10^get(aw2slide, 'Value'); wolffd@0: ab2 = 10^get(ab2slide, 'Value'); wolffd@0: wolffd@0: format = '%8f'; wolffd@0: set(aw1val, 'String', sprintf(format, aw1)); wolffd@0: set(ab1val, 'String', sprintf(format, ab1)); wolffd@0: set(aw2val, 'String', sprintf(format, aw2)); wolffd@0: set(ab2val, 'String', sprintf(format, ab2)); wolffd@0: wolffd@0: demprior('replot'); wolffd@0: wolffd@0: elseif strcmp(action, 'newval'), wolffd@0: wolffd@0: % Update when text is changed. wolffd@0: wolffd@0: hndlList = get(gcf, 'UserData'); wolffd@0: aw1slide = hndlList(2); wolffd@0: ab1slide = hndlList(3); wolffd@0: aw2slide = hndlList(4); wolffd@0: ab2slide = hndlList(5); wolffd@0: aw1val = hndlList(6); wolffd@0: ab1val = hndlList(7); wolffd@0: aw2val = hndlList(8); wolffd@0: ab2val = hndlList(9); wolffd@0: haxes = hndlList(10); wolffd@0: wolffd@0: aw1 = sscanf(get(aw1val, 'String'), '%f'); wolffd@0: ab1 = sscanf(get(ab1val, 'String'), '%f'); wolffd@0: aw2 = sscanf(get(aw2val, 'String'), '%f'); wolffd@0: ab2 = sscanf(get(ab2val, 'String'), '%f'); wolffd@0: wolffd@0: set(aw1slide, 'Value', log10(aw1)); wolffd@0: set(ab1slide, 'Value', log10(ab1)); wolffd@0: set(aw2slide, 'Value', log10(aw2)); wolffd@0: set(ab2slide, 'Value', log10(ab2)); wolffd@0: wolffd@0: demprior('replot'); wolffd@0: wolffd@0: elseif strcmp(action, 'replot'), wolffd@0: wolffd@0: % Re-sample from the prior and plot graphs. wolffd@0: wolffd@0: oldFigNumber=watchon; wolffd@0: wolffd@0: hndlList = get(gcf, 'UserData'); wolffd@0: aw1slide = hndlList(2); wolffd@0: ab1slide = hndlList(3); wolffd@0: aw2slide = hndlList(4); wolffd@0: ab2slide = hndlList(5); wolffd@0: haxes = hndlList(10); wolffd@0: wolffd@0: aw1 = 10^get(aw1slide, 'Value'); wolffd@0: ab1 = 10^get(ab1slide, 'Value'); wolffd@0: aw2 = 10^get(aw2slide, 'Value'); wolffd@0: ab2 = 10^get(ab2slide, 'Value'); wolffd@0: wolffd@0: axes(haxes); wolffd@0: cla wolffd@0: set(gca, ... wolffd@0: 'Box', 'on', ... wolffd@0: 'Color', [0 0 0], ... wolffd@0: 'XColor', [0 0 0], ... wolffd@0: 'YColor', [0 0 0], ... wolffd@0: 'FontSize', 14); wolffd@0: axis([-1 1 -10 10]); wolffd@0: set(gca,'DefaultLineLineWidth', 2); wolffd@0: wolffd@0: nhidden = 12; wolffd@0: prior = mlpprior(1, nhidden, 1, aw1, ab1, aw2, ab2); wolffd@0: xvals = -1:0.005:1; wolffd@0: nsample = 10; % Number of samples from prior. wolffd@0: hold on wolffd@0: plot([-1 0; 1 0], [0 -10; 0 10], 'b--'); wolffd@0: net = mlp(1, nhidden, 1, 'linear', prior); wolffd@0: for i = 1:nsample wolffd@0: net = mlpinit(net, prior); wolffd@0: yvals = mlpfwd(net, xvals'); wolffd@0: plot(xvals', yvals, 'y'); wolffd@0: end wolffd@0: wolffd@0: watchoff(oldFigNumber); wolffd@0: wolffd@0: elseif strcmp(action, 'help'), wolffd@0: wolffd@0: % Provide help to user. wolffd@0: wolffd@0: oldFigNumber=watchon; wolffd@0: wolffd@0: helpfig = figure('Position', [100 100 480 400], ... wolffd@0: 'Name', 'Help', ... wolffd@0: 'NumberTitle', 'off', ... wolffd@0: 'Color', [0.8 0.8 0.8], ... wolffd@0: 'Visible','on'); wolffd@0: wolffd@0: % The HELP TITLE BAR frame wolffd@0: uicontrol(helpfig, ... wolffd@0: 'Style','frame', ... wolffd@0: 'Units','normalized', ... wolffd@0: 'HorizontalAlignment', 'center', ... wolffd@0: 'Position', [0.05 0.82 0.9 0.1], ... wolffd@0: 'BackgroundColor',[0.60 0.60 0.60]); wolffd@0: wolffd@0: % The HELP TITLE BAR text wolffd@0: uicontrol(helpfig, ... wolffd@0: 'Style', 'text', ... wolffd@0: 'Units', 'normalized', ... wolffd@0: 'BackgroundColor', [0.6 0.6 0.6], ... wolffd@0: 'Position', [0.26 0.85 0.6 0.05], ... wolffd@0: 'HorizontalAlignment', 'left', ... wolffd@0: 'String', 'Help: Sampling from a Gaussian Prior'); wolffd@0: wolffd@0: helpstr1 = strcat( ... wolffd@0: 'This demonstration shows the effects of sampling from a Gaussian', ... wolffd@0: ' prior over weights for a two-layer feed-forward network. The', ... wolffd@0: ' parameters aw1, ab1, aw2 and ab2 control the inverse variances of', ... wolffd@0: ' the first-layer weights, the hidden unit biases, the second-layer', ... wolffd@0: ' weights and the output unit biases respectively. Their values can', ... wolffd@0: ' be adjusted on a logarithmic scale using the sliders, or by', ... wolffd@0: ' typing values into the text boxes and pressing the return key.', ... wolffd@0: ' After setting these values, press the ''Sample'' button to see a', ... wolffd@0: ' new sample from the prior. '); wolffd@0: helpstr2 = strcat( ... wolffd@0: 'Observe how aw1 controls the horizontal length-scale of the', ... wolffd@0: ' variation in the functions, ab1 controls the input range over', ... wolffd@0: ' such variations occur, aw2 sets the vertical scale of the output', ... wolffd@0: ' and ab2 sets the vertical off-set of the output. The network has', ... wolffd@0: ' 12 hidden units. '); wolffd@0: hstr(1) = {helpstr1}; wolffd@0: hstr(2) = {''}; wolffd@0: hstr(3) = {helpstr2}; wolffd@0: wolffd@0: % The HELP text wolffd@0: helpui = uicontrol(helpfig, ... wolffd@0: 'Style', 'edit', ... wolffd@0: 'Units', 'normalized', ... wolffd@0: 'ForegroundColor', [0 0 0], ... wolffd@0: 'HorizontalAlignment', 'left', ... wolffd@0: 'BackgroundColor', [1 1 1], ... wolffd@0: 'Min', 0, ... wolffd@0: 'Max', 2, ... wolffd@0: 'Position', [0.05 0.2 0.9 0.8]); wolffd@0: wolffd@0: [hstrw , newpos] = textwrap(helpui, hstr, 70); wolffd@0: set(helpui, 'String', hstrw, 'Position', [0.05, 0.2, 0.9, newpos(4)]); wolffd@0: wolffd@0: wolffd@0: % The CLOSE button wolffd@0: uicontrol(helpfig, ... wolffd@0: 'Style','push', ... wolffd@0: 'Units','normalized', ... wolffd@0: 'BackgroundColor', [0.6 0.6 0.6], ... wolffd@0: 'Position',[0.4 0.05 0.2 0.1], ... wolffd@0: 'String','Close', ... wolffd@0: 'Callback','close(gcf)'); wolffd@0: wolffd@0: watchoff(oldFigNumber); wolffd@0: wolffd@0: end; wolffd@0: