view src/samer/maths/opt/MinimiserBase.java @ 8:5e3cbbf173aa tip

Reorganise some more
author samer
date Fri, 05 Apr 2019 22:41:58 +0100
parents bf79fb79ee13
children
line wrap: on
line source
/*
 *	Copyright (c) 2000, Samer Abdallah, King's College London.
 *	All rights reserved.
 *
 *	This software is provided AS iS and WITHOUT ANY WARRANTY;
 *	without even the implied warranty of MERCHANTABILITY or
 *	FITNESS FOR A PARTICULAR PURPOSE.
 */

package	samer.maths.opt;
import  samer.maths.*;
import  samer.core.*;
import  samer.core.types.*;
import  samer.core.util.*;
import  samer.tools.*;
import	samer.core.util.swing.*;
import	java.awt.*;
import	java.util.*;

/**
		This is a base class for running multidimensional optimisation
		It requires the following objects to exist in object space
		(they are used by the State base class default constructor)

		Functionx	"functionx":	the function to be minimised
		VVector		"vector":		the vector to work with

		ConstrainedMinimiser also requires "constraintClass" to
		exist.

  */


public abstract class MinimiserBase extends State implements SafeTask, Agent
{
	// private Node	node;

	protected VVector					x1,x2,g1,g2,vh;
	protected VDouble				beta; // initial step length
	protected CubicLineSearch		ls;
	protected LSCondition			lstest; // convergence test for line search
	protected AbsXFConvergence	xfconv; // convergence test for optimisation
	protected GConvergence		gconv;
	protected double					XTOL=1e-4;
	protected VInteger				vmaxiter;
	protected boolean					flange;	// signal vector changes every iteration

	protected Meter				lsiters, iters;
	protected Meter				steplength;
	protected LED					sig1, sig2, sig3;
	
	private LinkedList<Viewable> viewables;

	public MinimiserBase(Vec x, Functionx func)
	{
		super(x, func);

		// node = new Node("Minimiser");
		//Shell.push(node);

		boolean oldreg=Shell.setAutoRegister(false);
		viewables = new LinkedList<Viewable>();

		x1 = new VVector("x1", P1.x); // independent view?
		x2 = new VVector("x2", P2.x);
		g1 = new VVector("g1", P1.g);
		g2 = new VVector("g2", P2.g);
		vh  = new VVector("h", h);
		beta = new VDouble("beta",1);
		vmaxiter=new VInteger("maxiter",400);
		xfconv=new AbsXFConvergence();
		gconv=new GConvergence();
		gconv.setGTolerance(0.005);
		flange=Shell.getBoolean("vprogress",false);
		ls = new CubicLineSearch(this);
		ls.setSafeguardFactor(0.125);
		lstest = new LSCondition();

		steplength = createMeter("step length",2,false,new Color(200,170,0));
		lsiters = createMeter("line search iterations",20,false,new Color(90,100,255));
		iters   = createMeter("iterations",2*x1.size(), true, new Color(200,50,0));
		sig1 = new LED(new Color(240,190,20));
		sig2 = new LED(new Color(90,100,255));
		sig3 = new LED(new Color(250,60,20));
		sig1.setToolTipText("Tiny line search step");
		sig2.setToolTipText("Resetting Hessian, trying steepest descent");
		sig3.setToolTipText("Maximum iterations reached");


		// register some viewables locally
		// add(x1); add(x2); add(g1); add(g2); add(vh); 
		add(beta);
		add(vmaxiter);

		add(new VParameter("xtol", new Parameter() {
				public void setParameter(double t) {
					xfconv.setXTolerance(t);
					XTOL=t;
				}
			} )
		);

		add(new VParameter("ftol", new Parameter() {
				public void setParameter(double t) {
					xfconv.setFTolerance(t);
				}
			} )
		);

		add(new VParameter("gtol", new Parameter() {
				public void setParameter(double t) {
					gconv.setGTolerance(t);
				}
			} )
		);

		add(new VParameter("ZETA", new Parameter() {
				public void setParameter(double t) {
					ls.setSafeguardFactor(t);
				}
			} )
		);
		Shell.setAutoRegister(oldreg);
		// Shell.pop();
	}

	public Viewer getViewer() {
		return new VPanel() {
			{
				setLayout( new StackLayout());
				// add(Shell.createButtonsFor(MinimiserBase.this));

				// set default border for meterbox and signal box?
				VPanel meterbox = new VPanel();
				//meterbox.setChildBorder(BorderFactory.createLoweredBevelBorder());
				meterbox.setLayout( new StackLayout(8));
				meterbox.add(steplength);
				meterbox.add(iters);
				meterbox.add(lsiters);

				VPanel signalbox = new VPanel();
				//signalbox.setChildBorder(BorderFactory.createLoweredBevelBorder());
				signalbox.setLayout( new FlowLayout(FlowLayout.LEFT));
				signalbox.add(sig1);
				signalbox.add(sig2);
				signalbox.add(sig3);

				add(meterbox);
				add(signalbox);
				
				add(x1); add(x2); add(g1); add(g2); add(vh);
				for (Iterator<Viewable> it=viewables.iterator(); it.hasNext();) {
					add(it.next());
				}
				// now add all other viewables
			}
		};
	}

	public void dispose() { super.dispose(); }

	public void add(Viewable vbl) { viewables.add(vbl); }
	public Viewable[] getViewables() { return viewables.toArray(new Viewable[0]); }

	public int  getMaxiter() { return vmaxiter.value; }
	protected void perIteration() {	if (flange) x1.changed(); }
	protected void perOptimisation(int i) {
		if (P2.f<=P1.f) {
			move();
			if (flange) x1.changed();
		}

		iters.next(i);
		if (i<vmaxiter.value) sig3.off();
		else sig3.on();
	}

	public void starting() {}
	public void stopping() {}

	public void getCommands(Agent.Registry r) {
		r.add("eval").add("step").add("move").add("turn").add("info");
	}

	public void execute(String cmd, Environment env) throws Exception
	{
		if (cmd.equals("step")) {
			step(beta.value);
			g2.changed();
			x2.changed();
		} else if (cmd.equals("info")) Shell.print(toString());
		else if (cmd.equals("move")) {
			steplength.next(alpha);
			move();
			x1.changed();
			g1.changed();
		} else if (cmd.equals("eval")) {
			evaluate();
			x1.changed();
			g1.changed();
			setSlope();
		} else if (cmd.equals("turn")) {
			Mathx.negate(P1.g,h);
			setSlope();
			vh.changed();
		}
	}

	class LSCondition implements Condition
	{
		double MU=1e-4;		// ensure function decreases enough
		double ETA=0.1;		// ensure gradient is smaller than last
		int		count;
		boolean tiny;

		LSCondition()
		{
			add(new VParameter("ETA",new DoubleModel() {
					public double get() { return ETA; }
					public void set(double t) { ETA=t; }
				} )
			);
		}

		public void init() { tiny=false; count=0; }
		public boolean test()
		{
			if (count++>96) {
				Shell.trace("\n**** LINE SEARCH OVERRUN ****\n");
				return true;
			}
			if (alpha*normh<XTOL) { tiny=true; return true; }
			return (P2.f<P1.f+MU*alpha*P1.s) && (Math.abs(P2.s)<-ETA*P1.s);
		}
	}

	private static Meter createMeter(String label, double max, boolean reinit, Color color) {
		Shell.push(label);
		Meter m=new Meter();
		m.setForeground(color);
		m.setBackground(color.darker().darker());
		m.getMap().setDomain(0,max);
		m.setToolTipText(label);
		m.exposeMap(reinit);
		Shell.pop();
		return m;
	}
}