Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Set array sizes from bounds #2042

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion fiat-amd64/gentest.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def removeprefix(s, prefix):

asm_op_names = OrderedDict()

regex = re.compile(r'fiat_(?P<name>[^_]+(_(solinas|montgomery|dettman))?)_(?P<op>(carry_)?(square|mul))')
regex = re.compile(r'fiat_(?P<name>[^_]+(_(solinas|montgomery|dettman))?)_(?P<op>(carry_)?(square|mul|from_bytes|to_bytes|add|sub|opp))')
for dirname in directories:
m = regex.match(os.path.basename(dirname))
if m:
Expand Down
81 changes: 48 additions & 33 deletions src/Assembly/Equivalence.v
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ Inductive EquivalenceCheckingError :=
| Internal_error_output_load_failed (_ : option Symbolic.error) (_ : list ((REG + idx) + idx)) (_ : symbolic_state)
| Internal_error_extra_input_arguments (t : API.type) (unused_arguments : list (idx + list idx))
| Internal_error_lingering_memory (_ : symbolic_state)
| Internal_error_LoadOutputs_length_mismatch (outputaddrs : list ((REG + idx) + idx)) (output_types : list (option nat))
| Internal_error_LoadOutputs_length_mismatch (outputaddrs : list ((REG + idx) + idx)) (output_types : list (N * (option nat)))
| Not_enough_registers (num_given num_extra_needed : nat)
| Registers_too_narrow (bad_reg : list REG)
| Callee_saved_registers_too_narrow (bad_reg : list REG)
Expand All @@ -233,6 +233,9 @@ Inductive EquivalenceCheckingError :=
| Expected_const_in_reference_code (_ : idx)
| Expected_power_of_two (w : N) (_ : idx)
| Unknown_array_length (t : base.type)
| Unknown_array_bounds {t : base.type} (bs : list (ZRange.type.base.option.interp t))
| Unknown_scalar_size (t : base.type)
| Invalid_zero_size_array (t : base.type)
| Registers_not_saved (regs : list (REG * (idx (* before *) * idx (* after *)))) (_ : symbolic_state)
| Invalid_arrow_type (t : API.type)
| Invalid_argument_type (t : API.type)
Expand Down Expand Up @@ -266,6 +269,9 @@ Definition symbolic_state_of_EquivalenceCheckingError (e : EquivalenceCheckingEr
| Expected_const_in_reference_code _
| Expected_power_of_two _ _
| Unknown_array_length _
| Unknown_scalar_size _
| Unknown_array_bounds _ _
| Invalid_zero_size_array _
| Invalid_arrow_type _
| Invalid_argument_type _
| Invalid_return_type _
Expand Down Expand Up @@ -708,6 +714,9 @@ Global Instance show_lines_EquivalenceCheckingError : ShowLines EquivalenceCheck
| Invalid_return_type t
=> ["Invalid type for return: " ++ show t]%string
| Unknown_array_length t => ["Unknown array length of type " ++ show t ++ "."]%string
| Unknown_array_bounds t bs => ["Unknown array bounds of type " ++ show t ++ ": " ++ show bs]%string
| Unknown_scalar_size t => ["Unknown scalar size of type " ++ show t ++ "."]%string
| Invalid_zero_size_array t => ["Array of type " ++ show t ++ " has zero size."]%string
| Invalid_arrow_type t => ["Invalid higher order function involving the type " ++ show t ++ "."]%string
| Invalid_higher_order_application var s d f x
=> let __ := @Compilers.ToString.PHOAS.expr.partially_show_expr in
Expand Down Expand Up @@ -780,7 +789,7 @@ Definition RevealWidth (i : idx) : symexM N :=
then symex_return w
else symex_error (Expected_power_of_two s i).

Definition type_spec := list (option nat). (* list of array lengths; None means not an array *)
Definition type_spec := list (N * option nat). (* list of element size in bytes * length; None means not an array *)

(** Convert PHOAS info about types and argument bounds into a simplified specification *)
Fixpoint simplify_base_type
Expand All @@ -789,18 +798,24 @@ Fixpoint simplify_base_type
:= match t return ZRange.type.base.option.interp t -> _ with
| base.type.unit
=> fun 'tt => Success []
| base.type.type_base base.type.Z
=> fun _ => Success [None]
| (base.type.type_base base.type.Z) as t
=> fun r
=> match ZRange.type.base.option.lift_Some r with
| Some r => Success [(Z.to_N (ZRange.type.base.bitwidth r), None)]
| None => Error (Unknown_scalar_size t)
end
| base.type.prod A B
=> fun '(bA, bB)
=> (vA <- simplify_base_type A bA;
vB <- simplify_base_type B bB;
Success (vA ++ vB))
| base.type.list (base.type.type_base base.type.Z)
| (base.type.list (base.type.type_base base.type.Z as tZ)) as t
=> fun b
=> match b with
| None => Error (Unknown_array_length t)
| Some bs => Success [Some (List.length bs)]
=> match b, option_map ZRange.type.base.bitwidth (ZRange.type.base.option.lift_Some b) with
| None, _ => Error (Unknown_array_length t)
| Some b, None => Error (@Unknown_array_bounds tZ b)
| Some nil, _ | _, Some nil => Error (Invalid_zero_size_array t)
| Some _, Some bs => Success [(Z.to_N (List.fold_right Z.max 0%Z bs), Some (List.length bs))]
end
| base.type.type_base _
| base.type.option _
Expand Down Expand Up @@ -829,39 +844,39 @@ Fixpoint simplify_input_type
Definition build_inputarray {descr:description} (len : nat) : dag.M (list idx) :=
List.foldmap (fun _ => merge_fresh_symbol) (List.seq 0 len).

Fixpoint build_inputs {descr:description} (types : type_spec) : dag.M (list (idx + list idx))
Fixpoint build_inputs {descr:description} (types : type_spec) : dag.M (list (N * (idx + list idx)))
:= match types with
| [] => dag.ret []
| None :: tys
| (sz, None) :: tys
=> (idx <- merge_fresh_symbol;
rest <- build_inputs tys;
dag.ret (inl idx :: rest))
| Some len :: tys
dag.ret ((sz, inl idx) :: rest))
| (sz, Some len) :: tys
=> (idxs <- build_inputarray len;
rest <- build_inputs tys;
dag.ret (inr idxs :: rest))
dag.ret ((sz, inr idxs) :: rest))
end%dagM.

(* we factor this out so that conversion is not slow when proving things about this *)
Definition compute_array_address {opts : symbolic_options_computed_opt} {descr:description} (base : idx) (i : nat)
:= (offset <- Symbolic.App (zconst 64%N (8 * Z.of_nat i), nil);
Definition compute_array_address {opts : symbolic_options_computed_opt} {descr:description} {bytes_per_element : N} (base : idx) (i : nat)
:= (offset <- Symbolic.App (zconst 64%N (Z.of_N bytes_per_element * Z.of_nat i), nil);
Symbolic.App (add 64%N, [base; offset]))%x86symex.

Definition build_merge_array_addresses {opts : symbolic_options_computed_opt} {descr:description} (base : idx) (items : list idx) : M (list idx)
Definition build_merge_array_addresses {opts : symbolic_options_computed_opt} {descr:description} {bytes_per_element : N} (base : idx) (items : list idx) : M (list idx)
:= mapM (fun '(i, idx) =>
(addr <- compute_array_address base i;
(addr <- compute_array_address (bytes_per_element:=bytes_per_element) base i;
(fun s => Success (addr, update_mem_with s (cons (addr,idx)))))
)%x86symex (List.enumerate items).

Fixpoint build_merge_base_addresses {opts : symbolic_options_computed_opt} {descr:description} {dereference_scalar:bool} (items : list (idx + list idx)) (reg_available : list REG) : M (list ((REG + idx) + idx))
Fixpoint build_merge_base_addresses {opts : symbolic_options_computed_opt} {descr:description} {dereference_scalar:bool} (items : list (N * (idx + list idx))) (reg_available : list REG) : M (list ((REG + idx) + idx))
:= match items, reg_available with
| [], _ | _, [] => Symbolic.ret []
| inr idxs :: xs, r :: reg_available
| (sz, inr idxs) :: xs, r :: reg_available
=> (base <- SetRegFresh r; (* note: overwrites initial value *)
addrs <- build_merge_array_addresses base idxs; (* note: overwrites initial value *)
addrs <- build_merge_array_addresses (bytes_per_element:=sz) base idxs; (* note: overwrites initial value *)
rest <- build_merge_base_addresses (dereference_scalar:=dereference_scalar) xs reg_available;
Symbolic.ret (inr base :: rest))
| inl idx :: xs, r :: reg_available =>
| (_sz, inl idx) :: xs, r :: reg_available =>
(addr <- (if dereference_scalar
then
(addr <- SetRegFresh r;
Expand Down Expand Up @@ -1273,10 +1288,10 @@ Definition symex_PHOAS_PHOAS {opts : symbolic_options_computed_opt} {t} (expr :
Definition symex_PHOAS
{opts : symbolic_options_computed_opt}
{t} (expr : API.Expr t)
(inputs : list (idx + list idx))
(inputs : list (N * (idx + list idx)))
(d : dag)
: ErrorT EquivalenceCheckingError (list (idx + list idx) * dag)
:= (input_var_data <- build_input_var t inputs;
:= (input_var_data <- build_input_var t (List.map snd inputs);
let '(input_var_data, unused_inputs) := input_var_data in
_ <- (if (List.length unused_inputs =? 0)%nat
then Success tt
Expand Down Expand Up @@ -1307,18 +1322,18 @@ Definition build_merge_stack_placeholders {opts : symbolic_options_computed_opt}
: M idx
:= (stack_placeholders <- lift_dag (build_inputarray stack_size);
stack_base <- compute_stack_base stack_size;
_ <- build_merge_array_addresses stack_base stack_placeholders;
_ <- build_merge_array_addresses (bytes_per_element:=8) stack_base stack_placeholders;
ret stack_base)%x86symex.

Definition LoadArray {opts : symbolic_options_computed_opt} {descr:description} (base : idx) (len : nat) : M (list idx)
Definition LoadArray {opts : symbolic_options_computed_opt} {descr:description} {bytes_per_element : N} (base : idx) (len : nat) : M (list idx)
:= mapM (fun i =>
(addr <- compute_array_address base i;
(addr <- compute_array_address (bytes_per_element:=bytes_per_element) base i;
Remove64 addr)%x86symex)
(seq 0 len).

Definition LoadOutputs_internal {opts : symbolic_options_computed_opt} {descr:description} {dereference_scalar:bool} (outputaddrs : list ((REG + idx) + idx)) (output_types : type_spec)
: M (list (idx + list idx))
:= (mapM (fun '(ocells, spec) =>
:= (mapM (fun '(ocells, (sz, spec)) =>
match ocells, spec with
| inl _, Some _ | inr _, None => err (error.unsupported_memory_access_size 0)
| inl addr, None
Expand All @@ -1331,7 +1346,7 @@ Definition LoadOutputs_internal {opts : symbolic_options_computed_opt} {descr:de
end;
ret (inl v))
| inr base, Some len
=> (v <- LoadArray base len;
=> (v <- LoadArray (bytes_per_element:=sz) base len;
ret (inr v))
end) (List.combine outputaddrs output_types))%N%x86symex.

Expand All @@ -1358,7 +1373,7 @@ Definition symex_asm_func_M
{dereference_output_scalars:bool}
(callee_saved_registers : list REG)
(output_types : type_spec) (stack_size : nat)
(inputs : list (idx + list idx)) (reg_available : list REG) (asm : Lines)
(inputs : list (N * (idx + list idx))) (reg_available : list REG) (asm : Lines)
: M (ErrorT EquivalenceCheckingError (list (idx + list idx)))
:= (output_placeholders <- lift_dag (build_inputs (descr:=Build_description "output_placeholders" true) output_types);
let n_outputs := List.length output_placeholders in
Expand All @@ -1369,12 +1384,12 @@ Definition symex_asm_func_M
initial_register_values <- mapM (GetReg (descr:=Build_description "initial_register_values" true)) callee_saved_registers;
_ <- SymexLines asm;
final_register_values <- mapM (GetReg (descr:=Build_description "final_register_values" true)) callee_saved_registers;
_ <- LoadArray (descr:=Build_description "load final stack" true) stack_base stack_size;
_ <- LoadArray (descr:=Build_description "load final stack" true) (bytes_per_element:=8) stack_base stack_size;
let unsaved_registers : list (REG * (idx * idx)) := List.filter (fun '(r, (init, final)) => negb (init =? final)%N) (List.combine callee_saved_registers (List.combine initial_register_values final_register_values)) in
asm_output <- LoadOutputs (descr:=Build_description "asm_output" true) (dereference_scalar:=dereference_output_scalars) outputaddrs output_types;
(* also load inputs, for the sake of the proof *)
(* reconstruct input types *)
let input_types := List.map (fun v => match v with inl _ => None | inr ls => Some (List.length ls) end) inputs in
let input_types := List.map (fun '(sz, v) => (sz, match v with inl _ => None | inr ls => Some (List.length ls) end)) inputs in
asm_input <- LoadOutputs (descr:=Build_description "asm_input <- LoadOutputs" true) (dereference_scalar:=dereference_input_scalars) inputaddrs input_types;
(fun s => Success
(match asm_output, asm_input, unsaved_registers, s.(symbolic_mem_state) with
Expand All @@ -1394,7 +1409,7 @@ Definition symex_asm_func
{opts : symbolic_options_computed_opt}
{dereference_output_scalars:bool}
(d : dag) (callee_saved_registers : list REG) (output_types : type_spec) (stack_size : nat)
(inputs : list (idx + list idx)) (reg_available : list REG) (asm : Lines)
(inputs : list (N * (idx + list idx))) (reg_available : list REG) (asm : Lines)
: ErrorT EquivalenceCheckingError (list (idx + list idx) * symbolic_state)
:= let num_reg_given := List.length reg_available in
let num_reg_needed := List.length inputs + List.length output_types in
Expand Down Expand Up @@ -1452,7 +1467,7 @@ Section check_equivalence.
Local Notation map_err_None v := (ErrorT.map_error (fun e => (None, e)) v).
Local Notation map_err_Some label v := (ErrorT.map_error (fun e => (Some label, e)) v).

Definition map_symex_asm (inputs : list (idx + list idx)) (output_types : type_spec) (d : dag)
Definition map_symex_asm (inputs : list (N * (idx + list idx))) (output_types : type_spec) (d : dag)
: ErrorT
(option (string (* fname *) * Lines (* asm lines *)) * EquivalenceCheckingError)
(list ((string (* fname *) * Lines (* asm lines *)) * (list (idx + list idx) * symbolic_state))) :=
Expand Down
Loading
Loading