comparison src/fftw-3.3.3/genfft/fft.ml @ 95:89f5e221ed7b

Add FFTW3
author Chris Cannam <cannam@all-day-breakfast.com>
date Wed, 20 Mar 2013 15:35:50 +0000
parents
children
comparison
equal deleted inserted replaced
94:d278df1123f9 95:89f5e221ed7b
1 (*
2 * Copyright (c) 1997-1999 Massachusetts Institute of Technology
3 * Copyright (c) 2003, 2007-11 Matteo Frigo
4 * Copyright (c) 2003, 2007-11 Massachusetts Institute of Technology
5 *
6 * This program is free software; you can redistribute it and/or modify
7 * it under the terms of the GNU General Public License as published by
8 * the Free Software Foundation; either version 2 of the License, or
9 * (at your option) any later version.
10 *
11 * This program is distributed in the hope that it will be useful,
12 * but WITHOUT ANY WARRANTY; without even the implied warranty of
13 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 * GNU General Public License for more details.
15 *
16 * You should have received a copy of the GNU General Public License
17 * along with this program; if not, write to the Free Software
18 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
19 *
20 *)
21
22
23 (* This is the part of the generator that actually computes the FFT
24 in symbolic form *)
25
26 open Complex
27 open Util
28
29 (* choose a suitable factor of n *)
30 let choose_factor n =
31 (* first choice: i such that gcd(i, n / i) = 1, i as big as possible *)
32 let choose1 n =
33 let rec loop i f =
34 if (i * i > n) then f
35 else if ((n mod i) == 0 && gcd i (n / i) == 1) then loop (i + 1) i
36 else loop (i + 1) f
37 in loop 1 1
38
39 (* second choice: the biggest factor i of n, where i < sqrt(n), if any *)
40 and choose2 n =
41 let rec loop i f =
42 if (i * i > n) then f
43 else if ((n mod i) == 0) then loop (i + 1) i
44 else loop (i + 1) f
45 in loop 1 1
46
47 in let i = choose1 n in
48 if (i > 1) then i
49 else choose2 n
50
51 let is_power_of_two n = (n > 0) && ((n - 1) land n == 0)
52
53 let rec dft_prime sign n input =
54 let sum filter i =
55 sigma 0 n (fun j ->
56 let coeff = filter (exp n (sign * i * j))
57 in coeff @* (input j)) in
58 let computation_even = array n (sum identity)
59 and computation_odd =
60 let sumr = array n (sum real)
61 and sumi = array n (sum ((times Complex.i) @@ imag)) in
62 array n (fun i ->
63 if (i = 0) then
64 (* expose some common subexpressions *)
65 input 0 @+
66 sigma 1 ((n + 1) / 2) (fun j -> input j @+ input (n - j))
67 else
68 let i' = min i (n - i) in
69 if (i < n - i) then
70 sumr i' @+ sumi i'
71 else
72 sumr i' @- sumi i') in
73 if (n >= !Magic.rader_min) then
74 dft_rader sign n input
75 else if (n == 2) then
76 computation_even
77 else
78 computation_odd
79
80
81 and dft_rader sign p input =
82 let half =
83 let one_half = inverse_int 2 in
84 times one_half
85
86 and make_product n a b =
87 let scale_factor = inverse_int n in
88 array n (fun i -> a i @* (scale_factor @* b i)) in
89
90 (* generates a convolution using ffts. (all arguments are the
91 same as to gen_convolution, below) *)
92 let gen_convolution_by_fft n a b addtoall =
93 let fft_a = dft 1 n a
94 and fft_b = dft 1 n b in
95
96 let fft_ab = make_product n fft_a fft_b
97 and dc_term i = if (i == 0) then addtoall else zero in
98
99 let fft_ab1 = array n (fun i -> fft_ab i @+ dc_term i)
100 and sum = fft_a 0 in
101 let conv = dft (-1) n fft_ab1 in
102 (sum, conv)
103
104 (* alternate routine for convolution. Seems to work better for
105 small sizes. I have no idea why. *)
106 and gen_convolution_by_fft_alt n a b addtoall =
107 let ap = array n (fun i -> half (a i @+ a ((n - i) mod n)))
108 and am = array n (fun i -> half (a i @- a ((n - i) mod n)))
109 and bp = array n (fun i -> half (b i @+ b ((n - i) mod n)))
110 and bm = array n (fun i -> half (b i @- b ((n - i) mod n)))
111 in
112
113 let fft_ap = dft 1 n ap
114 and fft_am = dft 1 n am
115 and fft_bp = dft 1 n bp
116 and fft_bm = dft 1 n bm in
117
118 let fft_abpp = make_product n fft_ap fft_bp
119 and fft_abpm = make_product n fft_ap fft_bm
120 and fft_abmp = make_product n fft_am fft_bp
121 and fft_abmm = make_product n fft_am fft_bm
122 and sum = fft_ap 0 @+ fft_am 0
123 and dc_term i = if (i == 0) then addtoall else zero in
124
125 let fft_ab1 = array n (fun i -> (fft_abpp i @+ fft_abmm i) @+ dc_term i)
126 and fft_ab2 = array n (fun i -> fft_abpm i @+ fft_abmp i) in
127 let conv1 = dft (-1) n fft_ab1
128 and conv2 = dft (-1) n fft_ab2 in
129 let conv = array n (fun i ->
130 conv1 i @+ conv2 i) in
131 (sum, conv)
132
133 (* generator of assignment list assigning conv to the convolution of
134 a and b, all of which are of length n. addtoall is added to
135 all of the elements of the result. Returns (sum, convolution) pair
136 where sum is the sum of the elements of a. *)
137
138 in let gen_convolution =
139 if (p <= !Magic.alternate_convolution) then
140 gen_convolution_by_fft_alt
141 else
142 gen_convolution_by_fft
143
144 (* fft generator for prime n = p using Rader's algorithm for
145 turning the fft into a convolution, which then can be
146 performed in a variety of ways *)
147 in
148 let g = find_generator p in
149 let ginv = pow_mod g (p - 2) p in
150 let input_perm = array p (fun i -> input (pow_mod g i p))
151 and omega_perm = array p (fun i -> exp p (sign * (pow_mod ginv i p)))
152 and output_perm = array p (fun i -> pow_mod ginv i p)
153 in let (sum, conv) =
154 (gen_convolution (p - 1) input_perm omega_perm (input 0))
155 in array p (fun i ->
156 if (i = 0) then
157 input 0 @+ sum
158 else
159 let i' = suchthat 0 (fun i' -> i = output_perm i')
160 in conv i')
161
162 (* our modified version of the conjugate-pair split-radix algorithm,
163 which reduces the number of multiplications by rescaling the
164 sub-transforms (power-of-two n's only) *)
165 and newsplit sign n input =
166 let rec s n k = (* recursive scale factor *)
167 if n <= 4 then
168 one
169 else
170 let k4 = (abs k) mod (n / 4) in
171 let k4' = if k4 <= (n / 8) then k4 else (n/4 - k4) in
172 (s (n / 4) k4') @* (real (exp n k4'))
173
174 and sinv n k = (* 1 / s(n,k) *)
175 if n <= 4 then
176 one
177 else
178 let k4 = (abs k) mod (n / 4) in
179 let k4' = if k4 <= (n / 8) then k4 else (n/4 - k4) in
180 (sinv (n / 4) k4') @* (sec n k4')
181
182 in let sdiv2 n k = (s n k) @* (sinv (2*n) k) (* s(n,k) / s(2*n,k) *)
183 and sdiv4 n k = (* s(n,k) / s(4*n,k) *)
184 let k4 = (abs k) mod n in
185 sec (4*n) (if k4 <= (n / 2) then k4 else (n - k4))
186
187 in let t n k = (exp n k) @* (sdiv4 (n/4) k)
188
189 and dft1 input = input
190 and dft2 input = array 2 (fun k -> (input 0) @+ ((input 1) @* exp 2 k))
191
192 in let rec newsplit0 sign n input =
193 if (n == 1) then dft1 input
194 else if (n == 2) then dft2 input
195 else let u = newsplit0 sign (n / 2) (fun i -> input (i*2))
196 and z = newsplitS sign (n / 4) (fun i -> input (i*4 + 1))
197 and z' = newsplitS sign (n / 4) (fun i -> input ((n + i*4 - 1) mod n))
198 and twid = array n (fun k -> s (n/4) k @* exp n (sign * k)) in
199 let w = array n (fun k -> twid k @* z (k mod (n / 4)))
200 and w' = array n (fun k -> conj (twid k) @* z' (k mod (n / 4))) in
201 let ww = array n (fun k -> w k @+ w' k) in
202 array n (fun k -> u (k mod (n / 2)) @+ ww k)
203
204 and newsplitS sign n input =
205 if (n == 1) then dft1 input
206 else if (n == 2) then dft2 input
207 else let u = newsplitS2 sign (n / 2) (fun i -> input (i*2))
208 and z = newsplitS sign (n / 4) (fun i -> input (i*4 + 1))
209 and z' = newsplitS sign (n / 4) (fun i -> input ((n + i*4 - 1) mod n)) in
210 let w = array n (fun k -> t n (sign * k) @* z (k mod (n / 4)))
211 and w' = array n (fun k -> conj (t n (sign * k)) @* z' (k mod (n / 4))) in
212 let ww = array n (fun k -> w k @+ w' k) in
213 array n (fun k -> u (k mod (n / 2)) @+ ww k)
214
215 and newsplitS2 sign n input =
216 if (n == 1) then dft1 input
217 else if (n == 2) then dft2 input
218 else let u = newsplitS4 sign (n / 2) (fun i -> input (i*2))
219 and z = newsplitS sign (n / 4) (fun i -> input (i*4 + 1))
220 and z' = newsplitS sign (n / 4) (fun i -> input ((n + i*4 - 1) mod n)) in
221 let w = array n (fun k -> t n (sign * k) @* z (k mod (n / 4)))
222 and w' = array n (fun k -> conj (t n (sign * k)) @* z' (k mod (n / 4))) in
223 let ww = array n (fun k -> (w k @+ w' k) @* (sdiv2 n k)) in
224 array n (fun k -> u (k mod (n / 2)) @+ ww k)
225
226 and newsplitS4 sign n input =
227 if (n == 1) then dft1 input
228 else if (n == 2) then
229 let f = dft2 input
230 in array 2 (fun k -> (f k) @* (sinv 8 k))
231 else let u = newsplitS2 sign (n / 2) (fun i -> input (i*2))
232 and z = newsplitS sign (n / 4) (fun i -> input (i*4 + 1))
233 and z' = newsplitS sign (n / 4) (fun i -> input ((n + i*4 - 1) mod n)) in
234 let w = array n (fun k -> t n (sign * k) @* z (k mod (n / 4)))
235 and w' = array n (fun k -> conj (t n (sign * k)) @* z' (k mod (n / 4))) in
236 let ww = array n (fun k -> w k @+ w' k) in
237 array n (fun k -> (u (k mod (n / 2)) @+ ww k) @* (sdiv4 n k))
238
239 in newsplit0 sign n input
240
241 and dft sign n input =
242 let rec cooley_tukey sign n1 n2 input =
243 let tmp1 =
244 array n2 (fun i2 ->
245 dft sign n1 (fun i1 -> input (i1 * n2 + i2))) in
246 let tmp2 =
247 array n1 (fun i1 ->
248 array n2 (fun i2 ->
249 exp n (sign * i1 * i2) @* tmp1 i2 i1)) in
250 let tmp3 = array n1 (fun i1 -> dft sign n2 (tmp2 i1)) in
251 (fun i -> tmp3 (i mod n1) (i / n1))
252
253 (*
254 * This is "exponent -1" split-radix by Dan Bernstein.
255 *)
256 and split_radix_dit sign n input =
257 let f0 = dft sign (n / 2) (fun i -> input (i * 2))
258 and f10 = dft sign (n / 4) (fun i -> input (i * 4 + 1))
259 and f11 = dft sign (n / 4) (fun i -> input ((n + i * 4 - 1) mod n)) in
260 let g10 = array n (fun k ->
261 exp n (sign * k) @* f10 (k mod (n / 4)))
262 and g11 = array n (fun k ->
263 exp n (- sign * k) @* f11 (k mod (n / 4))) in
264 let g1 = array n (fun k -> g10 k @+ g11 k) in
265 array n (fun k -> f0 (k mod (n / 2)) @+ g1 k)
266
267 and split_radix_dif sign n input =
268 let n2 = n / 2 and n4 = n / 4 in
269 let x0 = array n2 (fun i -> input i @+ input (i + n2))
270 and x10 = array n4 (fun i -> input i @- input (i + n2))
271 and x11 = array n4 (fun i ->
272 input (i + n4) @- input (i + n2 + n4)) in
273 let x1 k i =
274 exp n (k * i * sign) @* (x10 i @+ exp 4 (k * sign) @* x11 i) in
275 let f0 = dft sign n2 x0
276 and f1 = array 4 (fun k -> dft sign n4 (x1 k)) in
277 array n (fun k ->
278 if k mod 2 = 0 then f0 (k / 2)
279 else let k' = k mod 4 in f1 k' ((k - k') / 4))
280
281 and prime_factor sign n1 n2 input =
282 let tmp1 = array n2 (fun i2 ->
283 dft sign n1 (fun i1 -> input ((i1 * n2 + i2 * n1) mod n)))
284 in let tmp2 = array n1 (fun i1 ->
285 dft sign n2 (fun k2 -> tmp1 k2 i1))
286 in fun i -> tmp2 (i mod n1) (i mod n2)
287
288 in let algorithm sign n =
289 let r = choose_factor n in
290 if List.mem n !Magic.rader_list then
291 (* special cases *)
292 dft_rader sign n
293 else if (r == 1) then (* n is prime *)
294 dft_prime sign n
295 else if (gcd r (n / r)) == 1 then
296 prime_factor sign r (n / r)
297 else if (n mod 4 = 0 && n > 4) then
298 if !Magic.newsplit && is_power_of_two n then
299 newsplit sign n
300 else if !Magic.dif_split_radix then
301 split_radix_dif sign n
302 else
303 split_radix_dit sign n
304 else
305 cooley_tukey sign r (n / r)
306 in
307 array n (algorithm sign n input)