annotate src/fftw-3.3.5/genfft/algsimp.ml @ 148:b4bfdf10c4b3

Update Win64 capnp builds to v0.6
author Chris Cannam <cannam@all-day-breakfast.com>
date Mon, 22 May 2017 18:56:49 +0100
parents 7867fa7e1b6b
children
rev   line source
cannam@127 1 (*
cannam@127 2 * Copyright (c) 1997-1999 Massachusetts Institute of Technology
cannam@127 3 * Copyright (c) 2003, 2007-14 Matteo Frigo
cannam@127 4 * Copyright (c) 2003, 2007-14 Massachusetts Institute of Technology
cannam@127 5 *
cannam@127 6 * This program is free software; you can redistribute it and/or modify
cannam@127 7 * it under the terms of the GNU General Public License as published by
cannam@127 8 * the Free Software Foundation; either version 2 of the License, or
cannam@127 9 * (at your option) any later version.
cannam@127 10 *
cannam@127 11 * This program is distributed in the hope that it will be useful,
cannam@127 12 * but WITHOUT ANY WARRANTY; without even the implied warranty of
cannam@127 13 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
cannam@127 14 * GNU General Public License for more details.
cannam@127 15 *
cannam@127 16 * You should have received a copy of the GNU General Public License
cannam@127 17 * along with this program; if not, write to the Free Software
cannam@127 18 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
cannam@127 19 *
cannam@127 20 *)
cannam@127 21
cannam@127 22
cannam@127 23 open Util
cannam@127 24 open Expr
cannam@127 25
cannam@127 26 let node_insert x = Assoctable.insert Expr.hash x
cannam@127 27 let node_lookup x = Assoctable.lookup Expr.hash (==) x
cannam@127 28
cannam@127 29 (*************************************************************
cannam@127 30 * Algebraic simplifier/elimination of common subexpressions
cannam@127 31 *************************************************************)
cannam@127 32 module AlgSimp : sig
cannam@127 33 val algsimp : expr list -> expr list
cannam@127 34 end = struct
cannam@127 35
cannam@127 36 open Monads.StateMonad
cannam@127 37 open Monads.MemoMonad
cannam@127 38 open Assoctable
cannam@127 39
cannam@127 40 let fetchSimp =
cannam@127 41 fetchState >>= fun (s, _) -> returnM s
cannam@127 42 let storeSimp s =
cannam@127 43 fetchState >>= (fun (_, c) -> storeState (s, c))
cannam@127 44 let lookupSimpM key =
cannam@127 45 fetchSimp >>= fun table ->
cannam@127 46 returnM (node_lookup key table)
cannam@127 47 let insertSimpM key value =
cannam@127 48 fetchSimp >>= fun table ->
cannam@127 49 storeSimp (node_insert key value table)
cannam@127 50
cannam@127 51 let subset a b =
cannam@127 52 List.for_all (fun x -> List.exists (fun y -> x == y) b) a
cannam@127 53
cannam@127 54 let structurallyEqualCSE a b =
cannam@127 55 match (a, b) with
cannam@127 56 | (Num a, Num b) -> Number.equal a b
cannam@127 57 | (NaN a, NaN b) -> a == b
cannam@127 58 | (Load a, Load b) -> Variable.same a b
cannam@127 59 | (Times (a, a'), Times (b, b')) ->
cannam@127 60 ((a == b) && (a' == b')) ||
cannam@127 61 ((a == b') && (a' == b))
cannam@127 62 | (CTimes (a, a'), CTimes (b, b')) ->
cannam@127 63 ((a == b) && (a' == b')) ||
cannam@127 64 ((a == b') && (a' == b))
cannam@127 65 | (CTimesJ (a, a'), CTimesJ (b, b')) -> ((a == b) && (a' == b'))
cannam@127 66 | (Plus a, Plus b) -> subset a b && subset b a
cannam@127 67 | (Uminus a, Uminus b) -> (a == b)
cannam@127 68 | _ -> false
cannam@127 69
cannam@127 70 let hashCSE x =
cannam@127 71 if (!Magic.randomized_cse) then
cannam@127 72 Oracle.hash x
cannam@127 73 else
cannam@127 74 Expr.hash x
cannam@127 75
cannam@127 76 let equalCSE a b =
cannam@127 77 if (!Magic.randomized_cse) then
cannam@127 78 (structurallyEqualCSE a b || Oracle.likely_equal a b)
cannam@127 79 else
cannam@127 80 structurallyEqualCSE a b
cannam@127 81
cannam@127 82 let fetchCSE =
cannam@127 83 fetchState >>= fun (_, c) -> returnM c
cannam@127 84 let storeCSE c =
cannam@127 85 fetchState >>= (fun (s, _) -> storeState (s, c))
cannam@127 86 let lookupCSEM key =
cannam@127 87 fetchCSE >>= fun table ->
cannam@127 88 returnM (Assoctable.lookup hashCSE equalCSE key table)
cannam@127 89 let insertCSEM key value =
cannam@127 90 fetchCSE >>= fun table ->
cannam@127 91 storeCSE (Assoctable.insert hashCSE key value table)
cannam@127 92
cannam@127 93 (* memoize both x and Uminus x (unless x is already negated) *)
cannam@127 94 let identityM x =
cannam@127 95 let memo x = memoizing lookupCSEM insertCSEM returnM x in
cannam@127 96 match x with
cannam@127 97 Uminus _ -> memo x
cannam@127 98 | _ -> memo x >>= fun x' -> memo (Uminus x') >> returnM x'
cannam@127 99
cannam@127 100 let makeNode = identityM
cannam@127 101
cannam@127 102 (* simplifiers for various kinds of nodes *)
cannam@127 103 let rec snumM = function
cannam@127 104 n when Number.is_zero n ->
cannam@127 105 makeNode (Num (Number.zero))
cannam@127 106 | n when Number.negative n ->
cannam@127 107 makeNode (Num (Number.negate n)) >>= suminusM
cannam@127 108 | n -> makeNode (Num n)
cannam@127 109
cannam@127 110 and suminusM = function
cannam@127 111 Uminus x -> makeNode x
cannam@127 112 | Num a when (Number.is_zero a) -> snumM Number.zero
cannam@127 113 | a -> makeNode (Uminus a)
cannam@127 114
cannam@127 115 and stimesM = function
cannam@127 116 | (Uminus a, b) -> stimesM (a, b) >>= suminusM
cannam@127 117 | (a, Uminus b) -> stimesM (a, b) >>= suminusM
cannam@127 118 | (NaN I, CTimes (a, b)) -> stimesM (NaN I, b) >>=
cannam@127 119 fun ib -> sctimesM (a, ib)
cannam@127 120 | (NaN I, CTimesJ (a, b)) -> stimesM (NaN I, b) >>=
cannam@127 121 fun ib -> sctimesjM (a, ib)
cannam@127 122 | (Num a, Num b) -> snumM (Number.mul a b)
cannam@127 123 | (Num a, Times (Num b, c)) ->
cannam@127 124 snumM (Number.mul a b) >>= fun x -> stimesM (x, c)
cannam@127 125 | (Num a, b) when Number.is_zero a -> snumM Number.zero
cannam@127 126 | (Num a, b) when Number.is_one a -> makeNode b
cannam@127 127 | (Num a, b) when Number.is_mone a -> suminusM b
cannam@127 128 | (a, b) when is_known_constant b && not (is_known_constant a) ->
cannam@127 129 stimesM (b, a)
cannam@127 130 | (a, b) -> makeNode (Times (a, b))
cannam@127 131
cannam@127 132 and sctimesM = function
cannam@127 133 | (Uminus a, b) -> sctimesM (a, b) >>= suminusM
cannam@127 134 | (a, Uminus b) -> sctimesM (a, b) >>= suminusM
cannam@127 135 | (a, b) -> makeNode (CTimes (a, b))
cannam@127 136
cannam@127 137 and sctimesjM = function
cannam@127 138 | (Uminus a, b) -> sctimesjM (a, b) >>= suminusM
cannam@127 139 | (a, Uminus b) -> sctimesjM (a, b) >>= suminusM
cannam@127 140 | (a, b) -> makeNode (CTimesJ (a, b))
cannam@127 141
cannam@127 142 and reduce_sumM x = match x with
cannam@127 143 [] -> returnM []
cannam@127 144 | [Num a] ->
cannam@127 145 if (Number.is_zero a) then
cannam@127 146 returnM []
cannam@127 147 else returnM x
cannam@127 148 | [Uminus (Num a)] ->
cannam@127 149 if (Number.is_zero a) then
cannam@127 150 returnM []
cannam@127 151 else returnM x
cannam@127 152 | (Num a) :: (Num b) :: s ->
cannam@127 153 snumM (Number.add a b) >>= fun x ->
cannam@127 154 reduce_sumM (x :: s)
cannam@127 155 | (Num a) :: (Uminus (Num b)) :: s ->
cannam@127 156 snumM (Number.sub a b) >>= fun x ->
cannam@127 157 reduce_sumM (x :: s)
cannam@127 158 | (Uminus (Num a)) :: (Num b) :: s ->
cannam@127 159 snumM (Number.sub b a) >>= fun x ->
cannam@127 160 reduce_sumM (x :: s)
cannam@127 161 | (Uminus (Num a)) :: (Uminus (Num b)) :: s ->
cannam@127 162 snumM (Number.add a b) >>=
cannam@127 163 suminusM >>= fun x ->
cannam@127 164 reduce_sumM (x :: s)
cannam@127 165 | ((Num _) as a) :: b :: s -> reduce_sumM (b :: a :: s)
cannam@127 166 | ((Uminus (Num _)) as a) :: b :: s -> reduce_sumM (b :: a :: s)
cannam@127 167 | a :: s ->
cannam@127 168 reduce_sumM s >>= fun s' -> returnM (a :: s')
cannam@127 169
cannam@127 170 and collectible1 = function
cannam@127 171 | NaN _ -> false
cannam@127 172 | Uminus x -> collectible1 x
cannam@127 173 | _ -> true
cannam@127 174 and collectible (a, b) = collectible1 a
cannam@127 175
cannam@127 176 (* collect common factors: ax + bx -> (a+b)x *)
cannam@127 177 and collectM which x =
cannam@127 178 let rec findCoeffM which = function
cannam@127 179 | Times (a, b) when collectible (which (a, b)) -> returnM (which (a, b))
cannam@127 180 | Uminus x ->
cannam@127 181 findCoeffM which x >>= fun (coeff, b) ->
cannam@127 182 suminusM coeff >>= fun mcoeff ->
cannam@127 183 returnM (mcoeff, b)
cannam@127 184 | x -> snumM Number.one >>= fun one -> returnM (one, x)
cannam@127 185 and separateM xpr = function
cannam@127 186 [] -> returnM ([], [])
cannam@127 187 | a :: b ->
cannam@127 188 separateM xpr b >>= fun (w, wo) ->
cannam@127 189 (* try first factor *)
cannam@127 190 findCoeffM (fun (a, b) -> (a, b)) a >>= fun (c, x) ->
cannam@127 191 if (xpr == x) && collectible (c, x) then returnM (c :: w, wo)
cannam@127 192 else
cannam@127 193 (* try second factor *)
cannam@127 194 findCoeffM (fun (a, b) -> (b, a)) a >>= fun (c, x) ->
cannam@127 195 if (xpr == x) && collectible (c, x) then returnM (c :: w, wo)
cannam@127 196 else returnM (w, a :: wo)
cannam@127 197 in match x with
cannam@127 198 [] -> returnM x
cannam@127 199 | [a] -> returnM x
cannam@127 200 | a :: b ->
cannam@127 201 findCoeffM which a >>= fun (_, xpr) ->
cannam@127 202 separateM xpr x >>= fun (w, wo) ->
cannam@127 203 collectM which wo >>= fun wo' ->
cannam@127 204 splusM w >>= fun w' ->
cannam@127 205 stimesM (w', xpr) >>= fun t' ->
cannam@127 206 returnM (t':: wo')
cannam@127 207
cannam@127 208 and mangleSumM x = returnM x
cannam@127 209 >>= reduce_sumM
cannam@127 210 >>= collectM (fun (a, b) -> (a, b))
cannam@127 211 >>= collectM (fun (a, b) -> (b, a))
cannam@127 212 >>= reduce_sumM
cannam@127 213 >>= deepCollectM !Magic.deep_collect_depth
cannam@127 214 >>= reduce_sumM
cannam@127 215
cannam@127 216 and reorder_uminus = function (* push all Uminuses to the end *)
cannam@127 217 [] -> []
cannam@127 218 | ((Uminus _) as a' :: b) -> (reorder_uminus b) @ [a']
cannam@127 219 | (a :: b) -> a :: (reorder_uminus b)
cannam@127 220
cannam@127 221 and canonicalizeM = function
cannam@127 222 [] -> snumM Number.zero
cannam@127 223 | [a] -> makeNode a (* one term *)
cannam@127 224 | a -> generateFusedMultAddM (reorder_uminus a)
cannam@127 225
cannam@127 226 and generateFusedMultAddM =
cannam@127 227 let rec is_multiplication = function
cannam@127 228 | Times (Num a, b) -> true
cannam@127 229 | Uminus (Times (Num a, b)) -> true
cannam@127 230 | _ -> false
cannam@127 231 and separate = function
cannam@127 232 [] -> ([], [], Number.zero)
cannam@127 233 | (Times (Num a, b)) as this :: c ->
cannam@127 234 let (x, y, max) = separate c in
cannam@127 235 let newmax = if (Number.greater a max) then a else max in
cannam@127 236 (this :: x, y, newmax)
cannam@127 237 | (Uminus (Times (Num a, b))) as this :: c ->
cannam@127 238 let (x, y, max) = separate c in
cannam@127 239 let newmax = if (Number.greater a max) then a else max in
cannam@127 240 (this :: x, y, newmax)
cannam@127 241 | this :: c ->
cannam@127 242 let (x, y, max) = separate c in
cannam@127 243 (x, this :: y, max)
cannam@127 244 in fun l ->
cannam@127 245 if !Magic.enable_fma && count is_multiplication l >= 2 then
cannam@127 246 let (w, wo, max) = separate l in
cannam@127 247 snumM (Number.div Number.one max) >>= fun invmax' ->
cannam@127 248 snumM max >>= fun max' ->
cannam@127 249 mapM (fun x -> stimesM (invmax', x)) w >>= splusM >>= fun pw' ->
cannam@127 250 stimesM (max', pw') >>= fun mw' ->
cannam@127 251 splusM (wo @ [mw'])
cannam@127 252 else
cannam@127 253 makeNode (Plus l)
cannam@127 254
cannam@127 255
cannam@127 256 and negative = function
cannam@127 257 Uminus _ -> true
cannam@127 258 | _ -> false
cannam@127 259
cannam@127 260 (*
cannam@127 261 * simplify patterns of the form
cannam@127 262 *
cannam@127 263 * ((c_1 * a + ...) + ...) + (c_2 * a + ...)
cannam@127 264 *
cannam@127 265 * The pattern includes arbitrary coefficients and minus signs.
cannam@127 266 * A common case of this pattern is the butterfly
cannam@127 267 * (a + b) + (a - b)
cannam@127 268 * (a + b) - (a - b)
cannam@127 269 *)
cannam@127 270 (* this whole procedure needs much more thought *)
cannam@127 271 and deepCollectM maxdepth l =
cannam@127 272 let rec findTerms depth x = match x with
cannam@127 273 | Uminus x -> findTerms depth x
cannam@127 274 | Times (Num _, b) -> (findTerms (depth - 1) b)
cannam@127 275 | Plus l when depth > 0 ->
cannam@127 276 x :: List.flatten (List.map (findTerms (depth - 1)) l)
cannam@127 277 | x -> [x]
cannam@127 278 and duplicates = function
cannam@127 279 [] -> []
cannam@127 280 | a :: b -> if List.memq a b then a :: duplicates b
cannam@127 281 else duplicates b
cannam@127 282
cannam@127 283 in let rec splitDuplicates depth d x =
cannam@127 284 if (List.memq x d) then
cannam@127 285 snumM (Number.zero) >>= fun zero ->
cannam@127 286 returnM (zero, x)
cannam@127 287 else match x with
cannam@127 288 | Times (a, b) ->
cannam@127 289 splitDuplicates (depth - 1) d a >>= fun (a', xa) ->
cannam@127 290 splitDuplicates (depth - 1) d b >>= fun (b', xb) ->
cannam@127 291 stimesM (a', b') >>= fun ab ->
cannam@127 292 stimesM (a, xb) >>= fun xb' ->
cannam@127 293 stimesM (xa, b) >>= fun xa' ->
cannam@127 294 stimesM (xa, xb) >>= fun xab ->
cannam@127 295 splusM [xa'; xb'; xab] >>= fun x ->
cannam@127 296 returnM (ab, x)
cannam@127 297 | Uminus a ->
cannam@127 298 splitDuplicates depth d a >>= fun (x, y) ->
cannam@127 299 suminusM x >>= fun ux ->
cannam@127 300 suminusM y >>= fun uy ->
cannam@127 301 returnM (ux, uy)
cannam@127 302 | Plus l when depth > 0 ->
cannam@127 303 mapM (splitDuplicates (depth - 1) d) l >>= fun ld ->
cannam@127 304 let (l', d') = List.split ld in
cannam@127 305 splusM l' >>= fun p ->
cannam@127 306 splusM d' >>= fun d'' ->
cannam@127 307 returnM (p, d'')
cannam@127 308 | x ->
cannam@127 309 snumM (Number.zero) >>= fun zero' ->
cannam@127 310 returnM (x, zero')
cannam@127 311
cannam@127 312 in let l' = List.flatten (List.map (findTerms maxdepth) l)
cannam@127 313 in match duplicates l' with
cannam@127 314 | [] -> returnM l
cannam@127 315 | d ->
cannam@127 316 mapM (splitDuplicates maxdepth d) l >>= fun ld ->
cannam@127 317 let (l', d') = List.split ld in
cannam@127 318 splusM l' >>= fun l'' ->
cannam@127 319 let rec flattenPlusM = function
cannam@127 320 | Plus l -> returnM l
cannam@127 321 | Uminus x ->
cannam@127 322 flattenPlusM x >>= mapM suminusM
cannam@127 323 | x -> returnM [x]
cannam@127 324 in
cannam@127 325 mapM flattenPlusM d' >>= fun d'' ->
cannam@127 326 splusM (List.flatten d'') >>= fun d''' ->
cannam@127 327 mangleSumM [l''; d''']
cannam@127 328
cannam@127 329 and splusM l =
cannam@127 330 let fma_heuristics x =
cannam@127 331 if !Magic.enable_fma then
cannam@127 332 match x with
cannam@127 333 | [Uminus (Times _); Times _] -> Some false
cannam@127 334 | [Times _; Uminus (Times _)] -> Some false
cannam@127 335 | [Uminus (_); Times _] -> Some true
cannam@127 336 | [Times _; Uminus (Plus _)] -> Some true
cannam@127 337 | [_; Uminus (Times _)] -> Some false
cannam@127 338 | [Uminus (Times _); _] -> Some false
cannam@127 339 | _ -> None
cannam@127 340 else
cannam@127 341 None
cannam@127 342 in
cannam@127 343 mangleSumM l >>= fun l' ->
cannam@127 344 (* no terms are negative. Don't do anything *)
cannam@127 345 if not (List.exists negative l') then
cannam@127 346 canonicalizeM l'
cannam@127 347 (* all terms are negative. Negate them all and collect the minus sign *)
cannam@127 348 else if List.for_all negative l' then
cannam@127 349 mapM suminusM l' >>= splusM >>= suminusM
cannam@127 350 else match fma_heuristics l' with
cannam@127 351 | Some true -> mapM suminusM l' >>= splusM >>= suminusM
cannam@127 352 | Some false -> canonicalizeM l'
cannam@127 353 | None ->
cannam@127 354 (* Ask the Oracle for the canonical form *)
cannam@127 355 if (not !Magic.randomized_cse) &&
cannam@127 356 Oracle.should_flip_sign (Plus l') then
cannam@127 357 mapM suminusM l' >>= splusM >>= suminusM
cannam@127 358 else
cannam@127 359 canonicalizeM l'
cannam@127 360
cannam@127 361 (* monadic style algebraic simplifier for the dag *)
cannam@127 362 let rec algsimpM x =
cannam@127 363 memoizing lookupSimpM insertSimpM
cannam@127 364 (function
cannam@127 365 | Num a -> snumM a
cannam@127 366 | NaN _ as x -> makeNode x
cannam@127 367 | Plus a ->
cannam@127 368 mapM algsimpM a >>= splusM
cannam@127 369 | Times (a, b) ->
cannam@127 370 (algsimpM a >>= fun a' ->
cannam@127 371 algsimpM b >>= fun b' ->
cannam@127 372 stimesM (a', b'))
cannam@127 373 | CTimes (a, b) ->
cannam@127 374 (algsimpM a >>= fun a' ->
cannam@127 375 algsimpM b >>= fun b' ->
cannam@127 376 sctimesM (a', b'))
cannam@127 377 | CTimesJ (a, b) ->
cannam@127 378 (algsimpM a >>= fun a' ->
cannam@127 379 algsimpM b >>= fun b' ->
cannam@127 380 sctimesjM (a', b'))
cannam@127 381 | Uminus a ->
cannam@127 382 algsimpM a >>= suminusM
cannam@127 383 | Store (v, a) ->
cannam@127 384 algsimpM a >>= fun a' ->
cannam@127 385 makeNode (Store (v, a'))
cannam@127 386 | Load _ as x -> makeNode x)
cannam@127 387 x
cannam@127 388
cannam@127 389 let initialTable = (empty, empty)
cannam@127 390 let simp_roots = mapM algsimpM
cannam@127 391 let algsimp = runM initialTable simp_roots
cannam@127 392 end
cannam@127 393
cannam@127 394 (*************************************************************
cannam@127 395 * Network transposition algorithm
cannam@127 396 *************************************************************)
cannam@127 397 module Transpose = struct
cannam@127 398 open Monads.StateMonad
cannam@127 399 open Monads.MemoMonad
cannam@127 400 open Littlesimp
cannam@127 401
cannam@127 402 let fetchDuals = fetchState
cannam@127 403 let storeDuals = storeState
cannam@127 404
cannam@127 405 let lookupDualsM key =
cannam@127 406 fetchDuals >>= fun table ->
cannam@127 407 returnM (node_lookup key table)
cannam@127 408
cannam@127 409 let insertDualsM key value =
cannam@127 410 fetchDuals >>= fun table ->
cannam@127 411 storeDuals (node_insert key value table)
cannam@127 412
cannam@127 413 let rec visit visited vtable parent_table = function
cannam@127 414 [] -> (visited, parent_table)
cannam@127 415 | node :: rest ->
cannam@127 416 match node_lookup node vtable with
cannam@127 417 | Some _ -> visit visited vtable parent_table rest
cannam@127 418 | None ->
cannam@127 419 let children = match node with
cannam@127 420 | Store (v, n) -> [n]
cannam@127 421 | Plus l -> l
cannam@127 422 | Times (a, b) -> [a; b]
cannam@127 423 | CTimes (a, b) -> [a; b]
cannam@127 424 | CTimesJ (a, b) -> [a; b]
cannam@127 425 | Uminus x -> [x]
cannam@127 426 | _ -> []
cannam@127 427 in let rec loop t = function
cannam@127 428 [] -> t
cannam@127 429 | a :: rest ->
cannam@127 430 (match node_lookup a t with
cannam@127 431 None -> loop (node_insert a [node] t) rest
cannam@127 432 | Some c -> loop (node_insert a (node :: c) t) rest)
cannam@127 433 in
cannam@127 434 (visit
cannam@127 435 (node :: visited)
cannam@127 436 (node_insert node () vtable)
cannam@127 437 (loop parent_table children)
cannam@127 438 (children @ rest))
cannam@127 439
cannam@127 440 let make_transposer parent_table =
cannam@127 441 let rec termM node candidate_parent =
cannam@127 442 match candidate_parent with
cannam@127 443 | Store (_, n) when n == node ->
cannam@127 444 dualM candidate_parent >>= fun x' -> returnM [x']
cannam@127 445 | Plus (l) when List.memq node l ->
cannam@127 446 dualM candidate_parent >>= fun x' -> returnM [x']
cannam@127 447 | Times (a, b) when b == node ->
cannam@127 448 dualM candidate_parent >>= fun x' ->
cannam@127 449 returnM [makeTimes (a, x')]
cannam@127 450 | CTimes (a, b) when b == node ->
cannam@127 451 dualM candidate_parent >>= fun x' ->
cannam@127 452 returnM [CTimes (a, x')]
cannam@127 453 | CTimesJ (a, b) when b == node ->
cannam@127 454 dualM candidate_parent >>= fun x' ->
cannam@127 455 returnM [CTimesJ (a, x')]
cannam@127 456 | Uminus n when n == node ->
cannam@127 457 dualM candidate_parent >>= fun x' ->
cannam@127 458 returnM [makeUminus x']
cannam@127 459 | _ -> returnM []
cannam@127 460
cannam@127 461 and dualExpressionM this_node =
cannam@127 462 mapM (termM this_node)
cannam@127 463 (match node_lookup this_node parent_table with
cannam@127 464 | Some a -> a
cannam@127 465 | None -> failwith "bug in dualExpressionM"
cannam@127 466 ) >>= fun l ->
cannam@127 467 returnM (makePlus (List.flatten l))
cannam@127 468
cannam@127 469 and dualM this_node =
cannam@127 470 memoizing lookupDualsM insertDualsM
cannam@127 471 (function
cannam@127 472 | Load v as x ->
cannam@127 473 if (Variable.is_constant v) then
cannam@127 474 returnM (Load v)
cannam@127 475 else
cannam@127 476 (dualExpressionM x >>= fun d ->
cannam@127 477 returnM (Store (v, d)))
cannam@127 478 | Store (v, x) -> returnM (Load v)
cannam@127 479 | x -> dualExpressionM x)
cannam@127 480 this_node
cannam@127 481
cannam@127 482 in dualM
cannam@127 483
cannam@127 484 let is_store = function
cannam@127 485 | Store _ -> true
cannam@127 486 | _ -> false
cannam@127 487
cannam@127 488 let transpose dag =
cannam@127 489 let _ = Util.info "begin transpose" in
cannam@127 490 let (all_nodes, parent_table) =
cannam@127 491 visit [] Assoctable.empty Assoctable.empty dag in
cannam@127 492 let transposerM = make_transposer parent_table in
cannam@127 493 let mapTransposerM = mapM transposerM in
cannam@127 494 let duals = runM Assoctable.empty mapTransposerM all_nodes in
cannam@127 495 let roots = List.filter is_store duals in
cannam@127 496 let _ = Util.info "end transpose" in
cannam@127 497 roots
cannam@127 498 end
cannam@127 499
cannam@127 500
cannam@127 501 (*************************************************************
cannam@127 502 * Various dag statistics
cannam@127 503 *************************************************************)
cannam@127 504 module Stats : sig
cannam@127 505 type complexity
cannam@127 506 val complexity : Expr.expr list -> complexity
cannam@127 507 val same_complexity : complexity -> complexity -> bool
cannam@127 508 val leq_complexity : complexity -> complexity -> bool
cannam@127 509 val to_string : complexity -> string
cannam@127 510 end = struct
cannam@127 511 type complexity = int * int * int * int * int * int
cannam@127 512 let rec visit visited vtable = function
cannam@127 513 [] -> visited
cannam@127 514 | node :: rest ->
cannam@127 515 match node_lookup node vtable with
cannam@127 516 Some _ -> visit visited vtable rest
cannam@127 517 | None ->
cannam@127 518 let children = match node with
cannam@127 519 Store (v, n) -> [n]
cannam@127 520 | Plus l -> l
cannam@127 521 | Times (a, b) -> [a; b]
cannam@127 522 | Uminus x -> [x]
cannam@127 523 | _ -> []
cannam@127 524 in visit (node :: visited)
cannam@127 525 (node_insert node () vtable)
cannam@127 526 (children @ rest)
cannam@127 527
cannam@127 528 let complexity dag =
cannam@127 529 let rec loop (load, store, plus, times, uminus, num) = function
cannam@127 530 [] -> (load, store, plus, times, uminus, num)
cannam@127 531 | node :: rest ->
cannam@127 532 loop
cannam@127 533 (match node with
cannam@127 534 | Load _ -> (load + 1, store, plus, times, uminus, num)
cannam@127 535 | Store _ -> (load, store + 1, plus, times, uminus, num)
cannam@127 536 | Plus x -> (load, store, plus + (List.length x - 1), times, uminus, num)
cannam@127 537 | Times _ -> (load, store, plus, times + 1, uminus, num)
cannam@127 538 | Uminus _ -> (load, store, plus, times, uminus + 1, num)
cannam@127 539 | Num _ -> (load, store, plus, times, uminus, num + 1)
cannam@127 540 | CTimes _ -> (load, store, plus, times, uminus, num)
cannam@127 541 | CTimesJ _ -> (load, store, plus, times, uminus, num)
cannam@127 542 | NaN _ -> (load, store, plus, times, uminus, num))
cannam@127 543 rest
cannam@127 544 in let (l, s, p, t, u, n) =
cannam@127 545 loop (0, 0, 0, 0, 0, 0) (visit [] Assoctable.empty dag)
cannam@127 546 in (l, s, p, t, u, n)
cannam@127 547
cannam@127 548 let weight (l, s, p, t, u, n) =
cannam@127 549 l + s + 10 * p + 20 * t + u + n
cannam@127 550
cannam@127 551 let same_complexity a b = weight a = weight b
cannam@127 552 let leq_complexity a b = weight a <= weight b
cannam@127 553
cannam@127 554 let to_string (l, s, p, t, u, n) =
cannam@127 555 Printf.sprintf "ld=%d st=%d add=%d mul=%d uminus=%d num=%d\n"
cannam@127 556 l s p t u n
cannam@127 557
cannam@127 558 end
cannam@127 559
cannam@127 560 (* simplify the dag *)
cannam@127 561 let algsimp v =
cannam@127 562 let rec simplification_loop v =
cannam@127 563 let () = Util.info "simplification step" in
cannam@127 564 let complexity = Stats.complexity v in
cannam@127 565 let () = Util.info ("complexity = " ^ (Stats.to_string complexity)) in
cannam@127 566 let v = (AlgSimp.algsimp @@ Transpose.transpose @@
cannam@127 567 AlgSimp.algsimp @@ Transpose.transpose) v in
cannam@127 568 let complexity' = Stats.complexity v in
cannam@127 569 let () = Util.info ("complexity = " ^ (Stats.to_string complexity')) in
cannam@127 570 if (Stats.leq_complexity complexity' complexity) then
cannam@127 571 let () = Util.info "end algsimp" in
cannam@127 572 v
cannam@127 573 else
cannam@127 574 simplification_loop v
cannam@127 575
cannam@127 576 in
cannam@127 577 let () = Util.info "begin algsimp" in
cannam@127 578 let v = AlgSimp.algsimp v in
cannam@127 579 if !Magic.network_transposition then simplification_loop v else v
cannam@127 580