Chris@19: Chris@19: module em_onecolumn; Chris@19: Chris@19: mm = load may.mathmisc; Chris@19: vec = load may.vector; Chris@19: mat = load may.matrix; Chris@19: Chris@20: plot = load may.plot; Chris@20: Chris@19: inRange ranges instrument note = Chris@19: note >= ranges[instrument].lowest and note <= ranges[instrument].highest; Chris@19: Chris@22: normalise v = Chris@22: (s = vec.sum v; Chris@22: if s > 0 then vec.divideBy s v Chris@22: else v Chris@22: fi); Chris@22: Chris@25: normaliseSources s = Chris@25: (denoms = fold do acc inst: vec.add [acc, (mat.getColumn inst s)] done Chris@25: (vec.zeros (mat.height s)) [0..(mat.width s)-1]; Chris@25: mat.fromColumns Chris@25: (map do inst: vec.divide (mat.getColumn inst s) denoms done Chris@25: [0..(mat.width s)-1])); Chris@25: Chris@19: initialise ranges templates notes = Chris@22: (instruments = sort (keys ranges); Chris@19: { Chris@19: pitches = // z in the original. 1 per note Chris@22: normalise (vec.randoms notes), Chris@25: sources = normaliseSources // u in the original. 1 per note-instrument Chris@25: (mat.fromColumns Chris@25: (map do instrument: Chris@25: (vec.fromList Chris@25: (map do note: Chris@25: if inRange ranges instrument note then 1 else 0 fi Chris@25: done [0..notes-1])) Chris@25: done instruments)), Chris@22: instruments = array instruments, Chris@22: instCount = length instruments, Chris@23: noteCount = notes, Chris@22: templates = array Chris@22: (map do iname: Chris@22: m = templates[iname]; Chris@22: mat.fromColumns (map normalise (mat.asColumns m)) Chris@22: done instruments), Chris@22: ranges = array Chris@22: (map do iname: Chris@22: ranges[iname] Chris@22: done instruments), Chris@19: lowest = head Chris@22: (sort (map do iname: ranges[iname].lowest done instruments)), Chris@19: highest = head (reverse Chris@22: (sort (map do iname: ranges[iname].highest done instruments))), Chris@19: }); Chris@19: Chris@19: epsilon = 1e-16; Chris@19: Chris@19: select predicate = concatMap do v: if predicate v then [v] else [] fi done; Chris@19: Chris@22: distributionsFor data instNo note = Chris@19: { Chris@22: w = mat.getColumn note data.templates[instNo], Chris@19: p = vec.at data.pitches note, Chris@22: s = mat.at data.sources note instNo, Chris@19: }; Chris@19: Chris@19: performExpectation data column = Chris@22: (column = normalise column; Chris@22: estimate = Chris@22: fold do acc inst: Chris@19: fold do acc note: Chris@22: { w, p, s } = distributionsFor data inst note; Chris@19: vec.add [acc, vec.scaled (p * s) w]; Chris@22: done acc [data.ranges[inst].lowest .. Chris@22: data.ranges[inst].highest] Chris@22: done (vec.consts epsilon (vec.length column)) [0..data.instCount-1]; Chris@25: { estimate, q = vec.divide column estimate }); Chris@19: Chris@19: performMaximisation data column q = Chris@22: (column = normalise column; Chris@22: Chris@22: pitches = vec.fromList Chris@19: (map do note: Chris@23: if note >= data.lowest and note <= data.highest then Chris@23: fold do acc inst: Chris@23: { w, p, s } = distributionsFor data inst note; Chris@23: fold do acc bin: Chris@23: acc + s * p * (vec.at w bin) * (vec.at q bin); Chris@23: done acc [0..vec.length column - 1] Chris@23: done epsilon [0..data.instCount-1]; Chris@23: else epsilon Chris@23: fi Chris@23: done [0..data.noteCount-1]); Chris@20: pitches = vec.divideBy (vec.sum pitches) pitches; Chris@19: Chris@25: sources = mat.fromColumns Chris@25: (map do inst: vec.fromList Chris@25: (map do note: Chris@25: (if not inRange data.ranges inst note then epsilon else Chris@22: { w, p, s } = distributionsFor data inst note; Chris@20: fold do acc bin: Chris@21: acc + s * p * (vec.at w bin) * (vec.at q bin); Chris@20: done epsilon [0..vec.length column - 1] Chris@25: fi); Chris@25: done [0..data.noteCount-1]) Chris@25: done [0..data.instCount-1]); Chris@24: Chris@25: sources = normaliseSources sources; Chris@24: Chris@19: data with { Chris@20: pitches, Chris@20: sources, Chris@19: }); Chris@19: Chris@19: { Chris@19: initialise, Chris@19: performExpectation, Chris@19: performMaximisation, Chris@19: } Chris@19: Chris@19: