Mercurial > hg > sv-dependency-builds
diff src/fftw-3.3.3/genfft/to_alist.ml @ 10:37bf6b4a2645
Add FFTW3
author | Chris Cannam |
---|---|
date | Wed, 20 Mar 2013 15:35:50 +0000 |
parents | |
children |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/fftw-3.3.3/genfft/to_alist.ml Wed Mar 20 15:35:50 2013 +0000 @@ -0,0 +1,288 @@ +(* + * Copyright (c) 1997-1999 Massachusetts Institute of Technology + * Copyright (c) 2003, 2007-11 Matteo Frigo + * Copyright (c) 2003, 2007-11 Massachusetts Institute of Technology + * + * This program is free software; you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation; either version 2 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA + * + *) + +(************************************************************* + * Conversion of the dag to an assignment list + *************************************************************) +(* + * This function is messy. The main problem is that we want to + * inline dag nodes conditionally, depending on how many times they + * are used. The Right Thing to do would be to modify the + * state monad to propagate some of the state backwards, so that + * we know whether a given node will be used again in the future. + * This modification is trivial in a lazy language, but it is + * messy in a strict language like ML. + * + * In this implementation, we just do the obvious thing, i.e., visit + * the dag twice, the first to count the node usages, and the second to + * produce the output. + *) + +open Monads.StateMonad +open Monads.MemoMonad +open Expr + +let fresh = Variable.make_temporary +let node_insert x = Assoctable.insert Expr.hash x +let node_lookup x = Assoctable.lookup Expr.hash (==) x +let empty = Assoctable.empty + +let fetchAl = + fetchState >>= (fun (al, _, _) -> returnM al) + +let storeAl al = + fetchState >>= (fun (_, visited, visited') -> + storeState (al, visited, visited')) + +let fetchVisited = fetchState >>= (fun (_, v, _) -> returnM v) + +let storeVisited visited = + fetchState >>= (fun (al, _, visited') -> + storeState (al, visited, visited')) + +let fetchVisited' = fetchState >>= (fun (_, _, v') -> returnM v') +let storeVisited' visited' = + fetchState >>= (fun (al, visited, _) -> + storeState (al, visited, visited')) +let lookupVisitedM' key = + fetchVisited' >>= fun table -> + returnM (node_lookup key table) +let insertVisitedM' key value = + fetchVisited' >>= fun table -> + storeVisited' (node_insert key value table) + +let counting f x = + fetchVisited >>= (fun v -> + match node_lookup x v with + Some count -> + let incr_cnt = + fetchVisited >>= (fun v' -> + storeVisited (node_insert x (count + 1) v')) + in + begin + match x with + (* Uminus is always inlined. Visit child *) + Uminus y -> f y >> incr_cnt + | _ -> incr_cnt + end + | None -> + f x >> fetchVisited >>= (fun v' -> + storeVisited (node_insert x 1 v'))) + +let with_varM v x = + fetchAl >>= (fun al -> storeAl ((v, x) :: al)) >> returnM (Load v) + +let inlineM = returnM + +let with_tempM x = match x with +| Load v when Variable.is_temporary v -> inlineM x (* avoid trivial moves *) +| _ -> with_varM (fresh ()) x + +(* declare a temporary only if node is used more than once *) +let with_temp_maybeM node x = + fetchVisited >>= (fun v -> + match node_lookup node v with + Some count -> + if (count = 1 && !Magic.inline_single) then + inlineM x + else + with_tempM x + | None -> + failwith "with_temp_maybeM") +type fma = + NO_FMA + | FMA of expr * expr * expr (* FMA (a, b, c) => a + b * c *) + | FMS of expr * expr * expr (* FMS (a, b, c) => -a + b * c *) + | FNMS of expr * expr * expr (* FNMS (a, b, c) => a - b * c *) + +let good_for_fma (a, b) = + let good = function + | NaN I -> true + | NaN CONJ -> true + | NaN _ -> false + | Times(NaN _, _) -> false + | Times(_, NaN _) -> false + | _ -> true + in good a && good b + +let build_fma l = + if (not !Magic.enable_fma) then NO_FMA + else match l with + | [a; Uminus (Times (b, c))] when good_for_fma (b, c) -> FNMS (a, b, c) + | [Uminus (Times (b, c)); a] when good_for_fma (b, c) -> FNMS (a, b, c) + | [Uminus a; Times (b, c)] when good_for_fma (b, c) -> FMS (a, b, c) + | [Times (b, c); Uminus a] when good_for_fma (b, c) -> FMS (a, b, c) + | [a; Times (b, c)] when good_for_fma (b, c) -> FMA (a, b, c) + | [Times (b, c); a] when good_for_fma (b, c) -> FMA (a, b, c) + | _ -> NO_FMA + +let children_fma l = match build_fma l with +| FMA (a, b, c) -> Some (a, b, c) +| FMS (a, b, c) -> Some (a, b, c) +| FNMS (a, b, c) -> Some (a, b, c) +| NO_FMA -> None + + +let rec visitM x = + counting (function + | Load v -> returnM () + | Num a -> returnM () + | NaN a -> returnM () + | Store (v, x) -> visitM x + | Plus a -> (match children_fma a with + None -> mapM visitM a >> returnM () + | Some (a, b, c) -> + (* visit fma's arguments twice to make sure they are not inlined *) + visitM a >> visitM a >> + visitM b >> visitM b >> + visitM c >> visitM c) + | Times (a, b) -> visitM a >> visitM b + | CTimes (a, b) -> visitM a >> visitM b + | CTimesJ (a, b) -> visitM a >> visitM b + | Uminus a -> visitM a) + x + +let visit_rootsM = mapM visitM + + +let rec expr_of_nodeM x = + memoizing lookupVisitedM' insertVisitedM' + (function x -> match x with + | Load v -> + if (Variable.is_temporary v) then + inlineM (Load v) + else if (Variable.is_locative v && !Magic.inline_loads) then + inlineM (Load v) + else if (Variable.is_constant v && !Magic.inline_loads_constants) then + inlineM (Load v) + else + with_tempM (Load v) + | Num a -> + if !Magic.inline_constants then + inlineM (Num a) + else + with_temp_maybeM x (Num a) + | NaN a -> inlineM (NaN a) + | Store (v, x) -> + expr_of_nodeM x >>= + (if !Magic.trivial_stores then with_tempM else inlineM) >>= + with_varM v + + | Plus a -> + begin + match build_fma a with + FMA (a, b, c) -> + expr_of_nodeM a >>= fun a' -> + expr_of_nodeM b >>= fun b' -> + expr_of_nodeM c >>= fun c' -> + with_temp_maybeM x (Plus [a'; Times (b', c')]) + | FMS (a, b, c) -> + expr_of_nodeM a >>= fun a' -> + expr_of_nodeM b >>= fun b' -> + expr_of_nodeM c >>= fun c' -> + with_temp_maybeM x + (Plus [Times (b', c'); Uminus a']) + | FNMS (a, b, c) -> + expr_of_nodeM a >>= fun a' -> + expr_of_nodeM b >>= fun b' -> + expr_of_nodeM c >>= fun c' -> + with_temp_maybeM x + (Plus [a'; Uminus (Times (b', c'))]) + | NO_FMA -> + mapM expr_of_nodeM a >>= fun a' -> + with_temp_maybeM x (Plus a') + end + | CTimes (Load _ as a, b) when !Magic.generate_bytw -> + expr_of_nodeM b >>= fun b' -> + with_tempM (CTimes (a, b')) + | CTimes (a, b) -> + expr_of_nodeM a >>= fun a' -> + expr_of_nodeM b >>= fun b' -> + with_tempM (CTimes (a', b')) + | CTimesJ (Load _ as a, b) when !Magic.generate_bytw -> + expr_of_nodeM b >>= fun b' -> + with_tempM (CTimesJ (a, b')) + | CTimesJ (a, b) -> + expr_of_nodeM a >>= fun a' -> + expr_of_nodeM b >>= fun b' -> + with_tempM (CTimesJ (a', b')) + | Times (a, b) -> + expr_of_nodeM a >>= fun a' -> + expr_of_nodeM b >>= fun b' -> + begin + match a' with + Num a'' when !Magic.strength_reduce_mul && Number.is_two a'' -> + (inlineM b' >>= fun b'' -> + with_temp_maybeM x (Plus [b''; b''])) + | _ -> with_temp_maybeM x (Times (a', b')) + end + | Uminus a -> + expr_of_nodeM a >>= fun a' -> + inlineM (Uminus a')) + x + +let expr_of_rootsM = mapM expr_of_nodeM + +let peek_alistM roots = + visit_rootsM roots >> expr_of_rootsM roots >> fetchAl + +let wrap_assign (a, b) = Expr.Assign (a, b) + +let to_assignments dag = + let () = Util.info "begin to_alist" in + let al = List.rev (runM ([], empty, empty) peek_alistM dag) in + let res = List.map wrap_assign al in + let () = Util.info "end to_alist" in + res + + +(* dump alist in `dot' format *) +let dump print alist = + let vs v = "\"" ^ (Variable.unparse v) ^ "\"" in + begin + print "digraph G {\n"; + print "\tsize=\"6,6\";\n"; + + (* all input nodes have the same rank *) + print "{ rank = same;\n"; + List.iter (fun (Expr.Assign (v, x)) -> + List.iter (fun y -> + if (Variable.is_locative y) then print("\t" ^ (vs y) ^ ";\n")) + (Expr.find_vars x)) + alist; + print "}\n"; + + (* all output nodes have the same rank *) + print "{ rank = same;\n"; + List.iter (fun (Expr.Assign (v, x)) -> + if (Variable.is_locative v) then print("\t" ^ (vs v) ^ ";\n")) + alist; + print "}\n"; + + (* edges *) + List.iter (fun (Expr.Assign (v, x)) -> + List.iter (fun y -> print("\t" ^ (vs y) ^ " -> " ^ (vs v) ^ ";\n")) + (Expr.find_vars x)) + alist; + + print "}\n"; + end +