Chris@82
|
1 /*
|
Chris@82
|
2 * Copyright (c) 2003, 2007-14 Matteo Frigo
|
Chris@82
|
3 * Copyright (c) 2003, 2007-14 Massachusetts Institute of Technology
|
Chris@82
|
4 *
|
Chris@82
|
5 * This program is free software; you can redistribute it and/or modify
|
Chris@82
|
6 * it under the terms of the GNU General Public License as published by
|
Chris@82
|
7 * the Free Software Foundation; either version 2 of the License, or
|
Chris@82
|
8 * (at your option) any later version.
|
Chris@82
|
9 *
|
Chris@82
|
10 * This program is distributed in the hope that it will be useful,
|
Chris@82
|
11 * but WITHOUT ANY WARRANTY; without even the implied warranty of
|
Chris@82
|
12 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
Chris@82
|
13 * GNU General Public License for more details.
|
Chris@82
|
14 *
|
Chris@82
|
15 * You should have received a copy of the GNU General Public License
|
Chris@82
|
16 * along with this program; if not, write to the Free Software
|
Chris@82
|
17 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
|
Chris@82
|
18 *
|
Chris@82
|
19 */
|
Chris@82
|
20
|
Chris@82
|
21
|
Chris@82
|
22 /* direct DFT solver, if we have a codelet */
|
Chris@82
|
23
|
Chris@82
|
24 #include "dft/dft.h"
|
Chris@82
|
25
|
Chris@82
|
26 typedef struct {
|
Chris@82
|
27 solver super;
|
Chris@82
|
28 const kdft_desc *desc;
|
Chris@82
|
29 kdft k;
|
Chris@82
|
30 int bufferedp;
|
Chris@82
|
31 } S;
|
Chris@82
|
32
|
Chris@82
|
33 typedef struct {
|
Chris@82
|
34 plan_dft super;
|
Chris@82
|
35
|
Chris@82
|
36 stride is, os, bufstride;
|
Chris@82
|
37 INT n, vl, ivs, ovs;
|
Chris@82
|
38 kdft k;
|
Chris@82
|
39 const S *slv;
|
Chris@82
|
40 } P;
|
Chris@82
|
41
|
Chris@82
|
42 static void dobatch(const P *ego, R *ri, R *ii, R *ro, R *io,
|
Chris@82
|
43 R *buf, INT batchsz)
|
Chris@82
|
44 {
|
Chris@82
|
45 X(cpy2d_pair_ci)(ri, ii, buf, buf+1,
|
Chris@82
|
46 ego->n, WS(ego->is, 1), WS(ego->bufstride, 1),
|
Chris@82
|
47 batchsz, ego->ivs, 2);
|
Chris@82
|
48
|
Chris@82
|
49 if (IABS(WS(ego->os, 1)) < IABS(ego->ovs)) {
|
Chris@82
|
50 /* transform directly to output */
|
Chris@82
|
51 ego->k(buf, buf+1, ro, io,
|
Chris@82
|
52 ego->bufstride, ego->os, batchsz, 2, ego->ovs);
|
Chris@82
|
53 } else {
|
Chris@82
|
54 /* transform to buffer and copy back */
|
Chris@82
|
55 ego->k(buf, buf+1, buf, buf+1,
|
Chris@82
|
56 ego->bufstride, ego->bufstride, batchsz, 2, 2);
|
Chris@82
|
57 X(cpy2d_pair_co)(buf, buf+1, ro, io,
|
Chris@82
|
58 ego->n, WS(ego->bufstride, 1), WS(ego->os, 1),
|
Chris@82
|
59 batchsz, 2, ego->ovs);
|
Chris@82
|
60 }
|
Chris@82
|
61 }
|
Chris@82
|
62
|
Chris@82
|
63 static INT compute_batchsize(INT n)
|
Chris@82
|
64 {
|
Chris@82
|
65 /* round up to multiple of 4 */
|
Chris@82
|
66 n += 3;
|
Chris@82
|
67 n &= -4;
|
Chris@82
|
68
|
Chris@82
|
69 return (n + 2);
|
Chris@82
|
70 }
|
Chris@82
|
71
|
Chris@82
|
72 static void apply_buf(const plan *ego_, R *ri, R *ii, R *ro, R *io)
|
Chris@82
|
73 {
|
Chris@82
|
74 const P *ego = (const P *) ego_;
|
Chris@82
|
75 R *buf;
|
Chris@82
|
76 INT vl = ego->vl, n = ego->n, batchsz = compute_batchsize(n);
|
Chris@82
|
77 INT i;
|
Chris@82
|
78 size_t bufsz = n * batchsz * 2 * sizeof(R);
|
Chris@82
|
79
|
Chris@82
|
80 BUF_ALLOC(R *, buf, bufsz);
|
Chris@82
|
81
|
Chris@82
|
82 for (i = 0; i < vl - batchsz; i += batchsz) {
|
Chris@82
|
83 dobatch(ego, ri, ii, ro, io, buf, batchsz);
|
Chris@82
|
84 ri += batchsz * ego->ivs; ii += batchsz * ego->ivs;
|
Chris@82
|
85 ro += batchsz * ego->ovs; io += batchsz * ego->ovs;
|
Chris@82
|
86 }
|
Chris@82
|
87 dobatch(ego, ri, ii, ro, io, buf, vl - i);
|
Chris@82
|
88
|
Chris@82
|
89 BUF_FREE(buf, bufsz);
|
Chris@82
|
90 }
|
Chris@82
|
91
|
Chris@82
|
92 static void apply(const plan *ego_, R *ri, R *ii, R *ro, R *io)
|
Chris@82
|
93 {
|
Chris@82
|
94 const P *ego = (const P *) ego_;
|
Chris@82
|
95 ASSERT_ALIGNED_DOUBLE;
|
Chris@82
|
96 ego->k(ri, ii, ro, io, ego->is, ego->os, ego->vl, ego->ivs, ego->ovs);
|
Chris@82
|
97 }
|
Chris@82
|
98
|
Chris@82
|
99 static void apply_extra_iter(const plan *ego_, R *ri, R *ii, R *ro, R *io)
|
Chris@82
|
100 {
|
Chris@82
|
101 const P *ego = (const P *) ego_;
|
Chris@82
|
102 INT vl = ego->vl;
|
Chris@82
|
103
|
Chris@82
|
104 ASSERT_ALIGNED_DOUBLE;
|
Chris@82
|
105
|
Chris@82
|
106 /* for 4-way SIMD when VL is odd: iterate over an
|
Chris@82
|
107 even vector length VL, and then execute the last
|
Chris@82
|
108 iteration as a 2-vector with vector stride 0. */
|
Chris@82
|
109 ego->k(ri, ii, ro, io, ego->is, ego->os, vl - 1, ego->ivs, ego->ovs);
|
Chris@82
|
110
|
Chris@82
|
111 ego->k(ri + (vl - 1) * ego->ivs, ii + (vl - 1) * ego->ivs,
|
Chris@82
|
112 ro + (vl - 1) * ego->ovs, io + (vl - 1) * ego->ovs,
|
Chris@82
|
113 ego->is, ego->os, 1, 0, 0);
|
Chris@82
|
114 }
|
Chris@82
|
115
|
Chris@82
|
116 static void destroy(plan *ego_)
|
Chris@82
|
117 {
|
Chris@82
|
118 P *ego = (P *) ego_;
|
Chris@82
|
119 X(stride_destroy)(ego->is);
|
Chris@82
|
120 X(stride_destroy)(ego->os);
|
Chris@82
|
121 X(stride_destroy)(ego->bufstride);
|
Chris@82
|
122 }
|
Chris@82
|
123
|
Chris@82
|
124 static void print(const plan *ego_, printer *p)
|
Chris@82
|
125 {
|
Chris@82
|
126 const P *ego = (const P *) ego_;
|
Chris@82
|
127 const S *s = ego->slv;
|
Chris@82
|
128 const kdft_desc *d = s->desc;
|
Chris@82
|
129
|
Chris@82
|
130 if (ego->slv->bufferedp)
|
Chris@82
|
131 p->print(p, "(dft-directbuf/%D-%D%v \"%s\")",
|
Chris@82
|
132 compute_batchsize(d->sz), d->sz, ego->vl, d->nam);
|
Chris@82
|
133 else
|
Chris@82
|
134 p->print(p, "(dft-direct-%D%v \"%s\")", d->sz, ego->vl, d->nam);
|
Chris@82
|
135 }
|
Chris@82
|
136
|
Chris@82
|
137 static int applicable_buf(const solver *ego_, const problem *p_,
|
Chris@82
|
138 const planner *plnr)
|
Chris@82
|
139 {
|
Chris@82
|
140 const S *ego = (const S *) ego_;
|
Chris@82
|
141 const problem_dft *p = (const problem_dft *) p_;
|
Chris@82
|
142 const kdft_desc *d = ego->desc;
|
Chris@82
|
143 INT vl;
|
Chris@82
|
144 INT ivs, ovs;
|
Chris@82
|
145 INT batchsz;
|
Chris@82
|
146
|
Chris@82
|
147 return (
|
Chris@82
|
148 1
|
Chris@82
|
149 && p->sz->rnk == 1
|
Chris@82
|
150 && p->vecsz->rnk == 1
|
Chris@82
|
151 && p->sz->dims[0].n == d->sz
|
Chris@82
|
152
|
Chris@82
|
153 /* check strides etc */
|
Chris@82
|
154 && X(tensor_tornk1)(p->vecsz, &vl, &ivs, &ovs)
|
Chris@82
|
155
|
Chris@82
|
156 /* UGLY if IS <= IVS */
|
Chris@82
|
157 && !(NO_UGLYP(plnr) &&
|
Chris@82
|
158 X(iabs)(p->sz->dims[0].is) <= X(iabs)(ivs))
|
Chris@82
|
159
|
Chris@82
|
160 && (batchsz = compute_batchsize(d->sz), 1)
|
Chris@82
|
161 && (d->genus->okp(d, 0, ((const R *)0) + 1, p->ro, p->io,
|
Chris@82
|
162 2 * batchsz, p->sz->dims[0].os,
|
Chris@82
|
163 batchsz, 2, ovs, plnr))
|
Chris@82
|
164 && (d->genus->okp(d, 0, ((const R *)0) + 1, p->ro, p->io,
|
Chris@82
|
165 2 * batchsz, p->sz->dims[0].os,
|
Chris@82
|
166 vl % batchsz, 2, ovs, plnr))
|
Chris@82
|
167
|
Chris@82
|
168
|
Chris@82
|
169 && (0
|
Chris@82
|
170 /* can operate out-of-place */
|
Chris@82
|
171 || p->ri != p->ro
|
Chris@82
|
172
|
Chris@82
|
173 /* can operate in-place as long as strides are the same */
|
Chris@82
|
174 || X(tensor_inplace_strides2)(p->sz, p->vecsz)
|
Chris@82
|
175
|
Chris@82
|
176 /* can do it if the problem fits in the buffer, no matter
|
Chris@82
|
177 what the strides are */
|
Chris@82
|
178 || vl <= batchsz
|
Chris@82
|
179 )
|
Chris@82
|
180 );
|
Chris@82
|
181 }
|
Chris@82
|
182
|
Chris@82
|
183 static int applicable(const solver *ego_, const problem *p_,
|
Chris@82
|
184 const planner *plnr, int *extra_iterp)
|
Chris@82
|
185 {
|
Chris@82
|
186 const S *ego = (const S *) ego_;
|
Chris@82
|
187 const problem_dft *p = (const problem_dft *) p_;
|
Chris@82
|
188 const kdft_desc *d = ego->desc;
|
Chris@82
|
189 INT vl;
|
Chris@82
|
190 INT ivs, ovs;
|
Chris@82
|
191
|
Chris@82
|
192 return (
|
Chris@82
|
193 1
|
Chris@82
|
194 && p->sz->rnk == 1
|
Chris@82
|
195 && p->vecsz->rnk <= 1
|
Chris@82
|
196 && p->sz->dims[0].n == d->sz
|
Chris@82
|
197
|
Chris@82
|
198 /* check strides etc */
|
Chris@82
|
199 && X(tensor_tornk1)(p->vecsz, &vl, &ivs, &ovs)
|
Chris@82
|
200
|
Chris@82
|
201 && ((*extra_iterp = 0,
|
Chris@82
|
202 (d->genus->okp(d, p->ri, p->ii, p->ro, p->io,
|
Chris@82
|
203 p->sz->dims[0].is, p->sz->dims[0].os,
|
Chris@82
|
204 vl, ivs, ovs, plnr)))
|
Chris@82
|
205 ||
|
Chris@82
|
206 (*extra_iterp = 1,
|
Chris@82
|
207 ((d->genus->okp(d, p->ri, p->ii, p->ro, p->io,
|
Chris@82
|
208 p->sz->dims[0].is, p->sz->dims[0].os,
|
Chris@82
|
209 vl - 1, ivs, ovs, plnr))
|
Chris@82
|
210 &&
|
Chris@82
|
211 (d->genus->okp(d, p->ri, p->ii, p->ro, p->io,
|
Chris@82
|
212 p->sz->dims[0].is, p->sz->dims[0].os,
|
Chris@82
|
213 2, 0, 0, plnr)))))
|
Chris@82
|
214
|
Chris@82
|
215 && (0
|
Chris@82
|
216 /* can operate out-of-place */
|
Chris@82
|
217 || p->ri != p->ro
|
Chris@82
|
218
|
Chris@82
|
219 /* can always compute one transform */
|
Chris@82
|
220 || vl == 1
|
Chris@82
|
221
|
Chris@82
|
222 /* can operate in-place as long as strides are the same */
|
Chris@82
|
223 || X(tensor_inplace_strides2)(p->sz, p->vecsz)
|
Chris@82
|
224 )
|
Chris@82
|
225 );
|
Chris@82
|
226 }
|
Chris@82
|
227
|
Chris@82
|
228
|
Chris@82
|
229 static plan *mkplan(const solver *ego_, const problem *p_, planner *plnr)
|
Chris@82
|
230 {
|
Chris@82
|
231 const S *ego = (const S *) ego_;
|
Chris@82
|
232 P *pln;
|
Chris@82
|
233 const problem_dft *p;
|
Chris@82
|
234 iodim *d;
|
Chris@82
|
235 const kdft_desc *e = ego->desc;
|
Chris@82
|
236
|
Chris@82
|
237 static const plan_adt padt = {
|
Chris@82
|
238 X(dft_solve), X(null_awake), print, destroy
|
Chris@82
|
239 };
|
Chris@82
|
240
|
Chris@82
|
241 UNUSED(plnr);
|
Chris@82
|
242
|
Chris@82
|
243 if (ego->bufferedp) {
|
Chris@82
|
244 if (!applicable_buf(ego_, p_, plnr))
|
Chris@82
|
245 return (plan *)0;
|
Chris@82
|
246 pln = MKPLAN_DFT(P, &padt, apply_buf);
|
Chris@82
|
247 } else {
|
Chris@82
|
248 int extra_iterp = 0;
|
Chris@82
|
249 if (!applicable(ego_, p_, plnr, &extra_iterp))
|
Chris@82
|
250 return (plan *)0;
|
Chris@82
|
251 pln = MKPLAN_DFT(P, &padt, extra_iterp ? apply_extra_iter : apply);
|
Chris@82
|
252 }
|
Chris@82
|
253
|
Chris@82
|
254 p = (const problem_dft *) p_;
|
Chris@82
|
255 d = p->sz->dims;
|
Chris@82
|
256 pln->k = ego->k;
|
Chris@82
|
257 pln->n = d[0].n;
|
Chris@82
|
258 pln->is = X(mkstride)(pln->n, d[0].is);
|
Chris@82
|
259 pln->os = X(mkstride)(pln->n, d[0].os);
|
Chris@82
|
260 pln->bufstride = X(mkstride)(pln->n, 2 * compute_batchsize(pln->n));
|
Chris@82
|
261
|
Chris@82
|
262 X(tensor_tornk1)(p->vecsz, &pln->vl, &pln->ivs, &pln->ovs);
|
Chris@82
|
263 pln->slv = ego;
|
Chris@82
|
264
|
Chris@82
|
265 X(ops_zero)(&pln->super.super.ops);
|
Chris@82
|
266 X(ops_madd2)(pln->vl / e->genus->vl, &e->ops, &pln->super.super.ops);
|
Chris@82
|
267
|
Chris@82
|
268 if (ego->bufferedp)
|
Chris@82
|
269 pln->super.super.ops.other += 4 * pln->n * pln->vl;
|
Chris@82
|
270
|
Chris@82
|
271 pln->super.super.could_prune_now_p = !ego->bufferedp;
|
Chris@82
|
272 return &(pln->super.super);
|
Chris@82
|
273 }
|
Chris@82
|
274
|
Chris@82
|
275 static solver *mksolver(kdft k, const kdft_desc *desc, int bufferedp)
|
Chris@82
|
276 {
|
Chris@82
|
277 static const solver_adt sadt = { PROBLEM_DFT, mkplan, 0 };
|
Chris@82
|
278 S *slv = MKSOLVER(S, &sadt);
|
Chris@82
|
279 slv->k = k;
|
Chris@82
|
280 slv->desc = desc;
|
Chris@82
|
281 slv->bufferedp = bufferedp;
|
Chris@82
|
282 return &(slv->super);
|
Chris@82
|
283 }
|
Chris@82
|
284
|
Chris@82
|
285 solver *X(mksolver_dft_direct)(kdft k, const kdft_desc *desc)
|
Chris@82
|
286 {
|
Chris@82
|
287 return mksolver(k, desc, 0);
|
Chris@82
|
288 }
|
Chris@82
|
289
|
Chris@82
|
290 solver *X(mksolver_dft_directbuf)(kdft k, const kdft_desc *desc)
|
Chris@82
|
291 {
|
Chris@82
|
292 return mksolver(k, desc, 1);
|
Chris@82
|
293 }
|