comparison fft/fftw/fftw-3.3.4/rdft/ct-hc2c-direct.c @ 19:26056e866c29

Add FFTW to comparison table
author Chris Cannam
date Tue, 06 Oct 2015 13:08:39 +0100
parents
children
comparison
equal deleted inserted replaced
18:8db794ca3e0b 19:26056e866c29
1 /*
2 * Copyright (c) 2003, 2007-14 Matteo Frigo
3 * Copyright (c) 2003, 2007-14 Massachusetts Institute of Technology
4 *
5 * This program is free software; you can redistribute it and/or modify
6 * it under the terms of the GNU General Public License as published by
7 * the Free Software Foundation; either version 2 of the License, or
8 * (at your option) any later version.
9 *
10 * This program is distributed in the hope that it will be useful,
11 * but WITHOUT ANY WARRANTY; without even the implied warranty of
12 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13 * GNU General Public License for more details.
14 *
15 * You should have received a copy of the GNU General Public License
16 * along with this program; if not, write to the Free Software
17 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
18 *
19 */
20
21
22 #include "ct-hc2c.h"
23
24 typedef struct {
25 hc2c_solver super;
26 const hc2c_desc *desc;
27 int bufferedp;
28 khc2c k;
29 } S;
30
31 typedef struct {
32 plan_hc2c super;
33 khc2c k;
34 plan *cld0, *cldm; /* children for 0th and middle butterflies */
35 INT r, m, v, extra_iter;
36 INT ms, vs;
37 stride rs, brs;
38 twid *td;
39 const S *slv;
40 } P;
41
42 /*************************************************************
43 Nonbuffered code
44 *************************************************************/
45 static void apply(const plan *ego_, R *cr, R *ci)
46 {
47 const P *ego = (const P *) ego_;
48 plan_rdft2 *cld0 = (plan_rdft2 *) ego->cld0;
49 plan_rdft2 *cldm = (plan_rdft2 *) ego->cldm;
50 INT i, m = ego->m, v = ego->v;
51 INT ms = ego->ms, vs = ego->vs;
52
53 for (i = 0; i < v; ++i, cr += vs, ci += vs) {
54 cld0->apply((plan *) cld0, cr, ci, cr, ci);
55 ego->k(cr + ms, ci + ms, cr + (m-1)*ms, ci + (m-1)*ms,
56 ego->td->W, ego->rs, 1, (m+1)/2, ms);
57 cldm->apply((plan *) cldm, cr + (m/2)*ms, ci + (m/2)*ms,
58 cr + (m/2)*ms, ci + (m/2)*ms);
59 }
60 }
61
62 static void apply_extra_iter(const plan *ego_, R *cr, R *ci)
63 {
64 const P *ego = (const P *) ego_;
65 plan_rdft2 *cld0 = (plan_rdft2 *) ego->cld0;
66 plan_rdft2 *cldm = (plan_rdft2 *) ego->cldm;
67 INT i, m = ego->m, v = ego->v;
68 INT ms = ego->ms, vs = ego->vs;
69 INT mm = (m-1)/2;
70
71 for (i = 0; i < v; ++i, cr += vs, ci += vs) {
72 cld0->apply((plan *) cld0, cr, ci, cr, ci);
73
74 /* for 4-way SIMD when (m+1)/2-1 is odd: iterate over an
75 even vector length MM-1, and then execute the last
76 iteration as a 2-vector with vector stride 0. The
77 twiddle factors of the second half of the last iteration
78 are bogus, but we only store the results of the first
79 half. */
80 ego->k(cr + ms, ci + ms, cr + (m-1)*ms, ci + (m-1)*ms,
81 ego->td->W, ego->rs, 1, mm, ms);
82 ego->k(cr + mm*ms, ci + mm*ms, cr + (m-mm)*ms, ci + (m-mm)*ms,
83 ego->td->W, ego->rs, mm, mm+2, 0);
84 cldm->apply((plan *) cldm, cr + (m/2)*ms, ci + (m/2)*ms,
85 cr + (m/2)*ms, ci + (m/2)*ms);
86 }
87
88 }
89
90 /*************************************************************
91 Buffered code
92 *************************************************************/
93
94 /* should not be 2^k to avoid associativity conflicts */
95 static INT compute_batchsize(INT radix)
96 {
97 /* round up to multiple of 4 */
98 radix += 3;
99 radix &= -4;
100
101 return (radix + 2);
102 }
103
104 static void dobatch(const P *ego, R *Rp, R *Ip, R *Rm, R *Im,
105 INT mb, INT me, INT extra_iter, R *bufp)
106 {
107 INT b = WS(ego->brs, 1);
108 INT rs = WS(ego->rs, 1);
109 INT ms = ego->ms;
110 R *bufm = bufp + b - 2;
111
112 X(cpy2d_pair_ci)(Rp + mb * ms, Ip + mb * ms, bufp, bufp + 1,
113 ego->r / 2, rs, b,
114 me - mb, ms, 2);
115 X(cpy2d_pair_ci)(Rm - mb * ms, Im - mb * ms, bufm, bufm + 1,
116 ego->r / 2, rs, b,
117 me - mb, -ms, -2);
118 ego->k(bufp, bufp + 1, bufm, bufm + 1, ego->td->W,
119 ego->brs, mb, me + extra_iter, 2);
120 X(cpy2d_pair_co)(bufp, bufp + 1, Rp + mb * ms, Ip + mb * ms,
121 ego->r / 2, b, rs,
122 me - mb, 2, ms);
123 X(cpy2d_pair_co)(bufm, bufm + 1, Rm - mb * ms, Im - mb * ms,
124 ego->r / 2, b, rs,
125 me - mb, -2, -ms);
126 }
127
128 static void apply_buf(const plan *ego_, R *cr, R *ci)
129 {
130 const P *ego = (const P *) ego_;
131 plan_rdft2 *cld0 = (plan_rdft2 *) ego->cld0;
132 plan_rdft2 *cldm = (plan_rdft2 *) ego->cldm;
133 INT i, j, ms = ego->ms, v = ego->v;
134 INT batchsz = compute_batchsize(ego->r);
135 R *buf;
136 INT mb = 1, me = (ego->m+1) / 2;
137 size_t bufsz = ego->r * batchsz * 2 * sizeof(R);
138
139 BUF_ALLOC(R *, buf, bufsz);
140
141 for (i = 0; i < v; ++i, cr += ego->vs, ci += ego->vs) {
142 R *Rp = cr;
143 R *Ip = ci;
144 R *Rm = cr + ego->m * ms;
145 R *Im = ci + ego->m * ms;
146
147 cld0->apply((plan *) cld0, Rp, Ip, Rp, Ip);
148
149 for (j = mb; j + batchsz < me; j += batchsz)
150 dobatch(ego, Rp, Ip, Rm, Im, j, j + batchsz, 0, buf);
151
152 dobatch(ego, Rp, Ip, Rm, Im, j, me, ego->extra_iter, buf);
153
154 cldm->apply((plan *) cldm,
155 Rp + me * ms, Ip + me * ms,
156 Rp + me * ms, Ip + me * ms);
157
158 }
159
160 BUF_FREE(buf, bufsz);
161 }
162
163 /*************************************************************
164 common code
165 *************************************************************/
166 static void awake(plan *ego_, enum wakefulness wakefulness)
167 {
168 P *ego = (P *) ego_;
169
170 X(plan_awake)(ego->cld0, wakefulness);
171 X(plan_awake)(ego->cldm, wakefulness);
172 X(twiddle_awake)(wakefulness, &ego->td, ego->slv->desc->tw,
173 ego->r * ego->m, ego->r,
174 (ego->m - 1) / 2 + ego->extra_iter);
175 }
176
177 static void destroy(plan *ego_)
178 {
179 P *ego = (P *) ego_;
180 X(plan_destroy_internal)(ego->cld0);
181 X(plan_destroy_internal)(ego->cldm);
182 X(stride_destroy)(ego->rs);
183 X(stride_destroy)(ego->brs);
184 }
185
186 static void print(const plan *ego_, printer *p)
187 {
188 const P *ego = (const P *) ego_;
189 const S *slv = ego->slv;
190 const hc2c_desc *e = slv->desc;
191
192 if (slv->bufferedp)
193 p->print(p, "(hc2c-directbuf/%D-%D/%D/%D%v \"%s\"%(%p%)%(%p%))",
194 compute_batchsize(ego->r),
195 ego->r, X(twiddle_length)(ego->r, e->tw),
196 ego->extra_iter, ego->v, e->nam,
197 ego->cld0, ego->cldm);
198 else
199 p->print(p, "(hc2c-direct-%D/%D/%D%v \"%s\"%(%p%)%(%p%))",
200 ego->r, X(twiddle_length)(ego->r, e->tw),
201 ego->extra_iter, ego->v, e->nam,
202 ego->cld0, ego->cldm);
203 }
204
205 static int applicable0(const S *ego, rdft_kind kind,
206 INT r, INT rs,
207 INT m, INT ms,
208 INT v, INT vs,
209 const R *cr, const R *ci,
210 const planner *plnr,
211 INT *extra_iter)
212 {
213 const hc2c_desc *e = ego->desc;
214 UNUSED(v);
215
216 return (
217 1
218 && r == e->radix
219 && kind == e->genus->kind
220
221 /* first v-loop iteration */
222 && ((*extra_iter = 0,
223 e->genus->okp(cr + ms, ci + ms, cr + (m-1)*ms, ci + (m-1)*ms,
224 rs, 1, (m+1)/2, ms, plnr))
225 ||
226 (*extra_iter = 1,
227 ((e->genus->okp(cr + ms, ci + ms, cr + (m-1)*ms, ci + (m-1)*ms,
228 rs, 1, (m-1)/2, ms, plnr))
229 &&
230 (e->genus->okp(cr + ms, ci + ms, cr + (m-1)*ms, ci + (m-1)*ms,
231 rs, (m-1)/2, (m-1)/2 + 2, 0, plnr)))))
232
233 /* subsequent v-loop iterations */
234 && (cr += vs, ci += vs, 1)
235
236 && e->genus->okp(cr + ms, ci + ms, cr + (m-1)*ms, ci + (m-1)*ms,
237 rs, 1, (m+1)/2 - *extra_iter, ms, plnr)
238 );
239 }
240
241 static int applicable0_buf(const S *ego, rdft_kind kind,
242 INT r, INT rs,
243 INT m, INT ms,
244 INT v, INT vs,
245 const R *cr, const R *ci,
246 const planner *plnr, INT *extra_iter)
247 {
248 const hc2c_desc *e = ego->desc;
249 INT batchsz, brs;
250 UNUSED(v); UNUSED(rs); UNUSED(ms); UNUSED(vs);
251
252 return (
253 1
254 && r == e->radix
255 && kind == e->genus->kind
256
257 /* ignore cr, ci, use buffer */
258 && (cr = (const R *)0, ci = cr + 1,
259 batchsz = compute_batchsize(r),
260 brs = 4 * batchsz, 1)
261
262 && e->genus->okp(cr, ci, cr + brs - 2, ci + brs - 2,
263 brs, 1, 1+batchsz, 2, plnr)
264
265 && ((*extra_iter = 0,
266 e->genus->okp(cr, ci, cr + brs - 2, ci + brs - 2,
267 brs, 1, 1 + (((m-1)/2) % batchsz), 2, plnr))
268 ||
269 (*extra_iter = 1,
270 e->genus->okp(cr, ci, cr + brs - 2, ci + brs - 2,
271 brs, 1, 1 + 1 + (((m-1)/2) % batchsz), 2, plnr)))
272
273 );
274 }
275
276 static int applicable(const S *ego, rdft_kind kind,
277 INT r, INT rs,
278 INT m, INT ms,
279 INT v, INT vs,
280 R *cr, R *ci,
281 const planner *plnr, INT *extra_iter)
282 {
283 if (ego->bufferedp) {
284 if (!applicable0_buf(ego, kind, r, rs, m, ms, v, vs, cr, ci, plnr,
285 extra_iter))
286 return 0;
287 } else {
288 if (!applicable0(ego, kind, r, rs, m, ms, v, vs, cr, ci, plnr,
289 extra_iter))
290 return 0;
291 }
292
293 if (NO_UGLYP(plnr) && X(ct_uglyp)((ego->bufferedp? (INT)512 : (INT)16),
294 v, m * r, r))
295 return 0;
296
297 return 1;
298 }
299
300 static plan *mkcldw(const hc2c_solver *ego_, rdft_kind kind,
301 INT r, INT rs,
302 INT m, INT ms,
303 INT v, INT vs,
304 R *cr, R *ci,
305 planner *plnr)
306 {
307 const S *ego = (const S *) ego_;
308 P *pln;
309 const hc2c_desc *e = ego->desc;
310 plan *cld0 = 0, *cldm = 0;
311 INT imid = (m / 2) * ms;
312 INT extra_iter;
313
314 static const plan_adt padt = {
315 0, awake, print, destroy
316 };
317
318 if (!applicable(ego, kind, r, rs, m, ms, v, vs, cr, ci, plnr,
319 &extra_iter))
320 return (plan *)0;
321
322 cld0 = X(mkplan_d)(
323 plnr,
324 X(mkproblem_rdft2_d)(X(mktensor_1d)(r, rs, rs),
325 X(mktensor_0d)(),
326 TAINT(cr, vs), TAINT(ci, vs),
327 TAINT(cr, vs), TAINT(ci, vs),
328 kind));
329 if (!cld0) goto nada;
330
331 cldm = X(mkplan_d)(
332 plnr,
333 X(mkproblem_rdft2_d)(((m % 2) ?
334 X(mktensor_0d)() : X(mktensor_1d)(r, rs, rs) ),
335 X(mktensor_0d)(),
336 TAINT(cr + imid, vs), TAINT(ci + imid, vs),
337 TAINT(cr + imid, vs), TAINT(ci + imid, vs),
338 kind == R2HC ? R2HCII : HC2RIII));
339 if (!cldm) goto nada;
340
341 if (ego->bufferedp)
342 pln = MKPLAN_HC2C(P, &padt, apply_buf);
343 else
344 pln = MKPLAN_HC2C(P, &padt, extra_iter ? apply_extra_iter : apply);
345
346 pln->k = ego->k;
347 pln->td = 0;
348 pln->r = r; pln->rs = X(mkstride)(r, rs);
349 pln->m = m; pln->ms = ms;
350 pln->v = v; pln->vs = vs;
351 pln->slv = ego;
352 pln->brs = X(mkstride)(r, 4 * compute_batchsize(r));
353 pln->cld0 = cld0;
354 pln->cldm = cldm;
355 pln->extra_iter = extra_iter;
356
357 X(ops_zero)(&pln->super.super.ops);
358 X(ops_madd2)(v * (((m - 1) / 2) / e->genus->vl),
359 &e->ops, &pln->super.super.ops);
360 X(ops_madd2)(v, &cld0->ops, &pln->super.super.ops);
361 X(ops_madd2)(v, &cldm->ops, &pln->super.super.ops);
362
363 if (ego->bufferedp)
364 pln->super.super.ops.other += 4 * r * m * v;
365
366 return &(pln->super.super);
367
368 nada:
369 X(plan_destroy_internal)(cld0);
370 X(plan_destroy_internal)(cldm);
371 return 0;
372 }
373
374 static void regone(planner *plnr, khc2c codelet,
375 const hc2c_desc *desc,
376 hc2c_kind hc2ckind,
377 int bufferedp)
378 {
379 S *slv = (S *)X(mksolver_hc2c)(sizeof(S), desc->radix, hc2ckind, mkcldw);
380 slv->k = codelet;
381 slv->desc = desc;
382 slv->bufferedp = bufferedp;
383 REGISTER_SOLVER(plnr, &(slv->super.super));
384 }
385
386 void X(regsolver_hc2c_direct)(planner *plnr, khc2c codelet,
387 const hc2c_desc *desc,
388 hc2c_kind hc2ckind)
389 {
390 regone(plnr, codelet, desc, hc2ckind, /* bufferedp */0);
391 regone(plnr, codelet, desc, hc2ckind, /* bufferedp */1);
392 }