cannam@127: (* cannam@127: * Copyright (c) 1997-1999 Massachusetts Institute of Technology cannam@127: * Copyright (c) 2003, 2007-14 Matteo Frigo cannam@127: * Copyright (c) 2003, 2007-14 Massachusetts Institute of Technology cannam@127: * cannam@127: * This program is free software; you can redistribute it and/or modify cannam@127: * it under the terms of the GNU General Public License as published by cannam@127: * the Free Software Foundation; either version 2 of the License, or cannam@127: * (at your option) any later version. cannam@127: * cannam@127: * This program is distributed in the hope that it will be useful, cannam@127: * but WITHOUT ANY WARRANTY; without even the implied warranty of cannam@127: * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the cannam@127: * GNU General Public License for more details. cannam@127: * cannam@127: * You should have received a copy of the GNU General Public License cannam@127: * along with this program; if not, write to the Free Software cannam@127: * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA cannam@127: * cannam@127: *) cannam@127: cannam@127: (************************************************************* cannam@127: * Conversion of the dag to an assignment list cannam@127: *************************************************************) cannam@127: (* cannam@127: * This function is messy. The main problem is that we want to cannam@127: * inline dag nodes conditionally, depending on how many times they cannam@127: * are used. The Right Thing to do would be to modify the cannam@127: * state monad to propagate some of the state backwards, so that cannam@127: * we know whether a given node will be used again in the future. cannam@127: * This modification is trivial in a lazy language, but it is cannam@127: * messy in a strict language like ML. cannam@127: * cannam@127: * In this implementation, we just do the obvious thing, i.e., visit cannam@127: * the dag twice, the first to count the node usages, and the second to cannam@127: * produce the output. cannam@127: *) cannam@127: cannam@127: open Monads.StateMonad cannam@127: open Monads.MemoMonad cannam@127: open Expr cannam@127: cannam@127: let fresh = Variable.make_temporary cannam@127: let node_insert x = Assoctable.insert Expr.hash x cannam@127: let node_lookup x = Assoctable.lookup Expr.hash (==) x cannam@127: let empty = Assoctable.empty cannam@127: cannam@127: let fetchAl = cannam@127: fetchState >>= (fun (al, _, _) -> returnM al) cannam@127: cannam@127: let storeAl al = cannam@127: fetchState >>= (fun (_, visited, visited') -> cannam@127: storeState (al, visited, visited')) cannam@127: cannam@127: let fetchVisited = fetchState >>= (fun (_, v, _) -> returnM v) cannam@127: cannam@127: let storeVisited visited = cannam@127: fetchState >>= (fun (al, _, visited') -> cannam@127: storeState (al, visited, visited')) cannam@127: cannam@127: let fetchVisited' = fetchState >>= (fun (_, _, v') -> returnM v') cannam@127: let storeVisited' visited' = cannam@127: fetchState >>= (fun (al, visited, _) -> cannam@127: storeState (al, visited, visited')) cannam@127: let lookupVisitedM' key = cannam@127: fetchVisited' >>= fun table -> cannam@127: returnM (node_lookup key table) cannam@127: let insertVisitedM' key value = cannam@127: fetchVisited' >>= fun table -> cannam@127: storeVisited' (node_insert key value table) cannam@127: cannam@127: let counting f x = cannam@127: fetchVisited >>= (fun v -> cannam@127: match node_lookup x v with cannam@127: Some count -> cannam@127: let incr_cnt = cannam@127: fetchVisited >>= (fun v' -> cannam@127: storeVisited (node_insert x (count + 1) v')) cannam@127: in cannam@127: begin cannam@127: match x with cannam@127: (* Uminus is always inlined. Visit child *) cannam@127: Uminus y -> f y >> incr_cnt cannam@127: | _ -> incr_cnt cannam@127: end cannam@127: | None -> cannam@127: f x >> fetchVisited >>= (fun v' -> cannam@127: storeVisited (node_insert x 1 v'))) cannam@127: cannam@127: let with_varM v x = cannam@127: fetchAl >>= (fun al -> storeAl ((v, x) :: al)) >> returnM (Load v) cannam@127: cannam@127: let inlineM = returnM cannam@127: cannam@127: let with_tempM x = match x with cannam@127: | Load v when Variable.is_temporary v -> inlineM x (* avoid trivial moves *) cannam@127: | _ -> with_varM (fresh ()) x cannam@127: cannam@127: (* declare a temporary only if node is used more than once *) cannam@127: let with_temp_maybeM node x = cannam@127: fetchVisited >>= (fun v -> cannam@127: match node_lookup node v with cannam@127: Some count -> cannam@127: if (count = 1 && !Magic.inline_single) then cannam@127: inlineM x cannam@127: else cannam@127: with_tempM x cannam@127: | None -> cannam@127: failwith "with_temp_maybeM") cannam@127: type fma = cannam@127: NO_FMA cannam@127: | FMA of expr * expr * expr (* FMA (a, b, c) => a + b * c *) cannam@127: | FMS of expr * expr * expr (* FMS (a, b, c) => -a + b * c *) cannam@127: | FNMS of expr * expr * expr (* FNMS (a, b, c) => a - b * c *) cannam@127: cannam@127: let good_for_fma (a, b) = cannam@127: let good = function cannam@127: | NaN I -> true cannam@127: | NaN CONJ -> true cannam@127: | NaN _ -> false cannam@127: | Times(NaN _, _) -> false cannam@127: | Times(_, NaN _) -> false cannam@127: | _ -> true cannam@127: in good a && good b cannam@127: cannam@127: let build_fma l = cannam@127: if (not !Magic.enable_fma) then NO_FMA cannam@127: else match l with cannam@127: | [a; Uminus (Times (b, c))] when good_for_fma (b, c) -> FNMS (a, b, c) cannam@127: | [Uminus (Times (b, c)); a] when good_for_fma (b, c) -> FNMS (a, b, c) cannam@127: | [Uminus a; Times (b, c)] when good_for_fma (b, c) -> FMS (a, b, c) cannam@127: | [Times (b, c); Uminus a] when good_for_fma (b, c) -> FMS (a, b, c) cannam@127: | [a; Times (b, c)] when good_for_fma (b, c) -> FMA (a, b, c) cannam@127: | [Times (b, c); a] when good_for_fma (b, c) -> FMA (a, b, c) cannam@127: | _ -> NO_FMA cannam@127: cannam@127: let children_fma l = match build_fma l with cannam@127: | FMA (a, b, c) -> Some (a, b, c) cannam@127: | FMS (a, b, c) -> Some (a, b, c) cannam@127: | FNMS (a, b, c) -> Some (a, b, c) cannam@127: | NO_FMA -> None cannam@127: cannam@127: cannam@127: let rec visitM x = cannam@127: counting (function cannam@127: | Load v -> returnM () cannam@127: | Num a -> returnM () cannam@127: | NaN a -> returnM () cannam@127: | Store (v, x) -> visitM x cannam@127: | Plus a -> (match children_fma a with cannam@127: None -> mapM visitM a >> returnM () cannam@127: | Some (a, b, c) -> cannam@127: (* visit fma's arguments twice to make sure they are not inlined *) cannam@127: visitM a >> visitM a >> cannam@127: visitM b >> visitM b >> cannam@127: visitM c >> visitM c) cannam@127: | Times (a, b) -> visitM a >> visitM b cannam@127: | CTimes (a, b) -> visitM a >> visitM b cannam@127: | CTimesJ (a, b) -> visitM a >> visitM b cannam@127: | Uminus a -> visitM a) cannam@127: x cannam@127: cannam@127: let visit_rootsM = mapM visitM cannam@127: cannam@127: cannam@127: let rec expr_of_nodeM x = cannam@127: memoizing lookupVisitedM' insertVisitedM' cannam@127: (function x -> match x with cannam@127: | Load v -> cannam@127: if (Variable.is_temporary v) then cannam@127: inlineM (Load v) cannam@127: else if (Variable.is_locative v && !Magic.inline_loads) then cannam@127: inlineM (Load v) cannam@127: else if (Variable.is_constant v && !Magic.inline_loads_constants) then cannam@127: inlineM (Load v) cannam@127: else cannam@127: with_tempM (Load v) cannam@127: | Num a -> cannam@127: if !Magic.inline_constants then cannam@127: inlineM (Num a) cannam@127: else cannam@127: with_temp_maybeM x (Num a) cannam@127: | NaN a -> inlineM (NaN a) cannam@127: | Store (v, x) -> cannam@127: expr_of_nodeM x >>= cannam@127: (if !Magic.trivial_stores then with_tempM else inlineM) >>= cannam@127: with_varM v cannam@127: cannam@127: | Plus a -> cannam@127: begin cannam@127: match build_fma a with cannam@127: FMA (a, b, c) -> cannam@127: expr_of_nodeM a >>= fun a' -> cannam@127: expr_of_nodeM b >>= fun b' -> cannam@127: expr_of_nodeM c >>= fun c' -> cannam@127: with_temp_maybeM x (Plus [a'; Times (b', c')]) cannam@127: | FMS (a, b, c) -> cannam@127: expr_of_nodeM a >>= fun a' -> cannam@127: expr_of_nodeM b >>= fun b' -> cannam@127: expr_of_nodeM c >>= fun c' -> cannam@127: with_temp_maybeM x cannam@127: (Plus [Times (b', c'); Uminus a']) cannam@127: | FNMS (a, b, c) -> cannam@127: expr_of_nodeM a >>= fun a' -> cannam@127: expr_of_nodeM b >>= fun b' -> cannam@127: expr_of_nodeM c >>= fun c' -> cannam@127: with_temp_maybeM x cannam@127: (Plus [a'; Uminus (Times (b', c'))]) cannam@127: | NO_FMA -> cannam@127: mapM expr_of_nodeM a >>= fun a' -> cannam@127: with_temp_maybeM x (Plus a') cannam@127: end cannam@127: | CTimes (Load _ as a, b) when !Magic.generate_bytw -> cannam@127: expr_of_nodeM b >>= fun b' -> cannam@127: with_tempM (CTimes (a, b')) cannam@127: | CTimes (a, b) -> cannam@127: expr_of_nodeM a >>= fun a' -> cannam@127: expr_of_nodeM b >>= fun b' -> cannam@127: with_tempM (CTimes (a', b')) cannam@127: | CTimesJ (Load _ as a, b) when !Magic.generate_bytw -> cannam@127: expr_of_nodeM b >>= fun b' -> cannam@127: with_tempM (CTimesJ (a, b')) cannam@127: | CTimesJ (a, b) -> cannam@127: expr_of_nodeM a >>= fun a' -> cannam@127: expr_of_nodeM b >>= fun b' -> cannam@127: with_tempM (CTimesJ (a', b')) cannam@127: | Times (a, b) -> cannam@127: expr_of_nodeM a >>= fun a' -> cannam@127: expr_of_nodeM b >>= fun b' -> cannam@127: begin cannam@127: match a' with cannam@127: Num a'' when !Magic.strength_reduce_mul && Number.is_two a'' -> cannam@127: (inlineM b' >>= fun b'' -> cannam@127: with_temp_maybeM x (Plus [b''; b''])) cannam@127: | _ -> with_temp_maybeM x (Times (a', b')) cannam@127: end cannam@127: | Uminus a -> cannam@127: expr_of_nodeM a >>= fun a' -> cannam@127: inlineM (Uminus a')) cannam@127: x cannam@127: cannam@127: let expr_of_rootsM = mapM expr_of_nodeM cannam@127: cannam@127: let peek_alistM roots = cannam@127: visit_rootsM roots >> expr_of_rootsM roots >> fetchAl cannam@127: cannam@127: let wrap_assign (a, b) = Expr.Assign (a, b) cannam@127: cannam@127: let to_assignments dag = cannam@127: let () = Util.info "begin to_alist" in cannam@127: let al = List.rev (runM ([], empty, empty) peek_alistM dag) in cannam@127: let res = List.map wrap_assign al in cannam@127: let () = Util.info "end to_alist" in cannam@127: res cannam@127: cannam@127: cannam@127: (* dump alist in `dot' format *) cannam@127: let dump print alist = cannam@127: let vs v = "\"" ^ (Variable.unparse v) ^ "\"" in cannam@127: begin cannam@127: print "digraph G {\n"; cannam@127: print "\tsize=\"6,6\";\n"; cannam@127: cannam@127: (* all input nodes have the same rank *) cannam@127: print "{ rank = same;\n"; cannam@127: List.iter (fun (Expr.Assign (v, x)) -> cannam@127: List.iter (fun y -> cannam@127: if (Variable.is_locative y) then print("\t" ^ (vs y) ^ ";\n")) cannam@127: (Expr.find_vars x)) cannam@127: alist; cannam@127: print "}\n"; cannam@127: cannam@127: (* all output nodes have the same rank *) cannam@127: print "{ rank = same;\n"; cannam@127: List.iter (fun (Expr.Assign (v, x)) -> cannam@127: if (Variable.is_locative v) then print("\t" ^ (vs v) ^ ";\n")) cannam@127: alist; cannam@127: print "}\n"; cannam@127: cannam@127: (* edges *) cannam@127: List.iter (fun (Expr.Assign (v, x)) -> cannam@127: List.iter (fun y -> print("\t" ^ (vs y) ^ " -> " ^ (vs v) ^ ";\n")) cannam@127: (Expr.find_vars x)) cannam@127: alist; cannam@127: cannam@127: print "}\n"; cannam@127: end cannam@127: