cannam@95
|
1 /*
|
cannam@95
|
2 * Copyright (c) 2003, 2007-11 Matteo Frigo
|
cannam@95
|
3 * Copyright (c) 2003, 2007-11 Massachusetts Institute of Technology
|
cannam@95
|
4 *
|
cannam@95
|
5 * This program is free software; you can redistribute it and/or modify
|
cannam@95
|
6 * it under the terms of the GNU General Public License as published by
|
cannam@95
|
7 * the Free Software Foundation; either version 2 of the License, or
|
cannam@95
|
8 * (at your option) any later version.
|
cannam@95
|
9 *
|
cannam@95
|
10 * This program is distributed in the hope that it will be useful,
|
cannam@95
|
11 * but WITHOUT ANY WARRANTY; without even the implied warranty of
|
cannam@95
|
12 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
cannam@95
|
13 * GNU General Public License for more details.
|
cannam@95
|
14 *
|
cannam@95
|
15 * You should have received a copy of the GNU General Public License
|
cannam@95
|
16 * along with this program; if not, write to the Free Software
|
cannam@95
|
17 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
|
cannam@95
|
18 *
|
cannam@95
|
19 */
|
cannam@95
|
20
|
cannam@95
|
21
|
cannam@95
|
22 /* direct DFT solver, if we have a codelet */
|
cannam@95
|
23
|
cannam@95
|
24 #include "dft.h"
|
cannam@95
|
25
|
cannam@95
|
26 typedef struct {
|
cannam@95
|
27 solver super;
|
cannam@95
|
28 const kdft_desc *desc;
|
cannam@95
|
29 kdft k;
|
cannam@95
|
30 int bufferedp;
|
cannam@95
|
31 } S;
|
cannam@95
|
32
|
cannam@95
|
33 typedef struct {
|
cannam@95
|
34 plan_dft super;
|
cannam@95
|
35
|
cannam@95
|
36 stride is, os, bufstride;
|
cannam@95
|
37 INT n, vl, ivs, ovs;
|
cannam@95
|
38 kdft k;
|
cannam@95
|
39 const S *slv;
|
cannam@95
|
40 } P;
|
cannam@95
|
41
|
cannam@95
|
42 static void dobatch(const P *ego, R *ri, R *ii, R *ro, R *io,
|
cannam@95
|
43 R *buf, INT batchsz)
|
cannam@95
|
44 {
|
cannam@95
|
45 X(cpy2d_pair_ci)(ri, ii, buf, buf+1,
|
cannam@95
|
46 ego->n, WS(ego->is, 1), WS(ego->bufstride, 1),
|
cannam@95
|
47 batchsz, ego->ivs, 2);
|
cannam@95
|
48
|
cannam@95
|
49 if (IABS(WS(ego->os, 1)) < IABS(ego->ovs)) {
|
cannam@95
|
50 /* transform directly to output */
|
cannam@95
|
51 ego->k(buf, buf+1, ro, io,
|
cannam@95
|
52 ego->bufstride, ego->os, batchsz, 2, ego->ovs);
|
cannam@95
|
53 } else {
|
cannam@95
|
54 /* transform to buffer and copy back */
|
cannam@95
|
55 ego->k(buf, buf+1, buf, buf+1,
|
cannam@95
|
56 ego->bufstride, ego->bufstride, batchsz, 2, 2);
|
cannam@95
|
57 X(cpy2d_pair_co)(buf, buf+1, ro, io,
|
cannam@95
|
58 ego->n, WS(ego->bufstride, 1), WS(ego->os, 1),
|
cannam@95
|
59 batchsz, 2, ego->ovs);
|
cannam@95
|
60 }
|
cannam@95
|
61 }
|
cannam@95
|
62
|
cannam@95
|
63 static INT compute_batchsize(INT n)
|
cannam@95
|
64 {
|
cannam@95
|
65 /* round up to multiple of 4 */
|
cannam@95
|
66 n += 3;
|
cannam@95
|
67 n &= -4;
|
cannam@95
|
68
|
cannam@95
|
69 return (n + 2);
|
cannam@95
|
70 }
|
cannam@95
|
71
|
cannam@95
|
72 static void apply_buf(const plan *ego_, R *ri, R *ii, R *ro, R *io)
|
cannam@95
|
73 {
|
cannam@95
|
74 const P *ego = (const P *) ego_;
|
cannam@95
|
75 R *buf;
|
cannam@95
|
76 INT vl = ego->vl, n = ego->n, batchsz = compute_batchsize(n);
|
cannam@95
|
77 INT i;
|
cannam@95
|
78 size_t bufsz = n * batchsz * 2 * sizeof(R);
|
cannam@95
|
79
|
cannam@95
|
80 BUF_ALLOC(R *, buf, bufsz);
|
cannam@95
|
81
|
cannam@95
|
82 for (i = 0; i < vl - batchsz; i += batchsz) {
|
cannam@95
|
83 dobatch(ego, ri, ii, ro, io, buf, batchsz);
|
cannam@95
|
84 ri += batchsz * ego->ivs; ii += batchsz * ego->ivs;
|
cannam@95
|
85 ro += batchsz * ego->ovs; io += batchsz * ego->ovs;
|
cannam@95
|
86 }
|
cannam@95
|
87 dobatch(ego, ri, ii, ro, io, buf, vl - i);
|
cannam@95
|
88
|
cannam@95
|
89 BUF_FREE(buf, bufsz);
|
cannam@95
|
90 }
|
cannam@95
|
91
|
cannam@95
|
92 static void apply(const plan *ego_, R *ri, R *ii, R *ro, R *io)
|
cannam@95
|
93 {
|
cannam@95
|
94 const P *ego = (const P *) ego_;
|
cannam@95
|
95 ASSERT_ALIGNED_DOUBLE;
|
cannam@95
|
96 ego->k(ri, ii, ro, io, ego->is, ego->os, ego->vl, ego->ivs, ego->ovs);
|
cannam@95
|
97 }
|
cannam@95
|
98
|
cannam@95
|
99 static void apply_extra_iter(const plan *ego_, R *ri, R *ii, R *ro, R *io)
|
cannam@95
|
100 {
|
cannam@95
|
101 const P *ego = (const P *) ego_;
|
cannam@95
|
102 INT vl = ego->vl;
|
cannam@95
|
103
|
cannam@95
|
104 ASSERT_ALIGNED_DOUBLE;
|
cannam@95
|
105
|
cannam@95
|
106 /* for 4-way SIMD when VL is odd: iterate over an
|
cannam@95
|
107 even vector length VL, and then execute the last
|
cannam@95
|
108 iteration as a 2-vector with vector stride 0. */
|
cannam@95
|
109 ego->k(ri, ii, ro, io, ego->is, ego->os, vl - 1, ego->ivs, ego->ovs);
|
cannam@95
|
110
|
cannam@95
|
111 ego->k(ri + (vl - 1) * ego->ivs, ii + (vl - 1) * ego->ivs,
|
cannam@95
|
112 ro + (vl - 1) * ego->ovs, io + (vl - 1) * ego->ovs,
|
cannam@95
|
113 ego->is, ego->os, 1, 0, 0);
|
cannam@95
|
114 }
|
cannam@95
|
115
|
cannam@95
|
116 static void destroy(plan *ego_)
|
cannam@95
|
117 {
|
cannam@95
|
118 P *ego = (P *) ego_;
|
cannam@95
|
119 X(stride_destroy)(ego->is);
|
cannam@95
|
120 X(stride_destroy)(ego->os);
|
cannam@95
|
121 X(stride_destroy)(ego->bufstride);
|
cannam@95
|
122 }
|
cannam@95
|
123
|
cannam@95
|
124 static void print(const plan *ego_, printer *p)
|
cannam@95
|
125 {
|
cannam@95
|
126 const P *ego = (const P *) ego_;
|
cannam@95
|
127 const S *s = ego->slv;
|
cannam@95
|
128 const kdft_desc *d = s->desc;
|
cannam@95
|
129
|
cannam@95
|
130 if (ego->slv->bufferedp)
|
cannam@95
|
131 p->print(p, "(dft-directbuf/%D-%D%v \"%s\")",
|
cannam@95
|
132 compute_batchsize(d->sz), d->sz, ego->vl, d->nam);
|
cannam@95
|
133 else
|
cannam@95
|
134 p->print(p, "(dft-direct-%D%v \"%s\")", d->sz, ego->vl, d->nam);
|
cannam@95
|
135 }
|
cannam@95
|
136
|
cannam@95
|
137 static int applicable_buf(const solver *ego_, const problem *p_,
|
cannam@95
|
138 const planner *plnr)
|
cannam@95
|
139 {
|
cannam@95
|
140 const S *ego = (const S *) ego_;
|
cannam@95
|
141 const problem_dft *p = (const problem_dft *) p_;
|
cannam@95
|
142 const kdft_desc *d = ego->desc;
|
cannam@95
|
143 INT vl;
|
cannam@95
|
144 INT ivs, ovs;
|
cannam@95
|
145 INT batchsz;
|
cannam@95
|
146
|
cannam@95
|
147 return (
|
cannam@95
|
148 1
|
cannam@95
|
149 && p->sz->rnk == 1
|
cannam@95
|
150 && p->vecsz->rnk == 1
|
cannam@95
|
151 && p->sz->dims[0].n == d->sz
|
cannam@95
|
152
|
cannam@95
|
153 /* check strides etc */
|
cannam@95
|
154 && X(tensor_tornk1)(p->vecsz, &vl, &ivs, &ovs)
|
cannam@95
|
155
|
cannam@95
|
156 /* UGLY if IS <= IVS */
|
cannam@95
|
157 && !(NO_UGLYP(plnr) &&
|
cannam@95
|
158 X(iabs)(p->sz->dims[0].is) <= X(iabs)(ivs))
|
cannam@95
|
159
|
cannam@95
|
160 && (batchsz = compute_batchsize(d->sz), 1)
|
cannam@95
|
161 && (d->genus->okp(d, 0, ((const R *)0) + 1, p->ro, p->io,
|
cannam@95
|
162 2 * batchsz, p->sz->dims[0].os,
|
cannam@95
|
163 batchsz, 2, ovs, plnr))
|
cannam@95
|
164 && (d->genus->okp(d, 0, ((const R *)0) + 1, p->ro, p->io,
|
cannam@95
|
165 2 * batchsz, p->sz->dims[0].os,
|
cannam@95
|
166 vl % batchsz, 2, ovs, plnr))
|
cannam@95
|
167
|
cannam@95
|
168
|
cannam@95
|
169 && (0
|
cannam@95
|
170 /* can operate out-of-place */
|
cannam@95
|
171 || p->ri != p->ro
|
cannam@95
|
172
|
cannam@95
|
173 /* can operate in-place as long as strides are the same */
|
cannam@95
|
174 || X(tensor_inplace_strides2)(p->sz, p->vecsz)
|
cannam@95
|
175
|
cannam@95
|
176 /* can do it if the problem fits in the buffer, no matter
|
cannam@95
|
177 what the strides are */
|
cannam@95
|
178 || vl <= batchsz
|
cannam@95
|
179 )
|
cannam@95
|
180 );
|
cannam@95
|
181 }
|
cannam@95
|
182
|
cannam@95
|
183 static int applicable(const solver *ego_, const problem *p_,
|
cannam@95
|
184 const planner *plnr, int *extra_iterp)
|
cannam@95
|
185 {
|
cannam@95
|
186 const S *ego = (const S *) ego_;
|
cannam@95
|
187 const problem_dft *p = (const problem_dft *) p_;
|
cannam@95
|
188 const kdft_desc *d = ego->desc;
|
cannam@95
|
189 INT vl;
|
cannam@95
|
190 INT ivs, ovs;
|
cannam@95
|
191
|
cannam@95
|
192 return (
|
cannam@95
|
193 1
|
cannam@95
|
194 && p->sz->rnk == 1
|
cannam@95
|
195 && p->vecsz->rnk <= 1
|
cannam@95
|
196 && p->sz->dims[0].n == d->sz
|
cannam@95
|
197
|
cannam@95
|
198 /* check strides etc */
|
cannam@95
|
199 && X(tensor_tornk1)(p->vecsz, &vl, &ivs, &ovs)
|
cannam@95
|
200
|
cannam@95
|
201 && ((*extra_iterp = 0,
|
cannam@95
|
202 (d->genus->okp(d, p->ri, p->ii, p->ro, p->io,
|
cannam@95
|
203 p->sz->dims[0].is, p->sz->dims[0].os,
|
cannam@95
|
204 vl, ivs, ovs, plnr)))
|
cannam@95
|
205 ||
|
cannam@95
|
206 (*extra_iterp = 1,
|
cannam@95
|
207 ((d->genus->okp(d, p->ri, p->ii, p->ro, p->io,
|
cannam@95
|
208 p->sz->dims[0].is, p->sz->dims[0].os,
|
cannam@95
|
209 vl - 1, ivs, ovs, plnr))
|
cannam@95
|
210 &&
|
cannam@95
|
211 (d->genus->okp(d, p->ri, p->ii, p->ro, p->io,
|
cannam@95
|
212 p->sz->dims[0].is, p->sz->dims[0].os,
|
cannam@95
|
213 2, 0, 0, plnr)))))
|
cannam@95
|
214
|
cannam@95
|
215 && (0
|
cannam@95
|
216 /* can operate out-of-place */
|
cannam@95
|
217 || p->ri != p->ro
|
cannam@95
|
218
|
cannam@95
|
219 /* can always compute one transform */
|
cannam@95
|
220 || vl == 1
|
cannam@95
|
221
|
cannam@95
|
222 /* can operate in-place as long as strides are the same */
|
cannam@95
|
223 || X(tensor_inplace_strides2)(p->sz, p->vecsz)
|
cannam@95
|
224 )
|
cannam@95
|
225 );
|
cannam@95
|
226 }
|
cannam@95
|
227
|
cannam@95
|
228
|
cannam@95
|
229 static plan *mkplan(const solver *ego_, const problem *p_, planner *plnr)
|
cannam@95
|
230 {
|
cannam@95
|
231 const S *ego = (const S *) ego_;
|
cannam@95
|
232 P *pln;
|
cannam@95
|
233 const problem_dft *p;
|
cannam@95
|
234 iodim *d;
|
cannam@95
|
235 const kdft_desc *e = ego->desc;
|
cannam@95
|
236
|
cannam@95
|
237 static const plan_adt padt = {
|
cannam@95
|
238 X(dft_solve), X(null_awake), print, destroy
|
cannam@95
|
239 };
|
cannam@95
|
240
|
cannam@95
|
241 UNUSED(plnr);
|
cannam@95
|
242
|
cannam@95
|
243 if (ego->bufferedp) {
|
cannam@95
|
244 if (!applicable_buf(ego_, p_, plnr))
|
cannam@95
|
245 return (plan *)0;
|
cannam@95
|
246 pln = MKPLAN_DFT(P, &padt, apply_buf);
|
cannam@95
|
247 } else {
|
cannam@95
|
248 int extra_iterp = 0;
|
cannam@95
|
249 if (!applicable(ego_, p_, plnr, &extra_iterp))
|
cannam@95
|
250 return (plan *)0;
|
cannam@95
|
251 pln = MKPLAN_DFT(P, &padt, extra_iterp ? apply_extra_iter : apply);
|
cannam@95
|
252 }
|
cannam@95
|
253
|
cannam@95
|
254 p = (const problem_dft *) p_;
|
cannam@95
|
255 d = p->sz->dims;
|
cannam@95
|
256 pln->k = ego->k;
|
cannam@95
|
257 pln->n = d[0].n;
|
cannam@95
|
258 pln->is = X(mkstride)(pln->n, d[0].is);
|
cannam@95
|
259 pln->os = X(mkstride)(pln->n, d[0].os);
|
cannam@95
|
260 pln->bufstride = X(mkstride)(pln->n, 2 * compute_batchsize(pln->n));
|
cannam@95
|
261
|
cannam@95
|
262 X(tensor_tornk1)(p->vecsz, &pln->vl, &pln->ivs, &pln->ovs);
|
cannam@95
|
263 pln->slv = ego;
|
cannam@95
|
264
|
cannam@95
|
265 X(ops_zero)(&pln->super.super.ops);
|
cannam@95
|
266 X(ops_madd2)(pln->vl / e->genus->vl, &e->ops, &pln->super.super.ops);
|
cannam@95
|
267
|
cannam@95
|
268 if (ego->bufferedp)
|
cannam@95
|
269 pln->super.super.ops.other += 4 * pln->n * pln->vl;
|
cannam@95
|
270
|
cannam@95
|
271 pln->super.super.could_prune_now_p = !ego->bufferedp;
|
cannam@95
|
272 return &(pln->super.super);
|
cannam@95
|
273 }
|
cannam@95
|
274
|
cannam@95
|
275 static solver *mksolver(kdft k, const kdft_desc *desc, int bufferedp)
|
cannam@95
|
276 {
|
cannam@95
|
277 static const solver_adt sadt = { PROBLEM_DFT, mkplan, 0 };
|
cannam@95
|
278 S *slv = MKSOLVER(S, &sadt);
|
cannam@95
|
279 slv->k = k;
|
cannam@95
|
280 slv->desc = desc;
|
cannam@95
|
281 slv->bufferedp = bufferedp;
|
cannam@95
|
282 return &(slv->super);
|
cannam@95
|
283 }
|
cannam@95
|
284
|
cannam@95
|
285 solver *X(mksolver_dft_direct)(kdft k, const kdft_desc *desc)
|
cannam@95
|
286 {
|
cannam@95
|
287 return mksolver(k, desc, 0);
|
cannam@95
|
288 }
|
cannam@95
|
289
|
cannam@95
|
290 solver *X(mksolver_dft_directbuf)(kdft k, const kdft_desc *desc)
|
cannam@95
|
291 {
|
cannam@95
|
292 return mksolver(k, desc, 1);
|
cannam@95
|
293 }
|