Chris@10: (* Chris@10: * Copyright (c) 1997-1999 Massachusetts Institute of Technology Chris@10: * Copyright (c) 2003, 2007-11 Matteo Frigo Chris@10: * Copyright (c) 2003, 2007-11 Massachusetts Institute of Technology Chris@10: * Chris@10: * This program is free software; you can redistribute it and/or modify Chris@10: * it under the terms of the GNU General Public License as published by Chris@10: * the Free Software Foundation; either version 2 of the License, or Chris@10: * (at your option) any later version. Chris@10: * Chris@10: * This program is distributed in the hope that it will be useful, Chris@10: * but WITHOUT ANY WARRANTY; without even the implied warranty of Chris@10: * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the Chris@10: * GNU General Public License for more details. Chris@10: * Chris@10: * You should have received a copy of the GNU General Public License Chris@10: * along with this program; if not, write to the Free Software Chris@10: * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA Chris@10: * Chris@10: *) Chris@10: Chris@10: (************************************************************* Chris@10: * Conversion of the dag to an assignment list Chris@10: *************************************************************) Chris@10: (* Chris@10: * This function is messy. The main problem is that we want to Chris@10: * inline dag nodes conditionally, depending on how many times they Chris@10: * are used. The Right Thing to do would be to modify the Chris@10: * state monad to propagate some of the state backwards, so that Chris@10: * we know whether a given node will be used again in the future. Chris@10: * This modification is trivial in a lazy language, but it is Chris@10: * messy in a strict language like ML. Chris@10: * Chris@10: * In this implementation, we just do the obvious thing, i.e., visit Chris@10: * the dag twice, the first to count the node usages, and the second to Chris@10: * produce the output. Chris@10: *) Chris@10: Chris@10: open Monads.StateMonad Chris@10: open Monads.MemoMonad Chris@10: open Expr Chris@10: Chris@10: let fresh = Variable.make_temporary Chris@10: let node_insert x = Assoctable.insert Expr.hash x Chris@10: let node_lookup x = Assoctable.lookup Expr.hash (==) x Chris@10: let empty = Assoctable.empty Chris@10: Chris@10: let fetchAl = Chris@10: fetchState >>= (fun (al, _, _) -> returnM al) Chris@10: Chris@10: let storeAl al = Chris@10: fetchState >>= (fun (_, visited, visited') -> Chris@10: storeState (al, visited, visited')) Chris@10: Chris@10: let fetchVisited = fetchState >>= (fun (_, v, _) -> returnM v) Chris@10: Chris@10: let storeVisited visited = Chris@10: fetchState >>= (fun (al, _, visited') -> Chris@10: storeState (al, visited, visited')) Chris@10: Chris@10: let fetchVisited' = fetchState >>= (fun (_, _, v') -> returnM v') Chris@10: let storeVisited' visited' = Chris@10: fetchState >>= (fun (al, visited, _) -> Chris@10: storeState (al, visited, visited')) Chris@10: let lookupVisitedM' key = Chris@10: fetchVisited' >>= fun table -> Chris@10: returnM (node_lookup key table) Chris@10: let insertVisitedM' key value = Chris@10: fetchVisited' >>= fun table -> Chris@10: storeVisited' (node_insert key value table) Chris@10: Chris@10: let counting f x = Chris@10: fetchVisited >>= (fun v -> Chris@10: match node_lookup x v with Chris@10: Some count -> Chris@10: let incr_cnt = Chris@10: fetchVisited >>= (fun v' -> Chris@10: storeVisited (node_insert x (count + 1) v')) Chris@10: in Chris@10: begin Chris@10: match x with Chris@10: (* Uminus is always inlined. Visit child *) Chris@10: Uminus y -> f y >> incr_cnt Chris@10: | _ -> incr_cnt Chris@10: end Chris@10: | None -> Chris@10: f x >> fetchVisited >>= (fun v' -> Chris@10: storeVisited (node_insert x 1 v'))) Chris@10: Chris@10: let with_varM v x = Chris@10: fetchAl >>= (fun al -> storeAl ((v, x) :: al)) >> returnM (Load v) Chris@10: Chris@10: let inlineM = returnM Chris@10: Chris@10: let with_tempM x = match x with Chris@10: | Load v when Variable.is_temporary v -> inlineM x (* avoid trivial moves *) Chris@10: | _ -> with_varM (fresh ()) x Chris@10: Chris@10: (* declare a temporary only if node is used more than once *) Chris@10: let with_temp_maybeM node x = Chris@10: fetchVisited >>= (fun v -> Chris@10: match node_lookup node v with Chris@10: Some count -> Chris@10: if (count = 1 && !Magic.inline_single) then Chris@10: inlineM x Chris@10: else Chris@10: with_tempM x Chris@10: | None -> Chris@10: failwith "with_temp_maybeM") Chris@10: type fma = Chris@10: NO_FMA Chris@10: | FMA of expr * expr * expr (* FMA (a, b, c) => a + b * c *) Chris@10: | FMS of expr * expr * expr (* FMS (a, b, c) => -a + b * c *) Chris@10: | FNMS of expr * expr * expr (* FNMS (a, b, c) => a - b * c *) Chris@10: Chris@10: let good_for_fma (a, b) = Chris@10: let good = function Chris@10: | NaN I -> true Chris@10: | NaN CONJ -> true Chris@10: | NaN _ -> false Chris@10: | Times(NaN _, _) -> false Chris@10: | Times(_, NaN _) -> false Chris@10: | _ -> true Chris@10: in good a && good b Chris@10: Chris@10: let build_fma l = Chris@10: if (not !Magic.enable_fma) then NO_FMA Chris@10: else match l with Chris@10: | [a; Uminus (Times (b, c))] when good_for_fma (b, c) -> FNMS (a, b, c) Chris@10: | [Uminus (Times (b, c)); a] when good_for_fma (b, c) -> FNMS (a, b, c) Chris@10: | [Uminus a; Times (b, c)] when good_for_fma (b, c) -> FMS (a, b, c) Chris@10: | [Times (b, c); Uminus a] when good_for_fma (b, c) -> FMS (a, b, c) Chris@10: | [a; Times (b, c)] when good_for_fma (b, c) -> FMA (a, b, c) Chris@10: | [Times (b, c); a] when good_for_fma (b, c) -> FMA (a, b, c) Chris@10: | _ -> NO_FMA Chris@10: Chris@10: let children_fma l = match build_fma l with Chris@10: | FMA (a, b, c) -> Some (a, b, c) Chris@10: | FMS (a, b, c) -> Some (a, b, c) Chris@10: | FNMS (a, b, c) -> Some (a, b, c) Chris@10: | NO_FMA -> None Chris@10: Chris@10: Chris@10: let rec visitM x = Chris@10: counting (function Chris@10: | Load v -> returnM () Chris@10: | Num a -> returnM () Chris@10: | NaN a -> returnM () Chris@10: | Store (v, x) -> visitM x Chris@10: | Plus a -> (match children_fma a with Chris@10: None -> mapM visitM a >> returnM () Chris@10: | Some (a, b, c) -> Chris@10: (* visit fma's arguments twice to make sure they are not inlined *) Chris@10: visitM a >> visitM a >> Chris@10: visitM b >> visitM b >> Chris@10: visitM c >> visitM c) Chris@10: | Times (a, b) -> visitM a >> visitM b Chris@10: | CTimes (a, b) -> visitM a >> visitM b Chris@10: | CTimesJ (a, b) -> visitM a >> visitM b Chris@10: | Uminus a -> visitM a) Chris@10: x Chris@10: Chris@10: let visit_rootsM = mapM visitM Chris@10: Chris@10: Chris@10: let rec expr_of_nodeM x = Chris@10: memoizing lookupVisitedM' insertVisitedM' Chris@10: (function x -> match x with Chris@10: | Load v -> Chris@10: if (Variable.is_temporary v) then Chris@10: inlineM (Load v) Chris@10: else if (Variable.is_locative v && !Magic.inline_loads) then Chris@10: inlineM (Load v) Chris@10: else if (Variable.is_constant v && !Magic.inline_loads_constants) then Chris@10: inlineM (Load v) Chris@10: else Chris@10: with_tempM (Load v) Chris@10: | Num a -> Chris@10: if !Magic.inline_constants then Chris@10: inlineM (Num a) Chris@10: else Chris@10: with_temp_maybeM x (Num a) Chris@10: | NaN a -> inlineM (NaN a) Chris@10: | Store (v, x) -> Chris@10: expr_of_nodeM x >>= Chris@10: (if !Magic.trivial_stores then with_tempM else inlineM) >>= Chris@10: with_varM v Chris@10: Chris@10: | Plus a -> Chris@10: begin Chris@10: match build_fma a with Chris@10: FMA (a, b, c) -> Chris@10: expr_of_nodeM a >>= fun a' -> Chris@10: expr_of_nodeM b >>= fun b' -> Chris@10: expr_of_nodeM c >>= fun c' -> Chris@10: with_temp_maybeM x (Plus [a'; Times (b', c')]) Chris@10: | FMS (a, b, c) -> Chris@10: expr_of_nodeM a >>= fun a' -> Chris@10: expr_of_nodeM b >>= fun b' -> Chris@10: expr_of_nodeM c >>= fun c' -> Chris@10: with_temp_maybeM x Chris@10: (Plus [Times (b', c'); Uminus a']) Chris@10: | FNMS (a, b, c) -> Chris@10: expr_of_nodeM a >>= fun a' -> Chris@10: expr_of_nodeM b >>= fun b' -> Chris@10: expr_of_nodeM c >>= fun c' -> Chris@10: with_temp_maybeM x Chris@10: (Plus [a'; Uminus (Times (b', c'))]) Chris@10: | NO_FMA -> Chris@10: mapM expr_of_nodeM a >>= fun a' -> Chris@10: with_temp_maybeM x (Plus a') Chris@10: end Chris@10: | CTimes (Load _ as a, b) when !Magic.generate_bytw -> Chris@10: expr_of_nodeM b >>= fun b' -> Chris@10: with_tempM (CTimes (a, b')) Chris@10: | CTimes (a, b) -> Chris@10: expr_of_nodeM a >>= fun a' -> Chris@10: expr_of_nodeM b >>= fun b' -> Chris@10: with_tempM (CTimes (a', b')) Chris@10: | CTimesJ (Load _ as a, b) when !Magic.generate_bytw -> Chris@10: expr_of_nodeM b >>= fun b' -> Chris@10: with_tempM (CTimesJ (a, b')) Chris@10: | CTimesJ (a, b) -> Chris@10: expr_of_nodeM a >>= fun a' -> Chris@10: expr_of_nodeM b >>= fun b' -> Chris@10: with_tempM (CTimesJ (a', b')) Chris@10: | Times (a, b) -> Chris@10: expr_of_nodeM a >>= fun a' -> Chris@10: expr_of_nodeM b >>= fun b' -> Chris@10: begin Chris@10: match a' with Chris@10: Num a'' when !Magic.strength_reduce_mul && Number.is_two a'' -> Chris@10: (inlineM b' >>= fun b'' -> Chris@10: with_temp_maybeM x (Plus [b''; b''])) Chris@10: | _ -> with_temp_maybeM x (Times (a', b')) Chris@10: end Chris@10: | Uminus a -> Chris@10: expr_of_nodeM a >>= fun a' -> Chris@10: inlineM (Uminus a')) Chris@10: x Chris@10: Chris@10: let expr_of_rootsM = mapM expr_of_nodeM Chris@10: Chris@10: let peek_alistM roots = Chris@10: visit_rootsM roots >> expr_of_rootsM roots >> fetchAl Chris@10: Chris@10: let wrap_assign (a, b) = Expr.Assign (a, b) Chris@10: Chris@10: let to_assignments dag = Chris@10: let () = Util.info "begin to_alist" in Chris@10: let al = List.rev (runM ([], empty, empty) peek_alistM dag) in Chris@10: let res = List.map wrap_assign al in Chris@10: let () = Util.info "end to_alist" in Chris@10: res Chris@10: Chris@10: Chris@10: (* dump alist in `dot' format *) Chris@10: let dump print alist = Chris@10: let vs v = "\"" ^ (Variable.unparse v) ^ "\"" in Chris@10: begin Chris@10: print "digraph G {\n"; Chris@10: print "\tsize=\"6,6\";\n"; Chris@10: Chris@10: (* all input nodes have the same rank *) Chris@10: print "{ rank = same;\n"; Chris@10: List.iter (fun (Expr.Assign (v, x)) -> Chris@10: List.iter (fun y -> Chris@10: if (Variable.is_locative y) then print("\t" ^ (vs y) ^ ";\n")) Chris@10: (Expr.find_vars x)) Chris@10: alist; Chris@10: print "}\n"; Chris@10: Chris@10: (* all output nodes have the same rank *) Chris@10: print "{ rank = same;\n"; Chris@10: List.iter (fun (Expr.Assign (v, x)) -> Chris@10: if (Variable.is_locative v) then print("\t" ^ (vs v) ^ ";\n")) Chris@10: alist; Chris@10: print "}\n"; Chris@10: Chris@10: (* edges *) Chris@10: List.iter (fun (Expr.Assign (v, x)) -> Chris@10: List.iter (fun y -> print("\t" ^ (vs y) ^ " -> " ^ (vs v) ^ ";\n")) Chris@10: (Expr.find_vars x)) Chris@10: alist; Chris@10: Chris@10: print "}\n"; Chris@10: end Chris@10: