Mercurial > hg > sv-dependency-builds
diff src/fftw-3.3.3/genfft/algsimp.ml @ 10:37bf6b4a2645
Add FFTW3
author | Chris Cannam |
---|---|
date | Wed, 20 Mar 2013 15:35:50 +0000 |
parents | |
children |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/fftw-3.3.3/genfft/algsimp.ml Wed Mar 20 15:35:50 2013 +0000 @@ -0,0 +1,580 @@ +(* + * Copyright (c) 1997-1999 Massachusetts Institute of Technology + * Copyright (c) 2003, 2007-11 Matteo Frigo + * Copyright (c) 2003, 2007-11 Massachusetts Institute of Technology + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 2 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA + * + *) + + +open Util +open Expr + +let node_insert x = Assoctable.insert Expr.hash x +let node_lookup x = Assoctable.lookup Expr.hash (==) x + +(************************************************************* + * Algebraic simplifier/elimination of common subexpressions + *************************************************************) +module AlgSimp : sig + val algsimp : expr list -> expr list +end = struct + + open Monads.StateMonad + open Monads.MemoMonad + open Assoctable + + let fetchSimp = + fetchState >>= fun (s, _) -> returnM s + let storeSimp s = + fetchState >>= (fun (_, c) -> storeState (s, c)) + let lookupSimpM key = + fetchSimp >>= fun table -> + returnM (node_lookup key table) + let insertSimpM key value = + fetchSimp >>= fun table -> + storeSimp (node_insert key value table) + + let subset a b = + List.for_all (fun x -> List.exists (fun y -> x == y) b) a + + let structurallyEqualCSE a b = + match (a, b) with + | (Num a, Num b) -> Number.equal a b + | (NaN a, NaN b) -> a == b + | (Load a, Load b) -> Variable.same a b + | (Times (a, a'), Times (b, b')) -> + ((a == b) && (a' == b')) or + ((a == b') && (a' == b)) + | (CTimes (a, a'), CTimes (b, b')) -> + ((a == b) && (a' == b')) or + ((a == b') && (a' == b)) + | (CTimesJ (a, a'), CTimesJ (b, b')) -> ((a == b) && (a' == b')) + | (Plus a, Plus b) -> subset a b && subset b a + | (Uminus a, Uminus b) -> (a == b) + | _ -> false + + let hashCSE x = + if (!Magic.randomized_cse) then + Oracle.hash x + else + Expr.hash x + + let equalCSE a b = + if (!Magic.randomized_cse) then + (structurallyEqualCSE a b || Oracle.likely_equal a b) + else + structurallyEqualCSE a b + + let fetchCSE = + fetchState >>= fun (_, c) -> returnM c + let storeCSE c = + fetchState >>= (fun (s, _) -> storeState (s, c)) + let lookupCSEM key = + fetchCSE >>= fun table -> + returnM (Assoctable.lookup hashCSE equalCSE key table) + let insertCSEM key value = + fetchCSE >>= fun table -> + storeCSE (Assoctable.insert hashCSE key value table) + + (* memoize both x and Uminus x (unless x is already negated) *) + let identityM x = + let memo x = memoizing lookupCSEM insertCSEM returnM x in + match x with + Uminus _ -> memo x + | _ -> memo x >>= fun x' -> memo (Uminus x') >> returnM x' + + let makeNode = identityM + + (* simplifiers for various kinds of nodes *) + let rec snumM = function + n when Number.is_zero n -> + makeNode (Num (Number.zero)) + | n when Number.negative n -> + makeNode (Num (Number.negate n)) >>= suminusM + | n -> makeNode (Num n) + + and suminusM = function + Uminus x -> makeNode x + | Num a when (Number.is_zero a) -> snumM Number.zero + | a -> makeNode (Uminus a) + + and stimesM = function + | (Uminus a, b) -> stimesM (a, b) >>= suminusM + | (a, Uminus b) -> stimesM (a, b) >>= suminusM + | (NaN I, CTimes (a, b)) -> stimesM (NaN I, b) >>= + fun ib -> sctimesM (a, ib) + | (NaN I, CTimesJ (a, b)) -> stimesM (NaN I, b) >>= + fun ib -> sctimesjM (a, ib) + | (Num a, Num b) -> snumM (Number.mul a b) + | (Num a, Times (Num b, c)) -> + snumM (Number.mul a b) >>= fun x -> stimesM (x, c) + | (Num a, b) when Number.is_zero a -> snumM Number.zero + | (Num a, b) when Number.is_one a -> makeNode b + | (Num a, b) when Number.is_mone a -> suminusM b + | (a, b) when is_known_constant b && not (is_known_constant a) -> + stimesM (b, a) + | (a, b) -> makeNode (Times (a, b)) + + and sctimesM = function + | (Uminus a, b) -> sctimesM (a, b) >>= suminusM + | (a, Uminus b) -> sctimesM (a, b) >>= suminusM + | (a, b) -> makeNode (CTimes (a, b)) + + and sctimesjM = function + | (Uminus a, b) -> sctimesjM (a, b) >>= suminusM + | (a, Uminus b) -> sctimesjM (a, b) >>= suminusM + | (a, b) -> makeNode (CTimesJ (a, b)) + + and reduce_sumM x = match x with + [] -> returnM [] + | [Num a] -> + if (Number.is_zero a) then + returnM [] + else returnM x + | [Uminus (Num a)] -> + if (Number.is_zero a) then + returnM [] + else returnM x + | (Num a) :: (Num b) :: s -> + snumM (Number.add a b) >>= fun x -> + reduce_sumM (x :: s) + | (Num a) :: (Uminus (Num b)) :: s -> + snumM (Number.sub a b) >>= fun x -> + reduce_sumM (x :: s) + | (Uminus (Num a)) :: (Num b) :: s -> + snumM (Number.sub b a) >>= fun x -> + reduce_sumM (x :: s) + | (Uminus (Num a)) :: (Uminus (Num b)) :: s -> + snumM (Number.add a b) >>= + suminusM >>= fun x -> + reduce_sumM (x :: s) + | ((Num _) as a) :: b :: s -> reduce_sumM (b :: a :: s) + | ((Uminus (Num _)) as a) :: b :: s -> reduce_sumM (b :: a :: s) + | a :: s -> + reduce_sumM s >>= fun s' -> returnM (a :: s') + + and collectible1 = function + | NaN _ -> false + | Uminus x -> collectible1 x + | _ -> true + and collectible (a, b) = collectible1 a + + (* collect common factors: ax + bx -> (a+b)x *) + and collectM which x = + let rec findCoeffM which = function + | Times (a, b) when collectible (which (a, b)) -> returnM (which (a, b)) + | Uminus x -> + findCoeffM which x >>= fun (coeff, b) -> + suminusM coeff >>= fun mcoeff -> + returnM (mcoeff, b) + | x -> snumM Number.one >>= fun one -> returnM (one, x) + and separateM xpr = function + [] -> returnM ([], []) + | a :: b -> + separateM xpr b >>= fun (w, wo) -> + (* try first factor *) + findCoeffM (fun (a, b) -> (a, b)) a >>= fun (c, x) -> + if (xpr == x) && collectible (c, x) then returnM (c :: w, wo) + else + (* try second factor *) + findCoeffM (fun (a, b) -> (b, a)) a >>= fun (c, x) -> + if (xpr == x) && collectible (c, x) then returnM (c :: w, wo) + else returnM (w, a :: wo) + in match x with + [] -> returnM x + | [a] -> returnM x + | a :: b -> + findCoeffM which a >>= fun (_, xpr) -> + separateM xpr x >>= fun (w, wo) -> + collectM which wo >>= fun wo' -> + splusM w >>= fun w' -> + stimesM (w', xpr) >>= fun t' -> + returnM (t':: wo') + + and mangleSumM x = returnM x + >>= reduce_sumM + >>= collectM (fun (a, b) -> (a, b)) + >>= collectM (fun (a, b) -> (b, a)) + >>= reduce_sumM + >>= deepCollectM !Magic.deep_collect_depth + >>= reduce_sumM + + and reorder_uminus = function (* push all Uminuses to the end *) + [] -> [] + | ((Uminus _) as a' :: b) -> (reorder_uminus b) @ [a'] + | (a :: b) -> a :: (reorder_uminus b) + + and canonicalizeM = function + [] -> snumM Number.zero + | [a] -> makeNode a (* one term *) + | a -> generateFusedMultAddM (reorder_uminus a) + + and generateFusedMultAddM = + let rec is_multiplication = function + | Times (Num a, b) -> true + | Uminus (Times (Num a, b)) -> true + | _ -> false + and separate = function + [] -> ([], [], Number.zero) + | (Times (Num a, b)) as this :: c -> + let (x, y, max) = separate c in + let newmax = if (Number.greater a max) then a else max in + (this :: x, y, newmax) + | (Uminus (Times (Num a, b))) as this :: c -> + let (x, y, max) = separate c in + let newmax = if (Number.greater a max) then a else max in + (this :: x, y, newmax) + | this :: c -> + let (x, y, max) = separate c in + (x, this :: y, max) + in fun l -> + if !Magic.enable_fma && count is_multiplication l >= 2 then + let (w, wo, max) = separate l in + snumM (Number.div Number.one max) >>= fun invmax' -> + snumM max >>= fun max' -> + mapM (fun x -> stimesM (invmax', x)) w >>= splusM >>= fun pw' -> + stimesM (max', pw') >>= fun mw' -> + splusM (wo @ [mw']) + else + makeNode (Plus l) + + + and negative = function + Uminus _ -> true + | _ -> false + + (* + * simplify patterns of the form + * + * ((c_1 * a + ...) + ...) + (c_2 * a + ...) + * + * The pattern includes arbitrary coefficients and minus signs. + * A common case of this pattern is the butterfly + * (a + b) + (a - b) + * (a + b) - (a - b) + *) + (* this whole procedure needs much more thought *) + and deepCollectM maxdepth l = + let rec findTerms depth x = match x with + | Uminus x -> findTerms depth x + | Times (Num _, b) -> (findTerms (depth - 1) b) + | Plus l when depth > 0 -> + x :: List.flatten (List.map (findTerms (depth - 1)) l) + | x -> [x] + and duplicates = function + [] -> [] + | a :: b -> if List.memq a b then a :: duplicates b + else duplicates b + + in let rec splitDuplicates depth d x = + if (List.memq x d) then + snumM (Number.zero) >>= fun zero -> + returnM (zero, x) + else match x with + | Times (a, b) -> + splitDuplicates (depth - 1) d a >>= fun (a', xa) -> + splitDuplicates (depth - 1) d b >>= fun (b', xb) -> + stimesM (a', b') >>= fun ab -> + stimesM (a, xb) >>= fun xb' -> + stimesM (xa, b) >>= fun xa' -> + stimesM (xa, xb) >>= fun xab -> + splusM [xa'; xb'; xab] >>= fun x -> + returnM (ab, x) + | Uminus a -> + splitDuplicates depth d a >>= fun (x, y) -> + suminusM x >>= fun ux -> + suminusM y >>= fun uy -> + returnM (ux, uy) + | Plus l when depth > 0 -> + mapM (splitDuplicates (depth - 1) d) l >>= fun ld -> + let (l', d') = List.split ld in + splusM l' >>= fun p -> + splusM d' >>= fun d'' -> + returnM (p, d'') + | x -> + snumM (Number.zero) >>= fun zero' -> + returnM (x, zero') + + in let l' = List.flatten (List.map (findTerms maxdepth) l) + in match duplicates l' with + | [] -> returnM l + | d -> + mapM (splitDuplicates maxdepth d) l >>= fun ld -> + let (l', d') = List.split ld in + splusM l' >>= fun l'' -> + let rec flattenPlusM = function + | Plus l -> returnM l + | Uminus x -> + flattenPlusM x >>= mapM suminusM + | x -> returnM [x] + in + mapM flattenPlusM d' >>= fun d'' -> + splusM (List.flatten d'') >>= fun d''' -> + mangleSumM [l''; d'''] + + and splusM l = + let fma_heuristics x = + if !Magic.enable_fma then + match x with + | [Uminus (Times _); Times _] -> Some false + | [Times _; Uminus (Times _)] -> Some false + | [Uminus (_); Times _] -> Some true + | [Times _; Uminus (Plus _)] -> Some true + | [_; Uminus (Times _)] -> Some false + | [Uminus (Times _); _] -> Some false + | _ -> None + else + None + in + mangleSumM l >>= fun l' -> + (* no terms are negative. Don't do anything *) + if not (List.exists negative l') then + canonicalizeM l' + (* all terms are negative. Negate them all and collect the minus sign *) + else if List.for_all negative l' then + mapM suminusM l' >>= splusM >>= suminusM + else match fma_heuristics l' with + | Some true -> mapM suminusM l' >>= splusM >>= suminusM + | Some false -> canonicalizeM l' + | None -> + (* Ask the Oracle for the canonical form *) + if (not !Magic.randomized_cse) && + Oracle.should_flip_sign (Plus l') then + mapM suminusM l' >>= splusM >>= suminusM + else + canonicalizeM l' + + (* monadic style algebraic simplifier for the dag *) + let rec algsimpM x = + memoizing lookupSimpM insertSimpM + (function + | Num a -> snumM a + | NaN _ as x -> makeNode x + | Plus a -> + mapM algsimpM a >>= splusM + | Times (a, b) -> + (algsimpM a >>= fun a' -> + algsimpM b >>= fun b' -> + stimesM (a', b')) + | CTimes (a, b) -> + (algsimpM a >>= fun a' -> + algsimpM b >>= fun b' -> + sctimesM (a', b')) + | CTimesJ (a, b) -> + (algsimpM a >>= fun a' -> + algsimpM b >>= fun b' -> + sctimesjM (a', b')) + | Uminus a -> + algsimpM a >>= suminusM + | Store (v, a) -> + algsimpM a >>= fun a' -> + makeNode (Store (v, a')) + | Load _ as x -> makeNode x) + x + + let initialTable = (empty, empty) + let simp_roots = mapM algsimpM + let algsimp = runM initialTable simp_roots +end + +(************************************************************* + * Network transposition algorithm + *************************************************************) +module Transpose = struct + open Monads.StateMonad + open Monads.MemoMonad + open Littlesimp + + let fetchDuals = fetchState + let storeDuals = storeState + + let lookupDualsM key = + fetchDuals >>= fun table -> + returnM (node_lookup key table) + + let insertDualsM key value = + fetchDuals >>= fun table -> + storeDuals (node_insert key value table) + + let rec visit visited vtable parent_table = function + [] -> (visited, parent_table) + | node :: rest -> + match node_lookup node vtable with + | Some _ -> visit visited vtable parent_table rest + | None -> + let children = match node with + | Store (v, n) -> [n] + | Plus l -> l + | Times (a, b) -> [a; b] + | CTimes (a, b) -> [a; b] + | CTimesJ (a, b) -> [a; b] + | Uminus x -> [x] + | _ -> [] + in let rec loop t = function + [] -> t + | a :: rest -> + (match node_lookup a t with + None -> loop (node_insert a [node] t) rest + | Some c -> loop (node_insert a (node :: c) t) rest) + in + (visit + (node :: visited) + (node_insert node () vtable) + (loop parent_table children) + (children @ rest)) + + let make_transposer parent_table = + let rec termM node candidate_parent = + match candidate_parent with + | Store (_, n) when n == node -> + dualM candidate_parent >>= fun x' -> returnM [x'] + | Plus (l) when List.memq node l -> + dualM candidate_parent >>= fun x' -> returnM [x'] + | Times (a, b) when b == node -> + dualM candidate_parent >>= fun x' -> + returnM [makeTimes (a, x')] + | CTimes (a, b) when b == node -> + dualM candidate_parent >>= fun x' -> + returnM [CTimes (a, x')] + | CTimesJ (a, b) when b == node -> + dualM candidate_parent >>= fun x' -> + returnM [CTimesJ (a, x')] + | Uminus n when n == node -> + dualM candidate_parent >>= fun x' -> + returnM [makeUminus x'] + | _ -> returnM [] + + and dualExpressionM this_node = + mapM (termM this_node) + (match node_lookup this_node parent_table with + | Some a -> a + | None -> failwith "bug in dualExpressionM" + ) >>= fun l -> + returnM (makePlus (List.flatten l)) + + and dualM this_node = + memoizing lookupDualsM insertDualsM + (function + | Load v as x -> + if (Variable.is_constant v) then + returnM (Load v) + else + (dualExpressionM x >>= fun d -> + returnM (Store (v, d))) + | Store (v, x) -> returnM (Load v) + | x -> dualExpressionM x) + this_node + + in dualM + + let is_store = function + | Store _ -> true + | _ -> false + + let transpose dag = + let _ = Util.info "begin transpose" in + let (all_nodes, parent_table) = + visit [] Assoctable.empty Assoctable.empty dag in + let transposerM = make_transposer parent_table in + let mapTransposerM = mapM transposerM in + let duals = runM Assoctable.empty mapTransposerM all_nodes in + let roots = List.filter is_store duals in + let _ = Util.info "end transpose" in + roots +end + + +(************************************************************* + * Various dag statistics + *************************************************************) +module Stats : sig + type complexity + val complexity : Expr.expr list -> complexity + val same_complexity : complexity -> complexity -> bool + val leq_complexity : complexity -> complexity -> bool + val to_string : complexity -> string +end = struct + type complexity = int * int * int * int * int * int + let rec visit visited vtable = function + [] -> visited + | node :: rest -> + match node_lookup node vtable with + Some _ -> visit visited vtable rest + | None -> + let children = match node with + Store (v, n) -> [n] + | Plus l -> l + | Times (a, b) -> [a; b] + | Uminus x -> [x] + | _ -> [] + in visit (node :: visited) + (node_insert node () vtable) + (children @ rest) + + let complexity dag = + let rec loop (load, store, plus, times, uminus, num) = function + [] -> (load, store, plus, times, uminus, num) + | node :: rest -> + loop + (match node with + | Load _ -> (load + 1, store, plus, times, uminus, num) + | Store _ -> (load, store + 1, plus, times, uminus, num) + | Plus x -> (load, store, plus + (List.length x - 1), times, uminus, num) + | Times _ -> (load, store, plus, times + 1, uminus, num) + | Uminus _ -> (load, store, plus, times, uminus + 1, num) + | Num _ -> (load, store, plus, times, uminus, num + 1) + | CTimes _ -> (load, store, plus, times, uminus, num) + | CTimesJ _ -> (load, store, plus, times, uminus, num) + | NaN _ -> (load, store, plus, times, uminus, num)) + rest + in let (l, s, p, t, u, n) = + loop (0, 0, 0, 0, 0, 0) (visit [] Assoctable.empty dag) + in (l, s, p, t, u, n) + + let weight (l, s, p, t, u, n) = + l + s + 10 * p + 20 * t + u + n + + let same_complexity a b = weight a = weight b + let leq_complexity a b = weight a <= weight b + + let to_string (l, s, p, t, u, n) = + Printf.sprintf "ld=%d st=%d add=%d mul=%d uminus=%d num=%d\n" + l s p t u n + +end + +(* simplify the dag *) +let algsimp v = + let rec simplification_loop v = + let () = Util.info "simplification step" in + let complexity = Stats.complexity v in + let () = Util.info ("complexity = " ^ (Stats.to_string complexity)) in + let v = (AlgSimp.algsimp @@ Transpose.transpose @@ + AlgSimp.algsimp @@ Transpose.transpose) v in + let complexity' = Stats.complexity v in + let () = Util.info ("complexity = " ^ (Stats.to_string complexity')) in + if (Stats.leq_complexity complexity' complexity) then + let () = Util.info "end algsimp" in + v + else + simplification_loop v + + in + let () = Util.info "begin algsimp" in + let v = AlgSimp.algsimp v in + if !Magic.network_transposition then simplification_loop v else v +