Skip to content

Commit

Permalink
Merge branch 'main' into extract-memory-effects
Browse files Browse the repository at this point in the history
  • Loading branch information
alexkeizer committed Oct 9, 2024
2 parents 0aadff5 + a499a57 commit ebf3f06
Show file tree
Hide file tree
Showing 8 changed files with 141 additions and 74 deletions.
69 changes: 27 additions & 42 deletions Arm/Insts/Common.lean
Original file line number Diff line number Diff line change
Expand Up @@ -398,22 +398,16 @@ dsimproc [state_simp_rules] reduceInvalidBitMasks (invalid_bit_masks _ _ _ _) :=
let imm ← simp imm
let M ← simp M
let some ⟨immN_width, immN⟩ ← getBitVecValue? immN.expr | return .continue
if h1 : ¬ (immN_width = 1) then
return .continue
else
let some ⟨imms_width, imms⟩ ← getBitVecValue? imms.expr | return .continue
if h2 : ¬ (imms_width = 6) then
return .continue
else
let some M ← Nat.fromExpr? M.expr | return .continue
have h1' : immN_width = 1 := by simp_all only [Decidable.not_not]
have h2' : imms_width = 6 := by simp_all only [Decidable.not_not]
return .done <|
toExpr (invalid_bit_masks
(BitVec.cast h1' immN)
(BitVec.cast h2' imms)
imm.expr.isTrue
M)
let some ⟨imms_width, imms⟩ ← getBitVecValue? imms.expr | return .continue
if h : immN_width = 1 ∧ imms_width = 6 then
let some M ← Nat.fromExpr? M.expr | return .continue
return .done <|
toExpr (invalid_bit_masks
(BitVec.cast (by simp_all only) immN)
(BitVec.cast (by simp_all only) imms)
imm.expr.isTrue
M)
else return .continue

theorem Nat.lt_one_iff {n : Nat} : n < 1 ↔ n = 0 := by
omega
Expand Down Expand Up @@ -442,8 +436,9 @@ theorem M_divisible_by_esize_of_valid_bit_masks (immN : BitVec 1) (imms : BitVec
-- https://kddnewton.com/2022/08/11/aarch64-bitmask-immediates.html
-- Arm Implementation:
-- https://developer.arm.com/documentation/ddi0602/2023-12/Shared-Pseudocode/aarch64-functions-bitmasks?lang=en#impl-aarch64.DecodeBitMasks.5
def decode_bit_masks (immN : BitVec 1) (imms : BitVec 6) (immr : BitVec 6)
(immediate : Bool) (M : Nat) : Option (BitVec M × BitVec M) :=
def decode_bit_masks (immN : BitVec 1) (imms immr : BitVec 6)
(immediate : Bool) (M : Nat) :
Option (BitVec M × BitVec M) :=
if h0 : invalid_bit_masks immN imms immediate M then none
else
let len := Option.get! $ highest_set_bit $ immN ++ ~~~imms
Expand All @@ -463,36 +458,27 @@ def decode_bit_masks (immN : BitVec 1) (imms : BitVec 6) (immr : BitVec 6)
some (BitVec.cast h wmask, BitVec.cast h tmask)

open Lean Meta Simp in
dsimproc [state_simp_rules] reduceDecodeBitMasks (decode_bit_masks _ _ _ _ _) := fun e => do
dsimproc [state_simp_rules] reduceDecodeBitMasks (decode_bit_masks _ _ _ _ _) :=
fun e => do
let_expr decode_bit_masks immN imms immr imm M ← e | return .continue
let immN ← simp immN
let imms ← simp imms
let immr ← simp immr
let imm ← simp imm
let M ← simp M
let some ⟨immN_width, immN⟩ ← getBitVecValue? immN.expr | return .continue
if h1 : ¬ (immN_width = 1) then
return .continue
else
let some ⟨imms_width, imms⟩ ← getBitVecValue? imms.expr | return .continue
if h2 : ¬ (imms_width = 6) then
return .continue
else
let some ⟨immr_width, immr⟩ ← getBitVecValue? immr.expr | return .continue
if h3 : ¬ (immr_width = 6) then
return .continue
else
let some M ← Nat.fromExpr? M.expr | return .continue
have h1' : immN_width = 1 := by simp_all only [Decidable.not_not]
have h2' : imms_width = 6 := by simp_all only [Decidable.not_not]
have h3' : immr_width = 6 := by simp_all only [Decidable.not_not]
return .done <|
toExpr (decode_bit_masks
(BitVec.cast h1' immN)
(BitVec.cast h2' imms)
(BitVec.cast h3' immr)
imm.expr.isTrue
M)
let some ⟨imms_width, imms⟩ ← getBitVecValue? imms.expr | return .continue
let some ⟨immr_width, immr⟩ ← getBitVecValue? immr.expr | return .continue
if h : immN_width = 1 ∧ imms_width = 6 ∧ immr_width = 6 then
let some M ← Nat.fromExpr? M.expr | return .continue
return .done <|
toExpr (decode_bit_masks
(BitVec.cast (by simp_all only) immN)
(BitVec.cast (by simp_all only) imms)
(BitVec.cast (by simp_all only) immr)
imm.expr.isTrue
M)
else return .continue

----------------------------------------------------------------------

Expand Down Expand Up @@ -664,7 +650,6 @@ structure ShiftInfo where
unsigned := true
round := false
accumulate := false
h : esize > 0
deriving DecidableEq, Repr

export ShiftInfo (esize elements shift unsigned round accumulate)
Expand Down
12 changes: 12 additions & 0 deletions Arm/Insts/DPSFP/Advanced_simd_modified_immediate.lean
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,18 @@ def AdvSIMDExpandImm (op : BitVec 1) (cmode : BitVec 4) (imm8 : BitVec 8) : BitV
lsb imm8 7 ++ ~~~(lsb imm8 6) ++
(replicate 8 $ lsb imm8 6) ++ extractLsb' 0 6 imm8 ++ BitVec.zero 48

open Lean Meta Simp in
dsimproc [state_simp_rules] reduceAdvSIMDExpandImm (AdvSIMDExpandImm _ _ _) := fun e => do
let_expr AdvSIMDExpandImm op cmode imm8 ← e | return .continue
let some ⟨op_n, op⟩ ← getBitVecValue? op | return .continue
let some ⟨cmode_n, cmode⟩ ← getBitVecValue? cmode | return .continue
let some ⟨imm8_n, imm8⟩ ← getBitVecValue? imm8 | return .continue
if h : op_n = 1 ∧ cmode_n = 4 ∧ imm8_n = 8 then
return .done <| toExpr (AdvSIMDExpandImm
(BitVec.cast (by simp_all only) op)
(BitVec.cast (by simp_all only) cmode)
(BitVec.cast (by simp_all only) imm8))
else return .continue

private theorem mul_div_norm_form_lemma (n m : Nat) (_h1 : 0 < m) (h2 : n ∣ m) :
(n * (m / n)) = n * m / n := by
Expand Down
10 changes: 2 additions & 8 deletions Arm/Insts/DPSFP/Advanced_simd_scalar_shift_by_immediate.lean
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,14 @@ def exec_shift_right_scalar
write_err (StateError.Illegal s!"Illegal {inst} encountered!") s
else
let esize := 8 <<< 3
have h : esize > 0 := by decide
let datasize := esize
let (info : ShiftInfo) :=
{ esize := esize,
elements := 1,
shift := (esize * 2) - (inst.immh ++ inst.immb).toNat,
unsigned := inst.U = 0b1#1,
round := (lsb inst.opcode 2) = 0b1#1,
accumulate := (lsb inst.opcode 1) = 0b1#1,
h := h
}
accumulate := (lsb inst.opcode 1) = 0b1#1 }
let result := shift_right_common info datasize inst.Rn inst.Rd s
-- State Update
let s := write_sfp datasize inst.Rd result s
Expand All @@ -46,14 +43,11 @@ def exec_shl_scalar
write_err (StateError.Illegal s!"Illegal {inst} encountered!") s
else
let esize := 8 <<< 3
have h : esize > 0 := by decide
let datasize := esize
let (info : ShiftInfo) :=
{ esize := esize,
elements := 1,
shift := (inst.immh ++ inst.immb).toNat - esize,
h := h
}
shift := (inst.immh ++ inst.immb).toNat - esize }
let result := shift_left_common info datasize inst.Rn s
-- State Update
let s := write_sfp datasize inst.Rd result s
Expand Down
12 changes: 2 additions & 10 deletions Arm/Insts/DPSFP/Advanced_simd_shift_by_immediate.lean
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,14 @@ def exec_shift_right_vector
else
let l := highest_set_bit inst.immh
let esize := 8 <<< l
have h : esize > 0 := by
simp only [esize]
apply zero_lt_shift_left_pos (by decide)
let datasize := 64 <<< inst.Q.toNat
let (info : ShiftInfo) :=
{ esize := esize,
elements := datasize / esize,
shift := (2 * esize) - (inst.immh ++ inst.immb).toNat,
unsigned := inst.U = 0b1#1,
round := (lsb inst.opcode 2) = 0b1#1,
accumulate := (lsb inst.opcode 1) = 0b1#1,
h := h }
accumulate := (lsb inst.opcode 1) = 0b1#1 }
let result := shift_right_common info datasize inst.Rn inst.Rd s
-- State Update
let s := write_sfp datasize inst.Rd result s
Expand All @@ -58,15 +54,11 @@ def exec_shl_vector
else
let l := highest_set_bit inst.immh
let esize := 8 <<< l
have h : esize > 0 := by
simp only [esize]
apply zero_lt_shift_left_pos (by decide)
let datasize := 64 <<< inst.Q.toNat
let (info : ShiftInfo) :=
{ esize := esize,
elements := datasize / esize,
shift := (inst.immh ++ inst.immb).toNat - esize,
h := h }
shift := (inst.immh ++ inst.immb).toNat - esize }
let result := shift_left_common info datasize inst.Rn s
-- State Update
let s := write_sfp datasize inst.Rd result s
Expand Down
4 changes: 3 additions & 1 deletion Arm/Insts/DPSFP/Advanced_simd_three_different.lean
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ Released under Apache 2.0 license as described in the file LICENSE.
Author(s): Yan Peng
-/
-- PMULL and PMULL2
-- Polynomial arithmetic over {0,1}: https://tiny.amazon.com/5h01fjm6/devearmdocuddi0cApplApplPoly
-- Polynomial arithmetic over {0,1}:
-- Ref.:
-- https://developer.arm.com/documentation/ddi0602/2024-09/SIMD-FP-Instructions/PMULL--PMULL2--Polynomial-multiply-long-?lang=en

import Arm.Decode
import Arm.State
Expand Down
92 changes: 83 additions & 9 deletions Proofs/AES-GCM/GCMGmultV8Sym.lean
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
/-
Copyright (c) 2024 Amazon.com, Inc. or its affiliates. All Rights Reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Author(s): Alex Keizer
Author(s): Alex Keizer, Shilpi Goel
-/
import Specs.GCMV8
import Tests.«AES-GCM».GCMGmultV8Program
import Tactics.Sym
import Tactics.Aggregate
import Tactics.StepThms
import Tactics.CSE
import Tactics.ClearNamed
import Arm.Memory.SeparateAutomation
import Arm.Syntax

Expand All @@ -16,9 +18,33 @@ open ArmStateNotation

#genStepEqTheorems gcm_gmult_v8_program

/-
xxx: GCMGmultV8 Xi HTable
-/
private theorem lsb_from_extractLsb'_of_append_self (x : BitVec 128) :
BitVec.extractLsb' 64 64 (BitVec.extractLsb' 64 128 (x ++ x)) =
BitVec.extractLsb' 0 64 x := by
bv_decide

private theorem msb_from_extractLsb'_of_append_self (x : BitVec 128) :
BitVec.extractLsb' 0 64 (BitVec.extractLsb' 64 128 (x ++ x)) =
BitVec.extractLsb' 64 64 x := by
bv_decide

theorem extractLsb'_zero_extractLsb'_of_le (h : len1 ≤ len2) :
BitVec.extractLsb' 0 len1 (BitVec.extractLsb' start len2 x) =
BitVec.extractLsb' start len1 x := by
apply BitVec.eq_of_getLsbD_eq; intro i
simp only [BitVec.getLsbD_extractLsb', Fin.is_lt,
decide_True, Nat.zero_add, Bool.true_and,
Bool.and_iff_right_iff_imp, decide_eq_true_eq]
omega

theorem extractLsb'_extractLsb'_zero_of_le (h : start + len1 ≤ len2):
BitVec.extractLsb' start len1 (BitVec.extractLsb' 0 len2 x) =
BitVec.extractLsb' start len1 x := by
apply BitVec.eq_of_getLsbD_eq; intro i
simp only [BitVec.getLsbD_extractLsb', Fin.is_lt,
decide_True, Nat.zero_add, Bool.true_and,
Bool.and_iff_right_iff_imp, decide_eq_true_eq]
omega

set_option pp.deepTerms false in
set_option pp.deepTerms.threshold 50 in
Expand All @@ -29,10 +55,10 @@ theorem gcm_gmult_v8_program_run_27 (s0 sf : ArmState)
(h_s0_pc : read_pc s0 = gcm_gmult_v8_program.min)
(h_s0_sp_aligned : CheckSPAlignment s0)
(h_Xi : Xi = s0[read_gpr 64 0#5 s0, 16])
(h_HTable : HTable = s0[read_gpr 64 1#5 s0, 256])
(h_HTable : HTable = s0[read_gpr 64 1#5 s0, 32])
(h_mem_sep : Memory.Region.pairwiseSeparate
[(read_gpr 64 0#5 s0, 16),
(read_gpr 64 1#5 s0, 256)])
(read_gpr 64 1#5 s0, 32)])
(h_run : sf = run gcm_gmult_v8_program.length s0) :
-- The final state is error-free.
read_err sf = .None ∧
Expand All @@ -42,8 +68,11 @@ theorem gcm_gmult_v8_program_run_27 (s0 sf : ArmState)
CheckSPAlignment sf ∧
-- The final state returns to the address in register `x30` in `s0`.
read_pc sf = r (StateField.GPR 30#5) s0 ∧
-- (TODO) Delete the following conjunct because it is covered by the
-- MEM_UNCHANGED_EXCEPT frame condition. We keep it around because it
-- exposes the issue with `simp_mem` that @bollu will fix.
-- HTable is unmodified.
sf[read_gpr 64 1#5 s0, 256] = HTable ∧
sf[read_gpr 64 1#5 s0, 32] = HTable ∧
-- Frame conditions.
-- Note that the following also covers that the Xi address in .GPR 0
-- is unmodified.
Expand All @@ -52,8 +81,11 @@ theorem gcm_gmult_v8_program_run_27 (s0 sf : ArmState)
.SFP 21, .PC]
(sf, s0) ∧
-- Memory frame condition.
MEM_UNCHANGED_EXCEPT [(r (.GPR 0) s0, 128)] (sf, s0) := by
simp_all only [state_simp_rules, -h_run] -- prelude
MEM_UNCHANGED_EXCEPT [(r (.GPR 0) s0, 16)] (sf, s0) ∧
sf[r (.GPR 0) s0, 16] = GCMV8.GCMGmultV8_alt (HTable.extractLsb' 0 128) Xi := by
-- Prelude
simp_all only [state_simp_rules, -h_run]
simp only [Nat.reduceMul] at Xi HTable
simp (config := {ground := true}) only at h_s0_pc
-- ^^ Still needed, because `gcm_gmult_v8_program.min` is somehow
-- unable to be reflected
Expand Down Expand Up @@ -94,4 +126,46 @@ theorem gcm_gmult_v8_program_run_27 (s0 sf : ArmState)
simp_mem (config := { useOmegaToClose := false })
-- Aggregate the memory (non)effects.
simp only [*]
· clear_named [h_s, stepi_]
clear s1 s2 s3 s4 s5 s6 s7 s8 s9 s10 s11 s12 s13 s14 s15 s16 s17 s18 s19 s20 s21 s22 s23 s24 s25 s26
-- Simplifying the LHS
have h_HTable_low :
Memory.read_bytes 16 (r (StateField.GPR 1#5) s0) s0.mem = HTable.extractLsb' 0 128 := by
-- (FIXME @bollu) use `simp_mem` instead of the rw below.
rw [@Memory.read_bytes_eq_extractLsBytes_sub_of_mem_subset'
32 (r (StateField.GPR 1#5) s0) HTable (r (StateField.GPR 1#5) s0) 16 _ h_HTable.symm]
· simp only [Nat.reduceMul, BitVec.extractLsBytes, Nat.sub_self, Nat.zero_mul]
· simp_mem
have h_HTable_high :
(Memory.read_bytes 16 (r (StateField.GPR 1#5) s0 + 16#64) s0.mem) = HTable.extractLsb' 128 128 := by
-- (FIXME @bollu) use `simp_mem` instead of the rw below.
rw [@Memory.read_bytes_eq_extractLsBytes_sub_of_mem_subset'
32 (r (StateField.GPR 1#5) s0) HTable (r (StateField.GPR 1#5) s0 + 16#64) 16 _ h_HTable.symm]
repeat sorry
simp only [h_HTable_high, h_HTable_low, ←h_Xi]
/-
simp/ground below to reduce
(BitVec.extractLsb' 0 64
(shift_left_common_aux 0
{ esize := 64, elements := 2, shift := 57, unsigned := true, round := false,
accumulate := false }
300249147283180997173565830086854304225#128 0#128))
-/
simp (config := {ground := true}) only
simp only [msb_from_extractLsb'_of_append_self,
lsb_from_extractLsb'_of_append_self,
BitVec.partInstall]
-- (FIXME @bollu) cse leaves the goal unchanged here, quietly, likely due to
-- subexpressions occurring in dep. contexts. Maybe a message here would be helpful.
generalize h_Xi_rev : rev_vector 128 64 8 Xi _ _ _ _ _ = Xi_rev
-- Simplifying the RHS
simp only [←h_HTable, GCMV8.GCMGmultV8_alt,
GCMV8.lo, GCMV8.hi,
GCMV8.gcm_polyval]
repeat rw [extractLsb'_zero_extractLsb'_of_le (by decide)]
repeat rw [extractLsb'_extractLsb'_zero_of_le (by decide)]

sorry
done

end GCMGmultV8Program
1 change: 0 additions & 1 deletion Proofs/Popcount32.lean
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ def popcount32_program : Program :=

#genStepEqTheorems popcount32_program

set_option trace.simp_mem.info true in
theorem popcount32_sym_meets_spec (s0 sf : ArmState)
(h_s0_pc : read_pc s0 = 0x4005b4#64)
(h_s0_program : s0.program = popcount32_program)
Expand Down
15 changes: 12 additions & 3 deletions Specs/GCMV8.lean
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,10 @@ def refpoly : BitVec 129 := 0x1C2000000000000000000000000000001#129
private def gcm_init_H (H : BitVec 128) : BitVec 128 :=
pmod (H ++ 0b0#1) refpoly (by omega)

private def gcm_polyval_mul (x : BitVec 128) (y : BitVec 128) : BitVec 256 :=
def gcm_polyval_mul (x : BitVec 128) (y : BitVec 128) : BitVec 256 :=
0b0#1 ++ pmult x y

private def gcm_polyval_red (x : BitVec 256) : BitVec 128 :=
def gcm_polyval_red (x : BitVec 256) : BitVec 128 :=
reverse $ pmod (reverse x) irrepoly (by omega)

/--
Expand All @@ -146,7 +146,7 @@ private def gcm_polyval_red (x : BitVec 256) : BitVec 128 :=
"A New Interpretation for the GHASH Authenticator of AES-GCM"
2. Lemma: reverse (pmult x y) = pmult (reverse x) (reverse y)
-/
private def gcm_polyval (x : BitVec 128) (y : BitVec 128) : BitVec 128 :=
def gcm_polyval (x : BitVec 128) (y : BitVec 128) : BitVec 128 :=
GCMV8.gcm_polyval_red $ GCMV8.gcm_polyval_mul x y

/-- GCMInitV8 specification:
Expand Down Expand Up @@ -203,6 +203,15 @@ def GCMGmultV8 (H : BitVec 128) (Xi : List (BitVec 8)) (h : 8 * Xi.length = 128)
let H := (lo H) ++ (hi H)
split (GCMV8.gcm_polyval H (BitVec.cast h (BitVec.flatten Xi))) 8 (by omega)

/-- Alternative GCMGmultV8 specification that does not use lists:
H : BitVec 128 -- the first element in Htable, not the initial H input to GCMInitV8
Xi : BitVec 128 -- current hash value
output : BitVec 128 -- next hash value
-/
def GCMGmultV8_alt (H : BitVec 128) (Xi : BitVec 128) : BitVec 128 :=
let H := (lo H) ++ (hi H)
gcm_polyval H Xi

set_option maxRecDepth 8000 in
example : GCMGmultV8 0x1099f4b39468565ccdd297a9df145877#128
[ 0x10#8, 0x54#8, 0x43#8, 0xb0#8, 0x2c#8, 0x4b#8, 0x1f#8, 0x24#8,
Expand Down

0 comments on commit ebf3f06

Please sign in to comment.