comparison pymf/laesa.py @ 0:26838b1f560f

initial commit of a segmenter project
author mi tian
date Thu, 02 Apr 2015 18:09:27 +0100
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:26838b1f560f
1 #!/usr/bin/python
2 #
3 # Copyright (C) Christian Thurau, 2010.
4 # Licensed under the GNU General Public License (GPL).
5 # http://www.gnu.org/licenses/gpl.txt
6 """
7 PyMF LAESA
8 """
9
10
11 import scipy.sparse
12 import numpy as np
13
14 from dist import *
15 from sivm import SIVM
16
17 __all__ = ["LAESA"]
18
19 class LAESA(SIVM):
20 """
21 LAESA(data, num_bases=4)
22
23
24 Simplex Volume Maximization. Factorize a data matrix into two matrices s.t.
25 F = | data - W*H | is minimal. H is restricted to convexity. W is iteratively
26 found by maximizing the volume of the resulting simplex (see [1]).
27
28 Parameters
29 ----------
30 data : array_like, shape (_data_dimension, _num_samples)
31 the input data
32 num_bases: int, optional
33 Number of bases to compute (column rank of W and row rank of H).
34 4 (default)
35
36 Attributes
37 ----------
38 W : "data_dimension x num_bases" matrix of basis vectors
39 H : "num bases x num_samples" matrix of coefficients
40 ferr : frobenius norm (after calling .factorize())
41
42 Example
43 -------
44 Applying LAESA to some rather stupid data set:
45
46 >>> import numpy as np
47 >>> data = np.array([[1.0, 0.0, 2.0], [0.0, 1.0, 1.0]])
48 >>> laesa_mdl = LAESA(data, num_bases=2)
49 >>> laesa_mdl.factorize()
50
51 The basis vectors are now stored in laesa_mdl.W, the coefficients in laesa_mdl.H.
52 To compute coefficients for an existing set of basis vectors simply copy W
53 to laesa_mdl.W, and set compute_w to False:
54
55 >>> data = np.array([[1.5, 1.3], [1.2, 0.3]])
56 >>> W = np.array([[1.0, 0.0], [0.0, 1.0]])
57 >>> laesa_mdl = LAESA(data, num_bases=2)
58 >>> laesa_mdl.W = W
59 >>> laesa_mdl.factorize(niter=1, compute_w=False)
60
61 The result is a set of coefficients laesa_mdl.H, s.t. data = W * laesa_mdl.H.
62 """
63 def update_w(self):
64 # initialize some of the recursively updated distance measures
65 self.init_sivm()
66 distiter = self._distance(self.select[-1])
67
68 for l in range(self._num_bases-1):
69 d = self._distance(self.select[-1])
70
71 # replace distances in distiter
72 distiter = np.where(d<distiter, d, distiter)
73
74 # detect the next best data point
75 self.select.append(np.argmax(distiter))
76 self._logger.info('cur_nodes: ' + str(self.select))
77
78 # sort indices, otherwise h5py won't work
79 self.W = self.data[:, np.sort(self.select)]
80
81 # but "unsort" it again to keep the correct order
82 self.W = self.W[:, np.argsort(np.argsort(self.select))]
83
84
85 if __name__ == "__main__":
86 import doctest
87 doctest.testmod()