diff --git a/framework/base/src/contract_base/contract_base_trait.rs b/framework/base/src/contract_base/contract_base_trait.rs index 342b9ea61d..e6543cfe09 100644 --- a/framework/base/src/contract_base/contract_base_trait.rs +++ b/framework/base/src/contract_base/contract_base_trait.rs @@ -19,9 +19,7 @@ pub trait ContractBase: Sized { /// Gateway into the call value retrieval functionality. /// The payment annotations should normally be the ones to handle this, /// but the developer is also given direct access to the API. - fn call_value(&self) -> CallValueWrapper { - CallValueWrapper::new() - } + fn call_value(&self) -> CallValueWrapper; /// Gateway to the functionality related to sending transactions from the current contract. #[inline] diff --git a/framework/base/src/contract_base/universal_contract_obj.rs b/framework/base/src/contract_base/universal_contract_obj.rs index d0624c0526..3f6d0205ca 100644 --- a/framework/base/src/contract_base/universal_contract_obj.rs +++ b/framework/base/src/contract_base/universal_contract_obj.rs @@ -1,8 +1,22 @@ -use core::marker::PhantomData; +use core::{cell::UnsafeCell, marker::PhantomData}; -use crate::api::VMApi; +use crate::api::{const_handles, RawHandle, VMApi}; -use super::ContractBase; +use super::{CallValueWrapper, ContractBase}; + +pub struct ContractObjData { + pub call_value_egld_handle: RawHandle, + pub call_value_multi_esdt_handle: RawHandle, +} + +impl Default for ContractObjData { + fn default() -> Self { + ContractObjData { + call_value_egld_handle: const_handles::UNINITIALIZED_HANDLE, + call_value_multi_esdt_handle: const_handles::UNINITIALIZED_HANDLE, + } + } +} /// A unique empty structure that automatically implements all smart contract traits. /// @@ -18,8 +32,12 @@ where A: VMApi, { _phantom: PhantomData, + pub data: UnsafeCell, } +unsafe impl Sync for UniversalContractObj where A: VMApi {} +unsafe impl Send for UniversalContractObj where A: VMApi {} + impl UniversalContractObj where A: VMApi, @@ -27,6 +45,7 @@ where pub fn new() -> Self { Self { _phantom: PhantomData, + data: UnsafeCell::new(ContractObjData::default()), } } } @@ -45,4 +64,8 @@ where A: VMApi, { type Api = A; + + fn call_value(&self) -> CallValueWrapper<'_, Self::Api> { + CallValueWrapper::new(&self.data) + } } diff --git a/framework/base/src/contract_base/wrappers/call_value_wrapper.rs b/framework/base/src/contract_base/wrappers/call_value_wrapper.rs index 3fb85e1658..be66b7ad7c 100644 --- a/framework/base/src/contract_base/wrappers/call_value_wrapper.rs +++ b/framework/base/src/contract_base/wrappers/call_value_wrapper.rs @@ -1,10 +1,11 @@ -use core::marker::PhantomData; +use core::{cell::UnsafeCell, marker::PhantomData}; use crate::{ api::{ const_handles, use_raw_handle, CallValueApi, CallValueApiImpl, ErrorApi, ErrorApiImpl, - HandleConstraints, ManagedTypeApi, StaticVarApiImpl, + ManagedTypeApi, }, + contract_base::ContractObjData, err_msg, types::{ BigUint, ConstDecimals, EgldOrEsdtTokenIdentifier, EgldOrEsdtTokenPayment, @@ -13,35 +14,34 @@ use crate::{ }, }; -#[derive(Default)] -pub struct CallValueWrapper +pub struct CallValueWrapper<'a, A> where A: CallValueApi + ErrorApi + ManagedTypeApi, { _phantom: PhantomData, + pub data_cell: &'a UnsafeCell, } -impl CallValueWrapper +impl<'a, A> CallValueWrapper<'a, A> where A: CallValueApi + ErrorApi + ManagedTypeApi, { - pub fn new() -> Self { + pub fn new(data_cell: &'a UnsafeCell) -> Self { CallValueWrapper { _phantom: PhantomData, + data_cell, } } /// Retrieves the EGLD call value from the VM. /// Will return 0 in case of an ESDT transfer (cannot have both EGLD and ESDT transfer simultaneously). pub fn egld_value(&self) -> ManagedRef<'static, A, BigUint> { - let mut call_value_handle: A::BigIntHandle = - use_raw_handle(A::static_var_api_impl().get_call_value_egld_handle()); - if call_value_handle == const_handles::UNINITIALIZED_HANDLE { - call_value_handle = use_raw_handle(const_handles::CALL_VALUE_EGLD); - A::static_var_api_impl().set_call_value_egld_handle(call_value_handle.get_raw_handle()); - A::call_value_api_impl().load_egld_value(call_value_handle.clone()); + let data = unsafe { &mut *self.data_cell.get() }; + if data.call_value_egld_handle == const_handles::UNINITIALIZED_HANDLE { + data.call_value_egld_handle = const_handles::CALL_VALUE_EGLD; + A::call_value_api_impl().load_egld_value(use_raw_handle(data.call_value_egld_handle)); } - unsafe { ManagedRef::wrap_handle(call_value_handle) } + unsafe { ManagedRef::wrap_handle(use_raw_handle(data.call_value_egld_handle)) } } /// Returns the EGLD call value from the VM as ManagedDecimal @@ -55,15 +55,13 @@ where /// Will return 0 results if nothing was transfered, or just EGLD. /// Fully managed underlying types, very efficient. pub fn all_esdt_transfers(&self) -> ManagedRef<'static, A, ManagedVec>> { - let mut call_value_handle: A::ManagedBufferHandle = - use_raw_handle(A::static_var_api_impl().get_call_value_multi_esdt_handle()); - if call_value_handle == const_handles::UNINITIALIZED_HANDLE { - call_value_handle = use_raw_handle(const_handles::CALL_VALUE_MULTI_ESDT); - A::static_var_api_impl() - .set_call_value_multi_esdt_handle(call_value_handle.get_raw_handle()); - A::call_value_api_impl().load_all_esdt_transfers(call_value_handle.clone()); + let data = unsafe { &mut *self.data_cell.get() }; + if data.call_value_multi_esdt_handle == const_handles::UNINITIALIZED_HANDLE { + data.call_value_multi_esdt_handle = const_handles::CALL_VALUE_MULTI_ESDT; + A::call_value_api_impl() + .load_all_esdt_transfers(use_raw_handle(data.call_value_multi_esdt_handle)); } - unsafe { ManagedRef::wrap_handle(call_value_handle) } + unsafe { ManagedRef::wrap_handle(use_raw_handle(data.call_value_multi_esdt_handle)) } } /// Verify and casts the received multi ESDT transfer in to an array. diff --git a/framework/base/src/io/call_value_init.rs b/framework/base/src/io/call_value_init.rs index 93c1bd4f67..6f77c474c3 100644 --- a/framework/base/src/io/call_value_init.rs +++ b/framework/base/src/io/call_value_init.rs @@ -3,8 +3,8 @@ use crate::{ const_handles, use_raw_handle, CallValueApi, CallValueApiImpl, ErrorApi, ErrorApiImpl, ManagedBufferApiImpl, ManagedTypeApi, }, - contract_base::CallValueWrapper, err_msg, + imports::ContractBase, types::{ BigUint, EgldOrEsdtTokenIdentifier, EsdtTokenPayment, ManagedRef, ManagedType, ManagedVec, }, @@ -38,11 +38,12 @@ where /// Called initially in the generated code whenever `#[payable("")]` annotation is provided. /// /// Was never really used, expected to be deprecated/removed. -pub fn payable_single_specific_token(expected_tokend_identifier: &str) +pub fn payable_single_specific_token(obj: &O, expected_tokend_identifier: &str) where A: CallValueApi + ManagedTypeApi + ErrorApi, + O: ContractBase, { - let transfers = CallValueWrapper::::new().all_esdt_transfers(); + let transfers = obj.call_value().all_esdt_transfers(); if transfers.len() != 1 { A::error_api_impl().signal_error(err_msg::SINGLE_ESDT_EXPECTED.as_bytes()); } @@ -62,37 +63,39 @@ where } /// Initializes an argument annotated with `#[payment_amount]` or `#[payment]`. -pub fn arg_payment_amount() -> BigUint +pub fn arg_payment_amount(obj: &O) -> BigUint where A: CallValueApi + ManagedTypeApi, + O: ContractBase, { - CallValueWrapper::::new().egld_or_single_esdt().amount + obj.call_value().egld_or_single_esdt().amount } /// Initializes an argument annotated with `#[payment_token]`. -pub fn arg_payment_token() -> EgldOrEsdtTokenIdentifier +pub fn arg_payment_token(obj: &O) -> EgldOrEsdtTokenIdentifier where A: CallValueApi + ManagedTypeApi, + O: ContractBase, { - CallValueWrapper::::new() - .egld_or_single_esdt() - .token_identifier + obj.call_value().egld_or_single_esdt().token_identifier } /// Initializes an argument annotated with `#[payment_nonce]`. -pub fn arg_payment_nonce() -> u64 +pub fn arg_payment_nonce(obj: &O) -> u64 where A: CallValueApi + ManagedTypeApi, + O: ContractBase, { - CallValueWrapper::::new() - .egld_or_single_esdt() - .token_nonce + obj.call_value().egld_or_single_esdt().token_nonce } /// Initializes an argument annotated with `#[payment_multi]`. -pub fn arg_payment_multi() -> ManagedRef<'static, A, ManagedVec>> +pub fn arg_payment_multi( + obj: &O, +) -> ManagedRef<'static, A, ManagedVec>> where A: CallValueApi + ManagedTypeApi, + O: ContractBase, { - CallValueWrapper::::new().all_esdt_transfers() + obj.call_value().all_esdt_transfers() } diff --git a/framework/derive/src/generate/payable_gen.rs b/framework/derive/src/generate/payable_gen.rs index 0e67a442a6..aed25f01e3 100644 --- a/framework/derive/src/generate/payable_gen.rs +++ b/framework/derive/src/generate/payable_gen.rs @@ -31,7 +31,7 @@ fn call_value_init_snippet(mpm: MethodPayableMetadata) -> proc_macro2::TokenStre }, MethodPayableMetadata::SingleEsdtToken(token_identifier) => { quote! { - multiversx_sc::io::call_value_init::payable_single_specific_token::(#token_identifier); + multiversx_sc::io::call_value_init::payable_single_specific_token::(&*self, #token_identifier); } }, MethodPayableMetadata::AnyToken => { @@ -51,7 +51,7 @@ fn opt_payment_arg_snippet( .map(|arg| { let pat = &arg.pat; quote! { - let #pat = multiversx_sc::io::call_value_init::#init_fn_name::(); + let #pat = multiversx_sc::io::call_value_init::#init_fn_name::(&*self); } }) .unwrap_or_default() diff --git a/framework/derive/src/generate/snippets.rs b/framework/derive/src/generate/snippets.rs index 64c590ad5f..76db98872e 100644 --- a/framework/derive/src/generate/snippets.rs +++ b/framework/derive/src/generate/snippets.rs @@ -13,6 +13,10 @@ pub fn impl_contract_base() -> proc_macro2::TokenStream { A: multiversx_sc::api::VMApi, { type Api = A; + + fn call_value(&self) -> multiversx_sc::contract_base::CallValueWrapper<'_, Self::Api> { + multiversx_sc::contract_base::CallValueWrapper::new(&self.0.data) + } } } } diff --git a/framework/scenario/tests/contract_without_macros.rs b/framework/scenario/tests/contract_without_macros.rs index d0da5fac52..56f12102d1 100644 --- a/framework/scenario/tests/contract_without_macros.rs +++ b/framework/scenario/tests/contract_without_macros.rs @@ -604,6 +604,10 @@ mod sample_adder { A: multiversx_sc::api::VMApi, { type Api = A; + + fn call_value(&self) -> multiversx_sc::contract_base::CallValueWrapper<'_, Self::Api> { + multiversx_sc::contract_base::CallValueWrapper::new(&self.0.data) + } } impl super::module_1::AutoImpl for ContractObj where A: multiversx_sc::api::VMApi {}