cannam@95: (* cannam@95: * Copyright (c) 1997-1999 Massachusetts Institute of Technology cannam@95: * Copyright (c) 2003, 2007-11 Matteo Frigo cannam@95: * Copyright (c) 2003, 2007-11 Massachusetts Institute of Technology cannam@95: * cannam@95: * This program is free software; you can redistribute it and/or modify cannam@95: * it under the terms of the GNU General Public License as published by cannam@95: * the Free Software Foundation; either version 2 of the License, or cannam@95: * (at your option) any later version. cannam@95: * cannam@95: * This program is distributed in the hope that it will be useful, cannam@95: * but WITHOUT ANY WARRANTY; without even the implied warranty of cannam@95: * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the cannam@95: * GNU General Public License for more details. cannam@95: * cannam@95: * You should have received a copy of the GNU General Public License cannam@95: * along with this program; if not, write to the Free Software cannam@95: * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA cannam@95: * cannam@95: *) cannam@95: cannam@95: cannam@95: (* This is the part of the generator that actually computes the FFT cannam@95: in symbolic form *) cannam@95: cannam@95: open Complex cannam@95: open Util cannam@95: cannam@95: (* choose a suitable factor of n *) cannam@95: let choose_factor n = cannam@95: (* first choice: i such that gcd(i, n / i) = 1, i as big as possible *) cannam@95: let choose1 n = cannam@95: let rec loop i f = cannam@95: if (i * i > n) then f cannam@95: else if ((n mod i) == 0 && gcd i (n / i) == 1) then loop (i + 1) i cannam@95: else loop (i + 1) f cannam@95: in loop 1 1 cannam@95: cannam@95: (* second choice: the biggest factor i of n, where i < sqrt(n), if any *) cannam@95: and choose2 n = cannam@95: let rec loop i f = cannam@95: if (i * i > n) then f cannam@95: else if ((n mod i) == 0) then loop (i + 1) i cannam@95: else loop (i + 1) f cannam@95: in loop 1 1 cannam@95: cannam@95: in let i = choose1 n in cannam@95: if (i > 1) then i cannam@95: else choose2 n cannam@95: cannam@95: let is_power_of_two n = (n > 0) && ((n - 1) land n == 0) cannam@95: cannam@95: let rec dft_prime sign n input = cannam@95: let sum filter i = cannam@95: sigma 0 n (fun j -> cannam@95: let coeff = filter (exp n (sign * i * j)) cannam@95: in coeff @* (input j)) in cannam@95: let computation_even = array n (sum identity) cannam@95: and computation_odd = cannam@95: let sumr = array n (sum real) cannam@95: and sumi = array n (sum ((times Complex.i) @@ imag)) in cannam@95: array n (fun i -> cannam@95: if (i = 0) then cannam@95: (* expose some common subexpressions *) cannam@95: input 0 @+ cannam@95: sigma 1 ((n + 1) / 2) (fun j -> input j @+ input (n - j)) cannam@95: else cannam@95: let i' = min i (n - i) in cannam@95: if (i < n - i) then cannam@95: sumr i' @+ sumi i' cannam@95: else cannam@95: sumr i' @- sumi i') in cannam@95: if (n >= !Magic.rader_min) then cannam@95: dft_rader sign n input cannam@95: else if (n == 2) then cannam@95: computation_even cannam@95: else cannam@95: computation_odd cannam@95: cannam@95: cannam@95: and dft_rader sign p input = cannam@95: let half = cannam@95: let one_half = inverse_int 2 in cannam@95: times one_half cannam@95: cannam@95: and make_product n a b = cannam@95: let scale_factor = inverse_int n in cannam@95: array n (fun i -> a i @* (scale_factor @* b i)) in cannam@95: cannam@95: (* generates a convolution using ffts. (all arguments are the cannam@95: same as to gen_convolution, below) *) cannam@95: let gen_convolution_by_fft n a b addtoall = cannam@95: let fft_a = dft 1 n a cannam@95: and fft_b = dft 1 n b in cannam@95: cannam@95: let fft_ab = make_product n fft_a fft_b cannam@95: and dc_term i = if (i == 0) then addtoall else zero in cannam@95: cannam@95: let fft_ab1 = array n (fun i -> fft_ab i @+ dc_term i) cannam@95: and sum = fft_a 0 in cannam@95: let conv = dft (-1) n fft_ab1 in cannam@95: (sum, conv) cannam@95: cannam@95: (* alternate routine for convolution. Seems to work better for cannam@95: small sizes. I have no idea why. *) cannam@95: and gen_convolution_by_fft_alt n a b addtoall = cannam@95: let ap = array n (fun i -> half (a i @+ a ((n - i) mod n))) cannam@95: and am = array n (fun i -> half (a i @- a ((n - i) mod n))) cannam@95: and bp = array n (fun i -> half (b i @+ b ((n - i) mod n))) cannam@95: and bm = array n (fun i -> half (b i @- b ((n - i) mod n))) cannam@95: in cannam@95: cannam@95: let fft_ap = dft 1 n ap cannam@95: and fft_am = dft 1 n am cannam@95: and fft_bp = dft 1 n bp cannam@95: and fft_bm = dft 1 n bm in cannam@95: cannam@95: let fft_abpp = make_product n fft_ap fft_bp cannam@95: and fft_abpm = make_product n fft_ap fft_bm cannam@95: and fft_abmp = make_product n fft_am fft_bp cannam@95: and fft_abmm = make_product n fft_am fft_bm cannam@95: and sum = fft_ap 0 @+ fft_am 0 cannam@95: and dc_term i = if (i == 0) then addtoall else zero in cannam@95: cannam@95: let fft_ab1 = array n (fun i -> (fft_abpp i @+ fft_abmm i) @+ dc_term i) cannam@95: and fft_ab2 = array n (fun i -> fft_abpm i @+ fft_abmp i) in cannam@95: let conv1 = dft (-1) n fft_ab1 cannam@95: and conv2 = dft (-1) n fft_ab2 in cannam@95: let conv = array n (fun i -> cannam@95: conv1 i @+ conv2 i) in cannam@95: (sum, conv) cannam@95: cannam@95: (* generator of assignment list assigning conv to the convolution of cannam@95: a and b, all of which are of length n. addtoall is added to cannam@95: all of the elements of the result. Returns (sum, convolution) pair cannam@95: where sum is the sum of the elements of a. *) cannam@95: cannam@95: in let gen_convolution = cannam@95: if (p <= !Magic.alternate_convolution) then cannam@95: gen_convolution_by_fft_alt cannam@95: else cannam@95: gen_convolution_by_fft cannam@95: cannam@95: (* fft generator for prime n = p using Rader's algorithm for cannam@95: turning the fft into a convolution, which then can be cannam@95: performed in a variety of ways *) cannam@95: in cannam@95: let g = find_generator p in cannam@95: let ginv = pow_mod g (p - 2) p in cannam@95: let input_perm = array p (fun i -> input (pow_mod g i p)) cannam@95: and omega_perm = array p (fun i -> exp p (sign * (pow_mod ginv i p))) cannam@95: and output_perm = array p (fun i -> pow_mod ginv i p) cannam@95: in let (sum, conv) = cannam@95: (gen_convolution (p - 1) input_perm omega_perm (input 0)) cannam@95: in array p (fun i -> cannam@95: if (i = 0) then cannam@95: input 0 @+ sum cannam@95: else cannam@95: let i' = suchthat 0 (fun i' -> i = output_perm i') cannam@95: in conv i') cannam@95: cannam@95: (* our modified version of the conjugate-pair split-radix algorithm, cannam@95: which reduces the number of multiplications by rescaling the cannam@95: sub-transforms (power-of-two n's only) *) cannam@95: and newsplit sign n input = cannam@95: let rec s n k = (* recursive scale factor *) cannam@95: if n <= 4 then cannam@95: one cannam@95: else cannam@95: let k4 = (abs k) mod (n / 4) in cannam@95: let k4' = if k4 <= (n / 8) then k4 else (n/4 - k4) in cannam@95: (s (n / 4) k4') @* (real (exp n k4')) cannam@95: cannam@95: and sinv n k = (* 1 / s(n,k) *) cannam@95: if n <= 4 then cannam@95: one cannam@95: else cannam@95: let k4 = (abs k) mod (n / 4) in cannam@95: let k4' = if k4 <= (n / 8) then k4 else (n/4 - k4) in cannam@95: (sinv (n / 4) k4') @* (sec n k4') cannam@95: cannam@95: in let sdiv2 n k = (s n k) @* (sinv (2*n) k) (* s(n,k) / s(2*n,k) *) cannam@95: and sdiv4 n k = (* s(n,k) / s(4*n,k) *) cannam@95: let k4 = (abs k) mod n in cannam@95: sec (4*n) (if k4 <= (n / 2) then k4 else (n - k4)) cannam@95: cannam@95: in let t n k = (exp n k) @* (sdiv4 (n/4) k) cannam@95: cannam@95: and dft1 input = input cannam@95: and dft2 input = array 2 (fun k -> (input 0) @+ ((input 1) @* exp 2 k)) cannam@95: cannam@95: in let rec newsplit0 sign n input = cannam@95: if (n == 1) then dft1 input cannam@95: else if (n == 2) then dft2 input cannam@95: else let u = newsplit0 sign (n / 2) (fun i -> input (i*2)) cannam@95: and z = newsplitS sign (n / 4) (fun i -> input (i*4 + 1)) cannam@95: and z' = newsplitS sign (n / 4) (fun i -> input ((n + i*4 - 1) mod n)) cannam@95: and twid = array n (fun k -> s (n/4) k @* exp n (sign * k)) in cannam@95: let w = array n (fun k -> twid k @* z (k mod (n / 4))) cannam@95: and w' = array n (fun k -> conj (twid k) @* z' (k mod (n / 4))) in cannam@95: let ww = array n (fun k -> w k @+ w' k) in cannam@95: array n (fun k -> u (k mod (n / 2)) @+ ww k) cannam@95: cannam@95: and newsplitS sign n input = cannam@95: if (n == 1) then dft1 input cannam@95: else if (n == 2) then dft2 input cannam@95: else let u = newsplitS2 sign (n / 2) (fun i -> input (i*2)) cannam@95: and z = newsplitS sign (n / 4) (fun i -> input (i*4 + 1)) cannam@95: and z' = newsplitS sign (n / 4) (fun i -> input ((n + i*4 - 1) mod n)) in cannam@95: let w = array n (fun k -> t n (sign * k) @* z (k mod (n / 4))) cannam@95: and w' = array n (fun k -> conj (t n (sign * k)) @* z' (k mod (n / 4))) in cannam@95: let ww = array n (fun k -> w k @+ w' k) in cannam@95: array n (fun k -> u (k mod (n / 2)) @+ ww k) cannam@95: cannam@95: and newsplitS2 sign n input = cannam@95: if (n == 1) then dft1 input cannam@95: else if (n == 2) then dft2 input cannam@95: else let u = newsplitS4 sign (n / 2) (fun i -> input (i*2)) cannam@95: and z = newsplitS sign (n / 4) (fun i -> input (i*4 + 1)) cannam@95: and z' = newsplitS sign (n / 4) (fun i -> input ((n + i*4 - 1) mod n)) in cannam@95: let w = array n (fun k -> t n (sign * k) @* z (k mod (n / 4))) cannam@95: and w' = array n (fun k -> conj (t n (sign * k)) @* z' (k mod (n / 4))) in cannam@95: let ww = array n (fun k -> (w k @+ w' k) @* (sdiv2 n k)) in cannam@95: array n (fun k -> u (k mod (n / 2)) @+ ww k) cannam@95: cannam@95: and newsplitS4 sign n input = cannam@95: if (n == 1) then dft1 input cannam@95: else if (n == 2) then cannam@95: let f = dft2 input cannam@95: in array 2 (fun k -> (f k) @* (sinv 8 k)) cannam@95: else let u = newsplitS2 sign (n / 2) (fun i -> input (i*2)) cannam@95: and z = newsplitS sign (n / 4) (fun i -> input (i*4 + 1)) cannam@95: and z' = newsplitS sign (n / 4) (fun i -> input ((n + i*4 - 1) mod n)) in cannam@95: let w = array n (fun k -> t n (sign * k) @* z (k mod (n / 4))) cannam@95: and w' = array n (fun k -> conj (t n (sign * k)) @* z' (k mod (n / 4))) in cannam@95: let ww = array n (fun k -> w k @+ w' k) in cannam@95: array n (fun k -> (u (k mod (n / 2)) @+ ww k) @* (sdiv4 n k)) cannam@95: cannam@95: in newsplit0 sign n input cannam@95: cannam@95: and dft sign n input = cannam@95: let rec cooley_tukey sign n1 n2 input = cannam@95: let tmp1 = cannam@95: array n2 (fun i2 -> cannam@95: dft sign n1 (fun i1 -> input (i1 * n2 + i2))) in cannam@95: let tmp2 = cannam@95: array n1 (fun i1 -> cannam@95: array n2 (fun i2 -> cannam@95: exp n (sign * i1 * i2) @* tmp1 i2 i1)) in cannam@95: let tmp3 = array n1 (fun i1 -> dft sign n2 (tmp2 i1)) in cannam@95: (fun i -> tmp3 (i mod n1) (i / n1)) cannam@95: cannam@95: (* cannam@95: * This is "exponent -1" split-radix by Dan Bernstein. cannam@95: *) cannam@95: and split_radix_dit sign n input = cannam@95: let f0 = dft sign (n / 2) (fun i -> input (i * 2)) cannam@95: and f10 = dft sign (n / 4) (fun i -> input (i * 4 + 1)) cannam@95: and f11 = dft sign (n / 4) (fun i -> input ((n + i * 4 - 1) mod n)) in cannam@95: let g10 = array n (fun k -> cannam@95: exp n (sign * k) @* f10 (k mod (n / 4))) cannam@95: and g11 = array n (fun k -> cannam@95: exp n (- sign * k) @* f11 (k mod (n / 4))) in cannam@95: let g1 = array n (fun k -> g10 k @+ g11 k) in cannam@95: array n (fun k -> f0 (k mod (n / 2)) @+ g1 k) cannam@95: cannam@95: and split_radix_dif sign n input = cannam@95: let n2 = n / 2 and n4 = n / 4 in cannam@95: let x0 = array n2 (fun i -> input i @+ input (i + n2)) cannam@95: and x10 = array n4 (fun i -> input i @- input (i + n2)) cannam@95: and x11 = array n4 (fun i -> cannam@95: input (i + n4) @- input (i + n2 + n4)) in cannam@95: let x1 k i = cannam@95: exp n (k * i * sign) @* (x10 i @+ exp 4 (k * sign) @* x11 i) in cannam@95: let f0 = dft sign n2 x0 cannam@95: and f1 = array 4 (fun k -> dft sign n4 (x1 k)) in cannam@95: array n (fun k -> cannam@95: if k mod 2 = 0 then f0 (k / 2) cannam@95: else let k' = k mod 4 in f1 k' ((k - k') / 4)) cannam@95: cannam@95: and prime_factor sign n1 n2 input = cannam@95: let tmp1 = array n2 (fun i2 -> cannam@95: dft sign n1 (fun i1 -> input ((i1 * n2 + i2 * n1) mod n))) cannam@95: in let tmp2 = array n1 (fun i1 -> cannam@95: dft sign n2 (fun k2 -> tmp1 k2 i1)) cannam@95: in fun i -> tmp2 (i mod n1) (i mod n2) cannam@95: cannam@95: in let algorithm sign n = cannam@95: let r = choose_factor n in cannam@95: if List.mem n !Magic.rader_list then cannam@95: (* special cases *) cannam@95: dft_rader sign n cannam@95: else if (r == 1) then (* n is prime *) cannam@95: dft_prime sign n cannam@95: else if (gcd r (n / r)) == 1 then cannam@95: prime_factor sign r (n / r) cannam@95: else if (n mod 4 = 0 && n > 4) then cannam@95: if !Magic.newsplit && is_power_of_two n then cannam@95: newsplit sign n cannam@95: else if !Magic.dif_split_radix then cannam@95: split_radix_dif sign n cannam@95: else cannam@95: split_radix_dit sign n cannam@95: else cannam@95: cooley_tukey sign r (n / r) cannam@95: in cannam@95: array n (algorithm sign n input)