c@243
|
1 /*
|
c@243
|
2 * cluster.c
|
c@243
|
3 * cluster_melt
|
c@243
|
4 *
|
c@243
|
5 * Created by Mark Levy on 21/02/2006.
|
c@309
|
6 * Copyright 2006 Centre for Digital Music, Queen Mary, University of London.
|
c@309
|
7
|
c@309
|
8 This program is free software; you can redistribute it and/or
|
c@309
|
9 modify it under the terms of the GNU General Public License as
|
c@309
|
10 published by the Free Software Foundation; either version 2 of the
|
c@309
|
11 License, or (at your option) any later version. See the file
|
c@309
|
12 COPYING included with this distribution for more information.
|
c@243
|
13 *
|
c@243
|
14 */
|
c@243
|
15
|
c@243
|
16 #include <stdlib.h>
|
c@243
|
17
|
c@243
|
18 #include "cluster_melt.h"
|
c@243
|
19
|
c@243
|
20 #define DEFAULT_LAMBDA 0.02;
|
c@243
|
21 #define DEFAULT_LIMIT 20;
|
c@243
|
22
|
c@243
|
23 double kldist(double* a, double* b, int n) {
|
c@243
|
24 /* NB assume that all a[i], b[i] are non-negative
|
c@243
|
25 because a, b represent probability distributions */
|
c@243
|
26 double q, d;
|
c@243
|
27 int i;
|
c@243
|
28
|
c@243
|
29 d = 0;
|
c@243
|
30 for (i = 0; i < n; i++)
|
c@243
|
31 {
|
c@243
|
32 q = (a[i] + b[i]) / 2.0;
|
c@243
|
33 if (q > 0)
|
c@243
|
34 {
|
c@243
|
35 if (a[i] > 0)
|
c@243
|
36 d += a[i] * log(a[i] / q);
|
c@243
|
37 if (b[i] > 0)
|
c@243
|
38 d += b[i] * log(b[i] / q);
|
c@243
|
39 }
|
c@243
|
40 }
|
c@243
|
41 return d;
|
c@243
|
42 }
|
c@243
|
43
|
c@243
|
44 void cluster_melt(double *h, int m, int n, double *Bsched, int t, int k, int l, int *c) {
|
c@243
|
45 double lambda, sum, beta, logsumexp, maxlp;
|
c@414
|
46 int i, j, a, b, b0, b1, limit, /* B, */ it, maxiter, maxiter0, maxiter1;
|
c@243
|
47 double** cl; /* reference histograms for each cluster */
|
c@243
|
48 int** nc; /* neighbour counts for each histogram */
|
c@243
|
49 double** lp; /* soft assignment probs for each histogram */
|
c@243
|
50 int* oldc; /* previous hard assignments (to check convergence) */
|
c@243
|
51
|
c@243
|
52 /* NB h is passed as a 1d row major array */
|
c@243
|
53
|
c@243
|
54 /* parameter values */
|
c@243
|
55 lambda = DEFAULT_LAMBDA;
|
c@243
|
56 if (l > 0)
|
c@243
|
57 limit = l;
|
c@243
|
58 else
|
c@243
|
59 limit = DEFAULT_LIMIT; /* use default if no valid neighbourhood limit supplied */
|
c@414
|
60 // B = 2 * limit + 1;
|
c@243
|
61 maxiter0 = 20; /* number of iterations at initial temperature */
|
c@243
|
62 maxiter1 = 5; /* number of iterations at subsequent temperatures */
|
c@243
|
63
|
c@243
|
64 /* allocate memory */
|
c@243
|
65 cl = (double**) malloc(k*sizeof(double*));
|
c@243
|
66 for (i= 0; i < k; i++)
|
c@243
|
67 cl[i] = (double*) malloc(m*sizeof(double));
|
c@243
|
68
|
c@243
|
69 nc = (int**) malloc(n*sizeof(int*));
|
c@243
|
70 for (i= 0; i < n; i++)
|
c@243
|
71 nc[i] = (int*) malloc(k*sizeof(int));
|
c@243
|
72
|
c@243
|
73 lp = (double**) malloc(n*sizeof(double*));
|
c@243
|
74 for (i= 0; i < n; i++)
|
c@243
|
75 lp[i] = (double*) malloc(k*sizeof(double));
|
c@243
|
76
|
c@243
|
77 oldc = (int*) malloc(n * sizeof(int));
|
c@243
|
78
|
c@243
|
79 /* initialise */
|
c@243
|
80 for (i = 0; i < k; i++)
|
c@243
|
81 {
|
c@243
|
82 sum = 0;
|
c@243
|
83 for (j = 0; j < m; j++)
|
c@243
|
84 {
|
c@243
|
85 cl[i][j] = rand(); /* random initial reference histograms */
|
c@243
|
86 sum += cl[i][j] * cl[i][j];
|
c@243
|
87 }
|
c@243
|
88 sum = sqrt(sum);
|
c@243
|
89 for (j = 0; j < m; j++)
|
c@243
|
90 {
|
c@243
|
91 cl[i][j] /= sum; /* normalise */
|
c@243
|
92 }
|
c@243
|
93 }
|
c@243
|
94 //print_array(cl, k, m);
|
c@243
|
95
|
c@243
|
96 for (i = 0; i < n; i++)
|
c@243
|
97 c[i] = 1; /* initially assign all histograms to cluster 1 */
|
c@243
|
98
|
c@243
|
99 for (a = 0; a < t; a++)
|
c@243
|
100 {
|
c@243
|
101 beta = Bsched[a];
|
c@243
|
102
|
c@243
|
103 if (a == 0)
|
c@243
|
104 maxiter = maxiter0;
|
c@243
|
105 else
|
c@243
|
106 maxiter = maxiter1;
|
c@243
|
107
|
c@243
|
108 for (it = 0; it < maxiter; it++)
|
c@243
|
109 {
|
c@243
|
110 //if (it == maxiter - 1)
|
c@243
|
111 // mexPrintf("hasn't converged after %d iterations\n", maxiter);
|
c@243
|
112
|
c@243
|
113 for (i = 0; i < n; i++)
|
c@243
|
114 {
|
c@243
|
115 /* save current hard assignments */
|
c@243
|
116 oldc[i] = c[i];
|
c@243
|
117
|
c@243
|
118 /* calculate soft assignment logprobs for each cluster */
|
c@243
|
119 sum = 0;
|
c@243
|
120 for (j = 0; j < k; j++)
|
c@243
|
121 {
|
c@243
|
122 lp[i][ j] = -beta * kldist(cl[j], &h[i*m], m);
|
c@243
|
123
|
c@243
|
124 /* update matching neighbour counts for this histogram, based on current hard assignments */
|
c@243
|
125 /* old version:
|
c@243
|
126 nc[i][j] = 0;
|
c@243
|
127 if (i >= limit && i <= n - 1 - limit)
|
c@243
|
128 {
|
c@243
|
129 for (b = i - limit; b <= i + limit; b++)
|
c@243
|
130 {
|
c@243
|
131 if (c[b] == j+1)
|
c@243
|
132 nc[i][j]++;
|
c@243
|
133 }
|
c@243
|
134 nc[i][j] = B - nc[i][j];
|
c@243
|
135 }
|
c@243
|
136 */
|
c@243
|
137 b0 = i - limit;
|
c@243
|
138 if (b0 < 0)
|
c@243
|
139 b0 = 0;
|
c@243
|
140 b1 = i + limit;
|
c@243
|
141 if (b1 >= n)
|
c@243
|
142 b1 = n - 1;
|
c@243
|
143 nc[i][j] = b1 - b0 + 1; /* = B except at edges */
|
c@243
|
144 for (b = b0; b <= b1; b++)
|
c@243
|
145 if (c[b] == j+1)
|
c@243
|
146 nc[i][j]--;
|
c@243
|
147
|
c@243
|
148 sum += exp(lp[i][j]);
|
c@243
|
149 }
|
c@243
|
150
|
c@243
|
151 /* normalise responsibilities and add duration logprior */
|
c@243
|
152 logsumexp = log(sum);
|
c@243
|
153 for (j = 0; j < k; j++)
|
c@243
|
154 lp[i][j] -= logsumexp + lambda * nc[i][j];
|
c@243
|
155 }
|
c@243
|
156 //print_array(lp, n, k);
|
c@243
|
157 /*
|
c@243
|
158 for (i = 0; i < n; i++)
|
c@243
|
159 {
|
c@243
|
160 for (j = 0; j < k; j++)
|
c@243
|
161 mexPrintf("%d ", nc[i][j]);
|
c@243
|
162 mexPrintf("\n");
|
c@243
|
163 }
|
c@243
|
164 */
|
c@243
|
165
|
c@243
|
166
|
c@243
|
167 /* update the assignments now that we know the duration priors
|
c@243
|
168 based on the current assignments */
|
c@243
|
169 for (i = 0; i < n; i++)
|
c@243
|
170 {
|
c@243
|
171 maxlp = lp[i][0];
|
c@243
|
172 c[i] = 1;
|
c@243
|
173 for (j = 1; j < k; j++)
|
c@243
|
174 if (lp[i][j] > maxlp)
|
c@243
|
175 {
|
c@243
|
176 maxlp = lp[i][j];
|
c@243
|
177 c[i] = j+1;
|
c@243
|
178 }
|
c@243
|
179 }
|
c@243
|
180
|
c@243
|
181 /* break if assignments haven't changed */
|
c@243
|
182 i = 0;
|
c@243
|
183 while (i < n && oldc[i] == c[i])
|
c@243
|
184 i++;
|
c@243
|
185 if (i == n)
|
c@243
|
186 break;
|
c@243
|
187
|
c@243
|
188 /* update reference histograms now we know new responsibilities */
|
c@243
|
189 for (j = 0; j < k; j++)
|
c@243
|
190 {
|
c@243
|
191 for (b = 0; b < m; b++)
|
c@243
|
192 {
|
c@243
|
193 cl[j][b] = 0;
|
c@243
|
194 for (i = 0; i < n; i++)
|
c@243
|
195 {
|
c@243
|
196 cl[j][b] += exp(lp[i][j]) * h[i*m+b];
|
c@243
|
197 }
|
c@243
|
198 }
|
c@243
|
199
|
c@243
|
200 sum = 0;
|
c@243
|
201 for (i = 0; i < n; i++)
|
c@243
|
202 sum += exp(lp[i][j]);
|
c@243
|
203 for (b = 0; b < m; b++)
|
c@243
|
204 cl[j][b] /= sum; /* normalise */
|
c@243
|
205 }
|
c@243
|
206
|
c@243
|
207 //print_array(cl, k, m);
|
c@243
|
208 //mexPrintf("\n\n");
|
c@243
|
209 }
|
c@243
|
210 }
|
c@243
|
211
|
c@243
|
212 /* free memory */
|
c@243
|
213 for (i = 0; i < k; i++)
|
c@243
|
214 free(cl[i]);
|
c@243
|
215 free(cl);
|
c@243
|
216 for (i = 0; i < n; i++)
|
c@243
|
217 free(nc[i]);
|
c@243
|
218 free(nc);
|
c@243
|
219 for (i = 0; i < n; i++)
|
c@243
|
220 free(lp[i]);
|
c@243
|
221 free(lp);
|
c@243
|
222 free(oldc);
|
c@243
|
223 }
|
c@243
|
224
|
c@243
|
225
|