cannam@167: (* cannam@167: * Copyright (c) 1997-1999 Massachusetts Institute of Technology cannam@167: * Copyright (c) 2003, 2007-14 Matteo Frigo cannam@167: * Copyright (c) 2003, 2007-14 Massachusetts Institute of Technology cannam@167: * cannam@167: * This program is free software; you can redistribute it and/or modify cannam@167: * it under the terms of the GNU General Public License as published by cannam@167: * the Free Software Foundation; either version 2 of the License, or cannam@167: * (at your option) any later version. cannam@167: * cannam@167: * This program is distributed in the hope that it will be useful, cannam@167: * but WITHOUT ANY WARRANTY; without even the implied warranty of cannam@167: * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the cannam@167: * GNU General Public License for more details. cannam@167: * cannam@167: * You should have received a copy of the GNU General Public License cannam@167: * along with this program; if not, write to the Free Software cannam@167: * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA cannam@167: * cannam@167: *) cannam@167: cannam@167: cannam@167: open Util cannam@167: open Expr cannam@167: cannam@167: let node_insert x = Assoctable.insert Expr.hash x cannam@167: let node_lookup x = Assoctable.lookup Expr.hash (==) x cannam@167: cannam@167: (************************************************************* cannam@167: * Algebraic simplifier/elimination of common subexpressions cannam@167: *************************************************************) cannam@167: module AlgSimp : sig cannam@167: val algsimp : expr list -> expr list cannam@167: end = struct cannam@167: cannam@167: open Monads.StateMonad cannam@167: open Monads.MemoMonad cannam@167: open Assoctable cannam@167: cannam@167: let fetchSimp = cannam@167: fetchState >>= fun (s, _) -> returnM s cannam@167: let storeSimp s = cannam@167: fetchState >>= (fun (_, c) -> storeState (s, c)) cannam@167: let lookupSimpM key = cannam@167: fetchSimp >>= fun table -> cannam@167: returnM (node_lookup key table) cannam@167: let insertSimpM key value = cannam@167: fetchSimp >>= fun table -> cannam@167: storeSimp (node_insert key value table) cannam@167: cannam@167: let subset a b = cannam@167: List.for_all (fun x -> List.exists (fun y -> x == y) b) a cannam@167: cannam@167: let structurallyEqualCSE a b = cannam@167: match (a, b) with cannam@167: | (Num a, Num b) -> Number.equal a b cannam@167: | (NaN a, NaN b) -> a == b cannam@167: | (Load a, Load b) -> Variable.same a b cannam@167: | (Times (a, a'), Times (b, b')) -> cannam@167: ((a == b) && (a' == b')) || cannam@167: ((a == b') && (a' == b)) cannam@167: | (CTimes (a, a'), CTimes (b, b')) -> cannam@167: ((a == b) && (a' == b')) || cannam@167: ((a == b') && (a' == b)) cannam@167: | (CTimesJ (a, a'), CTimesJ (b, b')) -> ((a == b) && (a' == b')) cannam@167: | (Plus a, Plus b) -> subset a b && subset b a cannam@167: | (Uminus a, Uminus b) -> (a == b) cannam@167: | _ -> false cannam@167: cannam@167: let hashCSE x = cannam@167: if (!Magic.randomized_cse) then cannam@167: Oracle.hash x cannam@167: else cannam@167: Expr.hash x cannam@167: cannam@167: let equalCSE a b = cannam@167: if (!Magic.randomized_cse) then cannam@167: (structurallyEqualCSE a b || Oracle.likely_equal a b) cannam@167: else cannam@167: structurallyEqualCSE a b cannam@167: cannam@167: let fetchCSE = cannam@167: fetchState >>= fun (_, c) -> returnM c cannam@167: let storeCSE c = cannam@167: fetchState >>= (fun (s, _) -> storeState (s, c)) cannam@167: let lookupCSEM key = cannam@167: fetchCSE >>= fun table -> cannam@167: returnM (Assoctable.lookup hashCSE equalCSE key table) cannam@167: let insertCSEM key value = cannam@167: fetchCSE >>= fun table -> cannam@167: storeCSE (Assoctable.insert hashCSE key value table) cannam@167: cannam@167: (* memoize both x and Uminus x (unless x is already negated) *) cannam@167: let identityM x = cannam@167: let memo x = memoizing lookupCSEM insertCSEM returnM x in cannam@167: match x with cannam@167: Uminus _ -> memo x cannam@167: | _ -> memo x >>= fun x' -> memo (Uminus x') >> returnM x' cannam@167: cannam@167: let makeNode = identityM cannam@167: cannam@167: (* simplifiers for various kinds of nodes *) cannam@167: let rec snumM = function cannam@167: n when Number.is_zero n -> cannam@167: makeNode (Num (Number.zero)) cannam@167: | n when Number.negative n -> cannam@167: makeNode (Num (Number.negate n)) >>= suminusM cannam@167: | n -> makeNode (Num n) cannam@167: cannam@167: and suminusM = function cannam@167: Uminus x -> makeNode x cannam@167: | Num a when (Number.is_zero a) -> snumM Number.zero cannam@167: | a -> makeNode (Uminus a) cannam@167: cannam@167: and stimesM = function cannam@167: | (Uminus a, b) -> stimesM (a, b) >>= suminusM cannam@167: | (a, Uminus b) -> stimesM (a, b) >>= suminusM cannam@167: | (NaN I, CTimes (a, b)) -> stimesM (NaN I, b) >>= cannam@167: fun ib -> sctimesM (a, ib) cannam@167: | (NaN I, CTimesJ (a, b)) -> stimesM (NaN I, b) >>= cannam@167: fun ib -> sctimesjM (a, ib) cannam@167: | (Num a, Num b) -> snumM (Number.mul a b) cannam@167: | (Num a, Times (Num b, c)) -> cannam@167: snumM (Number.mul a b) >>= fun x -> stimesM (x, c) cannam@167: | (Num a, b) when Number.is_zero a -> snumM Number.zero cannam@167: | (Num a, b) when Number.is_one a -> makeNode b cannam@167: | (Num a, b) when Number.is_mone a -> suminusM b cannam@167: | (a, b) when is_known_constant b && not (is_known_constant a) -> cannam@167: stimesM (b, a) cannam@167: | (a, b) -> makeNode (Times (a, b)) cannam@167: cannam@167: and sctimesM = function cannam@167: | (Uminus a, b) -> sctimesM (a, b) >>= suminusM cannam@167: | (a, Uminus b) -> sctimesM (a, b) >>= suminusM cannam@167: | (a, b) -> makeNode (CTimes (a, b)) cannam@167: cannam@167: and sctimesjM = function cannam@167: | (Uminus a, b) -> sctimesjM (a, b) >>= suminusM cannam@167: | (a, Uminus b) -> sctimesjM (a, b) >>= suminusM cannam@167: | (a, b) -> makeNode (CTimesJ (a, b)) cannam@167: cannam@167: and reduce_sumM x = match x with cannam@167: [] -> returnM [] cannam@167: | [Num a] -> cannam@167: if (Number.is_zero a) then cannam@167: returnM [] cannam@167: else returnM x cannam@167: | [Uminus (Num a)] -> cannam@167: if (Number.is_zero a) then cannam@167: returnM [] cannam@167: else returnM x cannam@167: | (Num a) :: (Num b) :: s -> cannam@167: snumM (Number.add a b) >>= fun x -> cannam@167: reduce_sumM (x :: s) cannam@167: | (Num a) :: (Uminus (Num b)) :: s -> cannam@167: snumM (Number.sub a b) >>= fun x -> cannam@167: reduce_sumM (x :: s) cannam@167: | (Uminus (Num a)) :: (Num b) :: s -> cannam@167: snumM (Number.sub b a) >>= fun x -> cannam@167: reduce_sumM (x :: s) cannam@167: | (Uminus (Num a)) :: (Uminus (Num b)) :: s -> cannam@167: snumM (Number.add a b) >>= cannam@167: suminusM >>= fun x -> cannam@167: reduce_sumM (x :: s) cannam@167: | ((Num _) as a) :: b :: s -> reduce_sumM (b :: a :: s) cannam@167: | ((Uminus (Num _)) as a) :: b :: s -> reduce_sumM (b :: a :: s) cannam@167: | a :: s -> cannam@167: reduce_sumM s >>= fun s' -> returnM (a :: s') cannam@167: cannam@167: and collectible1 = function cannam@167: | NaN _ -> false cannam@167: | Uminus x -> collectible1 x cannam@167: | _ -> true cannam@167: and collectible (a, b) = collectible1 a cannam@167: cannam@167: (* collect common factors: ax + bx -> (a+b)x *) cannam@167: and collectM which x = cannam@167: let rec findCoeffM which = function cannam@167: | Times (a, b) when collectible (which (a, b)) -> returnM (which (a, b)) cannam@167: | Uminus x -> cannam@167: findCoeffM which x >>= fun (coeff, b) -> cannam@167: suminusM coeff >>= fun mcoeff -> cannam@167: returnM (mcoeff, b) cannam@167: | x -> snumM Number.one >>= fun one -> returnM (one, x) cannam@167: and separateM xpr = function cannam@167: [] -> returnM ([], []) cannam@167: | a :: b -> cannam@167: separateM xpr b >>= fun (w, wo) -> cannam@167: (* try first factor *) cannam@167: findCoeffM (fun (a, b) -> (a, b)) a >>= fun (c, x) -> cannam@167: if (xpr == x) && collectible (c, x) then returnM (c :: w, wo) cannam@167: else cannam@167: (* try second factor *) cannam@167: findCoeffM (fun (a, b) -> (b, a)) a >>= fun (c, x) -> cannam@167: if (xpr == x) && collectible (c, x) then returnM (c :: w, wo) cannam@167: else returnM (w, a :: wo) cannam@167: in match x with cannam@167: [] -> returnM x cannam@167: | [a] -> returnM x cannam@167: | a :: b -> cannam@167: findCoeffM which a >>= fun (_, xpr) -> cannam@167: separateM xpr x >>= fun (w, wo) -> cannam@167: collectM which wo >>= fun wo' -> cannam@167: splusM w >>= fun w' -> cannam@167: stimesM (w', xpr) >>= fun t' -> cannam@167: returnM (t':: wo') cannam@167: cannam@167: and mangleSumM x = returnM x cannam@167: >>= reduce_sumM cannam@167: >>= collectM (fun (a, b) -> (a, b)) cannam@167: >>= collectM (fun (a, b) -> (b, a)) cannam@167: >>= reduce_sumM cannam@167: >>= deepCollectM !Magic.deep_collect_depth cannam@167: >>= reduce_sumM cannam@167: cannam@167: and reorder_uminus = function (* push all Uminuses to the end *) cannam@167: [] -> [] cannam@167: | ((Uminus _) as a' :: b) -> (reorder_uminus b) @ [a'] cannam@167: | (a :: b) -> a :: (reorder_uminus b) cannam@167: cannam@167: and canonicalizeM = function cannam@167: [] -> snumM Number.zero cannam@167: | [a] -> makeNode a (* one term *) cannam@167: | a -> generateFusedMultAddM (reorder_uminus a) cannam@167: cannam@167: and generateFusedMultAddM = cannam@167: let rec is_multiplication = function cannam@167: | Times (Num a, b) -> true cannam@167: | Uminus (Times (Num a, b)) -> true cannam@167: | _ -> false cannam@167: and separate = function cannam@167: [] -> ([], [], Number.zero) cannam@167: | (Times (Num a, b)) as this :: c -> cannam@167: let (x, y, max) = separate c in cannam@167: let newmax = if (Number.greater a max) then a else max in cannam@167: (this :: x, y, newmax) cannam@167: | (Uminus (Times (Num a, b))) as this :: c -> cannam@167: let (x, y, max) = separate c in cannam@167: let newmax = if (Number.greater a max) then a else max in cannam@167: (this :: x, y, newmax) cannam@167: | this :: c -> cannam@167: let (x, y, max) = separate c in cannam@167: (x, this :: y, max) cannam@167: in fun l -> cannam@167: if !Magic.enable_fma && count is_multiplication l >= 2 then cannam@167: let (w, wo, max) = separate l in cannam@167: snumM (Number.div Number.one max) >>= fun invmax' -> cannam@167: snumM max >>= fun max' -> cannam@167: mapM (fun x -> stimesM (invmax', x)) w >>= splusM >>= fun pw' -> cannam@167: stimesM (max', pw') >>= fun mw' -> cannam@167: splusM (wo @ [mw']) cannam@167: else cannam@167: makeNode (Plus l) cannam@167: cannam@167: cannam@167: and negative = function cannam@167: Uminus _ -> true cannam@167: | _ -> false cannam@167: cannam@167: (* cannam@167: * simplify patterns of the form cannam@167: * cannam@167: * ((c_1 * a + ...) + ...) + (c_2 * a + ...) cannam@167: * cannam@167: * The pattern includes arbitrary coefficients and minus signs. cannam@167: * A common case of this pattern is the butterfly cannam@167: * (a + b) + (a - b) cannam@167: * (a + b) - (a - b) cannam@167: *) cannam@167: (* this whole procedure needs much more thought *) cannam@167: and deepCollectM maxdepth l = cannam@167: let rec findTerms depth x = match x with cannam@167: | Uminus x -> findTerms depth x cannam@167: | Times (Num _, b) -> (findTerms (depth - 1) b) cannam@167: | Plus l when depth > 0 -> cannam@167: x :: List.flatten (List.map (findTerms (depth - 1)) l) cannam@167: | x -> [x] cannam@167: and duplicates = function cannam@167: [] -> [] cannam@167: | a :: b -> if List.memq a b then a :: duplicates b cannam@167: else duplicates b cannam@167: cannam@167: in let rec splitDuplicates depth d x = cannam@167: if (List.memq x d) then cannam@167: snumM (Number.zero) >>= fun zero -> cannam@167: returnM (zero, x) cannam@167: else match x with cannam@167: | Times (a, b) -> cannam@167: splitDuplicates (depth - 1) d a >>= fun (a', xa) -> cannam@167: splitDuplicates (depth - 1) d b >>= fun (b', xb) -> cannam@167: stimesM (a', b') >>= fun ab -> cannam@167: stimesM (a, xb) >>= fun xb' -> cannam@167: stimesM (xa, b) >>= fun xa' -> cannam@167: stimesM (xa, xb) >>= fun xab -> cannam@167: splusM [xa'; xb'; xab] >>= fun x -> cannam@167: returnM (ab, x) cannam@167: | Uminus a -> cannam@167: splitDuplicates depth d a >>= fun (x, y) -> cannam@167: suminusM x >>= fun ux -> cannam@167: suminusM y >>= fun uy -> cannam@167: returnM (ux, uy) cannam@167: | Plus l when depth > 0 -> cannam@167: mapM (splitDuplicates (depth - 1) d) l >>= fun ld -> cannam@167: let (l', d') = List.split ld in cannam@167: splusM l' >>= fun p -> cannam@167: splusM d' >>= fun d'' -> cannam@167: returnM (p, d'') cannam@167: | x -> cannam@167: snumM (Number.zero) >>= fun zero' -> cannam@167: returnM (x, zero') cannam@167: cannam@167: in let l' = List.flatten (List.map (findTerms maxdepth) l) cannam@167: in match duplicates l' with cannam@167: | [] -> returnM l cannam@167: | d -> cannam@167: mapM (splitDuplicates maxdepth d) l >>= fun ld -> cannam@167: let (l', d') = List.split ld in cannam@167: splusM l' >>= fun l'' -> cannam@167: let rec flattenPlusM = function cannam@167: | Plus l -> returnM l cannam@167: | Uminus x -> cannam@167: flattenPlusM x >>= mapM suminusM cannam@167: | x -> returnM [x] cannam@167: in cannam@167: mapM flattenPlusM d' >>= fun d'' -> cannam@167: splusM (List.flatten d'') >>= fun d''' -> cannam@167: mangleSumM [l''; d'''] cannam@167: cannam@167: and splusM l = cannam@167: let fma_heuristics x = cannam@167: if !Magic.enable_fma then cannam@167: match x with cannam@167: | [Uminus (Times _); Times _] -> Some false cannam@167: | [Times _; Uminus (Times _)] -> Some false cannam@167: | [Uminus (_); Times _] -> Some true cannam@167: | [Times _; Uminus (Plus _)] -> Some true cannam@167: | [_; Uminus (Times _)] -> Some false cannam@167: | [Uminus (Times _); _] -> Some false cannam@167: | _ -> None cannam@167: else cannam@167: None cannam@167: in cannam@167: mangleSumM l >>= fun l' -> cannam@167: (* no terms are negative. Don't do anything *) cannam@167: if not (List.exists negative l') then cannam@167: canonicalizeM l' cannam@167: (* all terms are negative. Negate them all and collect the minus sign *) cannam@167: else if List.for_all negative l' then cannam@167: mapM suminusM l' >>= splusM >>= suminusM cannam@167: else match fma_heuristics l' with cannam@167: | Some true -> mapM suminusM l' >>= splusM >>= suminusM cannam@167: | Some false -> canonicalizeM l' cannam@167: | None -> cannam@167: (* Ask the Oracle for the canonical form *) cannam@167: if (not !Magic.randomized_cse) && cannam@167: Oracle.should_flip_sign (Plus l') then cannam@167: mapM suminusM l' >>= splusM >>= suminusM cannam@167: else cannam@167: canonicalizeM l' cannam@167: cannam@167: (* monadic style algebraic simplifier for the dag *) cannam@167: let rec algsimpM x = cannam@167: memoizing lookupSimpM insertSimpM cannam@167: (function cannam@167: | Num a -> snumM a cannam@167: | NaN _ as x -> makeNode x cannam@167: | Plus a -> cannam@167: mapM algsimpM a >>= splusM cannam@167: | Times (a, b) -> cannam@167: (algsimpM a >>= fun a' -> cannam@167: algsimpM b >>= fun b' -> cannam@167: stimesM (a', b')) cannam@167: | CTimes (a, b) -> cannam@167: (algsimpM a >>= fun a' -> cannam@167: algsimpM b >>= fun b' -> cannam@167: sctimesM (a', b')) cannam@167: | CTimesJ (a, b) -> cannam@167: (algsimpM a >>= fun a' -> cannam@167: algsimpM b >>= fun b' -> cannam@167: sctimesjM (a', b')) cannam@167: | Uminus a -> cannam@167: algsimpM a >>= suminusM cannam@167: | Store (v, a) -> cannam@167: algsimpM a >>= fun a' -> cannam@167: makeNode (Store (v, a')) cannam@167: | Load _ as x -> makeNode x) cannam@167: x cannam@167: cannam@167: let initialTable = (empty, empty) cannam@167: let simp_roots = mapM algsimpM cannam@167: let algsimp = runM initialTable simp_roots cannam@167: end cannam@167: cannam@167: (************************************************************* cannam@167: * Network transposition algorithm cannam@167: *************************************************************) cannam@167: module Transpose = struct cannam@167: open Monads.StateMonad cannam@167: open Monads.MemoMonad cannam@167: open Littlesimp cannam@167: cannam@167: let fetchDuals = fetchState cannam@167: let storeDuals = storeState cannam@167: cannam@167: let lookupDualsM key = cannam@167: fetchDuals >>= fun table -> cannam@167: returnM (node_lookup key table) cannam@167: cannam@167: let insertDualsM key value = cannam@167: fetchDuals >>= fun table -> cannam@167: storeDuals (node_insert key value table) cannam@167: cannam@167: let rec visit visited vtable parent_table = function cannam@167: [] -> (visited, parent_table) cannam@167: | node :: rest -> cannam@167: match node_lookup node vtable with cannam@167: | Some _ -> visit visited vtable parent_table rest cannam@167: | None -> cannam@167: let children = match node with cannam@167: | Store (v, n) -> [n] cannam@167: | Plus l -> l cannam@167: | Times (a, b) -> [a; b] cannam@167: | CTimes (a, b) -> [a; b] cannam@167: | CTimesJ (a, b) -> [a; b] cannam@167: | Uminus x -> [x] cannam@167: | _ -> [] cannam@167: in let rec loop t = function cannam@167: [] -> t cannam@167: | a :: rest -> cannam@167: (match node_lookup a t with cannam@167: None -> loop (node_insert a [node] t) rest cannam@167: | Some c -> loop (node_insert a (node :: c) t) rest) cannam@167: in cannam@167: (visit cannam@167: (node :: visited) cannam@167: (node_insert node () vtable) cannam@167: (loop parent_table children) cannam@167: (children @ rest)) cannam@167: cannam@167: let make_transposer parent_table = cannam@167: let rec termM node candidate_parent = cannam@167: match candidate_parent with cannam@167: | Store (_, n) when n == node -> cannam@167: dualM candidate_parent >>= fun x' -> returnM [x'] cannam@167: | Plus (l) when List.memq node l -> cannam@167: dualM candidate_parent >>= fun x' -> returnM [x'] cannam@167: | Times (a, b) when b == node -> cannam@167: dualM candidate_parent >>= fun x' -> cannam@167: returnM [makeTimes (a, x')] cannam@167: | CTimes (a, b) when b == node -> cannam@167: dualM candidate_parent >>= fun x' -> cannam@167: returnM [CTimes (a, x')] cannam@167: | CTimesJ (a, b) when b == node -> cannam@167: dualM candidate_parent >>= fun x' -> cannam@167: returnM [CTimesJ (a, x')] cannam@167: | Uminus n when n == node -> cannam@167: dualM candidate_parent >>= fun x' -> cannam@167: returnM [makeUminus x'] cannam@167: | _ -> returnM [] cannam@167: cannam@167: and dualExpressionM this_node = cannam@167: mapM (termM this_node) cannam@167: (match node_lookup this_node parent_table with cannam@167: | Some a -> a cannam@167: | None -> failwith "bug in dualExpressionM" cannam@167: ) >>= fun l -> cannam@167: returnM (makePlus (List.flatten l)) cannam@167: cannam@167: and dualM this_node = cannam@167: memoizing lookupDualsM insertDualsM cannam@167: (function cannam@167: | Load v as x -> cannam@167: if (Variable.is_constant v) then cannam@167: returnM (Load v) cannam@167: else cannam@167: (dualExpressionM x >>= fun d -> cannam@167: returnM (Store (v, d))) cannam@167: | Store (v, x) -> returnM (Load v) cannam@167: | x -> dualExpressionM x) cannam@167: this_node cannam@167: cannam@167: in dualM cannam@167: cannam@167: let is_store = function cannam@167: | Store _ -> true cannam@167: | _ -> false cannam@167: cannam@167: let transpose dag = cannam@167: let _ = Util.info "begin transpose" in cannam@167: let (all_nodes, parent_table) = cannam@167: visit [] Assoctable.empty Assoctable.empty dag in cannam@167: let transposerM = make_transposer parent_table in cannam@167: let mapTransposerM = mapM transposerM in cannam@167: let duals = runM Assoctable.empty mapTransposerM all_nodes in cannam@167: let roots = List.filter is_store duals in cannam@167: let _ = Util.info "end transpose" in cannam@167: roots cannam@167: end cannam@167: cannam@167: cannam@167: (************************************************************* cannam@167: * Various dag statistics cannam@167: *************************************************************) cannam@167: module Stats : sig cannam@167: type complexity cannam@167: val complexity : Expr.expr list -> complexity cannam@167: val same_complexity : complexity -> complexity -> bool cannam@167: val leq_complexity : complexity -> complexity -> bool cannam@167: val to_string : complexity -> string cannam@167: end = struct cannam@167: type complexity = int * int * int * int * int * int cannam@167: let rec visit visited vtable = function cannam@167: [] -> visited cannam@167: | node :: rest -> cannam@167: match node_lookup node vtable with cannam@167: Some _ -> visit visited vtable rest cannam@167: | None -> cannam@167: let children = match node with cannam@167: Store (v, n) -> [n] cannam@167: | Plus l -> l cannam@167: | Times (a, b) -> [a; b] cannam@167: | Uminus x -> [x] cannam@167: | _ -> [] cannam@167: in visit (node :: visited) cannam@167: (node_insert node () vtable) cannam@167: (children @ rest) cannam@167: cannam@167: let complexity dag = cannam@167: let rec loop (load, store, plus, times, uminus, num) = function cannam@167: [] -> (load, store, plus, times, uminus, num) cannam@167: | node :: rest -> cannam@167: loop cannam@167: (match node with cannam@167: | Load _ -> (load + 1, store, plus, times, uminus, num) cannam@167: | Store _ -> (load, store + 1, plus, times, uminus, num) cannam@167: | Plus x -> (load, store, plus + (List.length x - 1), times, uminus, num) cannam@167: | Times _ -> (load, store, plus, times + 1, uminus, num) cannam@167: | Uminus _ -> (load, store, plus, times, uminus + 1, num) cannam@167: | Num _ -> (load, store, plus, times, uminus, num + 1) cannam@167: | CTimes _ -> (load, store, plus, times, uminus, num) cannam@167: | CTimesJ _ -> (load, store, plus, times, uminus, num) cannam@167: | NaN _ -> (load, store, plus, times, uminus, num)) cannam@167: rest cannam@167: in let (l, s, p, t, u, n) = cannam@167: loop (0, 0, 0, 0, 0, 0) (visit [] Assoctable.empty dag) cannam@167: in (l, s, p, t, u, n) cannam@167: cannam@167: let weight (l, s, p, t, u, n) = cannam@167: l + s + 10 * p + 20 * t + u + n cannam@167: cannam@167: let same_complexity a b = weight a = weight b cannam@167: let leq_complexity a b = weight a <= weight b cannam@167: cannam@167: let to_string (l, s, p, t, u, n) = cannam@167: Printf.sprintf "ld=%d st=%d add=%d mul=%d uminus=%d num=%d\n" cannam@167: l s p t u n cannam@167: cannam@167: end cannam@167: cannam@167: (* simplify the dag *) cannam@167: let algsimp v = cannam@167: let rec simplification_loop v = cannam@167: let () = Util.info "simplification step" in cannam@167: let complexity = Stats.complexity v in cannam@167: let () = Util.info ("complexity = " ^ (Stats.to_string complexity)) in cannam@167: let v = (AlgSimp.algsimp @@ Transpose.transpose @@ cannam@167: AlgSimp.algsimp @@ Transpose.transpose) v in cannam@167: let complexity' = Stats.complexity v in cannam@167: let () = Util.info ("complexity = " ^ (Stats.to_string complexity')) in cannam@167: if (Stats.leq_complexity complexity' complexity) then cannam@167: let () = Util.info "end algsimp" in cannam@167: v cannam@167: else cannam@167: simplification_loop v cannam@167: cannam@167: in cannam@167: let () = Util.info "begin algsimp" in cannam@167: let v = AlgSimp.algsimp v in cannam@167: if !Magic.network_transposition then simplification_loop v else v cannam@167: