Mercurial > hg > qm-dsp
comparison dsp/segmentation/cluster_melt.c @ 18:8e90a56b4b5f
* merge in segmentation code from soundbite plugin/library repository
author | cannam |
---|---|
date | Wed, 09 Jan 2008 10:46:25 +0000 |
parents | |
children | e5907ae6de17 |
comparison
equal
deleted
inserted
replaced
17:a120ac7b26b2 | 18:8e90a56b4b5f |
---|---|
1 /* | |
2 * cluster.c | |
3 * cluster_melt | |
4 * | |
5 * Created by Mark Levy on 21/02/2006. | |
6 * Copyright 2006 Centre for Digital Music, Queen Mary, University of London. All rights reserved. | |
7 * | |
8 */ | |
9 | |
10 #include <stdlib.h> | |
11 | |
12 #include "cluster_melt.h" | |
13 | |
14 #define DEFAULT_LAMBDA 0.02; | |
15 #define DEFAULT_LIMIT 20; | |
16 | |
17 double kldist(double* a, double* b, int n) { | |
18 /* NB assume that all a[i], b[i] are non-negative | |
19 because a, b represent probability distributions */ | |
20 double q, d; | |
21 int i; | |
22 | |
23 d = 0; | |
24 for (i = 0; i < n; i++) | |
25 { | |
26 q = (a[i] + b[i]) / 2.0; | |
27 if (q > 0) | |
28 { | |
29 if (a[i] > 0) | |
30 d += a[i] * log(a[i] / q); | |
31 if (b[i] > 0) | |
32 d += b[i] * log(b[i] / q); | |
33 } | |
34 } | |
35 return d; | |
36 } | |
37 | |
38 void cluster_melt(double *h, int m, int n, double *Bsched, int t, int k, int l, int *c) { | |
39 double lambda, sum, beta, logsumexp, maxlp; | |
40 int i, j, a, b, b0, b1, limit, B, it, maxiter, maxiter0, maxiter1; | |
41 double** cl; /* reference histograms for each cluster */ | |
42 int** nc; /* neighbour counts for each histogram */ | |
43 double** lp; /* soft assignment probs for each histogram */ | |
44 int* oldc; /* previous hard assignments (to check convergence) */ | |
45 | |
46 /* NB h is passed as a 1d row major array */ | |
47 | |
48 /* parameter values */ | |
49 lambda = DEFAULT_LAMBDA; | |
50 if (l > 0) | |
51 limit = l; | |
52 else | |
53 limit = DEFAULT_LIMIT; /* use default if no valid neighbourhood limit supplied */ | |
54 B = 2 * limit + 1; | |
55 maxiter0 = 20; /* number of iterations at initial temperature */ | |
56 maxiter1 = 5; /* number of iterations at subsequent temperatures */ | |
57 | |
58 /* allocate memory */ | |
59 cl = (double**) malloc(k*sizeof(double*)); | |
60 for (i= 0; i < k; i++) | |
61 cl[i] = (double*) malloc(m*sizeof(double)); | |
62 | |
63 nc = (int**) malloc(n*sizeof(int*)); | |
64 for (i= 0; i < n; i++) | |
65 nc[i] = (int*) malloc(k*sizeof(int)); | |
66 | |
67 lp = (double**) malloc(n*sizeof(double*)); | |
68 for (i= 0; i < n; i++) | |
69 lp[i] = (double*) malloc(k*sizeof(double)); | |
70 | |
71 oldc = (int*) malloc(n * sizeof(int)); | |
72 | |
73 /* initialise */ | |
74 for (i = 0; i < k; i++) | |
75 { | |
76 sum = 0; | |
77 for (j = 0; j < m; j++) | |
78 { | |
79 cl[i][j] = rand(); /* random initial reference histograms */ | |
80 sum += cl[i][j] * cl[i][j]; | |
81 } | |
82 sum = sqrt(sum); | |
83 for (j = 0; j < m; j++) | |
84 { | |
85 cl[i][j] /= sum; /* normalise */ | |
86 } | |
87 } | |
88 //print_array(cl, k, m); | |
89 | |
90 for (i = 0; i < n; i++) | |
91 c[i] = 1; /* initially assign all histograms to cluster 1 */ | |
92 | |
93 for (a = 0; a < t; a++) | |
94 { | |
95 beta = Bsched[a]; | |
96 | |
97 if (a == 0) | |
98 maxiter = maxiter0; | |
99 else | |
100 maxiter = maxiter1; | |
101 | |
102 for (it = 0; it < maxiter; it++) | |
103 { | |
104 //if (it == maxiter - 1) | |
105 // mexPrintf("hasn't converged after %d iterations\n", maxiter); | |
106 | |
107 for (i = 0; i < n; i++) | |
108 { | |
109 /* save current hard assignments */ | |
110 oldc[i] = c[i]; | |
111 | |
112 /* calculate soft assignment logprobs for each cluster */ | |
113 sum = 0; | |
114 for (j = 0; j < k; j++) | |
115 { | |
116 lp[i][ j] = -beta * kldist(cl[j], &h[i*m], m); | |
117 | |
118 /* update matching neighbour counts for this histogram, based on current hard assignments */ | |
119 /* old version: | |
120 nc[i][j] = 0; | |
121 if (i >= limit && i <= n - 1 - limit) | |
122 { | |
123 for (b = i - limit; b <= i + limit; b++) | |
124 { | |
125 if (c[b] == j+1) | |
126 nc[i][j]++; | |
127 } | |
128 nc[i][j] = B - nc[i][j]; | |
129 } | |
130 */ | |
131 b0 = i - limit; | |
132 if (b0 < 0) | |
133 b0 = 0; | |
134 b1 = i + limit; | |
135 if (b1 >= n) | |
136 b1 = n - 1; | |
137 nc[i][j] = b1 - b0 + 1; /* = B except at edges */ | |
138 for (b = b0; b <= b1; b++) | |
139 if (c[b] == j+1) | |
140 nc[i][j]--; | |
141 | |
142 sum += exp(lp[i][j]); | |
143 } | |
144 | |
145 /* normalise responsibilities and add duration logprior */ | |
146 logsumexp = log(sum); | |
147 for (j = 0; j < k; j++) | |
148 lp[i][j] -= logsumexp + lambda * nc[i][j]; | |
149 } | |
150 //print_array(lp, n, k); | |
151 /* | |
152 for (i = 0; i < n; i++) | |
153 { | |
154 for (j = 0; j < k; j++) | |
155 mexPrintf("%d ", nc[i][j]); | |
156 mexPrintf("\n"); | |
157 } | |
158 */ | |
159 | |
160 | |
161 /* update the assignments now that we know the duration priors | |
162 based on the current assignments */ | |
163 for (i = 0; i < n; i++) | |
164 { | |
165 maxlp = lp[i][0]; | |
166 c[i] = 1; | |
167 for (j = 1; j < k; j++) | |
168 if (lp[i][j] > maxlp) | |
169 { | |
170 maxlp = lp[i][j]; | |
171 c[i] = j+1; | |
172 } | |
173 } | |
174 | |
175 /* break if assignments haven't changed */ | |
176 i = 0; | |
177 while (i < n && oldc[i] == c[i]) | |
178 i++; | |
179 if (i == n) | |
180 break; | |
181 | |
182 /* update reference histograms now we know new responsibilities */ | |
183 for (j = 0; j < k; j++) | |
184 { | |
185 for (b = 0; b < m; b++) | |
186 { | |
187 cl[j][b] = 0; | |
188 for (i = 0; i < n; i++) | |
189 { | |
190 cl[j][b] += exp(lp[i][j]) * h[i*m+b]; | |
191 } | |
192 } | |
193 | |
194 sum = 0; | |
195 for (i = 0; i < n; i++) | |
196 sum += exp(lp[i][j]); | |
197 for (b = 0; b < m; b++) | |
198 cl[j][b] /= sum; /* normalise */ | |
199 } | |
200 | |
201 //print_array(cl, k, m); | |
202 //mexPrintf("\n\n"); | |
203 } | |
204 } | |
205 | |
206 /* free memory */ | |
207 for (i = 0; i < k; i++) | |
208 free(cl[i]); | |
209 free(cl); | |
210 for (i = 0; i < n; i++) | |
211 free(nc[i]); | |
212 free(nc); | |
213 for (i = 0; i < n; i++) | |
214 free(lp[i]); | |
215 free(lp); | |
216 free(oldc); | |
217 } | |
218 | |
219 |