Chris@10
|
1 (*
|
Chris@10
|
2 * Copyright (c) 1997-1999 Massachusetts Institute of Technology
|
Chris@10
|
3 * Copyright (c) 2003, 2007-11 Matteo Frigo
|
Chris@10
|
4 * Copyright (c) 2003, 2007-11 Massachusetts Institute of Technology
|
Chris@10
|
5 *
|
Chris@10
|
6 * This program is free software; you can redistribute it and/or modify
|
Chris@10
|
7 * it under the terms of the GNU General Public License as published by
|
Chris@10
|
8 * the Free Software Foundation; either version 2 of the License, or
|
Chris@10
|
9 * (at your option) any later version.
|
Chris@10
|
10 *
|
Chris@10
|
11 * This program is distributed in the hope that it will be useful,
|
Chris@10
|
12 * but WITHOUT ANY WARRANTY; without even the implied warranty of
|
Chris@10
|
13 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
Chris@10
|
14 * GNU General Public License for more details.
|
Chris@10
|
15 *
|
Chris@10
|
16 * You should have received a copy of the GNU General Public License
|
Chris@10
|
17 * along with this program; if not, write to the Free Software
|
Chris@10
|
18 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
|
Chris@10
|
19 *
|
Chris@10
|
20 *)
|
Chris@10
|
21
|
Chris@10
|
22
|
Chris@10
|
23 open Util
|
Chris@10
|
24 open Expr
|
Chris@10
|
25
|
Chris@10
|
26 let node_insert x = Assoctable.insert Expr.hash x
|
Chris@10
|
27 let node_lookup x = Assoctable.lookup Expr.hash (==) x
|
Chris@10
|
28
|
Chris@10
|
29 (*************************************************************
|
Chris@10
|
30 * Algebraic simplifier/elimination of common subexpressions
|
Chris@10
|
31 *************************************************************)
|
Chris@10
|
32 module AlgSimp : sig
|
Chris@10
|
33 val algsimp : expr list -> expr list
|
Chris@10
|
34 end = struct
|
Chris@10
|
35
|
Chris@10
|
36 open Monads.StateMonad
|
Chris@10
|
37 open Monads.MemoMonad
|
Chris@10
|
38 open Assoctable
|
Chris@10
|
39
|
Chris@10
|
40 let fetchSimp =
|
Chris@10
|
41 fetchState >>= fun (s, _) -> returnM s
|
Chris@10
|
42 let storeSimp s =
|
Chris@10
|
43 fetchState >>= (fun (_, c) -> storeState (s, c))
|
Chris@10
|
44 let lookupSimpM key =
|
Chris@10
|
45 fetchSimp >>= fun table ->
|
Chris@10
|
46 returnM (node_lookup key table)
|
Chris@10
|
47 let insertSimpM key value =
|
Chris@10
|
48 fetchSimp >>= fun table ->
|
Chris@10
|
49 storeSimp (node_insert key value table)
|
Chris@10
|
50
|
Chris@10
|
51 let subset a b =
|
Chris@10
|
52 List.for_all (fun x -> List.exists (fun y -> x == y) b) a
|
Chris@10
|
53
|
Chris@10
|
54 let structurallyEqualCSE a b =
|
Chris@10
|
55 match (a, b) with
|
Chris@10
|
56 | (Num a, Num b) -> Number.equal a b
|
Chris@10
|
57 | (NaN a, NaN b) -> a == b
|
Chris@10
|
58 | (Load a, Load b) -> Variable.same a b
|
Chris@10
|
59 | (Times (a, a'), Times (b, b')) ->
|
Chris@10
|
60 ((a == b) && (a' == b')) or
|
Chris@10
|
61 ((a == b') && (a' == b))
|
Chris@10
|
62 | (CTimes (a, a'), CTimes (b, b')) ->
|
Chris@10
|
63 ((a == b) && (a' == b')) or
|
Chris@10
|
64 ((a == b') && (a' == b))
|
Chris@10
|
65 | (CTimesJ (a, a'), CTimesJ (b, b')) -> ((a == b) && (a' == b'))
|
Chris@10
|
66 | (Plus a, Plus b) -> subset a b && subset b a
|
Chris@10
|
67 | (Uminus a, Uminus b) -> (a == b)
|
Chris@10
|
68 | _ -> false
|
Chris@10
|
69
|
Chris@10
|
70 let hashCSE x =
|
Chris@10
|
71 if (!Magic.randomized_cse) then
|
Chris@10
|
72 Oracle.hash x
|
Chris@10
|
73 else
|
Chris@10
|
74 Expr.hash x
|
Chris@10
|
75
|
Chris@10
|
76 let equalCSE a b =
|
Chris@10
|
77 if (!Magic.randomized_cse) then
|
Chris@10
|
78 (structurallyEqualCSE a b || Oracle.likely_equal a b)
|
Chris@10
|
79 else
|
Chris@10
|
80 structurallyEqualCSE a b
|
Chris@10
|
81
|
Chris@10
|
82 let fetchCSE =
|
Chris@10
|
83 fetchState >>= fun (_, c) -> returnM c
|
Chris@10
|
84 let storeCSE c =
|
Chris@10
|
85 fetchState >>= (fun (s, _) -> storeState (s, c))
|
Chris@10
|
86 let lookupCSEM key =
|
Chris@10
|
87 fetchCSE >>= fun table ->
|
Chris@10
|
88 returnM (Assoctable.lookup hashCSE equalCSE key table)
|
Chris@10
|
89 let insertCSEM key value =
|
Chris@10
|
90 fetchCSE >>= fun table ->
|
Chris@10
|
91 storeCSE (Assoctable.insert hashCSE key value table)
|
Chris@10
|
92
|
Chris@10
|
93 (* memoize both x and Uminus x (unless x is already negated) *)
|
Chris@10
|
94 let identityM x =
|
Chris@10
|
95 let memo x = memoizing lookupCSEM insertCSEM returnM x in
|
Chris@10
|
96 match x with
|
Chris@10
|
97 Uminus _ -> memo x
|
Chris@10
|
98 | _ -> memo x >>= fun x' -> memo (Uminus x') >> returnM x'
|
Chris@10
|
99
|
Chris@10
|
100 let makeNode = identityM
|
Chris@10
|
101
|
Chris@10
|
102 (* simplifiers for various kinds of nodes *)
|
Chris@10
|
103 let rec snumM = function
|
Chris@10
|
104 n when Number.is_zero n ->
|
Chris@10
|
105 makeNode (Num (Number.zero))
|
Chris@10
|
106 | n when Number.negative n ->
|
Chris@10
|
107 makeNode (Num (Number.negate n)) >>= suminusM
|
Chris@10
|
108 | n -> makeNode (Num n)
|
Chris@10
|
109
|
Chris@10
|
110 and suminusM = function
|
Chris@10
|
111 Uminus x -> makeNode x
|
Chris@10
|
112 | Num a when (Number.is_zero a) -> snumM Number.zero
|
Chris@10
|
113 | a -> makeNode (Uminus a)
|
Chris@10
|
114
|
Chris@10
|
115 and stimesM = function
|
Chris@10
|
116 | (Uminus a, b) -> stimesM (a, b) >>= suminusM
|
Chris@10
|
117 | (a, Uminus b) -> stimesM (a, b) >>= suminusM
|
Chris@10
|
118 | (NaN I, CTimes (a, b)) -> stimesM (NaN I, b) >>=
|
Chris@10
|
119 fun ib -> sctimesM (a, ib)
|
Chris@10
|
120 | (NaN I, CTimesJ (a, b)) -> stimesM (NaN I, b) >>=
|
Chris@10
|
121 fun ib -> sctimesjM (a, ib)
|
Chris@10
|
122 | (Num a, Num b) -> snumM (Number.mul a b)
|
Chris@10
|
123 | (Num a, Times (Num b, c)) ->
|
Chris@10
|
124 snumM (Number.mul a b) >>= fun x -> stimesM (x, c)
|
Chris@10
|
125 | (Num a, b) when Number.is_zero a -> snumM Number.zero
|
Chris@10
|
126 | (Num a, b) when Number.is_one a -> makeNode b
|
Chris@10
|
127 | (Num a, b) when Number.is_mone a -> suminusM b
|
Chris@10
|
128 | (a, b) when is_known_constant b && not (is_known_constant a) ->
|
Chris@10
|
129 stimesM (b, a)
|
Chris@10
|
130 | (a, b) -> makeNode (Times (a, b))
|
Chris@10
|
131
|
Chris@10
|
132 and sctimesM = function
|
Chris@10
|
133 | (Uminus a, b) -> sctimesM (a, b) >>= suminusM
|
Chris@10
|
134 | (a, Uminus b) -> sctimesM (a, b) >>= suminusM
|
Chris@10
|
135 | (a, b) -> makeNode (CTimes (a, b))
|
Chris@10
|
136
|
Chris@10
|
137 and sctimesjM = function
|
Chris@10
|
138 | (Uminus a, b) -> sctimesjM (a, b) >>= suminusM
|
Chris@10
|
139 | (a, Uminus b) -> sctimesjM (a, b) >>= suminusM
|
Chris@10
|
140 | (a, b) -> makeNode (CTimesJ (a, b))
|
Chris@10
|
141
|
Chris@10
|
142 and reduce_sumM x = match x with
|
Chris@10
|
143 [] -> returnM []
|
Chris@10
|
144 | [Num a] ->
|
Chris@10
|
145 if (Number.is_zero a) then
|
Chris@10
|
146 returnM []
|
Chris@10
|
147 else returnM x
|
Chris@10
|
148 | [Uminus (Num a)] ->
|
Chris@10
|
149 if (Number.is_zero a) then
|
Chris@10
|
150 returnM []
|
Chris@10
|
151 else returnM x
|
Chris@10
|
152 | (Num a) :: (Num b) :: s ->
|
Chris@10
|
153 snumM (Number.add a b) >>= fun x ->
|
Chris@10
|
154 reduce_sumM (x :: s)
|
Chris@10
|
155 | (Num a) :: (Uminus (Num b)) :: s ->
|
Chris@10
|
156 snumM (Number.sub a b) >>= fun x ->
|
Chris@10
|
157 reduce_sumM (x :: s)
|
Chris@10
|
158 | (Uminus (Num a)) :: (Num b) :: s ->
|
Chris@10
|
159 snumM (Number.sub b a) >>= fun x ->
|
Chris@10
|
160 reduce_sumM (x :: s)
|
Chris@10
|
161 | (Uminus (Num a)) :: (Uminus (Num b)) :: s ->
|
Chris@10
|
162 snumM (Number.add a b) >>=
|
Chris@10
|
163 suminusM >>= fun x ->
|
Chris@10
|
164 reduce_sumM (x :: s)
|
Chris@10
|
165 | ((Num _) as a) :: b :: s -> reduce_sumM (b :: a :: s)
|
Chris@10
|
166 | ((Uminus (Num _)) as a) :: b :: s -> reduce_sumM (b :: a :: s)
|
Chris@10
|
167 | a :: s ->
|
Chris@10
|
168 reduce_sumM s >>= fun s' -> returnM (a :: s')
|
Chris@10
|
169
|
Chris@10
|
170 and collectible1 = function
|
Chris@10
|
171 | NaN _ -> false
|
Chris@10
|
172 | Uminus x -> collectible1 x
|
Chris@10
|
173 | _ -> true
|
Chris@10
|
174 and collectible (a, b) = collectible1 a
|
Chris@10
|
175
|
Chris@10
|
176 (* collect common factors: ax + bx -> (a+b)x *)
|
Chris@10
|
177 and collectM which x =
|
Chris@10
|
178 let rec findCoeffM which = function
|
Chris@10
|
179 | Times (a, b) when collectible (which (a, b)) -> returnM (which (a, b))
|
Chris@10
|
180 | Uminus x ->
|
Chris@10
|
181 findCoeffM which x >>= fun (coeff, b) ->
|
Chris@10
|
182 suminusM coeff >>= fun mcoeff ->
|
Chris@10
|
183 returnM (mcoeff, b)
|
Chris@10
|
184 | x -> snumM Number.one >>= fun one -> returnM (one, x)
|
Chris@10
|
185 and separateM xpr = function
|
Chris@10
|
186 [] -> returnM ([], [])
|
Chris@10
|
187 | a :: b ->
|
Chris@10
|
188 separateM xpr b >>= fun (w, wo) ->
|
Chris@10
|
189 (* try first factor *)
|
Chris@10
|
190 findCoeffM (fun (a, b) -> (a, b)) a >>= fun (c, x) ->
|
Chris@10
|
191 if (xpr == x) && collectible (c, x) then returnM (c :: w, wo)
|
Chris@10
|
192 else
|
Chris@10
|
193 (* try second factor *)
|
Chris@10
|
194 findCoeffM (fun (a, b) -> (b, a)) a >>= fun (c, x) ->
|
Chris@10
|
195 if (xpr == x) && collectible (c, x) then returnM (c :: w, wo)
|
Chris@10
|
196 else returnM (w, a :: wo)
|
Chris@10
|
197 in match x with
|
Chris@10
|
198 [] -> returnM x
|
Chris@10
|
199 | [a] -> returnM x
|
Chris@10
|
200 | a :: b ->
|
Chris@10
|
201 findCoeffM which a >>= fun (_, xpr) ->
|
Chris@10
|
202 separateM xpr x >>= fun (w, wo) ->
|
Chris@10
|
203 collectM which wo >>= fun wo' ->
|
Chris@10
|
204 splusM w >>= fun w' ->
|
Chris@10
|
205 stimesM (w', xpr) >>= fun t' ->
|
Chris@10
|
206 returnM (t':: wo')
|
Chris@10
|
207
|
Chris@10
|
208 and mangleSumM x = returnM x
|
Chris@10
|
209 >>= reduce_sumM
|
Chris@10
|
210 >>= collectM (fun (a, b) -> (a, b))
|
Chris@10
|
211 >>= collectM (fun (a, b) -> (b, a))
|
Chris@10
|
212 >>= reduce_sumM
|
Chris@10
|
213 >>= deepCollectM !Magic.deep_collect_depth
|
Chris@10
|
214 >>= reduce_sumM
|
Chris@10
|
215
|
Chris@10
|
216 and reorder_uminus = function (* push all Uminuses to the end *)
|
Chris@10
|
217 [] -> []
|
Chris@10
|
218 | ((Uminus _) as a' :: b) -> (reorder_uminus b) @ [a']
|
Chris@10
|
219 | (a :: b) -> a :: (reorder_uminus b)
|
Chris@10
|
220
|
Chris@10
|
221 and canonicalizeM = function
|
Chris@10
|
222 [] -> snumM Number.zero
|
Chris@10
|
223 | [a] -> makeNode a (* one term *)
|
Chris@10
|
224 | a -> generateFusedMultAddM (reorder_uminus a)
|
Chris@10
|
225
|
Chris@10
|
226 and generateFusedMultAddM =
|
Chris@10
|
227 let rec is_multiplication = function
|
Chris@10
|
228 | Times (Num a, b) -> true
|
Chris@10
|
229 | Uminus (Times (Num a, b)) -> true
|
Chris@10
|
230 | _ -> false
|
Chris@10
|
231 and separate = function
|
Chris@10
|
232 [] -> ([], [], Number.zero)
|
Chris@10
|
233 | (Times (Num a, b)) as this :: c ->
|
Chris@10
|
234 let (x, y, max) = separate c in
|
Chris@10
|
235 let newmax = if (Number.greater a max) then a else max in
|
Chris@10
|
236 (this :: x, y, newmax)
|
Chris@10
|
237 | (Uminus (Times (Num a, b))) as this :: c ->
|
Chris@10
|
238 let (x, y, max) = separate c in
|
Chris@10
|
239 let newmax = if (Number.greater a max) then a else max in
|
Chris@10
|
240 (this :: x, y, newmax)
|
Chris@10
|
241 | this :: c ->
|
Chris@10
|
242 let (x, y, max) = separate c in
|
Chris@10
|
243 (x, this :: y, max)
|
Chris@10
|
244 in fun l ->
|
Chris@10
|
245 if !Magic.enable_fma && count is_multiplication l >= 2 then
|
Chris@10
|
246 let (w, wo, max) = separate l in
|
Chris@10
|
247 snumM (Number.div Number.one max) >>= fun invmax' ->
|
Chris@10
|
248 snumM max >>= fun max' ->
|
Chris@10
|
249 mapM (fun x -> stimesM (invmax', x)) w >>= splusM >>= fun pw' ->
|
Chris@10
|
250 stimesM (max', pw') >>= fun mw' ->
|
Chris@10
|
251 splusM (wo @ [mw'])
|
Chris@10
|
252 else
|
Chris@10
|
253 makeNode (Plus l)
|
Chris@10
|
254
|
Chris@10
|
255
|
Chris@10
|
256 and negative = function
|
Chris@10
|
257 Uminus _ -> true
|
Chris@10
|
258 | _ -> false
|
Chris@10
|
259
|
Chris@10
|
260 (*
|
Chris@10
|
261 * simplify patterns of the form
|
Chris@10
|
262 *
|
Chris@10
|
263 * ((c_1 * a + ...) + ...) + (c_2 * a + ...)
|
Chris@10
|
264 *
|
Chris@10
|
265 * The pattern includes arbitrary coefficients and minus signs.
|
Chris@10
|
266 * A common case of this pattern is the butterfly
|
Chris@10
|
267 * (a + b) + (a - b)
|
Chris@10
|
268 * (a + b) - (a - b)
|
Chris@10
|
269 *)
|
Chris@10
|
270 (* this whole procedure needs much more thought *)
|
Chris@10
|
271 and deepCollectM maxdepth l =
|
Chris@10
|
272 let rec findTerms depth x = match x with
|
Chris@10
|
273 | Uminus x -> findTerms depth x
|
Chris@10
|
274 | Times (Num _, b) -> (findTerms (depth - 1) b)
|
Chris@10
|
275 | Plus l when depth > 0 ->
|
Chris@10
|
276 x :: List.flatten (List.map (findTerms (depth - 1)) l)
|
Chris@10
|
277 | x -> [x]
|
Chris@10
|
278 and duplicates = function
|
Chris@10
|
279 [] -> []
|
Chris@10
|
280 | a :: b -> if List.memq a b then a :: duplicates b
|
Chris@10
|
281 else duplicates b
|
Chris@10
|
282
|
Chris@10
|
283 in let rec splitDuplicates depth d x =
|
Chris@10
|
284 if (List.memq x d) then
|
Chris@10
|
285 snumM (Number.zero) >>= fun zero ->
|
Chris@10
|
286 returnM (zero, x)
|
Chris@10
|
287 else match x with
|
Chris@10
|
288 | Times (a, b) ->
|
Chris@10
|
289 splitDuplicates (depth - 1) d a >>= fun (a', xa) ->
|
Chris@10
|
290 splitDuplicates (depth - 1) d b >>= fun (b', xb) ->
|
Chris@10
|
291 stimesM (a', b') >>= fun ab ->
|
Chris@10
|
292 stimesM (a, xb) >>= fun xb' ->
|
Chris@10
|
293 stimesM (xa, b) >>= fun xa' ->
|
Chris@10
|
294 stimesM (xa, xb) >>= fun xab ->
|
Chris@10
|
295 splusM [xa'; xb'; xab] >>= fun x ->
|
Chris@10
|
296 returnM (ab, x)
|
Chris@10
|
297 | Uminus a ->
|
Chris@10
|
298 splitDuplicates depth d a >>= fun (x, y) ->
|
Chris@10
|
299 suminusM x >>= fun ux ->
|
Chris@10
|
300 suminusM y >>= fun uy ->
|
Chris@10
|
301 returnM (ux, uy)
|
Chris@10
|
302 | Plus l when depth > 0 ->
|
Chris@10
|
303 mapM (splitDuplicates (depth - 1) d) l >>= fun ld ->
|
Chris@10
|
304 let (l', d') = List.split ld in
|
Chris@10
|
305 splusM l' >>= fun p ->
|
Chris@10
|
306 splusM d' >>= fun d'' ->
|
Chris@10
|
307 returnM (p, d'')
|
Chris@10
|
308 | x ->
|
Chris@10
|
309 snumM (Number.zero) >>= fun zero' ->
|
Chris@10
|
310 returnM (x, zero')
|
Chris@10
|
311
|
Chris@10
|
312 in let l' = List.flatten (List.map (findTerms maxdepth) l)
|
Chris@10
|
313 in match duplicates l' with
|
Chris@10
|
314 | [] -> returnM l
|
Chris@10
|
315 | d ->
|
Chris@10
|
316 mapM (splitDuplicates maxdepth d) l >>= fun ld ->
|
Chris@10
|
317 let (l', d') = List.split ld in
|
Chris@10
|
318 splusM l' >>= fun l'' ->
|
Chris@10
|
319 let rec flattenPlusM = function
|
Chris@10
|
320 | Plus l -> returnM l
|
Chris@10
|
321 | Uminus x ->
|
Chris@10
|
322 flattenPlusM x >>= mapM suminusM
|
Chris@10
|
323 | x -> returnM [x]
|
Chris@10
|
324 in
|
Chris@10
|
325 mapM flattenPlusM d' >>= fun d'' ->
|
Chris@10
|
326 splusM (List.flatten d'') >>= fun d''' ->
|
Chris@10
|
327 mangleSumM [l''; d''']
|
Chris@10
|
328
|
Chris@10
|
329 and splusM l =
|
Chris@10
|
330 let fma_heuristics x =
|
Chris@10
|
331 if !Magic.enable_fma then
|
Chris@10
|
332 match x with
|
Chris@10
|
333 | [Uminus (Times _); Times _] -> Some false
|
Chris@10
|
334 | [Times _; Uminus (Times _)] -> Some false
|
Chris@10
|
335 | [Uminus (_); Times _] -> Some true
|
Chris@10
|
336 | [Times _; Uminus (Plus _)] -> Some true
|
Chris@10
|
337 | [_; Uminus (Times _)] -> Some false
|
Chris@10
|
338 | [Uminus (Times _); _] -> Some false
|
Chris@10
|
339 | _ -> None
|
Chris@10
|
340 else
|
Chris@10
|
341 None
|
Chris@10
|
342 in
|
Chris@10
|
343 mangleSumM l >>= fun l' ->
|
Chris@10
|
344 (* no terms are negative. Don't do anything *)
|
Chris@10
|
345 if not (List.exists negative l') then
|
Chris@10
|
346 canonicalizeM l'
|
Chris@10
|
347 (* all terms are negative. Negate them all and collect the minus sign *)
|
Chris@10
|
348 else if List.for_all negative l' then
|
Chris@10
|
349 mapM suminusM l' >>= splusM >>= suminusM
|
Chris@10
|
350 else match fma_heuristics l' with
|
Chris@10
|
351 | Some true -> mapM suminusM l' >>= splusM >>= suminusM
|
Chris@10
|
352 | Some false -> canonicalizeM l'
|
Chris@10
|
353 | None ->
|
Chris@10
|
354 (* Ask the Oracle for the canonical form *)
|
Chris@10
|
355 if (not !Magic.randomized_cse) &&
|
Chris@10
|
356 Oracle.should_flip_sign (Plus l') then
|
Chris@10
|
357 mapM suminusM l' >>= splusM >>= suminusM
|
Chris@10
|
358 else
|
Chris@10
|
359 canonicalizeM l'
|
Chris@10
|
360
|
Chris@10
|
361 (* monadic style algebraic simplifier for the dag *)
|
Chris@10
|
362 let rec algsimpM x =
|
Chris@10
|
363 memoizing lookupSimpM insertSimpM
|
Chris@10
|
364 (function
|
Chris@10
|
365 | Num a -> snumM a
|
Chris@10
|
366 | NaN _ as x -> makeNode x
|
Chris@10
|
367 | Plus a ->
|
Chris@10
|
368 mapM algsimpM a >>= splusM
|
Chris@10
|
369 | Times (a, b) ->
|
Chris@10
|
370 (algsimpM a >>= fun a' ->
|
Chris@10
|
371 algsimpM b >>= fun b' ->
|
Chris@10
|
372 stimesM (a', b'))
|
Chris@10
|
373 | CTimes (a, b) ->
|
Chris@10
|
374 (algsimpM a >>= fun a' ->
|
Chris@10
|
375 algsimpM b >>= fun b' ->
|
Chris@10
|
376 sctimesM (a', b'))
|
Chris@10
|
377 | CTimesJ (a, b) ->
|
Chris@10
|
378 (algsimpM a >>= fun a' ->
|
Chris@10
|
379 algsimpM b >>= fun b' ->
|
Chris@10
|
380 sctimesjM (a', b'))
|
Chris@10
|
381 | Uminus a ->
|
Chris@10
|
382 algsimpM a >>= suminusM
|
Chris@10
|
383 | Store (v, a) ->
|
Chris@10
|
384 algsimpM a >>= fun a' ->
|
Chris@10
|
385 makeNode (Store (v, a'))
|
Chris@10
|
386 | Load _ as x -> makeNode x)
|
Chris@10
|
387 x
|
Chris@10
|
388
|
Chris@10
|
389 let initialTable = (empty, empty)
|
Chris@10
|
390 let simp_roots = mapM algsimpM
|
Chris@10
|
391 let algsimp = runM initialTable simp_roots
|
Chris@10
|
392 end
|
Chris@10
|
393
|
Chris@10
|
394 (*************************************************************
|
Chris@10
|
395 * Network transposition algorithm
|
Chris@10
|
396 *************************************************************)
|
Chris@10
|
397 module Transpose = struct
|
Chris@10
|
398 open Monads.StateMonad
|
Chris@10
|
399 open Monads.MemoMonad
|
Chris@10
|
400 open Littlesimp
|
Chris@10
|
401
|
Chris@10
|
402 let fetchDuals = fetchState
|
Chris@10
|
403 let storeDuals = storeState
|
Chris@10
|
404
|
Chris@10
|
405 let lookupDualsM key =
|
Chris@10
|
406 fetchDuals >>= fun table ->
|
Chris@10
|
407 returnM (node_lookup key table)
|
Chris@10
|
408
|
Chris@10
|
409 let insertDualsM key value =
|
Chris@10
|
410 fetchDuals >>= fun table ->
|
Chris@10
|
411 storeDuals (node_insert key value table)
|
Chris@10
|
412
|
Chris@10
|
413 let rec visit visited vtable parent_table = function
|
Chris@10
|
414 [] -> (visited, parent_table)
|
Chris@10
|
415 | node :: rest ->
|
Chris@10
|
416 match node_lookup node vtable with
|
Chris@10
|
417 | Some _ -> visit visited vtable parent_table rest
|
Chris@10
|
418 | None ->
|
Chris@10
|
419 let children = match node with
|
Chris@10
|
420 | Store (v, n) -> [n]
|
Chris@10
|
421 | Plus l -> l
|
Chris@10
|
422 | Times (a, b) -> [a; b]
|
Chris@10
|
423 | CTimes (a, b) -> [a; b]
|
Chris@10
|
424 | CTimesJ (a, b) -> [a; b]
|
Chris@10
|
425 | Uminus x -> [x]
|
Chris@10
|
426 | _ -> []
|
Chris@10
|
427 in let rec loop t = function
|
Chris@10
|
428 [] -> t
|
Chris@10
|
429 | a :: rest ->
|
Chris@10
|
430 (match node_lookup a t with
|
Chris@10
|
431 None -> loop (node_insert a [node] t) rest
|
Chris@10
|
432 | Some c -> loop (node_insert a (node :: c) t) rest)
|
Chris@10
|
433 in
|
Chris@10
|
434 (visit
|
Chris@10
|
435 (node :: visited)
|
Chris@10
|
436 (node_insert node () vtable)
|
Chris@10
|
437 (loop parent_table children)
|
Chris@10
|
438 (children @ rest))
|
Chris@10
|
439
|
Chris@10
|
440 let make_transposer parent_table =
|
Chris@10
|
441 let rec termM node candidate_parent =
|
Chris@10
|
442 match candidate_parent with
|
Chris@10
|
443 | Store (_, n) when n == node ->
|
Chris@10
|
444 dualM candidate_parent >>= fun x' -> returnM [x']
|
Chris@10
|
445 | Plus (l) when List.memq node l ->
|
Chris@10
|
446 dualM candidate_parent >>= fun x' -> returnM [x']
|
Chris@10
|
447 | Times (a, b) when b == node ->
|
Chris@10
|
448 dualM candidate_parent >>= fun x' ->
|
Chris@10
|
449 returnM [makeTimes (a, x')]
|
Chris@10
|
450 | CTimes (a, b) when b == node ->
|
Chris@10
|
451 dualM candidate_parent >>= fun x' ->
|
Chris@10
|
452 returnM [CTimes (a, x')]
|
Chris@10
|
453 | CTimesJ (a, b) when b == node ->
|
Chris@10
|
454 dualM candidate_parent >>= fun x' ->
|
Chris@10
|
455 returnM [CTimesJ (a, x')]
|
Chris@10
|
456 | Uminus n when n == node ->
|
Chris@10
|
457 dualM candidate_parent >>= fun x' ->
|
Chris@10
|
458 returnM [makeUminus x']
|
Chris@10
|
459 | _ -> returnM []
|
Chris@10
|
460
|
Chris@10
|
461 and dualExpressionM this_node =
|
Chris@10
|
462 mapM (termM this_node)
|
Chris@10
|
463 (match node_lookup this_node parent_table with
|
Chris@10
|
464 | Some a -> a
|
Chris@10
|
465 | None -> failwith "bug in dualExpressionM"
|
Chris@10
|
466 ) >>= fun l ->
|
Chris@10
|
467 returnM (makePlus (List.flatten l))
|
Chris@10
|
468
|
Chris@10
|
469 and dualM this_node =
|
Chris@10
|
470 memoizing lookupDualsM insertDualsM
|
Chris@10
|
471 (function
|
Chris@10
|
472 | Load v as x ->
|
Chris@10
|
473 if (Variable.is_constant v) then
|
Chris@10
|
474 returnM (Load v)
|
Chris@10
|
475 else
|
Chris@10
|
476 (dualExpressionM x >>= fun d ->
|
Chris@10
|
477 returnM (Store (v, d)))
|
Chris@10
|
478 | Store (v, x) -> returnM (Load v)
|
Chris@10
|
479 | x -> dualExpressionM x)
|
Chris@10
|
480 this_node
|
Chris@10
|
481
|
Chris@10
|
482 in dualM
|
Chris@10
|
483
|
Chris@10
|
484 let is_store = function
|
Chris@10
|
485 | Store _ -> true
|
Chris@10
|
486 | _ -> false
|
Chris@10
|
487
|
Chris@10
|
488 let transpose dag =
|
Chris@10
|
489 let _ = Util.info "begin transpose" in
|
Chris@10
|
490 let (all_nodes, parent_table) =
|
Chris@10
|
491 visit [] Assoctable.empty Assoctable.empty dag in
|
Chris@10
|
492 let transposerM = make_transposer parent_table in
|
Chris@10
|
493 let mapTransposerM = mapM transposerM in
|
Chris@10
|
494 let duals = runM Assoctable.empty mapTransposerM all_nodes in
|
Chris@10
|
495 let roots = List.filter is_store duals in
|
Chris@10
|
496 let _ = Util.info "end transpose" in
|
Chris@10
|
497 roots
|
Chris@10
|
498 end
|
Chris@10
|
499
|
Chris@10
|
500
|
Chris@10
|
501 (*************************************************************
|
Chris@10
|
502 * Various dag statistics
|
Chris@10
|
503 *************************************************************)
|
Chris@10
|
504 module Stats : sig
|
Chris@10
|
505 type complexity
|
Chris@10
|
506 val complexity : Expr.expr list -> complexity
|
Chris@10
|
507 val same_complexity : complexity -> complexity -> bool
|
Chris@10
|
508 val leq_complexity : complexity -> complexity -> bool
|
Chris@10
|
509 val to_string : complexity -> string
|
Chris@10
|
510 end = struct
|
Chris@10
|
511 type complexity = int * int * int * int * int * int
|
Chris@10
|
512 let rec visit visited vtable = function
|
Chris@10
|
513 [] -> visited
|
Chris@10
|
514 | node :: rest ->
|
Chris@10
|
515 match node_lookup node vtable with
|
Chris@10
|
516 Some _ -> visit visited vtable rest
|
Chris@10
|
517 | None ->
|
Chris@10
|
518 let children = match node with
|
Chris@10
|
519 Store (v, n) -> [n]
|
Chris@10
|
520 | Plus l -> l
|
Chris@10
|
521 | Times (a, b) -> [a; b]
|
Chris@10
|
522 | Uminus x -> [x]
|
Chris@10
|
523 | _ -> []
|
Chris@10
|
524 in visit (node :: visited)
|
Chris@10
|
525 (node_insert node () vtable)
|
Chris@10
|
526 (children @ rest)
|
Chris@10
|
527
|
Chris@10
|
528 let complexity dag =
|
Chris@10
|
529 let rec loop (load, store, plus, times, uminus, num) = function
|
Chris@10
|
530 [] -> (load, store, plus, times, uminus, num)
|
Chris@10
|
531 | node :: rest ->
|
Chris@10
|
532 loop
|
Chris@10
|
533 (match node with
|
Chris@10
|
534 | Load _ -> (load + 1, store, plus, times, uminus, num)
|
Chris@10
|
535 | Store _ -> (load, store + 1, plus, times, uminus, num)
|
Chris@10
|
536 | Plus x -> (load, store, plus + (List.length x - 1), times, uminus, num)
|
Chris@10
|
537 | Times _ -> (load, store, plus, times + 1, uminus, num)
|
Chris@10
|
538 | Uminus _ -> (load, store, plus, times, uminus + 1, num)
|
Chris@10
|
539 | Num _ -> (load, store, plus, times, uminus, num + 1)
|
Chris@10
|
540 | CTimes _ -> (load, store, plus, times, uminus, num)
|
Chris@10
|
541 | CTimesJ _ -> (load, store, plus, times, uminus, num)
|
Chris@10
|
542 | NaN _ -> (load, store, plus, times, uminus, num))
|
Chris@10
|
543 rest
|
Chris@10
|
544 in let (l, s, p, t, u, n) =
|
Chris@10
|
545 loop (0, 0, 0, 0, 0, 0) (visit [] Assoctable.empty dag)
|
Chris@10
|
546 in (l, s, p, t, u, n)
|
Chris@10
|
547
|
Chris@10
|
548 let weight (l, s, p, t, u, n) =
|
Chris@10
|
549 l + s + 10 * p + 20 * t + u + n
|
Chris@10
|
550
|
Chris@10
|
551 let same_complexity a b = weight a = weight b
|
Chris@10
|
552 let leq_complexity a b = weight a <= weight b
|
Chris@10
|
553
|
Chris@10
|
554 let to_string (l, s, p, t, u, n) =
|
Chris@10
|
555 Printf.sprintf "ld=%d st=%d add=%d mul=%d uminus=%d num=%d\n"
|
Chris@10
|
556 l s p t u n
|
Chris@10
|
557
|
Chris@10
|
558 end
|
Chris@10
|
559
|
Chris@10
|
560 (* simplify the dag *)
|
Chris@10
|
561 let algsimp v =
|
Chris@10
|
562 let rec simplification_loop v =
|
Chris@10
|
563 let () = Util.info "simplification step" in
|
Chris@10
|
564 let complexity = Stats.complexity v in
|
Chris@10
|
565 let () = Util.info ("complexity = " ^ (Stats.to_string complexity)) in
|
Chris@10
|
566 let v = (AlgSimp.algsimp @@ Transpose.transpose @@
|
Chris@10
|
567 AlgSimp.algsimp @@ Transpose.transpose) v in
|
Chris@10
|
568 let complexity' = Stats.complexity v in
|
Chris@10
|
569 let () = Util.info ("complexity = " ^ (Stats.to_string complexity')) in
|
Chris@10
|
570 if (Stats.leq_complexity complexity' complexity) then
|
Chris@10
|
571 let () = Util.info "end algsimp" in
|
Chris@10
|
572 v
|
Chris@10
|
573 else
|
Chris@10
|
574 simplification_loop v
|
Chris@10
|
575
|
Chris@10
|
576 in
|
Chris@10
|
577 let () = Util.info "begin algsimp" in
|
Chris@10
|
578 let v = AlgSimp.algsimp v in
|
Chris@10
|
579 if !Magic.network_transposition then simplification_loop v else v
|
Chris@10
|
580
|