comparison pymf/bnmf.py @ 0:26838b1f560f

initial commit of a segmenter project
author mi tian
date Thu, 02 Apr 2015 18:09:27 +0100
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:26838b1f560f
1 #!/usr/bin/python
2 #
3 # Copyright (C) Christian Thurau, 2010.
4 # Licensed under the GNU General Public License (GPL).
5 # http://www.gnu.org/licenses/gpl.txt
6 """
7 PyMF Binary Matrix Factorization [1]
8
9 BNMF(NMF) : Class for binary matrix factorization
10
11 [1]Z. Zhang, T. Li, C. H. Q. Ding, X. Zhang: Binary Matrix Factorization with
12 Applications. ICDM 2007
13 """
14
15
16 import numpy as np
17 from nmf import NMF
18
19 __all__ = ["BNMF"]
20
21 class BNMF(NMF):
22 """
23 BNMF(data, data, num_bases=4)
24 Binary Matrix Factorization. Factorize a data matrix into two matrices s.t.
25 F = | data - W*H | is minimal. H and W are restricted to binary values.
26
27 Parameters
28 ----------
29 data : array_like, shape (_data_dimension, _num_samples)
30 the input data
31 num_bases: int, optional
32 Number of bases to compute (column rank of W and row rank of H).
33 4 (default)
34
35 Attributes
36 ----------
37 W : "data_dimension x num_bases" matrix of basis vectors
38 H : "num bases x num_samples" matrix of coefficients
39 ferr : frobenius norm (after calling .factorize())
40
41 Example
42 -------
43 Applying BNMF to some rather stupid data set:
44
45 >>> import numpy as np
46 >>> from bnmf import BNMF
47 >>> data = np.array([[1.0, 0.0, 1.0], [0.0, 1.0, 1.0]])
48
49 Use 2 basis vectors -> W shape(data_dimension, 2).
50
51 >>> bnmf_mdl = BNMF(data, num_bases=2)
52
53 Set number of iterations to 5 and start computing the factorization.
54
55 >>> bnmf_mdl.factorize(niter=5)
56
57 The basis vectors are now stored in bnmf_mdl.W, the coefficients in bnmf_mdl.H.
58 To compute coefficients for an existing set of basis vectors simply copy W
59 to bnmf_mdl.W, and set compute_w to False:
60
61 >>> data = np.array([[0.0], [1.0]])
62 >>> W = np.array([[1.0, 0.0], [0.0, 1.0]])
63 >>> bnmf_mdl = BNMF(data, num_bases=2)
64 >>> bnmf_mdl.W = W
65 >>> bnmf_mdl.factorize(niter=10, compute_w=False)
66
67 The result is a set of coefficients bnmf_mdl.H, s.t. data = W * bnmf_mdl.H.
68 """
69
70 # controls how fast lambda should increase:
71 # this influence convergence to binary values during the update. A value
72 # <1 will result in non-binary decompositions as the update rule effectively
73 # is a conventional nmf update rule. Values >1 give more weight to making the
74 # factorization binary with increasing iterations.
75 # setting either W or H to 0 results make the resulting matrix non binary.
76 _LAMB_INCREASE_W = 1.1
77 _LAMB_INCREASE_H = 1.1
78
79 def update_h(self):
80 H1 = np.dot(self.W.T, self.data[:,:]) + 3.0*self._lamb_H*(self.H**2)
81 H2 = np.dot(np.dot(self.W.T,self.W), self.H) + 2*self._lamb_H*(self.H**3) + self._lamb_H*self.H + 10**-9
82 self.H *= H1/H2
83
84 self._lamb_W = self._LAMB_INCREASE_W * self._lamb_W
85 self._lamb_H = self._LAMB_INCREASE_H * self._lamb_H
86
87 def update_w(self):
88 W1 = np.dot(self.data[:,:], self.H.T) + 3.0*self._lamb_W*(self.W**2)
89 W2 = np.dot(self.W, np.dot(self.H, self.H.T)) + 2.0*self._lamb_W*(self.W**3) + self._lamb_W*self.W + 10**-9
90 self.W *= W1/W2
91
92 def factorize(self, niter=10, compute_w=True, compute_h=True,
93 show_progress=False, compute_err=True):
94 """ Factorize s.t. WH = data
95
96 Parameters
97 ----------
98 niter : int
99 number of iterations.
100 show_progress : bool
101 print some extra information to stdout.
102 compute_h : bool
103 iteratively update values for H.
104 compute_w : bool
105 iteratively update values for W.
106 compute_err : bool
107 compute Frobenius norm |data-WH| after each update and store
108 it to .ferr[k].
109
110 Updated Values
111 --------------
112 .W : updated values for W.
113 .H : updated values for H.
114 .ferr : Frobenius norm |data-WH| for each iteration.
115 """
116
117 # init some learning parameters
118 self._lamb_W = 1.0/niter
119 self._lamb_H = 1.0/niter
120
121 NMF.factorize(self, niter=niter, compute_w=compute_w,
122 compute_h=compute_h, show_progress=show_progress,
123 compute_err=compute_err)
124
125 if __name__ == "__main__":
126 import doctest
127 doctest.testmod()