Skip to content

Commit ebf3f06

Browse files
committed
Merge branch 'main' into extract-memory-effects
2 parents 0aadff5 + a499a57 commit ebf3f06

8 files changed

+141
-74
lines changed

Arm/Insts/Common.lean

Lines changed: 27 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -398,22 +398,16 @@ dsimproc [state_simp_rules] reduceInvalidBitMasks (invalid_bit_masks _ _ _ _) :=
398398
let imm ← simp imm
399399
let M ← simp M
400400
let some ⟨immN_width, immN⟩ ← getBitVecValue? immN.expr | return .continue
401-
if h1 : ¬ (immN_width = 1) then
402-
return .continue
403-
else
404-
let some ⟨imms_width, imms⟩ ← getBitVecValue? imms.expr | return .continue
405-
if h2 : ¬ (imms_width = 6) then
406-
return .continue
407-
else
408-
let some M ← Nat.fromExpr? M.expr | return .continue
409-
have h1' : immN_width = 1 := by simp_all only [Decidable.not_not]
410-
have h2' : imms_width = 6 := by simp_all only [Decidable.not_not]
411-
return .done <|
412-
toExpr (invalid_bit_masks
413-
(BitVec.cast h1' immN)
414-
(BitVec.cast h2' imms)
415-
imm.expr.isTrue
416-
M)
401+
let some ⟨imms_width, imms⟩ ← getBitVecValue? imms.expr | return .continue
402+
if h : immN_width = 1 ∧ imms_width = 6 then
403+
let some M ← Nat.fromExpr? M.expr | return .continue
404+
return .done <|
405+
toExpr (invalid_bit_masks
406+
(BitVec.cast (by simp_all only) immN)
407+
(BitVec.cast (by simp_all only) imms)
408+
imm.expr.isTrue
409+
M)
410+
else return .continue
417411

418412
theorem Nat.lt_one_iff {n : Nat} : n < 1 ↔ n = 0 := by
419413
omega
@@ -442,8 +436,9 @@ theorem M_divisible_by_esize_of_valid_bit_masks (immN : BitVec 1) (imms : BitVec
442436
-- https://kddnewton.com/2022/08/11/aarch64-bitmask-immediates.html
443437
-- Arm Implementation:
444438
-- https://developer.arm.com/documentation/ddi0602/2023-12/Shared-Pseudocode/aarch64-functions-bitmasks?lang=en#impl-aarch64.DecodeBitMasks.5
445-
def decode_bit_masks (immN : BitVec 1) (imms : BitVec 6) (immr : BitVec 6)
446-
(immediate : Bool) (M : Nat) : Option (BitVec M × BitVec M) :=
439+
def decode_bit_masks (immN : BitVec 1) (imms immr : BitVec 6)
440+
(immediate : Bool) (M : Nat) :
441+
Option (BitVec M × BitVec M) :=
447442
if h0 : invalid_bit_masks immN imms immediate M then none
448443
else
449444
let len := Option.get! $ highest_set_bit $ immN ++ ~~~imms
@@ -463,36 +458,27 @@ def decode_bit_masks (immN : BitVec 1) (imms : BitVec 6) (immr : BitVec 6)
463458
some (BitVec.cast h wmask, BitVec.cast h tmask)
464459

465460
open Lean Meta Simp in
466-
dsimproc [state_simp_rules] reduceDecodeBitMasks (decode_bit_masks _ _ _ _ _) := fun e => do
461+
dsimproc [state_simp_rules] reduceDecodeBitMasks (decode_bit_masks _ _ _ _ _) :=
462+
fun e => do
467463
let_expr decode_bit_masks immN imms immr imm M ← e | return .continue
468464
let immN ← simp immN
469465
let imms ← simp imms
470466
let immr ← simp immr
471467
let imm ← simp imm
472468
let M ← simp M
473469
let some ⟨immN_width, immN⟩ ← getBitVecValue? immN.expr | return .continue
474-
if h1 : ¬ (immN_width = 1) then
475-
return .continue
476-
else
477-
let some ⟨imms_width, imms⟩ ← getBitVecValue? imms.expr | return .continue
478-
if h2 : ¬ (imms_width = 6) then
479-
return .continue
480-
else
481-
let some ⟨immr_width, immr⟩ ← getBitVecValue? immr.expr | return .continue
482-
if h3 : ¬ (immr_width = 6) then
483-
return .continue
484-
else
485-
let some M ← Nat.fromExpr? M.expr | return .continue
486-
have h1' : immN_width = 1 := by simp_all only [Decidable.not_not]
487-
have h2' : imms_width = 6 := by simp_all only [Decidable.not_not]
488-
have h3' : immr_width = 6 := by simp_all only [Decidable.not_not]
489-
return .done <|
490-
toExpr (decode_bit_masks
491-
(BitVec.cast h1' immN)
492-
(BitVec.cast h2' imms)
493-
(BitVec.cast h3' immr)
494-
imm.expr.isTrue
495-
M)
470+
let some ⟨imms_width, imms⟩ ← getBitVecValue? imms.expr | return .continue
471+
let some ⟨immr_width, immr⟩ ← getBitVecValue? immr.expr | return .continue
472+
if h : immN_width = 1 ∧ imms_width = 6 ∧ immr_width = 6 then
473+
let some M ← Nat.fromExpr? M.expr | return .continue
474+
return .done <|
475+
toExpr (decode_bit_masks
476+
(BitVec.cast (by simp_all only) immN)
477+
(BitVec.cast (by simp_all only) imms)
478+
(BitVec.cast (by simp_all only) immr)
479+
imm.expr.isTrue
480+
M)
481+
else return .continue
496482

497483
----------------------------------------------------------------------
498484

@@ -664,7 +650,6 @@ structure ShiftInfo where
664650
unsigned := true
665651
round := false
666652
accumulate := false
667-
h : esize > 0
668653
deriving DecidableEq, Repr
669654

670655
export ShiftInfo (esize elements shift unsigned round accumulate)

Arm/Insts/DPSFP/Advanced_simd_modified_immediate.lean

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,18 @@ def AdvSIMDExpandImm (op : BitVec 1) (cmode : BitVec 4) (imm8 : BitVec 8) : BitV
9090
lsb imm8 7 ++ ~~~(lsb imm8 6) ++
9191
(replicate 8 $ lsb imm8 6) ++ extractLsb' 0 6 imm8 ++ BitVec.zero 48
9292

93+
open Lean Meta Simp in
94+
dsimproc [state_simp_rules] reduceAdvSIMDExpandImm (AdvSIMDExpandImm _ _ _) := fun e => do
95+
let_expr AdvSIMDExpandImm op cmode imm8 ← e | return .continue
96+
let some ⟨op_n, op⟩ ← getBitVecValue? op | return .continue
97+
let some ⟨cmode_n, cmode⟩ ← getBitVecValue? cmode | return .continue
98+
let some ⟨imm8_n, imm8⟩ ← getBitVecValue? imm8 | return .continue
99+
if h : op_n = 1 ∧ cmode_n = 4 ∧ imm8_n = 8 then
100+
return .done <| toExpr (AdvSIMDExpandImm
101+
(BitVec.cast (by simp_all only) op)
102+
(BitVec.cast (by simp_all only) cmode)
103+
(BitVec.cast (by simp_all only) imm8))
104+
else return .continue
93105

94106
private theorem mul_div_norm_form_lemma (n m : Nat) (_h1 : 0 < m) (h2 : n ∣ m) :
95107
(n * (m / n)) = n * m / n := by

Arm/Insts/DPSFP/Advanced_simd_scalar_shift_by_immediate.lean

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,17 +22,14 @@ def exec_shift_right_scalar
2222
write_err (StateError.Illegal s!"Illegal {inst} encountered!") s
2323
else
2424
let esize := 8 <<< 3
25-
have h : esize > 0 := by decide
2625
let datasize := esize
2726
let (info : ShiftInfo) :=
2827
{ esize := esize,
2928
elements := 1,
3029
shift := (esize * 2) - (inst.immh ++ inst.immb).toNat,
3130
unsigned := inst.U = 0b1#1,
3231
round := (lsb inst.opcode 2) = 0b1#1,
33-
accumulate := (lsb inst.opcode 1) = 0b1#1,
34-
h := h
35-
}
32+
accumulate := (lsb inst.opcode 1) = 0b1#1 }
3633
let result := shift_right_common info datasize inst.Rn inst.Rd s
3734
-- State Update
3835
let s := write_sfp datasize inst.Rd result s
@@ -46,14 +43,11 @@ def exec_shl_scalar
4643
write_err (StateError.Illegal s!"Illegal {inst} encountered!") s
4744
else
4845
let esize := 8 <<< 3
49-
have h : esize > 0 := by decide
5046
let datasize := esize
5147
let (info : ShiftInfo) :=
5248
{ esize := esize,
5349
elements := 1,
54-
shift := (inst.immh ++ inst.immb).toNat - esize,
55-
h := h
56-
}
50+
shift := (inst.immh ++ inst.immb).toNat - esize }
5751
let result := shift_left_common info datasize inst.Rn s
5852
-- State Update
5953
let s := write_sfp datasize inst.Rd result s

Arm/Insts/DPSFP/Advanced_simd_shift_by_immediate.lean

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,18 +32,14 @@ def exec_shift_right_vector
3232
else
3333
let l := highest_set_bit inst.immh
3434
let esize := 8 <<< l
35-
have h : esize > 0 := by
36-
simp only [esize]
37-
apply zero_lt_shift_left_pos (by decide)
3835
let datasize := 64 <<< inst.Q.toNat
3936
let (info : ShiftInfo) :=
4037
{ esize := esize,
4138
elements := datasize / esize,
4239
shift := (2 * esize) - (inst.immh ++ inst.immb).toNat,
4340
unsigned := inst.U = 0b1#1,
4441
round := (lsb inst.opcode 2) = 0b1#1,
45-
accumulate := (lsb inst.opcode 1) = 0b1#1,
46-
h := h }
42+
accumulate := (lsb inst.opcode 1) = 0b1#1 }
4743
let result := shift_right_common info datasize inst.Rn inst.Rd s
4844
-- State Update
4945
let s := write_sfp datasize inst.Rd result s
@@ -58,15 +54,11 @@ def exec_shl_vector
5854
else
5955
let l := highest_set_bit inst.immh
6056
let esize := 8 <<< l
61-
have h : esize > 0 := by
62-
simp only [esize]
63-
apply zero_lt_shift_left_pos (by decide)
6457
let datasize := 64 <<< inst.Q.toNat
6558
let (info : ShiftInfo) :=
6659
{ esize := esize,
6760
elements := datasize / esize,
68-
shift := (inst.immh ++ inst.immb).toNat - esize,
69-
h := h }
61+
shift := (inst.immh ++ inst.immb).toNat - esize }
7062
let result := shift_left_common info datasize inst.Rn s
7163
-- State Update
7264
let s := write_sfp datasize inst.Rd result s

Arm/Insts/DPSFP/Advanced_simd_three_different.lean

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@ Released under Apache 2.0 license as described in the file LICENSE.
44
Author(s): Yan Peng
55
-/
66
-- PMULL and PMULL2
7-
-- Polynomial arithmetic over {0,1}: https://tiny.amazon.com/5h01fjm6/devearmdocuddi0cApplApplPoly
7+
-- Polynomial arithmetic over {0,1}:
8+
-- Ref.:
9+
-- https://developer.arm.com/documentation/ddi0602/2024-09/SIMD-FP-Instructions/PMULL--PMULL2--Polynomial-multiply-long-?lang=en
810

911
import Arm.Decode
1012
import Arm.State

Proofs/AES-GCM/GCMGmultV8Sym.lean

Lines changed: 83 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
/-
22
Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
33
Released under Apache 2.0 license as described in the file LICENSE.
4-
Author(s): Alex Keizer
4+
Author(s): Alex Keizer, Shilpi Goel
55
-/
6+
import Specs.GCMV8
67
import Tests.«AES-GCM».GCMGmultV8Program
78
import Tactics.Sym
89
import Tactics.Aggregate
910
import Tactics.StepThms
1011
import Tactics.CSE
12+
import Tactics.ClearNamed
1113
import Arm.Memory.SeparateAutomation
1214
import Arm.Syntax
1315

@@ -16,9 +18,33 @@ open ArmStateNotation
1618

1719
#genStepEqTheorems gcm_gmult_v8_program
1820

19-
/-
20-
xxx: GCMGmultV8 Xi HTable
21-
-/
21+
private theorem lsb_from_extractLsb'_of_append_self (x : BitVec 128) :
22+
BitVec.extractLsb' 64 64 (BitVec.extractLsb' 64 128 (x ++ x)) =
23+
BitVec.extractLsb' 0 64 x := by
24+
bv_decide
25+
26+
private theorem msb_from_extractLsb'_of_append_self (x : BitVec 128) :
27+
BitVec.extractLsb' 0 64 (BitVec.extractLsb' 64 128 (x ++ x)) =
28+
BitVec.extractLsb' 64 64 x := by
29+
bv_decide
30+
31+
theorem extractLsb'_zero_extractLsb'_of_le (h : len1 ≤ len2) :
32+
BitVec.extractLsb' 0 len1 (BitVec.extractLsb' start len2 x) =
33+
BitVec.extractLsb' start len1 x := by
34+
apply BitVec.eq_of_getLsbD_eq; intro i
35+
simp only [BitVec.getLsbD_extractLsb', Fin.is_lt,
36+
decide_True, Nat.zero_add, Bool.true_and,
37+
Bool.and_iff_right_iff_imp, decide_eq_true_eq]
38+
omega
39+
40+
theorem extractLsb'_extractLsb'_zero_of_le (h : start + len1 ≤ len2):
41+
BitVec.extractLsb' start len1 (BitVec.extractLsb' 0 len2 x) =
42+
BitVec.extractLsb' start len1 x := by
43+
apply BitVec.eq_of_getLsbD_eq; intro i
44+
simp only [BitVec.getLsbD_extractLsb', Fin.is_lt,
45+
decide_True, Nat.zero_add, Bool.true_and,
46+
Bool.and_iff_right_iff_imp, decide_eq_true_eq]
47+
omega
2248

2349
set_option pp.deepTerms false in
2450
set_option pp.deepTerms.threshold 50 in
@@ -29,10 +55,10 @@ theorem gcm_gmult_v8_program_run_27 (s0 sf : ArmState)
2955
(h_s0_pc : read_pc s0 = gcm_gmult_v8_program.min)
3056
(h_s0_sp_aligned : CheckSPAlignment s0)
3157
(h_Xi : Xi = s0[read_gpr 64 0#5 s0, 16])
32-
(h_HTable : HTable = s0[read_gpr 64 1#5 s0, 256])
58+
(h_HTable : HTable = s0[read_gpr 64 1#5 s0, 32])
3359
(h_mem_sep : Memory.Region.pairwiseSeparate
3460
[(read_gpr 64 0#5 s0, 16),
35-
(read_gpr 64 1#5 s0, 256)])
61+
(read_gpr 64 1#5 s0, 32)])
3662
(h_run : sf = run gcm_gmult_v8_program.length s0) :
3763
-- The final state is error-free.
3864
read_err sf = .None ∧
@@ -42,8 +68,11 @@ theorem gcm_gmult_v8_program_run_27 (s0 sf : ArmState)
4268
CheckSPAlignment sf ∧
4369
-- The final state returns to the address in register `x30` in `s0`.
4470
read_pc sf = r (StateField.GPR 30#5) s0 ∧
71+
-- (TODO) Delete the following conjunct because it is covered by the
72+
-- MEM_UNCHANGED_EXCEPT frame condition. We keep it around because it
73+
-- exposes the issue with `simp_mem` that @bollu will fix.
4574
-- HTable is unmodified.
46-
sf[read_gpr 64 1#5 s0, 256] = HTable ∧
75+
sf[read_gpr 64 1#5 s0, 32] = HTable ∧
4776
-- Frame conditions.
4877
-- Note that the following also covers that the Xi address in .GPR 0
4978
-- is unmodified.
@@ -52,8 +81,11 @@ theorem gcm_gmult_v8_program_run_27 (s0 sf : ArmState)
5281
.SFP 21, .PC]
5382
(sf, s0) ∧
5483
-- Memory frame condition.
55-
MEM_UNCHANGED_EXCEPT [(r (.GPR 0) s0, 128)] (sf, s0) := by
56-
simp_all only [state_simp_rules, -h_run] -- prelude
84+
MEM_UNCHANGED_EXCEPT [(r (.GPR 0) s0, 16)] (sf, s0) ∧
85+
sf[r (.GPR 0) s0, 16] = GCMV8.GCMGmultV8_alt (HTable.extractLsb' 0 128) Xi := by
86+
-- Prelude
87+
simp_all only [state_simp_rules, -h_run]
88+
simp only [Nat.reduceMul] at Xi HTable
5789
simp (config := {ground := true}) only at h_s0_pc
5890
-- ^^ Still needed, because `gcm_gmult_v8_program.min` is somehow
5991
-- unable to be reflected
@@ -94,4 +126,46 @@ theorem gcm_gmult_v8_program_run_27 (s0 sf : ArmState)
94126
simp_mem (config := { useOmegaToClose := false })
95127
-- Aggregate the memory (non)effects.
96128
simp only [*]
129+
· clear_named [h_s, stepi_]
130+
clear s1 s2 s3 s4 s5 s6 s7 s8 s9 s10 s11 s12 s13 s14 s15 s16 s17 s18 s19 s20 s21 s22 s23 s24 s25 s26
131+
-- Simplifying the LHS
132+
have h_HTable_low :
133+
Memory.read_bytes 16 (r (StateField.GPR 1#5) s0) s0.mem = HTable.extractLsb' 0 128 := by
134+
-- (FIXME @bollu) use `simp_mem` instead of the rw below.
135+
rw [@Memory.read_bytes_eq_extractLsBytes_sub_of_mem_subset'
136+
32 (r (StateField.GPR 1#5) s0) HTable (r (StateField.GPR 1#5) s0) 16 _ h_HTable.symm]
137+
· simp only [Nat.reduceMul, BitVec.extractLsBytes, Nat.sub_self, Nat.zero_mul]
138+
· simp_mem
139+
have h_HTable_high :
140+
(Memory.read_bytes 16 (r (StateField.GPR 1#5) s0 + 16#64) s0.mem) = HTable.extractLsb' 128 128 := by
141+
-- (FIXME @bollu) use `simp_mem` instead of the rw below.
142+
rw [@Memory.read_bytes_eq_extractLsBytes_sub_of_mem_subset'
143+
32 (r (StateField.GPR 1#5) s0) HTable (r (StateField.GPR 1#5) s0 + 16#64) 16 _ h_HTable.symm]
144+
repeat sorry
145+
simp only [h_HTable_high, h_HTable_low, ←h_Xi]
146+
/-
147+
simp/ground below to reduce
148+
(BitVec.extractLsb' 0 64
149+
(shift_left_common_aux 0
150+
{ esize := 64, elements := 2, shift := 57, unsigned := true, round := false,
151+
accumulate := false }
152+
300249147283180997173565830086854304225#128 0#128))
153+
-/
154+
simp (config := {ground := true}) only
155+
simp only [msb_from_extractLsb'_of_append_self,
156+
lsb_from_extractLsb'_of_append_self,
157+
BitVec.partInstall]
158+
-- (FIXME @bollu) cse leaves the goal unchanged here, quietly, likely due to
159+
-- subexpressions occurring in dep. contexts. Maybe a message here would be helpful.
160+
generalize h_Xi_rev : rev_vector 128 64 8 Xi _ _ _ _ _ = Xi_rev
161+
-- Simplifying the RHS
162+
simp only [←h_HTable, GCMV8.GCMGmultV8_alt,
163+
GCMV8.lo, GCMV8.hi,
164+
GCMV8.gcm_polyval]
165+
repeat rw [extractLsb'_zero_extractLsb'_of_le (by decide)]
166+
repeat rw [extractLsb'_extractLsb'_zero_of_le (by decide)]
167+
168+
sorry
97169
done
170+
171+
end GCMGmultV8Program

Proofs/Popcount32.lean

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,6 @@ def popcount32_program : Program :=
7070

7171
#genStepEqTheorems popcount32_program
7272

73-
set_option trace.simp_mem.info true in
7473
theorem popcount32_sym_meets_spec (s0 sf : ArmState)
7574
(h_s0_pc : read_pc s0 = 0x4005b4#64)
7675
(h_s0_program : s0.program = popcount32_program)

Specs/GCMV8.lean

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -131,10 +131,10 @@ def refpoly : BitVec 129 := 0x1C2000000000000000000000000000001#129
131131
private def gcm_init_H (H : BitVec 128) : BitVec 128 :=
132132
pmod (H ++ 0b0#1) refpoly (by omega)
133133

134-
private def gcm_polyval_mul (x : BitVec 128) (y : BitVec 128) : BitVec 256 :=
134+
def gcm_polyval_mul (x : BitVec 128) (y : BitVec 128) : BitVec 256 :=
135135
0b0#1 ++ pmult x y
136136

137-
private def gcm_polyval_red (x : BitVec 256) : BitVec 128 :=
137+
def gcm_polyval_red (x : BitVec 256) : BitVec 128 :=
138138
reverse $ pmod (reverse x) irrepoly (by omega)
139139

140140
/--
@@ -146,7 +146,7 @@ private def gcm_polyval_red (x : BitVec 256) : BitVec 128 :=
146146
"A New Interpretation for the GHASH Authenticator of AES-GCM"
147147
2. Lemma: reverse (pmult x y) = pmult (reverse x) (reverse y)
148148
-/
149-
private def gcm_polyval (x : BitVec 128) (y : BitVec 128) : BitVec 128 :=
149+
def gcm_polyval (x : BitVec 128) (y : BitVec 128) : BitVec 128 :=
150150
GCMV8.gcm_polyval_red $ GCMV8.gcm_polyval_mul x y
151151

152152
/-- GCMInitV8 specification:
@@ -203,6 +203,15 @@ def GCMGmultV8 (H : BitVec 128) (Xi : List (BitVec 8)) (h : 8 * Xi.length = 128)
203203
let H := (lo H) ++ (hi H)
204204
split (GCMV8.gcm_polyval H (BitVec.cast h (BitVec.flatten Xi))) 8 (by omega)
205205

206+
/-- Alternative GCMGmultV8 specification that does not use lists:
207+
H : BitVec 128 -- the first element in Htable, not the initial H input to GCMInitV8
208+
Xi : BitVec 128 -- current hash value
209+
output : BitVec 128 -- next hash value
210+
-/
211+
def GCMGmultV8_alt (H : BitVec 128) (Xi : BitVec 128) : BitVec 128 :=
212+
let H := (lo H) ++ (hi H)
213+
gcm_polyval H Xi
214+
206215
set_option maxRecDepth 8000 in
207216
example : GCMGmultV8 0x1099f4b39468565ccdd297a9df145877#128
208217
[ 0x10#8, 0x54#8, 0x43#8, 0xb0#8, 0x2c#8, 0x4b#8, 0x1f#8, 0x24#8,

0 commit comments

Comments
 (0)