@@ -45,6 +45,7 @@ Require Import Crypto.Util.ListUtil.FoldMap. Import FoldMap.List.
45
45
Require Import Crypto.Util.ListUtil.IndexOf. Import IndexOf.List.
46
46
Require Import Crypto.Util.ListUtil.Forall.
47
47
Require Import Crypto.Util.ListUtil.Permutation.
48
+ Require Import Crypto.Util.ListUtil.Injective. Import Injective.List.
48
49
Require Import Crypto.Util.ListUtil.Partition.
49
50
Require Import Crypto.Util.ListUtil.Filter.
50
51
Require Import Crypto.Util.ListUtil.PermutationCompat. Import ListUtil.PermutationCompat.Coq.Sorting.Permutation.
@@ -1937,6 +1938,28 @@ Section WithDag.
1937
1938
(_:interp_op ctx op args' = Some n)
1938
1939
: eval_node (op, args) n.
1939
1940
1941
+ Definition expr_dag_beq_fueled_step (expr_dag_beq_fueled : expr -> expr -> bool) :=
1942
+ fix expr_dag_beq (e1 e2 : expr) : bool :=
1943
+ match e1, e2 with
1944
+ | ExprRef i, ExprRef j => N.eqb i j
1945
+ | ExprApp (op1, args1), ExprApp (op2, args2) =>
1946
+ op_beq op1 op2 && list_beq _ expr_dag_beq args1 args2
1947
+ | _, _ => match reveal_expr 1 e1, reveal_expr 1 e2 with
1948
+ | ExprRef i, ExprRef j => N.eqb i j
1949
+ | ExprApp (op1, args1), ExprApp (op2, args2) =>
1950
+ op_beq op1 op2 && list_beq _ expr_dag_beq_fueled args1 args2
1951
+ | _, _ => false
1952
+ end end .
1953
+
1954
+ Fixpoint expr_dag_beq_fueled (fuel : nat) (e1 e2 : expr) : bool :=
1955
+ expr_dag_beq_fueled_step
1956
+ match fuel with
1957
+ | O => expr_beq
1958
+ | S fuel => expr_dag_beq_fueled fuel
1959
+ end
1960
+ e1 e2.
1961
+
1962
+ Definition expr_dag_beq e1 e2 := expr_dag_beq_fueled (S (N.to_nat (dag.size dag))) e1 e2.
1940
1963
1941
1964
Section eval_ind.
1942
1965
Context (P : expr -> Z -> Prop )
@@ -2052,6 +2075,165 @@ Section WithDag.
2052
2075
all: eapply Forall2_weaken; [ | eassumption ]; cbv beta.
2053
2076
all: intros; destruct_head'_and; eauto. }
2054
2077
Qed .
2078
+
2079
+ Lemma reveal0 : forall e, reveal 0 e = ExprRef e.
2080
+ Proof using Type . reflexivity. Qed .
2081
+
2082
+ Lemma reveal_expr0 : forall e, reveal_expr 0 e = e.
2083
+ Proof using Type . destruct e; cbn [reveal_expr]; break_innermost_match; now rewrite ?reveal0. Qed .
2084
+
2085
+ Lemma expr_dag_beq_fueled_step_refl
2086
+ expr_dag_beq_fueled (Hexpr_dag_beq_fueled : forall e, expr_dag_beq_fueled e e = true)
2087
+ : forall e, expr_dag_beq_fueled_step expr_dag_beq_fueled e e = true.
2088
+ Proof using Type .
2089
+ induction e; cbn; break_innermost_match; rewrite ?Bool.andb_true_iff; try split.
2090
+ all: try apply unreflect_bool; try reflexivity.
2091
+ cbn in *.
2092
+ let H := match goal with H : Forall _ _ |- _ => H end in
2093
+ induction H; cbn; trivial.
2094
+ rewrite Bool.andb_true_iff.
2095
+ split; eauto .
2096
+ Qed .
2097
+
2098
+ Lemma expr_dag_beq_fueled_refl : forall fuel e, expr_dag_beq_fueled fuel e e = true.
2099
+ Proof using Type .
2100
+ induction fuel as [|fuel IH]; cbn.
2101
+ all: apply expr_dag_beq_fueled_step_refl; eauto.
2102
+ intros; apply unreflect_bool; reflexivity.
2103
+ Qed .
2104
+
2105
+ Lemma expr_dag_beq_fueled_step_sym_iff
2106
+ expr_dag_beq_fueled (Hexpr_dag_beq_fueled : forall e1 e2, expr_dag_beq_fueled e1 e2 = true <-> expr_dag_beq_fueled e2 e1 = true)
2107
+ : forall e1 e2, expr_dag_beq_fueled_step expr_dag_beq_fueled e1 e2 = true <-> expr_dag_beq_fueled_step expr_dag_beq_fueled e2 e1 = true.
2108
+ Proof using Type .
2109
+ clear ctx.
2110
+ induction e1, e2; cbn; break_innermost_match; try tauto.
2111
+ all: rewrite ?Bool.andb_true_iff.
2112
+ all: try (split; intro; reflect_hyps; subst; try apply unreflect_bool; try reflexivity; assumption).
2113
+ all: match goal with
2114
+ | [ |- ?a /\ ?b <-> ?a' /\ ?b' ]
2115
+ => cut ((a <-> a') /\ (b <-> b')); [ tauto | ]
2116
+ end .
2117
+ all: split; [ split; intros; reflect_hyps; subst; apply unreflect_bool; reflexivity | ].
2118
+ all: cbv [reveal_step] in *; break_innermost_match_hyps.
2119
+ all: lazymatch goal with
2120
+ | [ H : ExprApp _ = ExprApp _ |- _ ] => inversion H; clear H
2121
+ | [ H : ExprRef _ = ExprApp _ |- _ ] => inversion H
2122
+ | [ H : ExprApp _ = ExprRef _ |- _ ] => inversion H
2123
+ | _ => idtac
2124
+ end .
2125
+ all: subst.
2126
+ all: try erewrite (@map_ext _ _ (reveal_expr 0)), map_id in * by (intros; apply reveal_expr0).
2127
+ all: cbn [snd] in *.
2128
+ all: match goal with
2129
+ | [ |- list_beq _ _ ?x ?y = true <-> list_beq _ _ ?y ?x = true ]
2130
+ => generalize y; induction x as [|?? IH']; let y := fresh in intro y; destruct y
2131
+ end .
2132
+ all: cbn [list_beq snd]; try tauto.
2133
+ all: try match goal with H : Forall _ (_ :: _) |- _ => inversion H; clear H end.
2134
+ all: rewrite !Bool.andb_true_iff, IH'.
2135
+ all: cbn [snd] in *.
2136
+ all: try (eapply Forall_weaken; [ | eassumption ]).
2137
+ all: eauto.
2138
+ all: lazymatch goal with
2139
+ | [ |- ?a /\ ?b <-> ?a' /\ ?b' ]
2140
+ => cut ((a <-> a') /\ (b <-> b')); [ tauto | ]
2141
+ end .
2142
+ all: try (split; try tauto; eauto).
2143
+ Qed .
2144
+
2145
+ Lemma expr_dag_beq_fueled_sym_iff : forall fuel e1 e2, expr_dag_beq_fueled fuel e1 e2 = true <-> expr_dag_beq_fueled fuel e2 e1 = true.
2146
+ Proof using Type .
2147
+ induction fuel as [|fuel IH]; cbn.
2148
+ all: apply expr_dag_beq_fueled_step_sym_iff; eauto.
2149
+ intros; split; intros; reflect_hyps; apply unreflect_bool; subst; reflexivity.
2150
+ Qed .
2151
+
2152
+ Lemma expr_dag_beq_fueled_sym : forall fuel e1 e2, expr_dag_beq_fueled fuel e1 e2 = expr_dag_beq_fueled fuel e2 e1.
2153
+ Proof using Type .
2154
+ intros fuel e1 e2.
2155
+ generalize (@expr_dag_beq_fueled_sym_iff fuel e1 e2).
2156
+ do 2 destruct expr_dag_beq_fueled; intros; repeat split; destruct_head' iff; try tauto.
2157
+ symmetry; tauto.
2158
+ Qed .
2159
+
2160
+ Lemma eval_expr_dag_beq_fueled_step_impl
2161
+ expr_dag_beq_fueled (Hexpr_dag_beq_fueled : forall e v, eval e v -> forall e', expr_dag_beq_fueled e e' = true -> eval e' v)
2162
+
2163
+ : forall e v, eval e v -> forall e', expr_dag_beq_fueled_step expr_dag_beq_fueled e e' = true -> eval e' v.
2164
+ Proof using Type .
2165
+ induction 1, e'; cbn; cbv [reveal_step].
2166
+ all: break_innermost_match; intros; reflect_hyps; subst.
2167
+ all: try now exfalso.
2168
+ all: repeat first [ progress subst
2169
+ | progress inversion_option
2170
+ | progress inversion_pair
2171
+ | progress destruct_head'_and
2172
+ | exfalso; assumption
2173
+ | rewrite Bool.andb_true_iff in *
2174
+ | progress cbv [reveal_step] in *
2175
+ | progress reflect_hyps
2176
+ | progress break_innermost_match_hyps
2177
+ | erewrite (@map_ext _ _ (reveal_expr 0)), map_id in * by (intros; apply reveal_expr0) ].
2178
+ all: try solve [ econstructor; try eassumption; eapply Forall2_weaken; [ | eassumption ]; cbv beta; try tauto ].
2179
+ all: econstructor; try eassumption.
2180
+ all: rewrite ?Forall2_map_l in *.
2181
+ all: try match goal with H : dag.lookup _ _ = Some _ |- _ => clear H end .
2182
+ all: try match goal with H : interp_op _ _ _ = Some _ |- _ => clear H end.
2183
+ all: lazymatch goal with
2184
+ | [ H : Forall2 _ ?x ?y |- Forall2 _ ?x' ?y' ]
2185
+ => (revert dependent x' + idtac); (revert dependent y' + idtac); induction H; cbn in *; intros
2186
+ end .
2187
+ all: subst.
2188
+ all: break_innermost_match_hyps; try congruence.
2189
+ all: try solve [ constructor ].
2190
+ all: lazymatch goal with
2191
+ | [ H : List.map _ ?x = [] |- _ ] => is_var x; destruct x
2192
+ | [ H : List.map _ ?x = _ :: _ |- _ ] => is_var x; destruct x
2193
+ | _ => idtac
2194
+ end .
2195
+ all: cbn in *.
2196
+ all: inversion_list.
2197
+ all: repeat first [ progress subst
2198
+ | progress inversion_option
2199
+ | progress inversion_pair
2200
+ | progress destruct_head'_and
2201
+ | exfalso; assumption
2202
+ | rewrite Bool.andb_true_iff in *
2203
+ | progress cbv [reveal_step] in *
2204
+ | progress reflect_hyps
2205
+ | progress break_innermost_match_hyps
2206
+ | solve [ constructor; eauto ]
2207
+ | erewrite (@map_ext _ _ (reveal_expr 0)), map_id in * by (intros; apply reveal_expr0) ].
2208
+ Qed .
2209
+
2210
+ Lemma eval_expr_dag_beq_fueled_impl : forall fuel e v, eval e v -> forall e', expr_dag_beq_fueled fuel e e' = true -> eval e' v.
2211
+ Proof using Type .
2212
+ induction fuel as [|fuel IH]; cbn.
2213
+ all: apply eval_expr_dag_beq_fueled_step_impl; eauto.
2214
+ intros; intros; reflect_hyps; subst; eauto .
2215
+ Qed .
2216
+
2217
+ Lemma eval_expr_dag_beq_fueled : forall fuel e e', expr_dag_beq_fueled fuel e e' = true -> forall v, eval e v <-> eval e' v.
2218
+ Proof using Type .
2219
+ split; intro; eapply eval_expr_dag_beq_fueled_impl; eauto .
2220
+ eapply expr_dag_beq_fueled_sym_iff; eauto .
2221
+ Qed .
2222
+
2223
+ Lemma expr_dag_beq_refl : forall e, expr_dag_beq e e = true.
2224
+ Proof using Type . apply expr_dag_beq_fueled_refl. Qed .
2225
+
2226
+ Lemma expr_dag_beq_sym_iff : forall e1 e2, expr_dag_beq e1 e2 = true <-> expr_dag_beq e2 e1 = true.
2227
+ Proof using Type . apply expr_dag_beq_fueled_sym_iff. Qed .
2228
+
2229
+ Lemma expr_dag_beq_sym : forall e1 e2, expr_dag_beq e1 e2 = expr_dag_beq e2 e1.
2230
+ Proof using Type . apply expr_dag_beq_fueled_sym. Qed .
2231
+
2232
+ Lemma eval_expr_dag_beq_impl : forall e v, eval e v -> forall e', expr_dag_beq e e' = true -> eval e' v.
2233
+ Proof using Type . apply eval_expr_dag_beq_fueled_impl. Qed .
2234
+
2235
+ Lemma eval_expr_dag_beq : forall e e', expr_dag_beq e e' = true -> forall v, eval e v <-> eval e' v.
2236
+ Proof using Type . apply eval_expr_dag_beq_fueled. Qed .
2055
2237
End WithDag.
2056
2238
2057
2239
Definition merge_node {descr : description} (n : node idx) : dag.M idx
@@ -3352,7 +3534,7 @@ Proof. t; cbn in *; rewrite ?Z.shiftl_mul_pow2, ?Z.land_0_r by lia; repeat (lia
3352
3534
Definition split_consts (d : dag) (o : op) (i : Z) : list expr -> list (expr * Z)
3353
3535
:= List.map
3354
3536
(fun e
3355
- => match e with
3537
+ => match reveal_expr_at_least d 2 e with
3356
3538
| ExprApp (o', args)
3357
3539
=> if op_beq o' o
3358
3540
then
@@ -3387,7 +3569,7 @@ Definition group_consts (d : dag) (ls : list (expr * Z)) : list (expr * list Z)
3387
3569
| [] => None
3388
3570
| (e, z) :: xs => Some (e, z :: List.map snd xs)
3389
3571
end )
3390
- (List.groupAllBy (fun x y => expr_beq (fst x) (fst y)) ls).
3572
+ (List.groupAllBy (fun x y => expr_dag_beq d (fst x) (fst y)) ls).
3391
3573
3392
3574
(* o is like add *)
3393
3575
(* spec: if interp0_op o zs is always Some _, then Forall2 (fun '(e, zs) '(e', z) => e = e' /\ interp0_op o zs = Some z) input output *)
@@ -3405,7 +3587,7 @@ Definition app_consts (d : dag) (o : op) (ls : list (expr * Z)) : list expr
3405
3587
:= List.map (fun '(e, z) => let z := ExprApp (const z, []) in
3406
3588
let default := ExprApp (o, [e; z]) in
3407
3589
if associative o
3408
- then match e with
3590
+ then match reveal_expr_at_least d 1 e with
3409
3591
| ExprApp (o', args)
3410
3592
=> if op_beq o' o
3411
3593
then ExprApp (o, args ++ [z])
@@ -3414,7 +3596,7 @@ Definition app_consts (d : dag) (o : op) (ls : list (expr * Z)) : list expr
3414
3596
ls.
3415
3597
3416
3598
Definition combine_consts_pre [opts : symbolic_options_computed_opt] (d : dag) : expr -> expr :=
3417
- fun e => match e with ExprApp (o, args) =>
3599
+ fun e => match reveal_expr_at_least d 1 e with ExprApp (o, args) =>
3418
3600
if commutative o && associative o && op_always_interps o then match combines_to o with
3419
3601
| Some o' => match identity o' with
3420
3602
| Some idv =>
@@ -3424,11 +3606,11 @@ Definition combine_consts_pre [opts : symbolic_options_computed_opt] (d : dag) :
3424
3606
Definition cleanup_combine_consts [opts : symbolic_options_computed_opt] (d : dag) : expr -> expr :=
3425
3607
let simp_outside := List.fold_left (fun e f => f e) [flatten_associative d] in
3426
3608
let simp_inside := List.fold_left (fun e f => f e) [constprop d;drop_identity d;unary_truncate d;truncate_small d] in
3427
- fun e => simp_outside match e with ExprApp (o, args) =>
3609
+ fun e => simp_outside match reveal_expr_at_least d 1 e with ExprApp (o, args) =>
3428
3610
ExprApp (o, List.map simp_inside args)
3429
3611
| _ => e end .
3430
3612
3431
- Definition combine_consts [opts : symbolic_options_computed_opt] (d : dag) : expr -> expr := fun e => cleanup_combine_consts d (combine_consts_pre d (reveal_expr_at_least d 3 e) ).
3613
+ Definition combine_consts [opts : symbolic_options_computed_opt] (d : dag) : expr -> expr := fun e => cleanup_combine_consts d (combine_consts_pre d e ).
3432
3614
#[local] Instance describe_combine_consts : description_of Rewrite.combine_consts
3433
3615
:= "Rewrites expressions like (x + x * 5) into (x * 6)".
3434
3616
0 commit comments