annotate src/fftw-3.3.3/genfft/algsimp.ml @ 83:ae30d91d2ffe

Replace these with versions built using an older toolset (so as to avoid ABI compatibilities when linking on Ubuntu 14.04 for packaging purposes)
author Chris Cannam
date Fri, 07 Feb 2020 11:51:13 +0000
parents 37bf6b4a2645
children
rev   line source
Chris@10 1 (*
Chris@10 2 * Copyright (c) 1997-1999 Massachusetts Institute of Technology
Chris@10 3 * Copyright (c) 2003, 2007-11 Matteo Frigo
Chris@10 4 * Copyright (c) 2003, 2007-11 Massachusetts Institute of Technology
Chris@10 5 *
Chris@10 6 * This program is free software; you can redistribute it and/or modify
Chris@10 7 * it under the terms of the GNU General Public License as published by
Chris@10 8 * the Free Software Foundation; either version 2 of the License, or
Chris@10 9 * (at your option) any later version.
Chris@10 10 *
Chris@10 11 * This program is distributed in the hope that it will be useful,
Chris@10 12 * but WITHOUT ANY WARRANTY; without even the implied warranty of
Chris@10 13 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
Chris@10 14 * GNU General Public License for more details.
Chris@10 15 *
Chris@10 16 * You should have received a copy of the GNU General Public License
Chris@10 17 * along with this program; if not, write to the Free Software
Chris@10 18 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
Chris@10 19 *
Chris@10 20 *)
Chris@10 21
Chris@10 22
Chris@10 23 open Util
Chris@10 24 open Expr
Chris@10 25
Chris@10 26 let node_insert x = Assoctable.insert Expr.hash x
Chris@10 27 let node_lookup x = Assoctable.lookup Expr.hash (==) x
Chris@10 28
Chris@10 29 (*************************************************************
Chris@10 30 * Algebraic simplifier/elimination of common subexpressions
Chris@10 31 *************************************************************)
Chris@10 32 module AlgSimp : sig
Chris@10 33 val algsimp : expr list -> expr list
Chris@10 34 end = struct
Chris@10 35
Chris@10 36 open Monads.StateMonad
Chris@10 37 open Monads.MemoMonad
Chris@10 38 open Assoctable
Chris@10 39
Chris@10 40 let fetchSimp =
Chris@10 41 fetchState >>= fun (s, _) -> returnM s
Chris@10 42 let storeSimp s =
Chris@10 43 fetchState >>= (fun (_, c) -> storeState (s, c))
Chris@10 44 let lookupSimpM key =
Chris@10 45 fetchSimp >>= fun table ->
Chris@10 46 returnM (node_lookup key table)
Chris@10 47 let insertSimpM key value =
Chris@10 48 fetchSimp >>= fun table ->
Chris@10 49 storeSimp (node_insert key value table)
Chris@10 50
Chris@10 51 let subset a b =
Chris@10 52 List.for_all (fun x -> List.exists (fun y -> x == y) b) a
Chris@10 53
Chris@10 54 let structurallyEqualCSE a b =
Chris@10 55 match (a, b) with
Chris@10 56 | (Num a, Num b) -> Number.equal a b
Chris@10 57 | (NaN a, NaN b) -> a == b
Chris@10 58 | (Load a, Load b) -> Variable.same a b
Chris@10 59 | (Times (a, a'), Times (b, b')) ->
Chris@10 60 ((a == b) && (a' == b')) or
Chris@10 61 ((a == b') && (a' == b))
Chris@10 62 | (CTimes (a, a'), CTimes (b, b')) ->
Chris@10 63 ((a == b) && (a' == b')) or
Chris@10 64 ((a == b') && (a' == b))
Chris@10 65 | (CTimesJ (a, a'), CTimesJ (b, b')) -> ((a == b) && (a' == b'))
Chris@10 66 | (Plus a, Plus b) -> subset a b && subset b a
Chris@10 67 | (Uminus a, Uminus b) -> (a == b)
Chris@10 68 | _ -> false
Chris@10 69
Chris@10 70 let hashCSE x =
Chris@10 71 if (!Magic.randomized_cse) then
Chris@10 72 Oracle.hash x
Chris@10 73 else
Chris@10 74 Expr.hash x
Chris@10 75
Chris@10 76 let equalCSE a b =
Chris@10 77 if (!Magic.randomized_cse) then
Chris@10 78 (structurallyEqualCSE a b || Oracle.likely_equal a b)
Chris@10 79 else
Chris@10 80 structurallyEqualCSE a b
Chris@10 81
Chris@10 82 let fetchCSE =
Chris@10 83 fetchState >>= fun (_, c) -> returnM c
Chris@10 84 let storeCSE c =
Chris@10 85 fetchState >>= (fun (s, _) -> storeState (s, c))
Chris@10 86 let lookupCSEM key =
Chris@10 87 fetchCSE >>= fun table ->
Chris@10 88 returnM (Assoctable.lookup hashCSE equalCSE key table)
Chris@10 89 let insertCSEM key value =
Chris@10 90 fetchCSE >>= fun table ->
Chris@10 91 storeCSE (Assoctable.insert hashCSE key value table)
Chris@10 92
Chris@10 93 (* memoize both x and Uminus x (unless x is already negated) *)
Chris@10 94 let identityM x =
Chris@10 95 let memo x = memoizing lookupCSEM insertCSEM returnM x in
Chris@10 96 match x with
Chris@10 97 Uminus _ -> memo x
Chris@10 98 | _ -> memo x >>= fun x' -> memo (Uminus x') >> returnM x'
Chris@10 99
Chris@10 100 let makeNode = identityM
Chris@10 101
Chris@10 102 (* simplifiers for various kinds of nodes *)
Chris@10 103 let rec snumM = function
Chris@10 104 n when Number.is_zero n ->
Chris@10 105 makeNode (Num (Number.zero))
Chris@10 106 | n when Number.negative n ->
Chris@10 107 makeNode (Num (Number.negate n)) >>= suminusM
Chris@10 108 | n -> makeNode (Num n)
Chris@10 109
Chris@10 110 and suminusM = function
Chris@10 111 Uminus x -> makeNode x
Chris@10 112 | Num a when (Number.is_zero a) -> snumM Number.zero
Chris@10 113 | a -> makeNode (Uminus a)
Chris@10 114
Chris@10 115 and stimesM = function
Chris@10 116 | (Uminus a, b) -> stimesM (a, b) >>= suminusM
Chris@10 117 | (a, Uminus b) -> stimesM (a, b) >>= suminusM
Chris@10 118 | (NaN I, CTimes (a, b)) -> stimesM (NaN I, b) >>=
Chris@10 119 fun ib -> sctimesM (a, ib)
Chris@10 120 | (NaN I, CTimesJ (a, b)) -> stimesM (NaN I, b) >>=
Chris@10 121 fun ib -> sctimesjM (a, ib)
Chris@10 122 | (Num a, Num b) -> snumM (Number.mul a b)
Chris@10 123 | (Num a, Times (Num b, c)) ->
Chris@10 124 snumM (Number.mul a b) >>= fun x -> stimesM (x, c)
Chris@10 125 | (Num a, b) when Number.is_zero a -> snumM Number.zero
Chris@10 126 | (Num a, b) when Number.is_one a -> makeNode b
Chris@10 127 | (Num a, b) when Number.is_mone a -> suminusM b
Chris@10 128 | (a, b) when is_known_constant b && not (is_known_constant a) ->
Chris@10 129 stimesM (b, a)
Chris@10 130 | (a, b) -> makeNode (Times (a, b))
Chris@10 131
Chris@10 132 and sctimesM = function
Chris@10 133 | (Uminus a, b) -> sctimesM (a, b) >>= suminusM
Chris@10 134 | (a, Uminus b) -> sctimesM (a, b) >>= suminusM
Chris@10 135 | (a, b) -> makeNode (CTimes (a, b))
Chris@10 136
Chris@10 137 and sctimesjM = function
Chris@10 138 | (Uminus a, b) -> sctimesjM (a, b) >>= suminusM
Chris@10 139 | (a, Uminus b) -> sctimesjM (a, b) >>= suminusM
Chris@10 140 | (a, b) -> makeNode (CTimesJ (a, b))
Chris@10 141
Chris@10 142 and reduce_sumM x = match x with
Chris@10 143 [] -> returnM []
Chris@10 144 | [Num a] ->
Chris@10 145 if (Number.is_zero a) then
Chris@10 146 returnM []
Chris@10 147 else returnM x
Chris@10 148 | [Uminus (Num a)] ->
Chris@10 149 if (Number.is_zero a) then
Chris@10 150 returnM []
Chris@10 151 else returnM x
Chris@10 152 | (Num a) :: (Num b) :: s ->
Chris@10 153 snumM (Number.add a b) >>= fun x ->
Chris@10 154 reduce_sumM (x :: s)
Chris@10 155 | (Num a) :: (Uminus (Num b)) :: s ->
Chris@10 156 snumM (Number.sub a b) >>= fun x ->
Chris@10 157 reduce_sumM (x :: s)
Chris@10 158 | (Uminus (Num a)) :: (Num b) :: s ->
Chris@10 159 snumM (Number.sub b a) >>= fun x ->
Chris@10 160 reduce_sumM (x :: s)
Chris@10 161 | (Uminus (Num a)) :: (Uminus (Num b)) :: s ->
Chris@10 162 snumM (Number.add a b) >>=
Chris@10 163 suminusM >>= fun x ->
Chris@10 164 reduce_sumM (x :: s)
Chris@10 165 | ((Num _) as a) :: b :: s -> reduce_sumM (b :: a :: s)
Chris@10 166 | ((Uminus (Num _)) as a) :: b :: s -> reduce_sumM (b :: a :: s)
Chris@10 167 | a :: s ->
Chris@10 168 reduce_sumM s >>= fun s' -> returnM (a :: s')
Chris@10 169
Chris@10 170 and collectible1 = function
Chris@10 171 | NaN _ -> false
Chris@10 172 | Uminus x -> collectible1 x
Chris@10 173 | _ -> true
Chris@10 174 and collectible (a, b) = collectible1 a
Chris@10 175
Chris@10 176 (* collect common factors: ax + bx -> (a+b)x *)
Chris@10 177 and collectM which x =
Chris@10 178 let rec findCoeffM which = function
Chris@10 179 | Times (a, b) when collectible (which (a, b)) -> returnM (which (a, b))
Chris@10 180 | Uminus x ->
Chris@10 181 findCoeffM which x >>= fun (coeff, b) ->
Chris@10 182 suminusM coeff >>= fun mcoeff ->
Chris@10 183 returnM (mcoeff, b)
Chris@10 184 | x -> snumM Number.one >>= fun one -> returnM (one, x)
Chris@10 185 and separateM xpr = function
Chris@10 186 [] -> returnM ([], [])
Chris@10 187 | a :: b ->
Chris@10 188 separateM xpr b >>= fun (w, wo) ->
Chris@10 189 (* try first factor *)
Chris@10 190 findCoeffM (fun (a, b) -> (a, b)) a >>= fun (c, x) ->
Chris@10 191 if (xpr == x) && collectible (c, x) then returnM (c :: w, wo)
Chris@10 192 else
Chris@10 193 (* try second factor *)
Chris@10 194 findCoeffM (fun (a, b) -> (b, a)) a >>= fun (c, x) ->
Chris@10 195 if (xpr == x) && collectible (c, x) then returnM (c :: w, wo)
Chris@10 196 else returnM (w, a :: wo)
Chris@10 197 in match x with
Chris@10 198 [] -> returnM x
Chris@10 199 | [a] -> returnM x
Chris@10 200 | a :: b ->
Chris@10 201 findCoeffM which a >>= fun (_, xpr) ->
Chris@10 202 separateM xpr x >>= fun (w, wo) ->
Chris@10 203 collectM which wo >>= fun wo' ->
Chris@10 204 splusM w >>= fun w' ->
Chris@10 205 stimesM (w', xpr) >>= fun t' ->
Chris@10 206 returnM (t':: wo')
Chris@10 207
Chris@10 208 and mangleSumM x = returnM x
Chris@10 209 >>= reduce_sumM
Chris@10 210 >>= collectM (fun (a, b) -> (a, b))
Chris@10 211 >>= collectM (fun (a, b) -> (b, a))
Chris@10 212 >>= reduce_sumM
Chris@10 213 >>= deepCollectM !Magic.deep_collect_depth
Chris@10 214 >>= reduce_sumM
Chris@10 215
Chris@10 216 and reorder_uminus = function (* push all Uminuses to the end *)
Chris@10 217 [] -> []
Chris@10 218 | ((Uminus _) as a' :: b) -> (reorder_uminus b) @ [a']
Chris@10 219 | (a :: b) -> a :: (reorder_uminus b)
Chris@10 220
Chris@10 221 and canonicalizeM = function
Chris@10 222 [] -> snumM Number.zero
Chris@10 223 | [a] -> makeNode a (* one term *)
Chris@10 224 | a -> generateFusedMultAddM (reorder_uminus a)
Chris@10 225
Chris@10 226 and generateFusedMultAddM =
Chris@10 227 let rec is_multiplication = function
Chris@10 228 | Times (Num a, b) -> true
Chris@10 229 | Uminus (Times (Num a, b)) -> true
Chris@10 230 | _ -> false
Chris@10 231 and separate = function
Chris@10 232 [] -> ([], [], Number.zero)
Chris@10 233 | (Times (Num a, b)) as this :: c ->
Chris@10 234 let (x, y, max) = separate c in
Chris@10 235 let newmax = if (Number.greater a max) then a else max in
Chris@10 236 (this :: x, y, newmax)
Chris@10 237 | (Uminus (Times (Num a, b))) as this :: c ->
Chris@10 238 let (x, y, max) = separate c in
Chris@10 239 let newmax = if (Number.greater a max) then a else max in
Chris@10 240 (this :: x, y, newmax)
Chris@10 241 | this :: c ->
Chris@10 242 let (x, y, max) = separate c in
Chris@10 243 (x, this :: y, max)
Chris@10 244 in fun l ->
Chris@10 245 if !Magic.enable_fma && count is_multiplication l >= 2 then
Chris@10 246 let (w, wo, max) = separate l in
Chris@10 247 snumM (Number.div Number.one max) >>= fun invmax' ->
Chris@10 248 snumM max >>= fun max' ->
Chris@10 249 mapM (fun x -> stimesM (invmax', x)) w >>= splusM >>= fun pw' ->
Chris@10 250 stimesM (max', pw') >>= fun mw' ->
Chris@10 251 splusM (wo @ [mw'])
Chris@10 252 else
Chris@10 253 makeNode (Plus l)
Chris@10 254
Chris@10 255
Chris@10 256 and negative = function
Chris@10 257 Uminus _ -> true
Chris@10 258 | _ -> false
Chris@10 259
Chris@10 260 (*
Chris@10 261 * simplify patterns of the form
Chris@10 262 *
Chris@10 263 * ((c_1 * a + ...) + ...) + (c_2 * a + ...)
Chris@10 264 *
Chris@10 265 * The pattern includes arbitrary coefficients and minus signs.
Chris@10 266 * A common case of this pattern is the butterfly
Chris@10 267 * (a + b) + (a - b)
Chris@10 268 * (a + b) - (a - b)
Chris@10 269 *)
Chris@10 270 (* this whole procedure needs much more thought *)
Chris@10 271 and deepCollectM maxdepth l =
Chris@10 272 let rec findTerms depth x = match x with
Chris@10 273 | Uminus x -> findTerms depth x
Chris@10 274 | Times (Num _, b) -> (findTerms (depth - 1) b)
Chris@10 275 | Plus l when depth > 0 ->
Chris@10 276 x :: List.flatten (List.map (findTerms (depth - 1)) l)
Chris@10 277 | x -> [x]
Chris@10 278 and duplicates = function
Chris@10 279 [] -> []
Chris@10 280 | a :: b -> if List.memq a b then a :: duplicates b
Chris@10 281 else duplicates b
Chris@10 282
Chris@10 283 in let rec splitDuplicates depth d x =
Chris@10 284 if (List.memq x d) then
Chris@10 285 snumM (Number.zero) >>= fun zero ->
Chris@10 286 returnM (zero, x)
Chris@10 287 else match x with
Chris@10 288 | Times (a, b) ->
Chris@10 289 splitDuplicates (depth - 1) d a >>= fun (a', xa) ->
Chris@10 290 splitDuplicates (depth - 1) d b >>= fun (b', xb) ->
Chris@10 291 stimesM (a', b') >>= fun ab ->
Chris@10 292 stimesM (a, xb) >>= fun xb' ->
Chris@10 293 stimesM (xa, b) >>= fun xa' ->
Chris@10 294 stimesM (xa, xb) >>= fun xab ->
Chris@10 295 splusM [xa'; xb'; xab] >>= fun x ->
Chris@10 296 returnM (ab, x)
Chris@10 297 | Uminus a ->
Chris@10 298 splitDuplicates depth d a >>= fun (x, y) ->
Chris@10 299 suminusM x >>= fun ux ->
Chris@10 300 suminusM y >>= fun uy ->
Chris@10 301 returnM (ux, uy)
Chris@10 302 | Plus l when depth > 0 ->
Chris@10 303 mapM (splitDuplicates (depth - 1) d) l >>= fun ld ->
Chris@10 304 let (l', d') = List.split ld in
Chris@10 305 splusM l' >>= fun p ->
Chris@10 306 splusM d' >>= fun d'' ->
Chris@10 307 returnM (p, d'')
Chris@10 308 | x ->
Chris@10 309 snumM (Number.zero) >>= fun zero' ->
Chris@10 310 returnM (x, zero')
Chris@10 311
Chris@10 312 in let l' = List.flatten (List.map (findTerms maxdepth) l)
Chris@10 313 in match duplicates l' with
Chris@10 314 | [] -> returnM l
Chris@10 315 | d ->
Chris@10 316 mapM (splitDuplicates maxdepth d) l >>= fun ld ->
Chris@10 317 let (l', d') = List.split ld in
Chris@10 318 splusM l' >>= fun l'' ->
Chris@10 319 let rec flattenPlusM = function
Chris@10 320 | Plus l -> returnM l
Chris@10 321 | Uminus x ->
Chris@10 322 flattenPlusM x >>= mapM suminusM
Chris@10 323 | x -> returnM [x]
Chris@10 324 in
Chris@10 325 mapM flattenPlusM d' >>= fun d'' ->
Chris@10 326 splusM (List.flatten d'') >>= fun d''' ->
Chris@10 327 mangleSumM [l''; d''']
Chris@10 328
Chris@10 329 and splusM l =
Chris@10 330 let fma_heuristics x =
Chris@10 331 if !Magic.enable_fma then
Chris@10 332 match x with
Chris@10 333 | [Uminus (Times _); Times _] -> Some false
Chris@10 334 | [Times _; Uminus (Times _)] -> Some false
Chris@10 335 | [Uminus (_); Times _] -> Some true
Chris@10 336 | [Times _; Uminus (Plus _)] -> Some true
Chris@10 337 | [_; Uminus (Times _)] -> Some false
Chris@10 338 | [Uminus (Times _); _] -> Some false
Chris@10 339 | _ -> None
Chris@10 340 else
Chris@10 341 None
Chris@10 342 in
Chris@10 343 mangleSumM l >>= fun l' ->
Chris@10 344 (* no terms are negative. Don't do anything *)
Chris@10 345 if not (List.exists negative l') then
Chris@10 346 canonicalizeM l'
Chris@10 347 (* all terms are negative. Negate them all and collect the minus sign *)
Chris@10 348 else if List.for_all negative l' then
Chris@10 349 mapM suminusM l' >>= splusM >>= suminusM
Chris@10 350 else match fma_heuristics l' with
Chris@10 351 | Some true -> mapM suminusM l' >>= splusM >>= suminusM
Chris@10 352 | Some false -> canonicalizeM l'
Chris@10 353 | None ->
Chris@10 354 (* Ask the Oracle for the canonical form *)
Chris@10 355 if (not !Magic.randomized_cse) &&
Chris@10 356 Oracle.should_flip_sign (Plus l') then
Chris@10 357 mapM suminusM l' >>= splusM >>= suminusM
Chris@10 358 else
Chris@10 359 canonicalizeM l'
Chris@10 360
Chris@10 361 (* monadic style algebraic simplifier for the dag *)
Chris@10 362 let rec algsimpM x =
Chris@10 363 memoizing lookupSimpM insertSimpM
Chris@10 364 (function
Chris@10 365 | Num a -> snumM a
Chris@10 366 | NaN _ as x -> makeNode x
Chris@10 367 | Plus a ->
Chris@10 368 mapM algsimpM a >>= splusM
Chris@10 369 | Times (a, b) ->
Chris@10 370 (algsimpM a >>= fun a' ->
Chris@10 371 algsimpM b >>= fun b' ->
Chris@10 372 stimesM (a', b'))
Chris@10 373 | CTimes (a, b) ->
Chris@10 374 (algsimpM a >>= fun a' ->
Chris@10 375 algsimpM b >>= fun b' ->
Chris@10 376 sctimesM (a', b'))
Chris@10 377 | CTimesJ (a, b) ->
Chris@10 378 (algsimpM a >>= fun a' ->
Chris@10 379 algsimpM b >>= fun b' ->
Chris@10 380 sctimesjM (a', b'))
Chris@10 381 | Uminus a ->
Chris@10 382 algsimpM a >>= suminusM
Chris@10 383 | Store (v, a) ->
Chris@10 384 algsimpM a >>= fun a' ->
Chris@10 385 makeNode (Store (v, a'))
Chris@10 386 | Load _ as x -> makeNode x)
Chris@10 387 x
Chris@10 388
Chris@10 389 let initialTable = (empty, empty)
Chris@10 390 let simp_roots = mapM algsimpM
Chris@10 391 let algsimp = runM initialTable simp_roots
Chris@10 392 end
Chris@10 393
Chris@10 394 (*************************************************************
Chris@10 395 * Network transposition algorithm
Chris@10 396 *************************************************************)
Chris@10 397 module Transpose = struct
Chris@10 398 open Monads.StateMonad
Chris@10 399 open Monads.MemoMonad
Chris@10 400 open Littlesimp
Chris@10 401
Chris@10 402 let fetchDuals = fetchState
Chris@10 403 let storeDuals = storeState
Chris@10 404
Chris@10 405 let lookupDualsM key =
Chris@10 406 fetchDuals >>= fun table ->
Chris@10 407 returnM (node_lookup key table)
Chris@10 408
Chris@10 409 let insertDualsM key value =
Chris@10 410 fetchDuals >>= fun table ->
Chris@10 411 storeDuals (node_insert key value table)
Chris@10 412
Chris@10 413 let rec visit visited vtable parent_table = function
Chris@10 414 [] -> (visited, parent_table)
Chris@10 415 | node :: rest ->
Chris@10 416 match node_lookup node vtable with
Chris@10 417 | Some _ -> visit visited vtable parent_table rest
Chris@10 418 | None ->
Chris@10 419 let children = match node with
Chris@10 420 | Store (v, n) -> [n]
Chris@10 421 | Plus l -> l
Chris@10 422 | Times (a, b) -> [a; b]
Chris@10 423 | CTimes (a, b) -> [a; b]
Chris@10 424 | CTimesJ (a, b) -> [a; b]
Chris@10 425 | Uminus x -> [x]
Chris@10 426 | _ -> []
Chris@10 427 in let rec loop t = function
Chris@10 428 [] -> t
Chris@10 429 | a :: rest ->
Chris@10 430 (match node_lookup a t with
Chris@10 431 None -> loop (node_insert a [node] t) rest
Chris@10 432 | Some c -> loop (node_insert a (node :: c) t) rest)
Chris@10 433 in
Chris@10 434 (visit
Chris@10 435 (node :: visited)
Chris@10 436 (node_insert node () vtable)
Chris@10 437 (loop parent_table children)
Chris@10 438 (children @ rest))
Chris@10 439
Chris@10 440 let make_transposer parent_table =
Chris@10 441 let rec termM node candidate_parent =
Chris@10 442 match candidate_parent with
Chris@10 443 | Store (_, n) when n == node ->
Chris@10 444 dualM candidate_parent >>= fun x' -> returnM [x']
Chris@10 445 | Plus (l) when List.memq node l ->
Chris@10 446 dualM candidate_parent >>= fun x' -> returnM [x']
Chris@10 447 | Times (a, b) when b == node ->
Chris@10 448 dualM candidate_parent >>= fun x' ->
Chris@10 449 returnM [makeTimes (a, x')]
Chris@10 450 | CTimes (a, b) when b == node ->
Chris@10 451 dualM candidate_parent >>= fun x' ->
Chris@10 452 returnM [CTimes (a, x')]
Chris@10 453 | CTimesJ (a, b) when b == node ->
Chris@10 454 dualM candidate_parent >>= fun x' ->
Chris@10 455 returnM [CTimesJ (a, x')]
Chris@10 456 | Uminus n when n == node ->
Chris@10 457 dualM candidate_parent >>= fun x' ->
Chris@10 458 returnM [makeUminus x']
Chris@10 459 | _ -> returnM []
Chris@10 460
Chris@10 461 and dualExpressionM this_node =
Chris@10 462 mapM (termM this_node)
Chris@10 463 (match node_lookup this_node parent_table with
Chris@10 464 | Some a -> a
Chris@10 465 | None -> failwith "bug in dualExpressionM"
Chris@10 466 ) >>= fun l ->
Chris@10 467 returnM (makePlus (List.flatten l))
Chris@10 468
Chris@10 469 and dualM this_node =
Chris@10 470 memoizing lookupDualsM insertDualsM
Chris@10 471 (function
Chris@10 472 | Load v as x ->
Chris@10 473 if (Variable.is_constant v) then
Chris@10 474 returnM (Load v)
Chris@10 475 else
Chris@10 476 (dualExpressionM x >>= fun d ->
Chris@10 477 returnM (Store (v, d)))
Chris@10 478 | Store (v, x) -> returnM (Load v)
Chris@10 479 | x -> dualExpressionM x)
Chris@10 480 this_node
Chris@10 481
Chris@10 482 in dualM
Chris@10 483
Chris@10 484 let is_store = function
Chris@10 485 | Store _ -> true
Chris@10 486 | _ -> false
Chris@10 487
Chris@10 488 let transpose dag =
Chris@10 489 let _ = Util.info "begin transpose" in
Chris@10 490 let (all_nodes, parent_table) =
Chris@10 491 visit [] Assoctable.empty Assoctable.empty dag in
Chris@10 492 let transposerM = make_transposer parent_table in
Chris@10 493 let mapTransposerM = mapM transposerM in
Chris@10 494 let duals = runM Assoctable.empty mapTransposerM all_nodes in
Chris@10 495 let roots = List.filter is_store duals in
Chris@10 496 let _ = Util.info "end transpose" in
Chris@10 497 roots
Chris@10 498 end
Chris@10 499
Chris@10 500
Chris@10 501 (*************************************************************
Chris@10 502 * Various dag statistics
Chris@10 503 *************************************************************)
Chris@10 504 module Stats : sig
Chris@10 505 type complexity
Chris@10 506 val complexity : Expr.expr list -> complexity
Chris@10 507 val same_complexity : complexity -> complexity -> bool
Chris@10 508 val leq_complexity : complexity -> complexity -> bool
Chris@10 509 val to_string : complexity -> string
Chris@10 510 end = struct
Chris@10 511 type complexity = int * int * int * int * int * int
Chris@10 512 let rec visit visited vtable = function
Chris@10 513 [] -> visited
Chris@10 514 | node :: rest ->
Chris@10 515 match node_lookup node vtable with
Chris@10 516 Some _ -> visit visited vtable rest
Chris@10 517 | None ->
Chris@10 518 let children = match node with
Chris@10 519 Store (v, n) -> [n]
Chris@10 520 | Plus l -> l
Chris@10 521 | Times (a, b) -> [a; b]
Chris@10 522 | Uminus x -> [x]
Chris@10 523 | _ -> []
Chris@10 524 in visit (node :: visited)
Chris@10 525 (node_insert node () vtable)
Chris@10 526 (children @ rest)
Chris@10 527
Chris@10 528 let complexity dag =
Chris@10 529 let rec loop (load, store, plus, times, uminus, num) = function
Chris@10 530 [] -> (load, store, plus, times, uminus, num)
Chris@10 531 | node :: rest ->
Chris@10 532 loop
Chris@10 533 (match node with
Chris@10 534 | Load _ -> (load + 1, store, plus, times, uminus, num)
Chris@10 535 | Store _ -> (load, store + 1, plus, times, uminus, num)
Chris@10 536 | Plus x -> (load, store, plus + (List.length x - 1), times, uminus, num)
Chris@10 537 | Times _ -> (load, store, plus, times + 1, uminus, num)
Chris@10 538 | Uminus _ -> (load, store, plus, times, uminus + 1, num)
Chris@10 539 | Num _ -> (load, store, plus, times, uminus, num + 1)
Chris@10 540 | CTimes _ -> (load, store, plus, times, uminus, num)
Chris@10 541 | CTimesJ _ -> (load, store, plus, times, uminus, num)
Chris@10 542 | NaN _ -> (load, store, plus, times, uminus, num))
Chris@10 543 rest
Chris@10 544 in let (l, s, p, t, u, n) =
Chris@10 545 loop (0, 0, 0, 0, 0, 0) (visit [] Assoctable.empty dag)
Chris@10 546 in (l, s, p, t, u, n)
Chris@10 547
Chris@10 548 let weight (l, s, p, t, u, n) =
Chris@10 549 l + s + 10 * p + 20 * t + u + n
Chris@10 550
Chris@10 551 let same_complexity a b = weight a = weight b
Chris@10 552 let leq_complexity a b = weight a <= weight b
Chris@10 553
Chris@10 554 let to_string (l, s, p, t, u, n) =
Chris@10 555 Printf.sprintf "ld=%d st=%d add=%d mul=%d uminus=%d num=%d\n"
Chris@10 556 l s p t u n
Chris@10 557
Chris@10 558 end
Chris@10 559
Chris@10 560 (* simplify the dag *)
Chris@10 561 let algsimp v =
Chris@10 562 let rec simplification_loop v =
Chris@10 563 let () = Util.info "simplification step" in
Chris@10 564 let complexity = Stats.complexity v in
Chris@10 565 let () = Util.info ("complexity = " ^ (Stats.to_string complexity)) in
Chris@10 566 let v = (AlgSimp.algsimp @@ Transpose.transpose @@
Chris@10 567 AlgSimp.algsimp @@ Transpose.transpose) v in
Chris@10 568 let complexity' = Stats.complexity v in
Chris@10 569 let () = Util.info ("complexity = " ^ (Stats.to_string complexity')) in
Chris@10 570 if (Stats.leq_complexity complexity' complexity) then
Chris@10 571 let () = Util.info "end algsimp" in
Chris@10 572 v
Chris@10 573 else
Chris@10 574 simplification_loop v
Chris@10 575
Chris@10 576 in
Chris@10 577 let () = Util.info "begin algsimp" in
Chris@10 578 let v = AlgSimp.algsimp v in
Chris@10 579 if !Magic.network_transposition then simplification_loop v else v
Chris@10 580