Skip to content

Commit 2f6d0aa

Browse files
committed
WIP on equality test
1 parent af3fd80 commit 2f6d0aa

File tree

1 file changed

+188
-6
lines changed

1 file changed

+188
-6
lines changed

src/Assembly/Symbolic.v

+188-6
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ Require Import Crypto.Util.ListUtil.FoldMap. Import FoldMap.List.
4545
Require Import Crypto.Util.ListUtil.IndexOf. Import IndexOf.List.
4646
Require Import Crypto.Util.ListUtil.Forall.
4747
Require Import Crypto.Util.ListUtil.Permutation.
48+
Require Import Crypto.Util.ListUtil.Injective. Import Injective.List.
4849
Require Import Crypto.Util.ListUtil.Partition.
4950
Require Import Crypto.Util.ListUtil.Filter.
5051
Require Import Crypto.Util.ListUtil.PermutationCompat. Import ListUtil.PermutationCompat.Coq.Sorting.Permutation.
@@ -1937,6 +1938,28 @@ Section WithDag.
19371938
(_:interp_op ctx op args' = Some n)
19381939
: eval_node (op, args) n.
19391940

1941+
Definition expr_dag_beq_fueled_step (expr_dag_beq_fueled : expr -> expr -> bool) :=
1942+
fix expr_dag_beq (e1 e2 : expr) : bool :=
1943+
match e1, e2 with
1944+
| ExprRef i, ExprRef j => N.eqb i j
1945+
| ExprApp (op1, args1), ExprApp (op2, args2) =>
1946+
op_beq op1 op2 && list_beq _ expr_dag_beq args1 args2
1947+
| _, _ => match reveal_expr 1 e1, reveal_expr 1 e2 with
1948+
| ExprRef i, ExprRef j => N.eqb i j
1949+
| ExprApp (op1, args1), ExprApp (op2, args2) =>
1950+
op_beq op1 op2 && list_beq _ expr_dag_beq_fueled args1 args2
1951+
| _, _ => false
1952+
end end.
1953+
1954+
Fixpoint expr_dag_beq_fueled (fuel : nat) (e1 e2 : expr) : bool :=
1955+
expr_dag_beq_fueled_step
1956+
match fuel with
1957+
| O => expr_beq
1958+
| S fuel => expr_dag_beq_fueled fuel
1959+
end
1960+
e1 e2.
1961+
1962+
Definition expr_dag_beq e1 e2 := expr_dag_beq_fueled (S (N.to_nat (dag.size dag))) e1 e2.
19401963

19411964
Section eval_ind.
19421965
Context (P : expr -> Z -> Prop)
@@ -2052,6 +2075,165 @@ Section WithDag.
20522075
all: eapply Forall2_weaken; [ | eassumption ]; cbv beta.
20532076
all: intros; destruct_head'_and; eauto. }
20542077
Qed.
2078+
2079+
Lemma reveal0 : forall e, reveal 0 e = ExprRef e.
2080+
Proof using Type. reflexivity. Qed.
2081+
2082+
Lemma reveal_expr0 : forall e, reveal_expr 0 e = e.
2083+
Proof using Type. destruct e; cbn [reveal_expr]; break_innermost_match; now rewrite ?reveal0. Qed.
2084+
2085+
Lemma expr_dag_beq_fueled_step_refl
2086+
expr_dag_beq_fueled (Hexpr_dag_beq_fueled : forall e, expr_dag_beq_fueled e e = true)
2087+
: forall e, expr_dag_beq_fueled_step expr_dag_beq_fueled e e = true.
2088+
Proof using Type.
2089+
induction e; cbn; break_innermost_match; rewrite ?Bool.andb_true_iff; try split.
2090+
all: try apply unreflect_bool; try reflexivity.
2091+
cbn in *.
2092+
let H := match goal with H : Forall _ _ |- _ => H end in
2093+
induction H; cbn; trivial.
2094+
rewrite Bool.andb_true_iff.
2095+
split; eauto.
2096+
Qed.
2097+
2098+
Lemma expr_dag_beq_fueled_refl : forall fuel e, expr_dag_beq_fueled fuel e e = true.
2099+
Proof using Type.
2100+
induction fuel as [|fuel IH]; cbn.
2101+
all: apply expr_dag_beq_fueled_step_refl; eauto.
2102+
intros; apply unreflect_bool; reflexivity.
2103+
Qed.
2104+
2105+
Lemma expr_dag_beq_fueled_step_sym_iff
2106+
expr_dag_beq_fueled (Hexpr_dag_beq_fueled : forall e1 e2, expr_dag_beq_fueled e1 e2 = true <-> expr_dag_beq_fueled e2 e1 = true)
2107+
: forall e1 e2, expr_dag_beq_fueled_step expr_dag_beq_fueled e1 e2 = true <-> expr_dag_beq_fueled_step expr_dag_beq_fueled e2 e1 = true.
2108+
Proof using Type.
2109+
clear ctx.
2110+
induction e1, e2; cbn; break_innermost_match; try tauto.
2111+
all: rewrite ?Bool.andb_true_iff.
2112+
all: try (split; intro; reflect_hyps; subst; try apply unreflect_bool; try reflexivity; assumption).
2113+
all: match goal with
2114+
| [ |- ?a /\ ?b <-> ?a' /\ ?b' ]
2115+
=> cut ((a <-> a') /\ (b <-> b')); [ tauto | ]
2116+
end.
2117+
all: split; [ split; intros; reflect_hyps; subst; apply unreflect_bool; reflexivity | ].
2118+
all: cbv [reveal_step] in *; break_innermost_match_hyps.
2119+
all: lazymatch goal with
2120+
| [ H : ExprApp _ = ExprApp _ |- _ ] => inversion H; clear H
2121+
| [ H : ExprRef _ = ExprApp _ |- _ ] => inversion H
2122+
| [ H : ExprApp _ = ExprRef _ |- _ ] => inversion H
2123+
| _ => idtac
2124+
end.
2125+
all: subst.
2126+
all: try erewrite (@map_ext _ _ (reveal_expr 0)), map_id in * by (intros; apply reveal_expr0).
2127+
all: cbn [snd] in *.
2128+
all: match goal with
2129+
| [ |- list_beq _ _ ?x ?y = true <-> list_beq _ _ ?y ?x = true ]
2130+
=> generalize y; induction x as [|?? IH']; let y := fresh in intro y; destruct y
2131+
end.
2132+
all: cbn [list_beq snd]; try tauto.
2133+
all: try match goal with H : Forall _ (_ :: _) |- _ => inversion H; clear H end.
2134+
all: rewrite !Bool.andb_true_iff, IH'.
2135+
all: cbn [snd] in *.
2136+
all: try (eapply Forall_weaken; [ | eassumption ]).
2137+
all: eauto.
2138+
all: lazymatch goal with
2139+
| [ |- ?a /\ ?b <-> ?a' /\ ?b' ]
2140+
=> cut ((a <-> a') /\ (b <-> b')); [ tauto | ]
2141+
end.
2142+
all: try (split; try tauto; eauto).
2143+
Qed.
2144+
2145+
Lemma expr_dag_beq_fueled_sym_iff : forall fuel e1 e2, expr_dag_beq_fueled fuel e1 e2 = true <-> expr_dag_beq_fueled fuel e2 e1 = true.
2146+
Proof using Type.
2147+
induction fuel as [|fuel IH]; cbn.
2148+
all: apply expr_dag_beq_fueled_step_sym_iff; eauto.
2149+
intros; split; intros; reflect_hyps; apply unreflect_bool; subst; reflexivity.
2150+
Qed.
2151+
2152+
Lemma expr_dag_beq_fueled_sym : forall fuel e1 e2, expr_dag_beq_fueled fuel e1 e2 = expr_dag_beq_fueled fuel e2 e1.
2153+
Proof using Type.
2154+
intros fuel e1 e2.
2155+
generalize (@expr_dag_beq_fueled_sym_iff fuel e1 e2).
2156+
do 2 destruct expr_dag_beq_fueled; intros; repeat split; destruct_head' iff; try tauto.
2157+
symmetry; tauto.
2158+
Qed.
2159+
2160+
Lemma eval_expr_dag_beq_fueled_step_impl
2161+
expr_dag_beq_fueled (Hexpr_dag_beq_fueled : forall e v, eval e v -> forall e', expr_dag_beq_fueled e e' = true -> eval e' v)
2162+
2163+
: forall e v, eval e v -> forall e', expr_dag_beq_fueled_step expr_dag_beq_fueled e e' = true -> eval e' v.
2164+
Proof using Type.
2165+
induction 1, e'; cbn; cbv [reveal_step].
2166+
all: break_innermost_match; intros; reflect_hyps; subst.
2167+
all: try now exfalso.
2168+
all: repeat first [ progress subst
2169+
| progress inversion_option
2170+
| progress inversion_pair
2171+
| progress destruct_head'_and
2172+
| exfalso; assumption
2173+
| rewrite Bool.andb_true_iff in *
2174+
| progress cbv [reveal_step] in *
2175+
| progress reflect_hyps
2176+
| progress break_innermost_match_hyps
2177+
| erewrite (@map_ext _ _ (reveal_expr 0)), map_id in * by (intros; apply reveal_expr0) ].
2178+
all: try solve [ econstructor; try eassumption; eapply Forall2_weaken; [ | eassumption ]; cbv beta; try tauto ].
2179+
all: econstructor; try eassumption.
2180+
all: rewrite ?Forall2_map_l in *.
2181+
all: try match goal with H : dag.lookup _ _ = Some _ |- _ => clear H end.
2182+
all: try match goal with H : interp_op _ _ _ = Some _ |- _ => clear H end.
2183+
all: lazymatch goal with
2184+
| [ H : Forall2 _ ?x ?y |- Forall2 _ ?x' ?y' ]
2185+
=> (revert dependent x' + idtac); (revert dependent y' + idtac); induction H; cbn in *; intros
2186+
end.
2187+
all: subst.
2188+
all: break_innermost_match_hyps; try congruence.
2189+
all: try solve [ constructor ].
2190+
all: lazymatch goal with
2191+
| [ H : List.map _ ?x = [] |- _ ] => is_var x; destruct x
2192+
| [ H : List.map _ ?x = _ :: _ |- _ ] => is_var x; destruct x
2193+
| _ => idtac
2194+
end.
2195+
all: cbn in *.
2196+
all: inversion_list.
2197+
all: repeat first [ progress subst
2198+
| progress inversion_option
2199+
| progress inversion_pair
2200+
| progress destruct_head'_and
2201+
| exfalso; assumption
2202+
| rewrite Bool.andb_true_iff in *
2203+
| progress cbv [reveal_step] in *
2204+
| progress reflect_hyps
2205+
| progress break_innermost_match_hyps
2206+
| solve [ constructor; eauto ]
2207+
| erewrite (@map_ext _ _ (reveal_expr 0)), map_id in * by (intros; apply reveal_expr0) ].
2208+
Qed.
2209+
2210+
Lemma eval_expr_dag_beq_fueled_impl : forall fuel e v, eval e v -> forall e', expr_dag_beq_fueled fuel e e' = true -> eval e' v.
2211+
Proof using Type.
2212+
induction fuel as [|fuel IH]; cbn.
2213+
all: apply eval_expr_dag_beq_fueled_step_impl; eauto.
2214+
intros; intros; reflect_hyps; subst; eauto.
2215+
Qed.
2216+
2217+
Lemma eval_expr_dag_beq_fueled : forall fuel e e', expr_dag_beq_fueled fuel e e' = true -> forall v, eval e v <-> eval e' v.
2218+
Proof using Type.
2219+
split; intro; eapply eval_expr_dag_beq_fueled_impl; eauto.
2220+
eapply expr_dag_beq_fueled_sym_iff; eauto.
2221+
Qed.
2222+
2223+
Lemma expr_dag_beq_refl : forall e, expr_dag_beq e e = true.
2224+
Proof using Type. apply expr_dag_beq_fueled_refl. Qed.
2225+
2226+
Lemma expr_dag_beq_sym_iff : forall e1 e2, expr_dag_beq e1 e2 = true <-> expr_dag_beq e2 e1 = true.
2227+
Proof using Type. apply expr_dag_beq_fueled_sym_iff. Qed.
2228+
2229+
Lemma expr_dag_beq_sym : forall e1 e2, expr_dag_beq e1 e2 = expr_dag_beq e2 e1.
2230+
Proof using Type. apply expr_dag_beq_fueled_sym. Qed.
2231+
2232+
Lemma eval_expr_dag_beq_impl : forall e v, eval e v -> forall e', expr_dag_beq e e' = true -> eval e' v.
2233+
Proof using Type. apply eval_expr_dag_beq_fueled_impl. Qed.
2234+
2235+
Lemma eval_expr_dag_beq : forall e e', expr_dag_beq e e' = true -> forall v, eval e v <-> eval e' v.
2236+
Proof using Type. apply eval_expr_dag_beq_fueled. Qed.
20552237
End WithDag.
20562238

20572239
Definition merge_node {descr : description} (n : node idx) : dag.M idx
@@ -3352,7 +3534,7 @@ Proof. t; cbn in *; rewrite ?Z.shiftl_mul_pow2, ?Z.land_0_r by lia; repeat (lia
33523534
Definition split_consts (d : dag) (o : op) (i : Z) : list expr -> list (expr * Z)
33533535
:= List.map
33543536
(fun e
3355-
=> match e with
3537+
=> match reveal_expr_at_least d 2 e with
33563538
| ExprApp (o', args)
33573539
=> if op_beq o' o
33583540
then
@@ -3387,7 +3569,7 @@ Definition group_consts (d : dag) (ls : list (expr * Z)) : list (expr * list Z)
33873569
| [] => None
33883570
| (e, z) :: xs => Some (e, z :: List.map snd xs)
33893571
end)
3390-
(List.groupAllBy (fun x y => expr_beq (fst x) (fst y)) ls).
3572+
(List.groupAllBy (fun x y => expr_dag_beq d (fst x) (fst y)) ls).
33913573

33923574
(* o is like add *)
33933575
(* spec: if interp0_op o zs is always Some _, then Forall2 (fun '(e, zs) '(e', z) => e = e' /\ interp0_op o zs = Some z) input output *)
@@ -3405,7 +3587,7 @@ Definition app_consts (d : dag) (o : op) (ls : list (expr * Z)) : list expr
34053587
:= List.map (fun '(e, z) => let z := ExprApp (const z, []) in
34063588
let default := ExprApp (o, [e; z]) in
34073589
if associative o
3408-
then match e with
3590+
then match reveal_expr_at_least d 1 e with
34093591
| ExprApp (o', args)
34103592
=> if op_beq o' o
34113593
then ExprApp (o, args ++ [z])
@@ -3414,7 +3596,7 @@ Definition app_consts (d : dag) (o : op) (ls : list (expr * Z)) : list expr
34143596
ls.
34153597

34163598
Definition combine_consts_pre [opts : symbolic_options_computed_opt] (d : dag) : expr -> expr :=
3417-
fun e => match e with ExprApp (o, args) =>
3599+
fun e => match reveal_expr_at_least d 1 e with ExprApp (o, args) =>
34183600
if commutative o && associative o && op_always_interps o then match combines_to o with
34193601
| Some o' => match identity o' with
34203602
| Some idv =>
@@ -3424,11 +3606,11 @@ Definition combine_consts_pre [opts : symbolic_options_computed_opt] (d : dag) :
34243606
Definition cleanup_combine_consts [opts : symbolic_options_computed_opt] (d : dag) : expr -> expr :=
34253607
let simp_outside := List.fold_left (fun e f => f e) [flatten_associative d] in
34263608
let simp_inside := List.fold_left (fun e f => f e) [constprop d;drop_identity d;unary_truncate d;truncate_small d] in
3427-
fun e => simp_outside match e with ExprApp (o, args) =>
3609+
fun e => simp_outside match reveal_expr_at_least d 1 e with ExprApp (o, args) =>
34283610
ExprApp (o, List.map simp_inside args)
34293611
| _ => e end.
34303612

3431-
Definition combine_consts [opts : symbolic_options_computed_opt] (d : dag) : expr -> expr := fun e => cleanup_combine_consts d (combine_consts_pre d (reveal_expr_at_least d 3 e)).
3613+
Definition combine_consts [opts : symbolic_options_computed_opt] (d : dag) : expr -> expr := fun e => cleanup_combine_consts d (combine_consts_pre d e).
34323614
#[local] Instance describe_combine_consts : description_of Rewrite.combine_consts
34333615
:= "Rewrites expressions like (x + x * 5) into (x * 6)".
34343616

0 commit comments

Comments
 (0)