diff --git a/deps/td-shim b/deps/td-shim index 3e604bfb..3f1a733e 160000 --- a/deps/td-shim +++ b/deps/td-shim @@ -1 +1 @@ -Subproject commit 3e604bfbd6118dcffc04b5a8147687053cd843cd +Subproject commit 3f1a733e3e938487e0ad3ac3fcce9442bb25e033 diff --git a/deps/td-shim-AzCVMEmu/tdx-tdcall/src/lib.rs b/deps/td-shim-AzCVMEmu/tdx-tdcall/src/lib.rs index 84e90353..0ced393f 100644 --- a/deps/td-shim-AzCVMEmu/tdx-tdcall/src/lib.rs +++ b/deps/td-shim-AzCVMEmu/tdx-tdcall/src/lib.rs @@ -68,6 +68,7 @@ pub mod tdx { tdvmcall_wrmsr, // Re-export types TdxDigest, + TargetTdUuid, }; // Export emulated functions @@ -75,7 +76,7 @@ pub mod tdx { tdcall_extend_rtmr, tdcall_servtd_rd, tdcall_servtd_wr, tdcall_sys_rd, tdcall_sys_wr, tdvmcall_get_quote, tdvmcall_migtd_receive_sync as tdvmcall_migtd_receive, tdvmcall_migtd_reportstatus, tdvmcall_migtd_send_sync as tdvmcall_migtd_send, - tdvmcall_migtd_waitforrequest, tdvmcall_setup_event_notify, + tdvmcall_migtd_waitforrequest, tdvmcall_setup_event_notify, tdcall_vm_write, tdcall_servtd_rebind_approve, }; } @@ -87,7 +88,7 @@ pub mod tdreport { // Re-export some useful constants and types from original pub use original_tdx_tdcall::tdreport::{ - TdxReport, TD_REPORT_ADDITIONAL_DATA_SIZE, TD_REPORT_SIZE, + TdxReport, TD_REPORT_ADDITIONAL_DATA_SIZE, TD_REPORT_SIZE, TdInfo, }; /// Emulated tdcall_report function for AzCVMEmu mode @@ -118,6 +119,12 @@ pub mod tdreport { Ok(tdx_report) } + + /// Emulated TD Report Verification + pub fn tdcall_verify_report(report_mac: &[u8]) -> Result<(), TdCallError> { + log::warn!("Emulated TD report verification"); + Ok(()) + } } // Add td_call emulation support diff --git a/deps/td-shim-AzCVMEmu/tdx-tdcall/src/tdreport_emu.rs b/deps/td-shim-AzCVMEmu/tdx-tdcall/src/tdreport_emu.rs index e9d1c8ea..0ae27836 100644 --- a/deps/td-shim-AzCVMEmu/tdx-tdcall/src/tdreport_emu.rs +++ b/deps/td-shim-AzCVMEmu/tdx-tdcall/src/tdreport_emu.rs @@ -25,6 +25,13 @@ pub enum QuoteError { ConversionError, } +/// Emulated TD Report Verification +#[cfg(feature = "test_mock_report")] +pub fn tdcall_verify_report(report_mac: &[u8]) -> Result<(), TdCallError> { + info!("Using mock TD report verification for test_mock_report feature"); + Ok(()) +} + /// Emulated TD report generation using mock report #[cfg(feature = "test_mock_report")] pub fn tdcall_report_emulated(_additional_data: &[u8; 64]) -> Result { @@ -423,7 +430,7 @@ fn create_td_report_from_file(quote_file_path: String) -> tdx::TdReport { } // Get report body from quote - let (report_body, servtd_hash) = if body_size == QUOTE_V5_BODY_SIZE_15 { + let (report_body, tee_tcb_svn2, servtd_hash) = if body_size == QUOTE_V5_BODY_SIZE_15 { // v5 with TD Report 1.5 (648 bytes) - includes mr_servicetd let report_v15 = unsafe { &*(quote_data[body_offset..body_offset + body_size].as_ptr() as *const SgxReport2BodyV15) @@ -445,7 +452,7 @@ fn create_td_report_from_file(quote_file_path: String) -> tdx::TdReport { rt_mr: report_v15.rt_mr, report_data: report_v15.report_data, }; - (base_body, report_v15.mr_servicetd) + (base_body, report_v15.tee_tcb_svn2, report_v15.mr_servicetd) } else { // v4 or v5 with TD Report 1.0 (584 bytes) let report = unsafe { @@ -468,7 +475,7 @@ fn create_td_report_from_file(quote_file_path: String) -> tdx::TdReport { rt_mr: report.rt_mr, report_data: report.report_data, }; - (base_body, [0u8; 48]) // SERVTD_HASH always zero for MigTD + (base_body, [0u8; 16], [0u8; 48]) // SERVTD_HASH always zero for MigTD }; // Create TD report with values from parsed quote body @@ -494,7 +501,8 @@ fn create_td_report_from_file(quote_file_path: String) -> tdx::TdReport { mrseam: report_body.mr_seam, mrsigner_seam: report_body.mrsigner_seam, attributes: report_body.seam_attributes, - reserved: [0u8; 111], + tee_tcb_svn2, + reserved: [0u8; 95], }, reserved: [0u8; 17], td_info: TdInfo { diff --git a/deps/td-shim-AzCVMEmu/tdx-tdcall/src/tdx_emu.rs b/deps/td-shim-AzCVMEmu/tdx-tdcall/src/tdx_emu.rs index 49066fdd..3c9d0c89 100644 --- a/deps/td-shim-AzCVMEmu/tdx-tdcall/src/tdx_emu.rs +++ b/deps/td-shim-AzCVMEmu/tdx-tdcall/src/tdx_emu.rs @@ -44,6 +44,10 @@ lazy_static! { static ref MSK_FIELDS: Mutex> = Mutex::new(HashMap::new()); /// Emulated global-scope SYS fields keyed by field_identifier static ref SYS_FIELDS: Mutex> = Mutex::new(HashMap::new()); + /// Emulated td-scope metadata fields keyed by field_identifier + static ref VM_FIELDS: Mutex> = Mutex::new(HashMap::new()); + /// Emulated rebind-session-token + static ref REBIND_SESSION_TOKEN: Mutex> = Mutex::new(HashMap::new()); /// Event notification vector for GetQuote completion static ref EVENT_NOTIFY_VECTOR: Mutex> = Mutex::new(None); /// Pending receive buffer for large transfers that span multiple GHCI transactions @@ -824,6 +828,44 @@ pub fn tdcall_sys_wr(field_identifier: u64, value: u64) -> core::result::Result< Ok(()) } +/// Emulation for TDG.VM.WR: write a TD-scope metadata field +pub fn tdcall_vm_write(field_identifier: u64, value: u64, mask: u64) -> Result { + warn!( + "AzCVMEmu: tdcall_vm_write emulated: field=0x{:x} <= 0x{:x}", + field_identifier, value + ); + SYS_FIELDS.lock().insert(field_identifier, value); + Ok(field_identifier) +} + +/// Emulation for TDG.SERVTD.REBIND.APPROVE: called by the currently bound service TD to approve +/// a new Service TD to be bound to the target TD. +pub fn tdcall_servtd_rebind_approve( + old_binding_handle: u64, + rebind_session_token: &[u8], + target_td_uuid: &[u64], +) -> Result<[u64; 4], TdCallError> { + warn!( + "AzCVMEmu: tdcall_servtd_rebind_approve emulated: old_binding_hanlde=0x{:x} target_td_uuid= 0x{:x?}", + old_binding_handle, target_td_uuid + ); + let uuid = [ + target_td_uuid[0], + target_td_uuid[1], + target_td_uuid[2], + target_td_uuid[3], + ]; + let key = ( + old_binding_handle, + uuid, + ); + let mut value = [0u8; 32]; + value.copy_from_slice(&rebind_session_token[..32]); + + REBIND_SESSION_TOKEN.lock().insert(key, value); + Ok(uuid) +} + /// Emulation for TDG.VP.VMCALL: Generate TD-Quote using vTPM or return hardcoded collateral /// This mimics the exact API signature of tdx_tdcall::tdx::tdvmcall_get_quote /// diff --git a/src/crypto/src/rustls_impl/tls.rs b/src/crypto/src/rustls_impl/tls.rs index bc4ff16f..862bb666 100644 --- a/src/crypto/src/rustls_impl/tls.rs +++ b/src/crypto/src/rustls_impl/tls.rs @@ -55,6 +55,10 @@ where pub async fn read(&mut self, data: &mut [u8]) -> Result { self.conn.read(data).await } + + pub fn peer_certs(&self) -> Option> { + self.conn.peer_certs() + } } enum TlsConnection { @@ -94,6 +98,13 @@ impl TlsConnection { } } + fn peer_certs(&self) -> Option> { + match self { + Self::Server(conn) => conn.peer_certs(), + Self::Client(conn) => conn.peer_certs(), + } + } + fn transport_mut(&mut self) -> &mut T { match self { Self::Server(conn) => &mut conn.transport, @@ -498,7 +509,7 @@ pub(crate) mod connection { } pub struct TlsServerConnection { - conn: UnbufferedServerConnection, + pub(super) conn: UnbufferedServerConnection, input: TlsBuffer, output: TlsBuffer, pub transport: T, @@ -518,6 +529,12 @@ pub(crate) mod connection { }) } + pub fn peer_certs(&self) -> Option> { + self.conn + .peer_certificates() + .map(|certs| certs.iter().map(|der| der.as_ref()).collect()) + } + pub async fn read(&mut self, data: &mut [u8]) -> Result { if self.is_handshaking { self.process_tls_status().await?; @@ -776,6 +793,12 @@ pub(crate) mod connection { }) } + pub fn peer_certs(&self) -> Option> { + self.conn + .peer_certificates() + .map(|certs| certs.iter().map(|der| der.as_ref()).collect()) + } + pub async fn read(&mut self, data: &mut [u8]) -> Result { if self.is_handshaking { self.process_tls_status().await?; diff --git a/src/migtd/src/bin/migtd/cvmemu.rs b/src/migtd/src/bin/migtd/cvmemu.rs index cb916c03..a3ea77a1 100644 --- a/src/migtd/src/bin/migtd/cvmemu.rs +++ b/src/migtd/src/bin/migtd/cvmemu.rs @@ -557,6 +557,11 @@ fn handle_pre_mig_emu() -> i32 { log::trace!(migration_request_id = report_info.mig_request_id; "ReportStatus for get TDREPORT completed.\n"); // Continue to process next request (migration) } + #[cfg(all(feature = "policy_v2"))] + WaitForRequestResponse::StartRebinding(_) + | WaitForRequestResponse::GetMigtdData(_) => { + unimplemented!(); + } } } Err(e) => { diff --git a/src/migtd/src/bin/migtd/main.rs b/src/migtd/src/bin/migtd/main.rs index 97d5a654..1bbbaa5f 100644 --- a/src/migtd/src/bin/migtd/main.rs +++ b/src/migtd/src/bin/migtd/main.rs @@ -13,9 +13,9 @@ use core::task::Poll; #[cfg(feature = "policy_v2")] use alloc::string::String; use alloc::vec::Vec; +use log::info; #[cfg(feature = "vmcall-raw")] use log::{debug, Level}; -use log::{info, LevelFilter}; use migtd::event_log::*; #[cfg(not(feature = "vmcall-raw"))] use migtd::migration::data::MigrationInformation; @@ -101,7 +101,7 @@ pub fn runtime_main() { { // Initialize logging with level filter. The actual log level is determined by // compile-time feature flags. - let _ = td_logger::init(LevelFilter::Trace); + let _ = td_logger::init(log::LevelFilter::Trace); } // Create LogArea per vCPU @@ -446,6 +446,39 @@ fn handle_pre_mig() { log::trace!(migration_request_id = wfr_info.mig_info.mig_request_id; "ReportStatus for key exchange completed\n"); REQUESTS.lock().remove(&wfr_info.mig_info.mig_request_id); } + #[cfg(feature = "policy_v2")] + WaitForRequestResponse::StartRebinding(rebinding_info) => { + use migtd::migration::rebinding::start_rebinding; + + let status = start_rebinding(&rebinding_info, &mut data) + .await + .map(|_| MigrationResult::Success) + .unwrap_or_else(|e| e); + if status == MigrationResult::Success { + log::trace!("Successfully completed key exchange\n"); + log::trace!( + migration_request_id = rebinding_info.mig_request_id; "Successfully completed rebinding\n", + ); + } else { + log::error!( + migration_request_id = rebinding_info.mig_request_id; "Failure during rebinding status code: {:x}\n", status.clone() as u8); + } + let _ = + report_status(status as u8, rebinding_info.mig_request_id, &data) + .await + .map_err(|e| { + log::error!( + migration_request_id = rebinding_info.mig_request_id; + "Failed to report status for StartRebinding: {:?}\n", + e + ); + }); + log::trace!( + migration_request_id = rebinding_info.mig_request_id; + "ReportStatus for rebinding completed\n" + ); + REQUESTS.lock().remove(&rebinding_info.mig_request_id); + } WaitForRequestResponse::GetTdReport(wfr_info) => { let status = get_tdreport( &wfr_info.reportdata, @@ -497,6 +530,26 @@ fn handle_pre_mig() { log::trace!(migration_request_id = wfr_info.mig_request_id; "ReportStatus for Enable LogArea completed\n"); REQUESTS.lock().remove(&wfr_info.mig_request_id); } + #[cfg(feature = "policy_v2")] + WaitForRequestResponse::GetMigtdData(wfr_info) => { + let status = get_migtd_data( + &wfr_info.reportdata, + &mut data, + wfr_info.mig_request_id, + ) + .await + .map(|_| MigrationResult::Success) + .unwrap_or_else(|e| e); + if status == MigrationResult::Success { + log::trace!(migration_request_id = wfr_info.mig_request_id; "Successfully completed get migtd data\n"); + } else { + log::error!(migration_request_id = wfr_info.mig_request_id; "Failure during get migtd data status code: {:x}\n", status.clone() as u8); + } + let _ = + report_status(status as u8, wfr_info.mig_request_id, &data).await; + log::trace!(migration_request_id = wfr_info.mig_request_id; "ReportStatus for get migtd data completed.\n"); + REQUESTS.lock().remove(&wfr_info.mig_request_id); + } } } #[cfg(any(feature = "test_stack_size", feature = "test_heap_size"))] diff --git a/src/migtd/src/event_log.rs b/src/migtd/src/event_log.rs index 09b3fdeb..0d5eb50c 100644 --- a/src/migtd/src/event_log.rs +++ b/src/migtd/src/event_log.rs @@ -12,7 +12,7 @@ use cc_measurement::{ }; use core::mem::size_of; use crypto::hash::digest_sha384; -use policy::{CcEvent, EventName, Report, REPORT_DATA_SIZE}; +use policy::{CcEvent, EventName}; use spin::Once; use td_payload::acpi::get_acpi_tables; use td_shim::event_log::{ @@ -218,11 +218,17 @@ pub(crate) fn parse_events(event_log: &[u8]) -> Option Result<()> { - replay_event_log_with_report(event_log, report) +pub fn verify_event_log( + event_log: &[u8], + report_rtmrs: &[[u8; SHA384_DIGEST_SIZE]; 4], +) -> Result<()> { + replay_event_log_with_report(event_log, report_rtmrs) } -fn replay_event_log_with_report(event_log: &[u8], report: &[u8]) -> Result<()> { +fn replay_event_log_with_report( + event_log: &[u8], + report_rtmrs: &[[u8; SHA384_DIGEST_SIZE]; 4], +) -> Result<()> { let mut rtmrs: [[u8; 96]; 4] = [[0; 96]; 4]; let event_log = if let Some(event_log) = CcEventLogReader::new(event_log) { @@ -250,14 +256,10 @@ fn replay_event_log_with_report(event_log: &[u8], report: &[u8]) -> Result<()> { } } - if report.len() < REPORT_DATA_SIZE { - return Err(anyhow!("Invalid report")); - } - - if report[Report::R_MIGTD_RTMR0] == rtmrs[0][0..48] - && report[Report::R_MIGTD_RTMR1] == rtmrs[1][0..48] - && report[Report::R_MIGTD_RTMR2] == rtmrs[2][0..48] - && report[Report::R_MIGTD_RTMR3] == rtmrs[3][0..48] + if report_rtmrs[0] == rtmrs[0][0..48] + && report_rtmrs[1] == rtmrs[1][0..48] + && report_rtmrs[2] == rtmrs[2][0..48] + && report_rtmrs[3] == rtmrs[3][0..48] { Ok(()) } else { diff --git a/src/migtd/src/mig_policy.rs b/src/migtd/src/mig_policy.rs index 265e84ba..1bb0ebfd 100644 --- a/src/migtd/src/mig_policy.rs +++ b/src/migtd/src/mig_policy.rs @@ -2,14 +2,17 @@ // // SPDX-License-Identifier: BSD-2-Clause-Patent +use crypto::SHA384_DIGEST_SIZE; +pub use policy::{PolicyError, Report, REPORT_DATA_SIZE}; + #[cfg(not(feature = "policy_v2"))] pub use v1::*; #[cfg(not(feature = "policy_v2"))] mod v1 { use policy::verify_policy; - pub use policy::PolicyError; + use super::{get_rtmrs_from_suppl_data, PolicyError}; use crate::{ config::get_policy, event_log::{get_event_log, parse_events, verify_event_log}, @@ -33,8 +36,11 @@ mod v1 { return Err(PolicyError::InvalidParameter); }; - verify_event_log(event_log_peer, verified_report_peer) - .map_err(|_| PolicyError::InvalidEventLog)?; + verify_event_log( + event_log_peer, + &get_rtmrs_from_suppl_data(verified_report_peer)?, + ) + .map_err(|_| PolicyError::InvalidEventLog)?; let event_log = parse_events(event_log).ok_or(PolicyError::InvalidParameter)?; let event_log_peer = parse_events(event_log_peer).ok_or(PolicyError::InvalidParameter)?; @@ -58,13 +64,30 @@ mod v2 { use alloc::{string::String, string::ToString, vec::Vec}; use attestation::verify_quote_with_collaterals; use chrono::DateTime; - use crypto::{crl::get_crl_number, pem_cert_to_der}; + use crypto::{crl::get_crl_number, hash::digest_sha384, pem_cert_to_der, SHA384_DIGEST_SIZE}; use lazy_static::lazy_static; use policy::*; use spin::Once; + use tdx_tdcall::tdreport::{tdcall_verify_report, TdInfo, TdxReport}; use crate::config::get_policy_issuer_chain; use crate::event_log::{parse_events, verify_event_log}; + use crate::mig_policy::get_rtmrs_from_suppl_data; + use crate::migration::servtd_ext::ServtdExt; + + const SERVTD_ATTR_IGNORE_ATTRIBUTES: u64 = 0x1_0000_0000; + const SERVTD_ATTR_IGNORE_XFAM: u64 = 0x2_0000_0000; + const SERVTD_ATTR_IGNORE_MRTD: u64 = 0x4_0000_0000; + const SERVTD_ATTR_IGNORE_MRCONFIGID: u64 = 0x8_0000_0000; + const SERVTD_ATTR_IGNORE_MROWNER: u64 = 0x10_0000_0000; + const SERVTD_ATTR_IGNORE_MROWNERCONFIG: u64 = 0x20_0000_0000; + const SERVTD_ATTR_IGNORE_RTMR0: u64 = 0x40_0000_0000; + const SERVTD_ATTR_IGNORE_RTMR1: u64 = 0x80_0000_0000; + const SERVTD_ATTR_IGNORE_RTMR2: u64 = 0x100_0000_0000; + const SERVTD_ATTR_IGNORE_RTMR3: u64 = 0x200_0000_0000; + + const SERVTD_TYPE_MIGTD: u16 = 0; + const TD_INFO_OFFSET: usize = 512; lazy_static! { pub static ref LOCAL_TCB_INFO: Once = Once::new(); @@ -113,6 +136,13 @@ mod v2 { .ok_or(PolicyError::InvalidParameter) } + pub fn get_init_tcb_evaluation_info( + init_report: &TdxReport, + init_policy: &VerifiedPolicy, + ) -> Result { + setup_evaluation_data_with_tdreport(init_report, init_policy) + } + /// Get reference to the global verified policy /// Returns None if the policy hasn't been initialized yet pub fn get_verified_policy() -> Option<&'static VerifiedPolicy<'static>> { @@ -195,6 +225,93 @@ mod v2 { Ok(suppl_data) } + // Authenticate the migtd-new from migtd-old side + pub fn authenticate_rebinding_new( + tdreport_dst: &[u8], + event_log_dst: &[u8], + mig_policy_dst: &[u8], + ) -> Result, PolicyError> { + let policy_issuer_chain = get_policy_issuer_chain().ok_or(PolicyError::InvalidParameter)?; + + let (evaluation_data_dst, verified_policy_dst, tdx_report) = authenticate_rebinding_common( + tdreport_dst, + event_log_dst, + mig_policy_dst, + policy_issuer_chain, + )?; + let relative_reference = get_local_tcb_evaluation_info()?; + let policy = get_verified_policy().ok_or(PolicyError::InvalidParameter)?; + + policy + .policy_data + .evaluate_policy_forward(&evaluation_data_dst, &relative_reference)?; + + // Verify the destination's policy against local policy + verified_policy_dst + .policy_data + .evaluate_against_policy(&policy.policy_data)?; + + Ok(tdx_report.as_bytes().to_vec()) + } + + // Authenticate the migtd-old from migtd-new side + pub fn authenticate_rebinding_old( + tdreport_src: &[u8], + event_log_src: &[u8], + mig_policy_src: &[u8], + init_policy: &[u8], + init_event_log: &[u8], + init_td_report: &[u8], + servtd_ext_src: &[u8], + ) -> Result, PolicyError> { + let policy_issuer_chain = get_policy_issuer_chain().ok_or(PolicyError::InvalidParameter)?; + + // Verify quote src / event log src / policy src + let (evaluation_data_src, _verified_policy_src, tdx_report) = + authenticate_rebinding_common( + tdreport_src, + event_log_src, + mig_policy_src, + policy_issuer_chain, + )?; + let policy = get_verified_policy().ok_or(PolicyError::InvalidParameter)?; + + // Verify the td report init / event log init / policy init + let servtd_ext_src_obj = + ServtdExt::read_from_bytes(servtd_ext_src).ok_or(PolicyError::InvalidParameter)?; + let init_tdreport = verify_init_tdreport(init_td_report, &servtd_ext_src_obj)?; + let _engine_svn = policy + .servtd_tcb_mapping + .get_engine_svn_by_measurements(&Measurements::new_from_bytes( + &init_tdreport.td_info.mrtd, + &init_tdreport.td_info.rtmr0, + &init_tdreport.td_info.rtmr1, + None, + None, + )) + .ok_or(PolicyError::SvnMismatch)?; + let verified_policy_init = verify_policy_and_event_log( + init_event_log, + init_policy, + policy_issuer_chain, + &get_rtmrs_from_tdreport(&init_tdreport)?, + )?; + + let relative_reference = + get_init_tcb_evaluation_info(&init_tdreport, &verified_policy_init)?; + policy + .policy_data + .evaluate_policy_common(&evaluation_data_src, &relative_reference)?; + + // If backward policy exists, evaluate the migration src based on it. + let relative_reference = get_local_tcb_evaluation_info()?; + policy + .policy_data + .evaluate_policy_backward(&evaluation_data_src, &relative_reference)?; + + Ok(tdx_report.as_bytes().to_vec()) + } + fn authenticate_remote_common<'p>( quote: &[u8], event_log: &[u8], @@ -202,33 +319,20 @@ mod v2 { policy_issuer_chain: &[u8], ) -> Result<(PolicyEvaluationInfo, VerifiedPolicy<'p>, Vec), PolicyError> { let policy = get_verified_policy().ok_or(PolicyError::InvalidParameter)?; - let unverified_policy = RawPolicyData::deserialize_from_json(mig_policy)?; // 1. Verify quote & get supplemental data let (fmspc, suppl_data) = verify_quote(quote, policy.get_collaterals()) .map_err(|_| PolicyError::QuoteVerification)?; - // 2. Verify the event log integrity - verify_event_log( + // 2. Verify the signature of the provided policy and the integrity of the event log + let verified_policy = verify_policy_and_event_log( event_log, - suppl_data - .get(..REPORT_DATA_SIZE) - .ok_or(PolicyError::QuoteVerification)?, - ) - .map_err(|_| PolicyError::InvalidEventLog)?; - - // 3. Verify the integrity of migration policy, with the issuer chains from local policy - let verified_policy = unverified_policy.verify( + mig_policy, policy_issuer_chain, - Some(policy.servtd_identity_issuer_chain.as_bytes()), - Some(policy.servtd_tcb_mapping_issuer_chain.as_bytes()), + &get_rtmrs_from_suppl_data(&suppl_data)?, )?; - // 4. Check the integrity of the policy with its event log - let events = parse_events(event_log).ok_or(PolicyError::InvalidEventLog)?; - check_policy_integrity(mig_policy, &events)?; - - // 5. Get TCB evaluation info from the collaterals + // 3. Get TCB evaluation info from the collaterals let evaluation_data = setup_evaluation_data( fmspc, &suppl_data, @@ -239,6 +343,69 @@ mod v2 { Ok((evaluation_data, verified_policy, suppl_data)) } + fn authenticate_rebinding_common<'p>( + tdreport: &[u8], + event_log: &[u8], + mig_policy: &'p [u8], + policy_issuer_chain: &[u8], + ) -> Result<(PolicyEvaluationInfo, VerifiedPolicy<'p>, TdxReport), PolicyError> { + // 1. Verify quote & get supplemental data + let tdreport_verified = + verify_tdreport(tdreport).map_err(|_| PolicyError::QuoteVerification)?; + + // 2. Verify the signature of the provided policy and the integrity of the event log + let verified_policy = verify_policy_and_event_log( + event_log, + mig_policy, + policy_issuer_chain, + &get_rtmrs_from_tdreport(&tdreport_verified)?, + )?; + + // 3. Get TCB evaluation info from the collaterals + let evaluation_data = + setup_evaluation_data_with_tdreport(&tdreport_verified, &verified_policy)?; + + Ok((evaluation_data, verified_policy, tdreport_verified)) + } + + fn get_rtmrs_from_tdreport( + td_report: &TdxReport, + ) -> Result<[[u8; SHA384_DIGEST_SIZE]; 4], PolicyError> { + let mut rtmrs = [[0u8; SHA384_DIGEST_SIZE]; 4]; + rtmrs[0].copy_from_slice(&td_report.td_info.rtmr0); + rtmrs[1].copy_from_slice(&td_report.td_info.rtmr1); + rtmrs[2].copy_from_slice(&td_report.td_info.rtmr2); + rtmrs[3].copy_from_slice(&td_report.td_info.rtmr3); + + Ok(rtmrs) + } + + fn verify_policy_and_event_log<'p>( + event_log: &[u8], + mig_policy: &'p [u8], + policy_issuer_chain: &[u8], + rtmrs: &[[u8; SHA384_DIGEST_SIZE]; 4], + ) -> Result, PolicyError> { + let policy = get_verified_policy().ok_or(PolicyError::InvalidParameter)?; + let unverified_policy = RawPolicyData::deserialize_from_json(mig_policy)?; + + // 1. Verify the event log integrity + verify_event_log(event_log, rtmrs).map_err(|_| PolicyError::InvalidEventLog)?; + + // 2. Verify the integrity of migration policy, with the issuer chains from local policy + let verified_policy = unverified_policy.verify( + policy_issuer_chain, + Some(policy.servtd_identity_issuer_chain.as_bytes()), + Some(policy.servtd_tcb_mapping_issuer_chain.as_bytes()), + )?; + + // 3. Check the integrity of the policy with its event log + let events = parse_events(event_log).ok_or(PolicyError::InvalidEventLog)?; + check_policy_integrity(mig_policy, &events)?; + + Ok(verified_policy) + } + fn verify_quote( quote: &[u8], collaterals: &Collaterals, @@ -252,6 +419,110 @@ mod v2 { Ok((fmspc, suppl_data)) } + fn verify_tdreport(tdreport: &[u8]) -> Result { + let tdx_report = + TdxReport::read_from_bytes(tdreport).ok_or(PolicyError::InvalidTdReport)?; + + // Verify the REPORTMACSTRUCT + tdcall_verify_report(tdx_report.report_mac.as_bytes()) + .map_err(|_| PolicyError::TdReportVerification)?; + + // Verify the TDINFO_STRUCT and TEE_TCB_INFO + let tdinfo_hash = digest_sha384(tdx_report.td_info.as_bytes()) + .map_err(|_| PolicyError::HashCalculation)?; + let tee_tcb_info_hash = digest_sha384(tdx_report.tee_tcb_info.as_bytes()) + .map_err(|_| PolicyError::HashCalculation)?; + + let mut validity = true; + validity &= &tdx_report.report_mac.tee_tcb_info_hash == tee_tcb_info_hash.as_slice(); + validity &= tdx_report.report_mac.tee_info_hash != [0; 48]; + validity &= &tdx_report.report_mac.tee_info_hash == tdinfo_hash.as_slice(); + + if !validity { + return Err(PolicyError::InvalidTdReport); + } + Ok(tdx_report) + } + + fn verify_servtd_hash( + servtd_report: &[u8], + servtd_attr: u64, + init_servtd_hash: &[u8], + ) -> Result { + if servtd_report.len() < TD_INFO_OFFSET + size_of::() { + return Err(PolicyError::InvalidParameter); + } + + // Extract TdInfo from the report + let mut td_report = + TdxReport::read_from_bytes(servtd_report).ok_or(PolicyError::InvalidTdReport)?; + + if (servtd_attr & SERVTD_ATTR_IGNORE_ATTRIBUTES) != 0 { + td_report.td_info.attributes.fill(0); + } + if (servtd_attr & SERVTD_ATTR_IGNORE_XFAM) != 0 { + td_report.td_info.xfam.fill(0); + } + if (servtd_attr & SERVTD_ATTR_IGNORE_MRTD) != 0 { + td_report.td_info.mrtd.fill(0); + } + if (servtd_attr & SERVTD_ATTR_IGNORE_MRCONFIGID) != 0 { + td_report.td_info.mrconfig_id.fill(0); + } + if (servtd_attr & SERVTD_ATTR_IGNORE_MROWNER) != 0 { + td_report.td_info.mrowner.fill(0); + } + if (servtd_attr & SERVTD_ATTR_IGNORE_MROWNERCONFIG) != 0 { + td_report.td_info.mrownerconfig.fill(0); + } + if (servtd_attr & SERVTD_ATTR_IGNORE_RTMR0) != 0 { + td_report.td_info.rtmr0.fill(0); + } + if (servtd_attr & SERVTD_ATTR_IGNORE_RTMR1) != 0 { + td_report.td_info.rtmr1.fill(0); + } + if (servtd_attr & SERVTD_ATTR_IGNORE_RTMR2) != 0 { + td_report.td_info.rtmr2.fill(0); + } + if (servtd_attr & SERVTD_ATTR_IGNORE_RTMR3) != 0 { + td_report.td_info.rtmr3.fill(0); + } + + let info_hash = digest_sha384(td_report.td_info.as_bytes()) + .map_err(|_| PolicyError::HashCalculation)?; + + // Calculate ServTD hash: SHA384(info_hash || type || attr) + let mut buffer = [0u8; SHA384_DIGEST_SIZE + size_of::() + size_of::()]; + let mut offset = 0; + + buffer[offset..offset + SHA384_DIGEST_SIZE].copy_from_slice(&info_hash); + offset += SHA384_DIGEST_SIZE; + + buffer[offset..offset + size_of::()].copy_from_slice(&SERVTD_TYPE_MIGTD.to_le_bytes()); + offset += size_of::(); + + buffer[offset..offset + size_of::()].copy_from_slice(&servtd_attr.to_le_bytes()); + + let calculated_hash = digest_sha384(&buffer).map_err(|_| PolicyError::HashCalculation)?; + + if calculated_hash.as_slice() != init_servtd_hash { + return Err(PolicyError::InvalidTdReport); + } + + Ok(td_report) + } + + fn verify_init_tdreport( + init_report: &[u8], + servtd_ext: &ServtdExt, + ) -> Result { + verify_servtd_hash( + init_report, + u64::from_le_bytes(servtd_ext.init_attr), + &servtd_ext.init_servtd_info_hash, + ) + } + fn setup_evaluation_data( fmspc: [u8; 6], suppl_data: &[u8], @@ -278,6 +549,7 @@ mod v2 { .map_err(|_| PolicyError::InvalidCollateral)?; Ok(PolicyEvaluationInfo { + tee_tcb_svn: None, tcb_date: Some(tcb_date.to_string()), tcb_status: Some(tcb_status.as_str().to_string()), tcb_evaluation_number: Some(tcb_evaluation_number), @@ -290,6 +562,36 @@ mod v2 { }) } + fn setup_evaluation_data_with_tdreport( + tdreport: &TdxReport, + policy: &VerifiedPolicy, + ) -> Result { + let migtd_svn = policy.servtd_tcb_mapping.get_engine_svn_by_measurements( + &Measurements::new_from_bytes( + &tdreport.td_info.mrtd, + &tdreport.td_info.rtmr0, + &tdreport.td_info.rtmr1, + None, + None, + ), + ); + + let migtd_tcb = migtd_svn.and_then(|svn| policy.servtd_identity.get_tcb_level_by_svn(svn)); + + Ok(PolicyEvaluationInfo { + tee_tcb_svn: Some(tdreport.tee_tcb_info.tee_tcb_svn), + tcb_date: None, + tcb_status: None, + tcb_evaluation_number: None, + fmspc: None, + migtd_isvsvn: migtd_svn, + migtd_tcb_date: migtd_tcb.map(|tcb| tcb.tcb_date.clone()), + migtd_tcb_status: migtd_tcb.map(|tcb| tcb.tcb_status.clone()), + pck_crl_num: None, + root_ca_crl_num: None, + }) + } + fn get_tcb_date_and_status_from_suppl_data( suppl_data: &[u8], ) -> Result<(String, String), PolicyError> { @@ -350,3 +652,19 @@ mod v2 { assert_eq!(iso_date, "2024-01-01T00:00:00Z"); } } + +fn get_rtmrs_from_suppl_data( + suppl_data: &[u8], +) -> Result<[[u8; SHA384_DIGEST_SIZE]; 4], PolicyError> { + if suppl_data.len() < REPORT_DATA_SIZE { + return Err(PolicyError::InvalidParameter); + } + + let mut rtmrs = [[0u8; SHA384_DIGEST_SIZE]; 4]; + rtmrs[0].copy_from_slice(&suppl_data[Report::R_MIGTD_RTMR0]); + rtmrs[1].copy_from_slice(&suppl_data[Report::R_MIGTD_RTMR1]); + rtmrs[2].copy_from_slice(&suppl_data[Report::R_MIGTD_RTMR2]); + rtmrs[3].copy_from_slice(&suppl_data[Report::R_MIGTD_RTMR3]); + + Ok(rtmrs) +} diff --git a/src/migtd/src/migration/data.rs b/src/migtd/src/migration/data.rs index a014f387..e5ccf16a 100644 --- a/src/migtd/src/migration/data.rs +++ b/src/migtd/src/migration/data.rs @@ -2,6 +2,9 @@ // // SPDX-License-Identifier: BSD-2-Clause-Patent +#[cfg(all(feature = "vmcall-raw", feature = "policy_v2"))] +use crate::migration::rebinding::RebindingInfo; + use super::*; #[cfg(feature = "vmcall-raw")] use bitfield_struct::bitfield; @@ -254,8 +257,12 @@ pub struct RequestDataBuffer<'a> { #[cfg(feature = "vmcall-raw")] pub enum WaitForRequestResponse { StartMigration(MigrationInformation), + #[cfg(feature = "policy_v2")] + StartRebinding(RebindingInfo), GetTdReport(ReportInfo), EnableLogArea(EnableLogAreaInfo), + #[cfg(feature = "policy_v2")] + GetMigtdData(MigtdDataInfo), } pub struct MigrationInformation { diff --git a/src/migtd/src/migration/mod.rs b/src/migtd/src/migration/mod.rs index 59c5670f..86636eb5 100644 --- a/src/migtd/src/migration/mod.rs +++ b/src/migtd/src/migration/mod.rs @@ -7,6 +7,9 @@ pub mod event; pub mod logging; #[cfg(feature = "policy_v2")] pub mod pre_session_data; +#[cfg(all(feature = "main", feature = "policy_v2", feature = "vmcall-raw"))] +pub mod rebinding; +pub mod servtd_ext; #[cfg(feature = "main")] pub mod session; #[cfg(feature = "main")] @@ -112,6 +115,16 @@ pub struct ReportInfo { pub reportdata: [u8; 64], } +#[repr(C)] +#[derive(Debug, Pread, Pwrite)] +#[cfg(all(feature = "vmcall-raw", feature = "policy_v2"))] +pub struct MigtdDataInfo { + // ID for the migration request, which can be used in TDG.VP.VMCALL + // + pub mig_request_id: u64, + pub reportdata: [u8; 64], +} + #[repr(C)] #[derive(Debug, Pread, Pwrite)] #[cfg(feature = "vmcall-raw")] @@ -220,7 +233,8 @@ impl From for MigrationResult { RatlsError::Crypto(_) | RatlsError::X509(_) | RatlsError::InvalidEventlog - | RatlsError::InvalidPolicy => MigrationResult::SecureSessionError, + | RatlsError::InvalidPolicy + | RatlsError::GenerateCertificate => MigrationResult::SecureSessionError, RatlsError::TdxModule(_) => MigrationResult::TdxModuleError, RatlsError::GetQuote | RatlsError::VerifyQuote => { MigrationResult::MutualAttestationError diff --git a/src/migtd/src/migration/rebinding.rs b/src/migtd/src/migration/rebinding.rs new file mode 100644 index 00000000..51e2b71b --- /dev/null +++ b/src/migtd/src/migration/rebinding.rs @@ -0,0 +1,720 @@ +// Copyright (c) 2025 Intel Corporation +// +// SPDX-License-Identifier: BSD-2-Clause-Patent + +use alloc::{boxed::Box, vec::Vec}; +use core::mem::MaybeUninit; +use core::time::Duration; +use crypto::{ + tls::SecureChannel, + x509::{Certificate, Decode}, + SHA384_DIGEST_SIZE, +}; +use ring::rand::{SecureRandom, SystemRandom}; +use tdx_tdcall::tdx::{tdcall_servtd_rebind_approve, tdcall_vm_write}; + +use crate::migration::servtd_ext::read_servtd_ext; +use crate::{event_log, migration::transport::*}; +use crypto::hash::digest_sha384; + +use crate::{ + config, + migration::pre_session_data::{ + exchange_hello_packet, receive_pre_session_data_packet, receive_start_session_packet, + send_pre_session_data_packet, send_start_session_packet, + }, +}; + +use crate::{ + driver::ticks::with_timeout, + migration::{ + servtd_ext::{write_approved_servtd_ext_hash, ServtdExt}, + MigrationResult, + }, + ratls::{self, find_extension, EXTNID_MIGTD_SERVTD_EXT}, +}; +pub use tdx_tdcall::tdx::TargetTdUuid; + +/// Rebind session token held by the Service TD. This field is written by the ServiceTD +/// executing TDG.VM.WR. +pub const TDCS_FIELD_SERVTD_REBIND_ACCEPT_TOKEN: u64 = 0x191000030000021E; +/// The intended SERVTD_ATTR for the Service TD about to be bound to the TD. +pub const TDCS_FIELD_SERVTD_REBIND_ATTR: u64 = 0x1910000300000222; + +const TLS_TIMEOUT: Duration = Duration::from_secs(60); // 60 seconds + // FIXME: Need VMM provide socket information +const MIGTD_DATA_SIGNATURE: &[u8] = b"MIGTDATA"; +const MIGTD_DATA_TYPE_INIT_MIG_POLICY: u32 = 0; +const MIGTD_DATA_TYPE_INIT_TD_REPORT: u32 = 1; +const MIGTD_DATA_TYPE_INIT_EVENT_LOG: u32 = 2; + +const MIGTD_REBIND_OP_PREPARE: u8 = 0; +const MIGTD_REBIND_OP_FINALIZE: u8 = 1; + +#[repr(C)] +pub struct RebindingToken { + pub token: [u8; 32], + pub target_td_uuid: TargetTdUuid, +} + +impl RebindingToken { + pub fn read_from_bytes(bytes: &[u8]) -> Option { + if bytes.len() < size_of::() { + return None; + } + + let mut uinit: MaybeUninit = MaybeUninit::uninit(); + // Safety: MaybeUninit has same layout with RebindingToken + Some(unsafe { + core::ptr::copy_nonoverlapping( + bytes.as_ptr(), + uinit.as_mut_ptr() as *mut u8, + size_of::(), + ); + uinit.assume_init() + }) + } + + pub fn as_bytes(&self) -> &[u8] { + unsafe { core::slice::from_raw_parts(self as *const _ as *const u8, size_of::()) } + } +} + +pub struct RebindingInfo { + pub mig_request_id: u64, + pub rebinding_src: u8, + pub has_init_data: u8, + pub operation: u8, + pub target_td_uuid: [u64; 4], + pub binding_handle: u64, + pub init_migtd_data: Option, +} + +impl RebindingInfo { + pub fn read_from_bytes(b: &[u8]) -> Option { + // Check the length of input and the reserved fields + if b.len() < 56 || b[11..16] != [0; 5] { + return None; + } + let mig_request_id = u64::from_le_bytes(b[..8].try_into().unwrap()); + let rebinding_src = b[8]; + let has_init_data = b[9]; + let operation = b[10]; + + let target_td_uuid: [u64; 4] = core::array::from_fn(|i| { + let offset = 16 + i * 8; + u64::from_le_bytes(b[offset..offset + 8].try_into().unwrap()) + }); + let binding_handle = u64::from_le_bytes(b[48..56].try_into().unwrap()); + + let mut init_migtd_data = None; + if has_init_data == 1 { + // Returns None if `has_init_data` is set but reading initialization data from the input buffer fails. + init_migtd_data = Some(InitData::read_from_bytes(&b[56..])?); + } + + Some(Self { + mig_request_id, + rebinding_src, + has_init_data, + operation, + target_td_uuid, + binding_handle, + init_migtd_data, + }) + } +} + +pub struct InitData { + pub init_report: Vec, + pub init_policy: Vec, + pub init_event_log: Vec, +} + +impl InitData { + pub fn read_from_bytes(b: &[u8]) -> Option { + if b.len() < 20 || &b[..8] != MIGTD_DATA_SIGNATURE { + return None; + } + + let version = u32::from_le_bytes(b[8..12].try_into().unwrap()); + let length = u32::from_le_bytes(b[12..16].try_into().unwrap()); + let num_entries = u32::from_le_bytes(b[16..20].try_into().unwrap()); + + if version != 0x00010000 || b.len() < length as usize { + return None; + } + + let mut offset = 20; + let mut init_report = None; + let mut init_policy = None; + let mut init_event_log = None; + for _ in 0..num_entries { + let entry = MigtdDateEntry::read_from_bytes(&b[offset..])?; + match entry.r#type { + MIGTD_DATA_TYPE_INIT_MIG_POLICY => init_policy = Some(entry.value), + MIGTD_DATA_TYPE_INIT_TD_REPORT => { + if entry.value.len() > 1024 { + return None; + } + init_report = Some(entry.value.to_vec()) + } + MIGTD_DATA_TYPE_INIT_EVENT_LOG => init_event_log = Some(entry.value), + _ => return None, + } + offset += entry.length as usize + 8; + } + + Some(Self { + init_report: init_report?, + init_policy: init_policy?.to_vec(), + init_event_log: init_event_log?.to_vec(), + }) + } + + pub fn write_into_bytes(&self, buf: &mut Vec) { + let start_len = buf.len(); + buf.extend_from_slice(MIGTD_DATA_SIGNATURE); + buf.extend_from_slice(&0x00010000u32.to_le_bytes()); // Version + + // Placeholder for length. + buf.extend_from_slice(&0u32.to_le_bytes()); + + buf.extend_from_slice(&3u32.to_le_bytes()); // num_entries + + // Helper to write entries + let mut write_entry = |type_: u32, value: &[u8]| { + buf.extend_from_slice(&type_.to_le_bytes()); + buf.extend_from_slice(&(value.len() as u32).to_le_bytes()); + buf.extend_from_slice(value); + }; + + write_entry(MIGTD_DATA_TYPE_INIT_MIG_POLICY, &self.init_policy); + write_entry(MIGTD_DATA_TYPE_INIT_TD_REPORT, &self.init_report); + write_entry(MIGTD_DATA_TYPE_INIT_EVENT_LOG, &self.init_event_log); + + let total_size = (buf.len() - start_len) as u32; + + // Update length field + let length_offset = start_len + 12; + buf[length_offset..length_offset + 4].copy_from_slice(&total_size.to_le_bytes()); + } + + pub fn get_from_local(report_data: &[u8; 64]) -> Option { + Some(Self { + init_report: tdx_tdcall::tdreport::tdcall_report(report_data) + .ok()? + .as_bytes() + .to_vec(), + init_policy: config::get_policy()?.to_vec(), + init_event_log: event_log::get_event_log()?.to_vec(), + }) + } +} + +pub struct MigtdDateEntry<'a> { + pub r#type: u32, + pub length: u32, + pub value: &'a [u8], +} + +impl<'a> MigtdDateEntry<'a> { + pub fn read_from_bytes(b: &'a [u8]) -> Option { + if b.len() < 8 { + return None; + } + + let r#type = u32::from_le_bytes(b[0..4].try_into().unwrap()); + let length = u32::from_le_bytes(b[4..8].try_into().unwrap()); + + if b.len() < length as usize + 8 { + return None; + } + + Some(Self { + r#type, + length, + value: &b[8..8 + length as usize], + }) + } +} + +pub(super) async fn rebinding_old_pre_session_data_exchange( + transport: &mut TransportType, + init_policy: &[u8], +) -> Result, MigrationResult> { + let version = exchange_hello_packet(transport).await.map_err(|e| { + log::error!( + "pre_session_data_exchange: exchange_hello_packet error: {:?}\n", + e + ); + e + })?; + log::info!("Pre-Session-Message Version: 0x{:04x}\n", version); + + let policy = config::get_policy() + .ok_or(MigrationResult::InvalidParameter) + .map_err(|e| { + log::error!("pre_session_data_exchange: get_policy error: {:?}\n", e); + e + })?; + send_pre_session_data_packet(policy, transport) + .await + .map_err(|e| { + log::error!( + "pre_session_data_exchange: send_pre_session_data_packet error: {:?}\n", + e + ); + e + })?; + let remote_policy = receive_pre_session_data_packet(transport) + .await + .map_err(|e| { + log::error!( + "pre_session_data_exchange: receive_pre_session_data_packet error: {:?}\n", + e + ); + e + })?; + + send_pre_session_data_packet(init_policy, transport) + .await + .map_err(|e| { + log::error!( + "pre_session_data_exchange: send_pre_session_data_packet error: {:?}\n", + e + ); + e + })?; + + send_start_session_packet(transport).await.map_err(|e| { + log::error!( + "pre_session_data_exchange: send_start_session_packet error: {:?}\n", + e + ); + e + })?; + receive_start_session_packet(transport).await.map_err(|e| { + log::error!( + "pre_session_data_exchange: receive_start_session_packet error: {:?}\n", + e + ); + e + })?; + + Ok(remote_policy) +} + +pub(super) async fn rebinding_new_pre_session_data_exchange( + transport: &mut TransportType, +) -> Result, MigrationResult> { + let version = exchange_hello_packet(transport).await.map_err(|e| { + log::error!( + "pre_session_data_exchange: exchange_hello_packet error: {:?}\n", + e + ); + e + })?; + log::info!("Pre-Session-Message Version: 0x{:04x}\n", version); + + let policy = config::get_policy() + .ok_or(MigrationResult::InvalidParameter) + .map_err(|e| { + log::error!("pre_session_data_exchange: get_policy error: {:?}\n", e); + e + })?; + send_pre_session_data_packet(policy, transport) + .await + .map_err(|e| { + log::error!( + "pre_session_data_exchange: send_pre_session_data_packet error: {:?}\n", + e + ); + e + })?; + let remote_policy = receive_pre_session_data_packet(transport) + .await + .map_err(|e| { + log::error!( + "pre_session_data_exchange: receive_pre_session_data_packet error: {:?}\n", + e + ); + e + })?; + + let init_policy = receive_pre_session_data_packet(transport) + .await + .map_err(|e| { + log::error!( + "pre_session_data_exchange: send_pre_session_data_packet error: {:?}\n", + e + ); + e + })?; + + send_start_session_packet(transport).await.map_err(|e| { + log::error!( + "pre_session_data_exchange: send_start_session_packet error: {:?}\n", + e + ); + e + })?; + receive_start_session_packet(transport).await.map_err(|e| { + log::error!( + "pre_session_data_exchange: receive_start_session_packet error: {:?}\n", + e + ); + e + })?; + + // FIXME: Refactor the TLS verification callback to enable easier access to pre-session data. + let mut policy_buffer = Vec::new(); + policy_buffer.extend_from_slice(&(remote_policy.len() as u32).to_le_bytes()); + policy_buffer.extend_from_slice(&remote_policy); + policy_buffer.extend_from_slice(&(init_policy.len() as u32).to_le_bytes()); + policy_buffer.extend_from_slice(&init_policy); + + Ok(policy_buffer) +} + +pub async fn start_rebinding( + info: &RebindingInfo, + data: &mut Vec, +) -> Result<(), MigrationResult> { + let mut transport = setup_transport(info.mig_request_id).await?; + + // Exchange policy firstly because of the message size limitation of TLS protocol + const PRE_SESSION_TIMEOUT: Duration = Duration::from_secs(60); // 60 seconds + if info.rebinding_src == 1 { + let local_data = + InitData::get_from_local(&[0u8; 64]).ok_or(MigrationResult::InvalidParameter)?; + let init_migtd_data = info + .init_migtd_data + .as_ref() + .or(Some(&local_data)) + .ok_or(MigrationResult::InvalidParameter)?; + let remote_policy = Box::pin(with_timeout( + PRE_SESSION_TIMEOUT, + rebinding_old_pre_session_data_exchange(&mut transport, &init_migtd_data.init_policy), + )) + .await + .map_err(|e| { + log::error!( + "start_rebinding: rebinding_old_pre_session_data_exchange timeout error: {:?}\n", + e + ); + e + })? + .map_err(|e| { + log::error!( + "start_rebinding: rebinding_old_pre_session_data_exchange error: {:?}\n", + e + ); + e + })?; + #[cfg(not(feature = "spdm_attestation"))] + match info.operation { + MIGTD_REBIND_OP_PREPARE => { + rebinding_old_prepare(transport, info, &init_migtd_data, data, remote_policy) + .await? + } + MIGTD_REBIND_OP_FINALIZE => rebinding_old_finalize(info, data).await?, + _ => return Err(MigrationResult::InvalidParameter), + } + } else { + let pre_session_data = Box::pin(with_timeout( + PRE_SESSION_TIMEOUT, + rebinding_new_pre_session_data_exchange(&mut transport), + )) + .await + .map_err(|e| { + log::error!( + "start_rebinding: rebinding_new_pre_session_data_exchange timeout error: {:?}\n", + e + ); + e + })? + .map_err(|e| { + log::error!( + "start_rebinding: rebinding_new_pre_session_data_exchange error: {:?}\n", + e + ); + e + })?; + + #[cfg(not(feature = "spdm_attestation"))] + match info.operation { + MIGTD_REBIND_OP_PREPARE => { + rebinding_new_prepare(transport, info, data, pre_session_data).await? + } + MIGTD_REBIND_OP_FINALIZE => rebinding_new_finalize(info, data).await?, + _ => return Err(MigrationResult::InvalidParameter), + } + } + + #[cfg(feature = "vmcall-raw")] + { + use crate::migration::logging::entrylog; + + entrylog( + &format!("Complete rebinding and report status\n").into_bytes(), + log::Level::Info, + info.mig_request_id, + ); + log::info!("Complete rebinding and report status\n"); + } + Ok(()) +} + +pub async fn rebinding_old_prepare( + transport: TransportType, + info: &RebindingInfo, + init_migtd_data: &InitData, + data: &mut Vec, + remote_policy: Vec, +) -> Result<(), MigrationResult> { + let servtd_ext = read_servtd_ext(info.binding_handle, &info.target_td_uuid)?; + let init_policy_hash = digest_sha384(&init_migtd_data.init_policy)?; + + // TLS client + let mut ratls_client = ratls::client_rebinding( + transport, + remote_policy, + &init_policy_hash, + &init_migtd_data.init_report, + &init_migtd_data.init_event_log, + &servtd_ext, + ) + .map_err(|_| { + #[cfg(feature = "vmcall-raw")] + data.extend_from_slice( + &format!( + "Error: rebinding_old(): Failed in ratls transport. Migration ID: {:x}\n", + info.mig_request_id, + ) + .into_bytes(), + ); + log::error!( + "rebinding_old(): Failed in ratls transport. Migration ID: {}\n", + info.mig_request_id + ); + MigrationResult::SecureSessionError + })?; + + let rebind_token = create_rebind_token(info)?; + tls_send_rebind_token(&mut ratls_client, &rebind_token).await?; + + approve_rebinding(info, &rebind_token)?; + + shutdown_transport(ratls_client.transport_mut(), info.mig_request_id).await?; + Ok(()) +} + +pub async fn rebinding_old_finalize( + _info: &RebindingInfo, + _data: &mut Vec, +) -> Result<(), MigrationResult> { + Ok(()) +} + +async fn rebinding_new_prepare( + transport: TransportType, + info: &RebindingInfo, + data: &mut Vec, + pre_session_data: Vec, +) -> Result<(), MigrationResult> { + // TLS server + let mut ratls_server = ratls::server_rebinding(transport, pre_session_data).map_err(|_| { + #[cfg(feature = "vmcall-raw")] + data.extend_from_slice( + &format!( + "Error: rebinding_new(): Failed in ratls transport. Migration ID: {:x}\n", + info.mig_request_id + ) + .into_bytes(), + ); + log::error!( + "rebinding_new(): Failed in ratls transport. Migration ID: {}\n", + info.mig_request_id + ); + MigrationResult::SecureSessionError + })?; + + let servtd_ext = get_servtd_ext_from_cert(&ratls_server.peer_certs())?; + let rebind_token = tls_receive_rebind_token(&mut ratls_server).await?; + if rebind_token.target_td_uuid != info.target_td_uuid { + return Err(MigrationResult::InvalidParameter); + } + + write_rebinding_session_token(&rebind_token.token)?; + write_servtd_rebind_attr(&servtd_ext.cur_servtd_attr)?; + write_approved_servtd_ext_hash(&servtd_ext.calculate_approved_servtd_ext_hash()?)?; + + shutdown_transport(ratls_server.transport_mut(), info.mig_request_id).await?; + Ok(()) +} + +async fn rebinding_new_finalize( + _info: &RebindingInfo, + _data: &mut Vec, +) -> Result<(), MigrationResult> { + write_rebinding_session_token(&[0u8; 32])?; + write_approved_servtd_ext_hash(&[0u8; SHA384_DIGEST_SIZE])?; + Ok(()) +} + +pub fn write_rebinding_session_token(rebind_token: &[u8]) -> Result<(), MigrationResult> { + if rebind_token.len() != 32 { + return Err(MigrationResult::InvalidParameter); + } + + for (idx, chunk) in rebind_token.chunks_exact(size_of::()).enumerate() { + let elem = u64::from_le_bytes(chunk.try_into().unwrap()); + tdcall_vm_write(TDCS_FIELD_SERVTD_REBIND_ACCEPT_TOKEN + idx as u64, elem, 0)?; + } + + Ok(()) +} + +pub fn write_servtd_rebind_attr(servtd_attr: &[u8]) -> Result<(), MigrationResult> { + if servtd_attr.len() != 8 { + return Err(MigrationResult::InvalidParameter); + } + + let elem = u64::from_le_bytes(servtd_attr.try_into().unwrap()); + tdcall_vm_write(TDCS_FIELD_SERVTD_REBIND_ATTR, elem, 0)?; + + Ok(()) +} + +pub fn approve_rebinding( + info: &RebindingInfo, + rebind_token: &RebindingToken, +) -> Result<(), MigrationResult> { + tdcall_servtd_rebind_approve( + info.binding_handle, + &rebind_token.token, + &info.target_td_uuid, + )?; + Ok(()) +} + +fn get_servtd_ext_from_cert(certs: &Option>) -> Result { + if let Some(cert_chain) = certs { + if cert_chain.is_empty() { + return Err(MigrationResult::SecureSessionError); + } + + let cert = Certificate::from_der(cert_chain[0]) + .map_err(|_| MigrationResult::SecureSessionError)?; + + let extensions = cert + .tbs_certificate + .extensions + .as_ref() + .ok_or(MigrationResult::SecureSessionError)?; + + let servtd_ext = find_extension(extensions, &EXTNID_MIGTD_SERVTD_EXT) + .ok_or(MigrationResult::SecureSessionError)?; + + ServtdExt::read_from_bytes(servtd_ext).ok_or(MigrationResult::InvalidParameter) + } else { + Err(MigrationResult::SecureSessionError) + } +} + +fn create_rebind_token(info: &RebindingInfo) -> Result { + let mut token = [0u8; 32]; + let rng = SystemRandom::new(); + rng.fill(&mut token) + .map_err(|_| MigrationResult::InvalidParameter)?; + + Ok(RebindingToken { + token, + target_td_uuid: info.target_td_uuid, + }) +} + +async fn tls_send_rebind_token( + tls_session: &mut SecureChannel, + rebind_token: &RebindingToken, +) -> Result<(), MigrationResult> { + // MigTD old send rebinding session token to peer + with_timeout( + TLS_TIMEOUT, + tls_session_write_all(tls_session, rebind_token.as_bytes()), + ) + .await + .map_err(|e| { + log::error!( + "tls_send_rebind_token: tls_session_write_all timeout error: {:?}\n", + e + ); + e + })? + .map_err(|e| { + log::error!( + "tls_send_rebind_token: tls_session_write_all error: {:?}\n", + e + ); + e + })?; + Ok(()) +} + +async fn tls_receive_rebind_token( + tls_session: &mut SecureChannel, +) -> Result { + let mut data = [0u8; size_of::()]; + // MigTD old send rebinding session token to peer + with_timeout(TLS_TIMEOUT, tls_session_read_exact(tls_session, &mut data)) + .await + .map_err(|e| { + log::error!( + "tls_receive_rebind_token: tls_session_read_exact timeout error: {:?}\n", + e + ); + e + })? + .map_err(|e| { + log::error!( + "tls_receive_rebind_token: tls_session_read_exact error: {:?}\n", + e + ); + e + })?; + + let rebind_token = + RebindingToken::read_from_bytes(&data).ok_or(MigrationResult::InvalidParameter)?; + Ok(rebind_token) +} + +async fn tls_session_write_all( + tls_session: &mut SecureChannel, + data: &[u8], +) -> Result<(), MigrationResult> { + let mut sent = 0; + while sent < data.len() { + let n = tls_session + .write(&data[sent..]) + .await + .map_err(|_| MigrationResult::SecureSessionError)?; + sent += n; + } + Ok(()) +} + +async fn tls_session_read_exact( + tls_session: &mut SecureChannel, + data: &mut [u8], +) -> Result<(), MigrationResult> { + let mut recvd = 0; + while recvd < data.len() { + let n = tls_session + .read(&mut data[recvd..]) + .await + .map_err(|_| MigrationResult::NetworkError)?; + recvd += n; + } + Ok(()) +} diff --git a/src/migtd/src/migration/servtd_ext.rs b/src/migtd/src/migration/servtd_ext.rs new file mode 100644 index 00000000..8ef77482 --- /dev/null +++ b/src/migtd/src/migration/servtd_ext.rs @@ -0,0 +1,149 @@ +// Copyright (c) 2025 Intel Corporation +// +// SPDX-License-Identifier: BSD-2-Clause-Patent + +use alloc::vec::Vec; +use core::mem::MaybeUninit; +use crypto::{hash::digest_sha384, SHA384_DIGEST_SIZE}; +use tdx_tdcall::tdx::{tdcall_servtd_rd, tdcall_vm_write}; + +use crate::migration::MigrationResult; + +/// SERVTD_EXT_STRUCT fields in target TD’s TDCS +pub const TDCS_FIELD_SERVTD_INIT_SERVTD_INFO_HASH: u64 = 0x191000030000020E; +pub const TDCS_FIELD_SERVTD_INIT_ATTR: u64 = 0x191000030000020D; +pub const TDCS_FIELD_INIT_CPUSVN: u64 = 0x1110000300000060; +pub const TDCS_FIELD_INIT_TEE_TCB_SVN: u64 = 0x1110000300000062; +pub const TDCS_FIELD_INIT_TEE_MODEL: u64 = 0x1110000200000064; +/// SERVTD_EXT_STRUCT fields in Service TDs Binding Table Entry in target TD’s TDCS +pub const TDCS_FIELD_SERVTD_INFO_HASH: u64 = 0x1910000300000207; +pub const TDCS_FIELD_SERVTD_ATTR: u64 = 0x1910000300000202; + +/// Hash of SERVTD_EXT that the new Service TD 0 (i.e., rebound Service TD or MigTD on the +/// destination platform) believes is the SERVTD_EXT for this TD. +pub const TDCS_FIELD_SERVTD_ACCEPT_SERVTD_EXT_HASH: u64 = 0x1910000300000214; + +#[repr(C)] +pub struct ServtdExt { + pub init_servtd_info_hash: [u8; 48], + pub init_attr: [u8; 8], + reserved: [u8; 8], + pub init_cpusvn: [u8; 16], + pub init_tee_tcb_svn: [u8; 16], + pub init_tee_model: [u8; 12], + pub cur_servtd_info_hash: [u8; 48], + pub cur_servtd_attr: [u8; 8], + reserved2: [u8; 104], +} + +impl ServtdExt { + pub fn read_from_bytes(bytes: &[u8]) -> Option { + if bytes.len() < size_of::() { + return None; + } + + let mut uninit = MaybeUninit::::uninit(); + // SAFETY: `MaybeUninit` has same memory layout with T. + Some(unsafe { + core::ptr::copy_nonoverlapping( + bytes.as_ptr(), + uninit.as_mut_ptr() as *mut u8, + core::mem::size_of::(), + ); + uninit.assume_init() + }) + } + + pub fn as_bytes(&self) -> &[u8] { + unsafe { core::slice::from_raw_parts(self as *const _ as *const u8, size_of::()) } + } + + pub fn calculate_approved_servtd_ext_hash(mut self) -> Result, MigrationResult> { + self.cur_servtd_attr.fill(0); + self.cur_servtd_info_hash.fill(0); + digest_sha384(self.as_bytes()).map_err(|_| MigrationResult::InvalidParameter) + } +} + +#[repr(C)] +pub struct TeeModel { + custom: u16, + platform_id: u16, + fm: u32, + reservtd: [u8; 8], +} + +pub fn read_servtd_ext( + binding_handle: u64, + target_td_uuid: &[u64], +) -> Result { + let read_field = + |field_base: u64, elem_size: usize, buf: &mut [u8]| -> Result<(), MigrationResult> { + for (idx, chunk) in buf.chunks_mut(elem_size).enumerate() { + let result = + tdcall_servtd_rd(binding_handle, field_base + idx as u64, target_td_uuid)?; + let bytes = result.content.to_le_bytes(); + chunk.copy_from_slice(&bytes[..chunk.len()]); + } + + Ok(()) + }; + + let mut init_servtd_info_hash = [0u8; 48]; + let mut init_attr = [0u8; 8]; + let mut init_cpusvn = [0u8; 16]; + let mut init_tee_tcb_svn = [0u8; 16]; + let mut init_tee_model = [0u8; 12]; + let mut cur_servtd_info_hash = [0u8; 48]; + let mut cur_servtd_attr = [0u8; 8]; + + read_field( + TDCS_FIELD_SERVTD_INIT_SERVTD_INFO_HASH, + 8, + &mut init_servtd_info_hash, + )?; + read_field(TDCS_FIELD_SERVTD_INIT_ATTR, 8, &mut init_attr)?; + read_field(TDCS_FIELD_INIT_CPUSVN, 8, &mut init_cpusvn)?; + read_field(TDCS_FIELD_INIT_TEE_TCB_SVN, 8, &mut init_tee_tcb_svn)?; + read_field(TDCS_FIELD_INIT_TEE_MODEL, 4, &mut init_tee_model)?; + read_field(TDCS_FIELD_SERVTD_INFO_HASH, 8, &mut cur_servtd_info_hash)?; + read_field(TDCS_FIELD_SERVTD_ATTR, 8, &mut cur_servtd_attr)?; + + Ok(ServtdExt { + init_servtd_info_hash, + init_attr, + init_cpusvn, + init_tee_tcb_svn, + init_tee_model, + cur_servtd_info_hash, + cur_servtd_attr, + reserved: [0u8; 8], + reserved2: [0u8; 104], + }) +} + +pub fn write_approved_servtd_ext_hash(servtd_ext_hash: &[u8]) -> Result<(), MigrationResult> { + if servtd_ext_hash.len() != SHA384_DIGEST_SIZE { + return Err(MigrationResult::InvalidParameter); + } + + for (idx, chunk) in servtd_ext_hash.chunks_exact(size_of::()).enumerate() { + let elem = u64::from_le_bytes(chunk.try_into().unwrap()); + tdcall_vm_write( + TDCS_FIELD_SERVTD_ACCEPT_SERVTD_EXT_HASH + idx as u64, + elem, + 0, + )?; + } + + Ok(()) +} + +mod test { + use super::ServtdExt; + + #[test] + fn test_structure_sizes() { + assert_eq!(size_of::(), 268) + } +} diff --git a/src/migtd/src/migration/session.rs b/src/migtd/src/migration/session.rs index 8483de85..9f37f458 100644 --- a/src/migtd/src/migration/session.rs +++ b/src/migtd/src/migration/session.rs @@ -6,6 +6,8 @@ use crate::migration::event::VMCALL_MIG_REPORTSTATUS_FLAGS; #[cfg(feature = "policy_v2")] use crate::migration::pre_session_data::pre_session_data_exchange; +#[cfg(all(feature = "vmcall-raw", feature = "policy_v2"))] +use crate::migration::rebinding::RebindingInfo; use crate::migration::transport::setup_transport; use crate::migration::transport::shutdown_transport; use crate::migration::transport::TransportType; @@ -70,8 +72,10 @@ const TDX_VMCALL_VMM_SUCCESS: u8 = 1; #[derive(Debug, PartialEq, Eq)] pub enum DataStatusOperation { StartMigration = 1, + StartRebinding = 2, GetReportData = 3, EnableLogArea = 4, + GetMigtdData = 5, } #[cfg(feature = "vmcall-raw")] @@ -200,7 +204,7 @@ pub fn query() -> Result<()> { } #[cfg(feature = "vmcall-raw")] -fn process_buffer(buffer: &mut [u8]) -> RequestDataBufferHeader { +fn process_buffer(buffer: &[u8]) -> RequestDataBufferHeader { let length = size_of::(); let mut outputbuffer = RequestDataBufferHeader { datastatus: 0, @@ -214,7 +218,7 @@ fn process_buffer(buffer: &mut [u8]) -> RequestDataBufferHeader { ); return outputbuffer; } - let (header, _payload_buffer) = buffer.split_at_mut(length); // Split at 12th byte + let (header, _payload_buffer) = buffer.split_at(length); // Split at 12th byte outputbuffer = RequestDataBufferHeader { datastatus: u64::from_le_bytes(header[0..8].try_into().unwrap()), @@ -224,6 +228,21 @@ fn process_buffer(buffer: &mut [u8]) -> RequestDataBufferHeader { outputbuffer } +#[cfg(feature = "vmcall-raw")] +fn calculate_shared_page_nums(reqbufferhdrlen: usize) -> Result { + let init_data_header_size = 44; // size of MIGTD_DATA_STRUCT header + MIGTD_DATA_ENTRY_STRUCT header + let policy_size = crate::config::get_policy() + .ok_or(MigrationResult::InvalidParameter)? + .len(); + let event_log_size = crate::event_log::get_event_log() + .ok_or(MigrationResult::InvalidParameter)? + .len(); + let report_size = 1024; + let total_size = + reqbufferhdrlen + init_data_header_size + policy_size + event_log_size + report_size; + Ok((total_size + PAGE_SIZE - 1) / PAGE_SIZE) +} + #[cfg(feature = "vmcall-raw")] pub async fn wait_for_request() -> Result { let mut reqbufferhdr = RequestDataBufferHeader { @@ -231,22 +250,26 @@ pub async fn wait_for_request() -> Result { length: 0, }; let reqbufferhdrlen = size_of::(); - let mut data_buffer = SharedMemory::new(1).ok_or_else(|| { + let shared_page_nums = calculate_shared_page_nums(reqbufferhdrlen)?; + + let mut data_buffer = SharedMemory::new(shared_page_nums).ok_or_else(|| { log::error!("wait_for_request: Failed to allocate shared memory\n"); MigrationResult::OutOfResource })?; - let data_buffer = data_buffer.as_mut_bytes(); + let shared_data_buffer = data_buffer.as_mut_bytes(); - data_buffer[0..reqbufferhdrlen].copy_from_slice(&reqbufferhdr.as_bytes()); + shared_data_buffer[0..reqbufferhdrlen].copy_from_slice(&reqbufferhdr.as_bytes()); - tdx::tdvmcall_migtd_waitforrequest(data_buffer, event::VMCALL_SERVICE_VECTOR).map_err(|e| { - log::error!( - "wait_for_request: tdvmcall_migtd_waitforrequest failure {:?}\n", + tdx::tdvmcall_migtd_waitforrequest(shared_data_buffer, event::VMCALL_SERVICE_VECTOR).map_err( + |e| { + log::error!( + "wait_for_request: tdvmcall_migtd_waitforrequest failure {:?}\n", + e + ); e - ); - e - })?; + }, + )?; poll_fn(|_cx| { if VMCALL_SERVICE_FLAG.load(Ordering::SeqCst) { @@ -255,6 +278,13 @@ pub async fn wait_for_request() -> Result { return Poll::Pending; } + let data_buffer = if let Some(private_buffer) = data_buffer.copy_to_private_shadow() { + private_buffer + } else { + log::error!("wait_for_request: copy_to_private_shadow failure\n"); + return Poll::Pending + }; + reqbufferhdr = process_buffer(data_buffer); let data_status = reqbufferhdr.datastatus; let data_length = reqbufferhdr.length; @@ -304,6 +334,37 @@ pub async fn wait_for_request() -> Result { REQUESTS.lock().insert(mig_request_id); Poll::Ready(Ok(WaitForRequestResponse::StartMigration(wfr_info))) } + } else if operation == DataStatusOperation::StartRebinding as u8 { + #[cfg(all(feature = "vmcall-raw", feature = "policy_v2"))] + match RebindingInfo::read_from_bytes(&data_buffer[reqbufferhdrlen..]) { + Some(rebinding_info) => { + VMCALL_MIG_REPORTSTATUS_FLAGS + .lock() + .insert(rebinding_info.mig_request_id, AtomicBool::new(false)); + + if REQUESTS.lock().contains(&rebinding_info.mig_request_id) { + Poll::Pending + } else { + REQUESTS.lock().insert(rebinding_info.mig_request_id); + Poll::Ready(Ok(WaitForRequestResponse::StartRebinding(rebinding_info))) + } + } + None => { + if data_length >= size_of::() as u32 { + let slice = &data_buffer[reqbufferhdrlen..reqbufferhdrlen + data_length as usize]; + let mig_request_id = u64::from_le_bytes(slice[0..8].try_into().unwrap()); + log::error!(migration_request_id = mig_request_id; "wait_for_request: StartRebinding operation incorrect data received\n"); + } else { + log::error!("wait_for_request: StartRebinding operation incorrect data received\n"); + } + Poll::Pending + } + } + #[cfg(not(all(feature = "vmcall-raw", feature = "policy_v2")))] + { + log::debug!("wait_for_request: invalid operation StartRebinding received\n"); + Poll::Pending + } } else if operation == DataStatusOperation::GetReportData as u8 { let mut reportdata: [u8; 64] = [0; 64]; let mut mig_request_id: u64 = 0; @@ -376,6 +437,50 @@ pub async fn wait_for_request() -> Result { REQUESTS.lock().insert(mig_request_id); Poll::Ready(Ok(WaitForRequestResponse::EnableLogArea(wfr_info))) } + } else if operation == DataStatusOperation::GetMigtdData as u8 { + #[cfg(all(feature = "vmcall-raw", feature = "policy_v2"))] + { + let mut reportdata: [u8; 64] = [0; 64]; + let mut mig_request_id: u64 = 0; + if data_length != size_of::() as u32 + { + if data_length >= size_of::() as u32 { + let slice = &data_buffer[reqbufferhdrlen..reqbufferhdrlen + data_length as usize]; + let mig_request_id = u64::from_le_bytes(slice[0..8].try_into().unwrap()); + log::error!(migration_request_id = mig_request_id; "wait_for_request: StartMigration operation incorrect data length - expected {} actual {}\n", size_of::(), data_length); + } else { + log::error!("wait_for_request: StartMigration operation incorrect data length - expected {} actual {}\n", size_of::(), data_length); + } + return Poll::Pending; + } + let slice = &data_buffer[reqbufferhdrlen..reqbufferhdrlen + data_length as usize]; + mig_request_id = u64::from_le_bytes(slice[0..8].try_into().unwrap()); + + if data_length == (size_of_val(&mig_request_id) + TD_REPORT_ADDITIONAL_DATA_SIZE) as u32 + { + reportdata = slice[8..72].try_into().unwrap(); + } + + VMCALL_MIG_REPORTSTATUS_FLAGS + .lock() + .insert(mig_request_id, AtomicBool::new(false)); + + let wfr_info = MigtdDataInfo { + mig_request_id, + reportdata, + }; + if REQUESTS.lock().contains(&mig_request_id) { + Poll::Pending + } else { + REQUESTS.lock().insert(mig_request_id); + Poll::Ready(Ok(WaitForRequestResponse::GetMigtdData(wfr_info))) + } + } + #[cfg(not(all(feature = "vmcall-raw", feature = "policy_v2")))] + { + log::debug!("wait_for_request: invalid operation GetMigtdData received\n"); + Poll::Pending + } } else { Poll::Pending } @@ -538,6 +643,25 @@ pub async fn get_tdreport( Ok(()) } +#[cfg(all(feature = "vmcall-raw", feature = "policy_v2"))] +pub async fn get_migtd_data( + additional_data: &[u8; TD_REPORT_ADDITIONAL_DATA_SIZE], + data: &mut Vec, + request_id: u64, +) -> Result<()> { + use crate::migration::rebinding::InitData; + + let init_data = InitData::get_from_local(additional_data).ok_or_else(|| { + log::error!( migration_request_id = request_id; + "Failed to get init migtd data from local\n", + ); + MigrationResult::InvalidParameter + })?; + + init_data.write_into_bytes(data); + Ok(()) +} + #[cfg(feature = "vmcall-raw")] pub async fn report_status(status: u8, request_id: u64, data: &Vec) -> Result<()> { let mut reportstatus = ReportStatusResponse::new() @@ -549,7 +673,9 @@ pub async fn report_status(status: u8, request_id: u64, data: &Vec) -> Resul length: 0, }; let reqbufferhdrlen = size_of::(); - let mut data_buffer = SharedMemory::new(1).ok_or_else(|| { + + let shared_page_nums = (reqbufferhdrlen + data.len() + PAGE_SIZE - 1) / PAGE_SIZE; + let mut data_buffer = SharedMemory::new(shared_page_nums).ok_or_else(|| { log::error!(migration_request_id = request_id; "report_status: Failed to allocate shared memory for data buffer\n"); MigrationResult::OutOfResource })?; @@ -568,13 +694,13 @@ pub async fn report_status(status: u8, request_id: u64, data: &Vec) -> Resul return Err(MigrationResult::InvalidParameter); } - if data.len() > 0 && data.len() < (PAGE_SIZE - reqbufferhdrlen) { + if data.len() > 0 { reqbufferhdr.length += data.len() as u32; } let data_buffer = data_buffer.as_mut_bytes(); data_buffer[0..reqbufferhdrlen].copy_from_slice(&reqbufferhdr.as_bytes()); - if data.len() > 0 && data.len() < (PAGE_SIZE - reqbufferhdrlen) { + if data.len() > 0 { data_buffer[reqbufferhdrlen..data.len() + reqbufferhdrlen] .copy_from_slice(&data[0..data.len()]); } @@ -724,7 +850,7 @@ async fn migration_src_exchange_msk( log::error!(migration_request_id = info.mig_info.mig_request_id; "exchange_msk(): Incorrect ExchangeInformation size Size - Expected: {} Actual: {}\n", size_of::(), size); return Err(MigrationResult::NetworkError); } - shutdown_transport(ratls_client.transport_mut(), info).await?; + shutdown_transport(ratls_client.transport_mut(), info.mig_info.mig_request_id).await?; Ok(()) } @@ -785,7 +911,7 @@ async fn migration_dst_exchange_msk( log::error!(migration_request_id = info.mig_info.mig_request_id; "exchange_msk(): Incorrect ExchangeInformation size. Size - Expected: {} Actual: {}\n", size_of::(), size); return Err(MigrationResult::NetworkError); } - shutdown_transport(ratls_server.transport_mut(), info).await?; + shutdown_transport(ratls_server.transport_mut(), info.mig_info.mig_request_id).await?; Ok(()) } @@ -870,7 +996,14 @@ async fn migration_dst_exchange_msk( #[cfg(feature = "main")] pub async fn exchange_msk(info: &MigrationInformation) -> Result<()> { - let mut transport = setup_transport(info).await?; + let mut transport = setup_transport( + info.mig_info.mig_request_id, + #[cfg(any(feature = "vmcall-vsock", feature = "virtio-vsock"))] + info.mig_socket_info.mig_td_cid, + #[cfg(any(feature = "vmcall-vsock", feature = "virtio-vsock"))] + info.mig_socket_info.mig_channel_port, + ) + .await?; // Exchange policy firstly because of the message size limitation of TLS protocol #[cfg(feature = "policy_v2")] @@ -902,7 +1035,7 @@ pub async fn exchange_msk(info: &MigrationInformation) -> Result<()> { #[cfg(not(feature = "spdm_attestation"))] { let mut remote_information = ExchangeInformation::default(); - let mut exchange_information = + let exchange_information = exchange_info(&info.mig_info, info.is_src()).map_err(|e| { log::error!(migration_request_id = info.mig_info.mig_request_id; "exchange_msk: exchange_info error: {:?}\n", e); e diff --git a/src/migtd/src/migration/transport.rs b/src/migtd/src/migration/transport.rs index d0accd56..35f689db 100644 --- a/src/migtd/src/migration/transport.rs +++ b/src/migtd/src/migration/transport.rs @@ -3,7 +3,6 @@ // SPDX-License-Identifier: BSD-2-Clause-Patent use super::MigrationResult; -use crate::migration::data::MigrationInformation; type Result = core::result::Result; @@ -16,19 +15,22 @@ pub(super) type TransportType = virtio_serial::VirtioSerialPort; #[cfg(all(not(feature = "virtio-serial"), not(feature = "vmcall-raw")))] pub(super) type TransportType = vsock::stream::VsockStream; -pub(super) async fn setup_transport(info: &MigrationInformation) -> Result { +pub(super) async fn setup_transport( + mig_request_id: u64, + #[cfg(any(feature = "vmcall-vsock", feature = "virtio-vsock"))] migtd_cid: u64, + #[cfg(any(feature = "vmcall-vsock", feature = "virtio-vsock"))] mig_channel_port: u32, +) -> Result { #[cfg(feature = "vmcall-raw")] { use vmcall_raw::stream::VmcallRaw; - let mut vmcall_raw_instance = VmcallRaw::new_with_mid(info.mig_info.mig_request_id) - .map_err(|e| { - log::error!(migration_request_id = info.mig_info.mig_request_id; + let mut vmcall_raw_instance = VmcallRaw::new_with_mid(mig_request_id).map_err(|e| { + log::error!(migration_request_id = mig_request_id; "exchange_msk: Failed to create vmcall_raw_instance errorcode: {:?}\n", e); - MigrationResult::InvalidParameter - })?; + MigrationResult::InvalidParameter + })?; vmcall_raw_instance.connect().await.map_err(|e| { - log::error!(migration_request_id = info.mig_info.mig_request_id; + log::error!(migration_request_id = mig_request_id; "exchange_msk: Failed to connect vmcall_raw_instance errorcode: {:?}\n", e); MigrationResult::InvalidParameter })?; @@ -53,17 +55,11 @@ pub(super) async fn setup_transport(info: &MigrationInformation) -> Result Result Result<()> { #[cfg(feature = "vmcall-raw")] transport.shutdown().await.map_err(|e| { - log::error!(migration_request_id = info.mig_info.mig_request_id; + log::error!(migration_request_id = mig_request_id; "shutdown_transport: Failed to shutdown vmcall_raw_instance errorcode: {:?}\n", e); MigrationResult::InvalidParameter })?; diff --git a/src/migtd/src/ratls/mod.rs b/src/migtd/src/ratls/mod.rs index 3117cad4..716739e9 100644 --- a/src/migtd/src/ratls/mod.rs +++ b/src/migtd/src/ratls/mod.rs @@ -22,6 +22,7 @@ pub enum RatlsError { X509(DerError), InvalidEventlog, InvalidPolicy, + GenerateCertificate, } impl From for RatlsError { @@ -54,6 +55,16 @@ pub const EXTNID_MIGTD_EVENT_LOG: ObjectIdentifier = ObjectIdentifier::new_unwrap("1.2.840.113741.1.5.5.1.3"); pub const EXTNID_MIGTD_POLICY_HASH: ObjectIdentifier = ObjectIdentifier::new_unwrap("1.2.840.113741.1.5.5.1.4"); +pub const EXTNID_MIGTD_TDREPORT: ObjectIdentifier = + ObjectIdentifier::new_unwrap("1.2.840.113741.1.5.5.1.5"); +pub const EXTNID_MIGTD_SERVTD_EXT: ObjectIdentifier = + ObjectIdentifier::new_unwrap("1.2.840.113741.1.5.5.1.6"); +pub const EXTNID_MIGTD_TDREPORT_INIT: ObjectIdentifier = + ObjectIdentifier::new_unwrap("1.2.840.113741.1.5.5.1.7"); +pub const EXTNID_MIGTD_EVENT_LOG_INIT: ObjectIdentifier = + ObjectIdentifier::new_unwrap("1.2.840.113741.1.5.5.1.8"); +pub const EXTNID_MIGTD_INIT_POLICY_HASH: ObjectIdentifier = + ObjectIdentifier::new_unwrap("1.2.840.113741.1.5.5.1.9"); // As specified in https://datatracker.ietf.org/doc/html/rfc5480#appendix-A // id-ecPublicKey OBJECT IDENTIFIER ::= { diff --git a/src/migtd/src/ratls/server_client.rs b/src/migtd/src/ratls/server_client.rs index 252596ca..4f1d05f3 100644 --- a/src/migtd/src/ratls/server_client.rs +++ b/src/migtd/src/ratls/server_client.rs @@ -14,11 +14,12 @@ use crypto::{ }, Error as CryptoError, }; +use tdx_tdcall::tdreport::TdxReport; use super::*; -#[cfg(feature = "policy_v2")] -use crate::config::get_policy; use crate::event_log::get_event_log; +#[cfg(feature = "policy_v2")] +use crate::{config::get_policy, migration::servtd_ext::ServtdExt}; use verify::*; type Result = core::result::Result; @@ -29,7 +30,7 @@ pub fn server(stream: T) -> Result( ); e })?; - let (certs, _quote) = gen_cert(&signing_key).map_err(|e| { + let (certs, _quote) = create_certificate_for_server(&signing_key).map_err(|e| { log::error!("server policy_v2 gen_cert() failed with error {:?}\n", e); e })?; let certs = vec![certs]; // Server verifies certificate of client - let config = - TlsConfig::new(certs, signing_key, verify_client_cert, remote_policy).map_err(|e| { - log::error!( - "server policy_v2 TlsConfig::new() failed with error {:?}\n", - e - ); + let config = TlsConfig::new( + certs, + signing_key, + move |cert, quote| verify_client_cert(cert, quote), + remote_policy, + ) + .map_err(|e| { + log::error!( + "server policy_v2 TlsConfig::new() failed with error {:?}\n", e - })?; + ); + e + })?; config.tls_server(stream).map_err(|e| { log::error!("server policy_v2 tls_server() failed with error {:?}\n", e); e.into() @@ -86,7 +92,7 @@ pub fn client(stream: T) -> Result( ); e })?; - let (certs, _quote) = gen_cert(&signing_key).map_err(|e| { + let (certs, _quote) = create_certificate_for_client(&signing_key).map_err(|e| { log::error!("client policy_v2 gen_cert() failed with error {:?}\n", e); e })?; @@ -139,21 +145,109 @@ pub fn client( }) } -fn gen_cert(signing_key: &EcdsaPk) -> Result<(Vec, Vec)> { - let algorithm = AlgorithmIdentifier { - algorithm: ID_EC_PUBKEY_OID, - parameters: Some(AnyRef::new( - Tag::ObjectIdentifier, - SECP384R1_OID.as_bytes(), - )?), - }; - let eku = vec![SERVER_AUTH, CLIENT_AUTH, MIGTD_EXTENDED_KEY_USAGE] - .to_der() +// TLS server for rebinding new +#[cfg(feature = "policy_v2")] +pub fn server_rebinding( + stream: T, + remote_policy: Vec, +) -> Result> { + let signing_key = EcdsaPk::new().map_err(|e| { + log::error!( + "server rebinding EcdsaPk::new() failed with error {:?}\n", + e + ); + e + })?; + let certs = create_certificate_for_rebinding_new(&signing_key).map_err(|e| { + log::error!("server rebinding gen_cert() failed with error {:?}\n", e); + e + })?; + let certs = vec![certs]; + + let config = TlsConfig::new(certs, signing_key, verify_rebinding_old_cert, remote_policy) .map_err(|e| { - log::error!("gen_cert to_der failed with error {:?}\n", e); + log::error!( + "server rebinding TlsConfig::new() failed with error {:?}\n", + e + ); e })?; + config.tls_server(stream).map_err(|e| { + log::error!("server rebinding tls_server() failed with error {:?}\n", e); + e.into() + }) +} +// TLS client for rebinding old +#[cfg(feature = "policy_v2")] +pub fn client_rebinding( + stream: T, + remote_policy: Vec, + init_policy_hash: &[u8], + init_td_report: &[u8], + init_event_log: &[u8], + servtd_ext: &ServtdExt, +) -> Result> { + let signing_key = EcdsaPk::new().map_err(|e| { + log::error!( + "server rebinding EcdsaPk::new() failed with error {:?}\n", + e + ); + e + })?; + let certs = create_certificate_for_rebinding_old( + &signing_key, + init_policy_hash, + init_td_report, + init_event_log, + servtd_ext, + ) + .map_err(|e| { + log::error!("server rebinding gen_cert() failed with error {:?}\n", e); + e + })?; + let certs = vec![certs]; + + let config = TlsConfig::new(certs, signing_key, verify_rebinding_new_cert, remote_policy) + .map_err(|e| { + log::error!( + "server rebinding TlsConfig::new() failed with error {:?}\n", + e + ); + e + })?; + config.tls_server(stream).map_err(|e| { + log::error!("server rebinding tls_server() failed with error {:?}\n", e); + e.into() + }) +} + +fn gen_quote(public_key: &[u8]) -> Result> { + let td_report = gen_tdreport(public_key)?; + + attestation::get_quote(td_report.as_bytes()).map_err(|e| { + log::error!("Failed to get quote from TD report. Error: {:?}\n", e); + RatlsError::GetQuote + }) +} + +fn gen_tdreport(public_key: &[u8]) -> Result { + let hash = digest_sha384(public_key).map_err(|e| { + log::error!("Failed to compute SHA384 digest: {:?}\n", e); + e + })?; + + // Generate the TD Report that contains the public key hash as nonce + let mut additional_data = [0u8; 64]; + additional_data[..hash.len()].copy_from_slice(hash.as_ref()); + + tdx_tdcall::tdreport::tdcall_report(&additional_data).map_err(|e| { + log::error!("Failed to get TD report via tdcall. Error: {:?}\n", e); + e.into() + }) +} + +fn create_certificate_for_server(signing_key: &EcdsaPk) -> Result<(Vec, Vec)> { let pub_key = signing_key.public_key().map_err(|e| { log::error!( "gen_cert signing_key.public_key() failed with error {:?}\n", @@ -161,34 +255,90 @@ fn gen_cert(signing_key: &EcdsaPk) -> Result<(Vec, Vec)> { ); e })?; - let sig_alg = AlgorithmIdentifier { - algorithm: ID_EC_SIG_OID, - parameters: None, - }; - let key_usage = BitStringRef::from_bytes(&[0x80]) + let quote = gen_quote(&pub_key).map_err(|e| { + log::error!("gen_cert gen_quote() failed with error {:?}\n", e); + e + })?; + + #[cfg(feature = "policy_v2")] + let policy_hash = { + let policy = get_policy().ok_or_else(|| { + log::error!( + "gen_cert client policy_v2 Failed to get migration policy for policy hash.\n" + ); + RatlsError::InvalidPolicy + })?; + digest_sha384(policy) + } + .map_err(|e| { + log::error!("gen_cert digest_sha384() failed with error {:?}\n", e); + e + })?; + + let eku = create_eku()?; + let key_usage = create_key_usage()?; + + let x509_builder = create_tls_tbs_common(&pub_key, &key_usage, &eku)?.add_extension( + Extension::new( + EXTNID_MIGTD_QUOTE_REPORT, + Some(false), + Some(quote.as_slice()), + ) + .map_err(|e| { + log::error!( + "gen_cert Extension::new for EXTNID_MIGTD_QUOTE_REPORT failed with error {:?}\n", + e + ); + e + })?, + ) .map_err(|e| { log::error!( - "gen_cert BitStringRef::from_bytes() failed with error {:?}\n", + "gen_cert add_extension for EXTNID_MIGTD_QUOTE_REPORT failed with error {:?}\n", e ); e - })? - .to_der() + })?; + + // If policy_v2 feature is enabled, add policy extension + #[cfg(feature = "policy_v2")] + let x509_builder = x509_builder + .add_extension( + Extension::new(EXTNID_MIGTD_POLICY_HASH, Some(false), Some(&policy_hash)).map_err( + |e| { + log::error!( + "gen_cert policy_v2 add_extension failed with error {:?}.\n", + e + ); + e + }, + )?, + ) .map_err(|e| { log::error!( - "gen_cert BitStringRef::to_der() failed with error {:?}\n", + "gen_cert policy_v2 add_extension for policy hash failed with error {:?}.\n", e ); e })?; + + let x509_cert_der = sign_tls_tbs(x509_builder, &signing_key)?; + Ok((x509_cert_der, quote)) +} + +fn create_certificate_for_client(signing_key: &EcdsaPk) -> Result<(Vec, Vec)> { + let pub_key = signing_key.public_key().map_err(|e| { + log::error!( + "gen_cert signing_key.public_key() failed with error {:?}\n", + e + ); + e + })?; let quote = gen_quote(&pub_key).map_err(|e| { log::error!("gen_cert gen_quote() failed with error {:?}\n", e); e })?; - let event_log = get_event_log().ok_or_else(|| { - log::error!("gen_cert get_event_log() failed with error RatlsError::InvalidEventlog.\n"); - RatlsError::InvalidEventlog - })?; + #[cfg(feature = "policy_v2")] let policy_hash = { let policy = get_policy().ok_or_else(|| { @@ -204,28 +354,122 @@ fn gen_cert(signing_key: &EcdsaPk) -> Result<(Vec, Vec)> { e })?; - let x509_builder = CertificateBuilder::new(sig_alg, algorithm, &pub_key) + let eku = create_eku()?; + let key_usage = create_key_usage()?; + + let x509_builder = create_tls_tbs_common(&pub_key, &key_usage, &eku)?.add_extension( + Extension::new( + EXTNID_MIGTD_QUOTE_REPORT, + Some(false), + Some(quote.as_slice()), + ) + .map_err(|e| { + log::error!( + "gen_cert Extension::new for EXTNID_MIGTD_QUOTE_REPORT failed with error {:?}\n", + e + ); + e + })?, + ) .map_err(|e| { - log::error!("gen_cert CertificateBuilder::new failed with error {:?}\n", e); + log::error!( + "gen_cert add_extension for EXTNID_MIGTD_QUOTE_REPORT failed with error {:?}\n", + e + ); e - })? - // 1970-01-01T00:00:00Z - .set_not_before(core::time::Duration::new(0, 0)) + })?; + + // If policy_v2 feature is enabled, add policy extension + #[cfg(feature = "policy_v2")] + let x509_builder = x509_builder + .add_extension( + Extension::new(EXTNID_MIGTD_POLICY_HASH, Some(false), Some(&policy_hash)).map_err( + |e| { + log::error!( + "gen_cert policy_v2 add_extension failed with error {:?}.\n", + e + ); + e + }, + )?, + ) .map_err(|e| { - log::error!("gen_cert set_not_before failed with error {:?}\n", e); + log::error!( + "gen_cert policy_v2 add_extension for policy hash failed with error {:?}.\n", + e + ); e - })? - // 9999-12-31T23:59:59Z - .set_not_after(core::time::Duration::new(253402300799, 0)) + })?; + + let x509_cert_der = sign_tls_tbs(x509_builder, &signing_key)?; + Ok((x509_cert_der, quote)) +} + +#[cfg(feature = "policy_v2")] +fn create_certificate_for_rebinding_old( + signing_key: &EcdsaPk, + init_policy_hash: &[u8], + init_tdreport: &[u8], + init_event_log: &[u8], + servtd_ext: &ServtdExt, +) -> Result> { + let pub_key = signing_key.public_key().map_err(|e| { + log::error!( + "gen_cert signing_key.public_key() failed with error {:?}\n", + e + ); + e + })?; + let tdreport = gen_tdreport(&pub_key).map_err(|e| { + log::error!("gen_cert gen_tdreport() failed with error {:?}\n", e); + e + })?; + + let policy = get_policy().ok_or_else(|| { + log::error!( + "gen_cert rebinding old policy_v2 Failed to get migration policy for policy hash.\n" + ); + RatlsError::InvalidPolicy + })?; + let policy_hash = digest_sha384(policy).map_err(|e| { + log::error!("gen_cert digest_sha384() failed with error {:?}\n", e); + e + })?; + + let eku = create_eku()?; + let key_usage = create_key_usage()?; + + let x509_builder = create_tls_tbs_common(&pub_key, &key_usage, &eku)? + .add_extension( + Extension::new( + EXTNID_MIGTD_TDREPORT, + Some(false), + Some(tdreport.as_bytes()), + ) + .map_err(|e| { + log::error!( + "gen_cert Extension::new for EXTNID_MIGTD_TDREPORT failed with error {:?}\n", + e + ); + e + })?, + ) .map_err(|e| { - log::error!("gen_cert set_not_after failed with error {:?}\n", e); + log::error!( + "gen_cert add_extension for EXTNID_MIGTD_TDREPORT failed with error {:?}\n", + e + ); e - })? + })?; + + // If policy_v2 feature is enabled, add policy extension + #[cfg(feature = "policy_v2")] + let x509_builder = x509_builder .add_extension( - Extension::new(KEY_USAGE_EXTENSION, Some(true), Some(key_usage.as_slice())).map_err( + Extension::new(EXTNID_MIGTD_POLICY_HASH, Some(false), Some(&policy_hash)).map_err( |e| { log::error!( - "gen_cert Extension::new for KEY_USAGE_EXTENSION failed with error {:?}\n", + "gen_cert policy_v2 add_extension failed with error {:?}.\n", e ); e @@ -234,15 +478,20 @@ fn gen_cert(signing_key: &EcdsaPk) -> Result<(Vec, Vec)> { ) .map_err(|e| { log::error!( - "gen_cert add_extension for KEY_USAGE_EXTENSION failed with error {:?}\n", + "gen_cert policy_v2 add_extension for policy hash failed with error {:?}.\n", e ); e })? .add_extension( - Extension::new(EXTENDED_KEY_USAGE, Some(false), Some(eku.as_slice())).map_err(|e| { + Extension::new( + EXTNID_MIGTD_SERVTD_EXT, + Some(false), + Some(servtd_ext.as_bytes()), + ) + .map_err(|e| { log::error!( - "gen_cert Extension::new for EXTENDED_KEY_USAGE failed with error {:?}\n", + "gen_cert policy_v2 add_extension failed with error {:?}.\n", e ); e @@ -250,20 +499,219 @@ fn gen_cert(signing_key: &EcdsaPk) -> Result<(Vec, Vec)> { ) .map_err(|e| { log::error!( - "gen_cert add_extension for EXTENDED_KEY_USAGE failed with error {:?}\n", + "gen_cert policy_v2 add_extension for servtd_ext failed with error {:?}.\n", e ); e })? .add_extension( Extension::new( - EXTNID_MIGTD_QUOTE_REPORT, + EXTNID_MIGTD_TDREPORT_INIT, Some(false), - Some(quote.as_slice()), + Some(&init_tdreport), + ) + .map_err(|e| { + log::error!( + "gen_cert policy_v2 add_extension failed with error {:?}.\n", + e + ); + e + })?, + ) + .map_err(|e| { + log::error!( + "gen_cert policy_v2 add_extension for tdreport init failed with error {:?}.\n", + e + ); + e + })? + .add_extension( + Extension::new( + EXTNID_MIGTD_EVENT_LOG_INIT, + Some(false), + Some(&init_event_log), + ) + .map_err(|e| { + log::error!( + "gen_cert policy_v2 add_extension failed with error {:?}.\n", + e + ); + e + })?, + ) + .map_err(|e| { + log::error!( + "gen_cert policy_v2 add_extension for event log init failed with error {:?}.\n", + e + ); + e + })? + .add_extension( + Extension::new( + EXTNID_MIGTD_INIT_POLICY_HASH, + Some(false), + Some(&init_policy_hash), + ) + .map_err(|e| { + log::error!( + "gen_cert policy_v2 add_extension failed with error {:?}.\n", + e + ); + e + })?, + ) + .map_err(|e| { + log::error!( + "gen_cert policy_v2 add_extension for init policy hash failed with error {:?}.\n", + e + ); + e + })?; + + let x509_cert_der = sign_tls_tbs(x509_builder, &signing_key)?; + Ok(x509_cert_der) +} + +#[cfg(feature = "policy_v2")] +fn create_certificate_for_rebinding_new(signing_key: &EcdsaPk) -> Result> { + let pub_key = signing_key.public_key().map_err(|e| { + log::error!( + "gen_cert signing_key.public_key() failed with error {:?}\n", + e + ); + e + })?; + let tdreport = gen_tdreport(&pub_key).map_err(|e| { + log::error!("gen_cert gen_quote() failed with error {:?}\n", e); + e + })?; + + let policy_hash = { + let policy = get_policy().ok_or_else(|| { + log::error!( + "gen_cert rebinding old policy_v2 Failed to get migration policy for policy hash.\n" + ); + RatlsError::InvalidPolicy + })?; + digest_sha384(policy) + } + .map_err(|e| { + log::error!("gen_cert digest_sha384() failed with error {:?}\n", e); + e + })?; + + let eku = create_eku()?; + let key_usage = create_key_usage()?; + + let x509_builder = create_tls_tbs_common(&pub_key, &key_usage, &eku)? + .add_extension( + Extension::new( + EXTNID_MIGTD_TDREPORT, + Some(false), + Some(tdreport.as_bytes()), ) .map_err(|e| { log::error!( - "gen_cert Extension::new for EXTNID_MIGTD_QUOTE_REPORT failed with error {:?}\n", + "gen_cert Extension::new for EXTNID_MIGTD_TDREPORT failed with error {:?}\n", + e + ); + e + })?, + ) + .map_err(|e| { + log::error!( + "gen_cert add_extension for EXTNID_MIGTD_TDREPORT failed with error {:?}\n", + e + ); + e + })?; + + let x509_builder = x509_builder + .add_extension( + Extension::new(EXTNID_MIGTD_POLICY_HASH, Some(false), Some(&policy_hash)).map_err( + |e| { + log::error!( + "gen_cert policy_v2 add_extension failed with error {:?}.\n", + e + ); + e + }, + )?, + ) + .map_err(|e| { + log::error!( + "gen_cert policy_v2 add_extension for policy hash failed with error {:?}.\n", + e + ); + e + })?; + + let x509_cert_der = sign_tls_tbs(x509_builder, &signing_key)?; + Ok(x509_cert_der) +} + +fn create_tls_tbs_common<'a>( + public_key: &'a [u8], + key_usage: &'a [u8], + eku: &'a [u8], +) -> Result> { + let algorithm = AlgorithmIdentifier { + algorithm: ID_EC_PUBKEY_OID, + parameters: Some(AnyRef::new( + Tag::ObjectIdentifier, + SECP384R1_OID.as_bytes(), + )?), + }; + let sig_alg = AlgorithmIdentifier { + algorithm: ID_EC_SIG_OID, + parameters: None, + }; + + let event_log = get_event_log().ok_or_else(|| { + log::error!("gen_cert get_event_log() failed with error RatlsError::InvalidEventlog.\n"); + RatlsError::InvalidEventlog + })?; + + let x509_builder = CertificateBuilder::new(sig_alg, algorithm, public_key) + .map_err(|e| { + log::error!( + "gen_cert CertificateBuilder::new failed with error {:?}\n", + e + ); + e + })? + // 1970-01-01T00:00:00Z + .set_not_before(core::time::Duration::new(0, 0)) + .map_err(|e| { + log::error!("gen_cert set_not_before failed with error {:?}\n", e); + e + })? + // 9999-12-31T23:59:59Z + .set_not_after(core::time::Duration::new(253402300799, 0)) + .map_err(|e| { + log::error!("gen_cert set_not_after failed with error {:?}\n", e); + e + })? + .add_extension( + Extension::new(KEY_USAGE_EXTENSION, Some(true), Some(key_usage)).map_err(|e| { + log::error!( + "gen_cert Extension::new for KEY_USAGE_EXTENSION failed with error {:?}\n", + e + ); + e + })?, + ) + .map_err(|e| { + log::error!( + "gen_cert add_extension for KEY_USAGE_EXTENSION failed with error {:?}\n", + e + ); + e + })? + .add_extension( + Extension::new(EXTENDED_KEY_USAGE, Some(false), Some(eku)).map_err(|e| { + log::error!( + "gen_cert Extension::new for EXTENDED_KEY_USAGE failed with error {:?}\n", e ); e @@ -271,7 +719,7 @@ fn gen_cert(signing_key: &EcdsaPk) -> Result<(Vec, Vec)> { ) .map_err(|e| { log::error!( - "gen_cert add_extension for EXTNID_MIGTD_QUOTE_REPORT failed with error {:?}\n", + "gen_cert add_extension for EXTENDED_KEY_USAGE failed with error {:?}\n", e ); e @@ -293,28 +741,10 @@ fn gen_cert(signing_key: &EcdsaPk) -> Result<(Vec, Vec)> { e })?; - // If policy_v2 feature is enabled, add policy extension - #[cfg(feature = "policy_v2")] - let x509_builder = x509_builder - .add_extension( - Extension::new(EXTNID_MIGTD_POLICY_HASH, Some(false), Some(&policy_hash)).map_err( - |e| { - log::error!( - "gen_cert policy_v2 add_extension failed with error {:?}.\n", - e - ); - e - }, - )?, - ) - .map_err(|e| { - log::error!( - "gen_cert policy_v2 add_extension for policy hash failed with error {:?}.\n", - e - ); - e - })?; + Ok(x509_builder) +} +fn sign_tls_tbs(x509_builder: CertificateBuilder, signing_key: &EcdsaPk) -> Result> { let mut x509_certificate = x509_builder.build(); let tbs = x509_certificate.tbs_certificate.to_der().map_err(|e| { log::error!( @@ -335,36 +765,41 @@ fn gen_cert(signing_key: &EcdsaPk) -> Result<(Vec, Vec)> { e })?; - Ok(( - x509_certificate.to_der().map_err(|e| { - log::error!( - "gen_cert x509_certificate.to_der failed with error {:?}.\n", - e - ); + Ok(x509_certificate.to_der().map_err(|e| { + log::error!( + "gen_cert x509_certificate.to_der failed with error {:?}.\n", e - })?, - quote, - )) -} - -fn gen_quote(public_key: &[u8]) -> Result> { - let hash = digest_sha384(public_key).map_err(|e| { - log::error!("Failed to compute SHA384 digest: {:?}\n", e); + ); e - })?; + })?) +} - // Generate the TD Report that contains the public key hash as nonce - let mut additional_data = [0u8; 64]; - additional_data[..hash.len()].copy_from_slice(hash.as_ref()); - let td_report = tdx_tdcall::tdreport::tdcall_report(&additional_data).map_err(|e| { - log::error!("Failed to get TD report via tdcall. Error: {:?}\n", e); - e - })?; +fn create_eku() -> Result> { + Ok(vec![SERVER_AUTH, CLIENT_AUTH, MIGTD_EXTENDED_KEY_USAGE] + .to_der() + .map_err(|e| { + log::error!("gen_cert to_der failed with error {:?}\n", e); + e + })?) +} - attestation::get_quote(td_report.as_bytes()).map_err(|e| { - log::error!("Failed to get quote from TD report. Error: {:?}\n", e); - RatlsError::GetQuote - }) +fn create_key_usage() -> Result> { + Ok(BitStringRef::from_bytes(&[0x80]) + .map_err(|e| { + log::error!( + "gen_cert BitStringRef::from_bytes() failed with error {:?}\n", + e + ); + e + })? + .to_der() + .map_err(|e| { + log::error!( + "gen_cert BitStringRef::to_der() failed with error {:?}\n", + e + ); + e + })?) } fn verify_server_cert(cert: &[u8], quote: &[u8]) -> core::result::Result<(), CryptoError> { @@ -384,6 +819,7 @@ mod verify { use crypto::ecdsa::ecdsa_verify; use crypto::{Error as CryptoError, Result as CryptoResult}; use policy::PolicyError; + use tdx_tdcall::tdreport::TdxReport; #[cfg(not(feature = "policy_v2"))] pub fn verify_peer_cert( @@ -523,6 +959,194 @@ mod verify { verify_signature(&cert, suppl_data.as_slice()) } + #[cfg(feature = "policy_v2")] + pub fn verify_rebinding_old_cert( + cert: &[u8], + pre_session_data: &[u8], + ) -> core::result::Result<(), CryptoError> { + let cert = Certificate::from_der(cert).map_err(|_| { + log::error!("Failed to parse certificate from DER.\n"); + CryptoError::ParseCertificate + })?; + + let extensions = cert.tbs_certificate.extensions.as_ref().ok_or_else(|| { + log::error!("Failed to get certificate extensions.\n"); + CryptoError::ParseCertificate + })?; + // Check if extensions contain `MIGTD_EXTENDED_KEY_USAGE` + check_migtd_eku(extensions).map_err(|e| { + log::error!("Failed to check MIGTD EKU: {:?}\n", e); + e + })?; + + let td_report = find_extension(extensions, &EXTNID_MIGTD_TDREPORT).ok_or_else(|| { + log::error!("Failed to find tdreport extension.\n"); + CryptoError::ParseCertificate + })?; + let event_log = find_extension(extensions, &EXTNID_MIGTD_EVENT_LOG).ok_or_else(|| { + log::error!("Failed to find event log extension.\n"); + CryptoError::ParseCertificate + })?; + let expected_policy_hash = find_extension(extensions, &EXTNID_MIGTD_POLICY_HASH) + .ok_or_else(|| { + log::error!("Failed to find expected policy hash extension.\n"); + CryptoError::ParseCertificate + })?; + let init_td_report = + find_extension(extensions, &EXTNID_MIGTD_TDREPORT_INIT).ok_or_else(|| { + log::error!("Failed to find init tdreport extension.\n"); + CryptoError::ParseCertificate + })?; + let init_event_log = + find_extension(extensions, &EXTNID_MIGTD_EVENT_LOG_INIT).ok_or_else(|| { + log::error!("Failed to find init event log extension.\n"); + CryptoError::ParseCertificate + })?; + let init_policy_hash = find_extension(extensions, &EXTNID_MIGTD_INIT_POLICY_HASH) + .ok_or_else(|| { + log::error!("Failed to find init policy hash extension.\n"); + CryptoError::ParseCertificate + })?; + let servtd_ext = find_extension(extensions, &EXTNID_MIGTD_SERVTD_EXT).ok_or_else(|| { + log::error!("Failed to find servtd ext extension.\n"); + CryptoError::ParseCertificate + })?; + + let remote_policy_size = u32::from_le_bytes( + pre_session_data + .get(..4) + .ok_or(CryptoError::TlsVerifyPeerCert( + INVALID_MIG_POLICY_ERROR.to_string(), + ))? + .try_into() + .unwrap(), + ) as usize; + let remote_policy = pre_session_data.get(4..4 + remote_policy_size).ok_or( + CryptoError::TlsVerifyPeerCert(INVALID_MIG_POLICY_ERROR.to_string()), + )?; + let init_policy_offset = 4 + remote_policy_size; + let init_policy_size = u32::from_le_bytes( + pre_session_data + .get(init_policy_offset..4 + init_policy_offset) + .ok_or(CryptoError::TlsVerifyPeerCert( + INVALID_MIG_POLICY_ERROR.to_string(), + ))? + .try_into() + .unwrap(), + ) as usize; + let init_policy = pre_session_data + .get(init_policy_offset + 4..init_policy_offset + 4 + init_policy_size) + .ok_or(CryptoError::TlsVerifyPeerCert( + INVALID_MIG_POLICY_ERROR.to_string(), + ))?; + let exact_policy_hash = digest_sha384(remote_policy)?; + if expected_policy_hash != exact_policy_hash.as_slice() { + log::error!("Invalid rebinding policy.\n"); + return Err(CryptoError::TlsVerifyPeerCert( + INVALID_MIG_POLICY_ERROR.to_string(), + )); + } + let exact_init_policy_hash = digest_sha384(init_policy)?; + if init_policy_hash != exact_init_policy_hash.as_slice() { + log::error!("Invalid init rebinding policy.\n"); + return Err(CryptoError::TlsVerifyPeerCert( + INVALID_MIG_POLICY_ERROR.to_string(), + )); + } + + let policy_check_result = mig_policy::authenticate_rebinding_old( + td_report, + event_log, + remote_policy, + init_policy, + init_event_log, + init_td_report, + servtd_ext, + ); + + if let Err(e) = &policy_check_result { + log::error!("Policy check failed, below is the detail information:\n"); + log::error!("{:x?}\n", e); + } + + let suppl_data = policy_check_result.map_err(|e| match e { + PolicyError::InvalidPolicy => { + log::error!("Invalid rebinding policy.\n"); + CryptoError::TlsVerifyPeerCert(INVALID_MIG_POLICY_ERROR.to_string()) + } + _ => { + log::error!("Rebinding policy unsatisfied.\n"); + CryptoError::TlsVerifyPeerCert(MIG_POLICY_UNSATISFIED_ERROR.to_string()) + } + })?; + + verify_signature_with_tdreport(&cert, suppl_data.as_slice()) + } + + #[cfg(feature = "policy_v2")] + pub fn verify_rebinding_new_cert( + cert: &[u8], + policy: &[u8], + ) -> core::result::Result<(), CryptoError> { + let cert = Certificate::from_der(cert).map_err(|_| { + log::error!("Failed to parse certificate from DER.\n"); + CryptoError::ParseCertificate + })?; + + let extensions = cert.tbs_certificate.extensions.as_ref().ok_or_else(|| { + log::error!("Failed to get certificate extensions.\n"); + CryptoError::ParseCertificate + })?; + // Check if extensions contain `MIGTD_EXTENDED_KEY_USAGE` + check_migtd_eku(extensions).map_err(|e| { + log::error!("Failed to check MIGTD EKU: {:?}\n", e); + e + })?; + + let td_report = find_extension(extensions, &EXTNID_MIGTD_TDREPORT).ok_or_else(|| { + log::error!("Failed to find quote report extension.\n"); + CryptoError::ParseCertificate + })?; + let event_log = find_extension(extensions, &EXTNID_MIGTD_EVENT_LOG).ok_or_else(|| { + log::error!("Failed to find event log extension.\n"); + CryptoError::ParseCertificate + })?; + let expected_policy_hash = find_extension(extensions, &EXTNID_MIGTD_POLICY_HASH) + .ok_or_else(|| { + log::error!("Failed to find expected policy hash extension.\n"); + CryptoError::ParseCertificate + })?; + + let exact_policy_hash = digest_sha384(policy)?; + if expected_policy_hash != exact_policy_hash.as_slice() { + log::error!("Invalid migration policy.\n"); + return Err(CryptoError::TlsVerifyPeerCert( + INVALID_MIG_POLICY_ERROR.to_string(), + )); + } + + let policy_check_result = + mig_policy::authenticate_rebinding_new(td_report, event_log, policy); + + if let Err(e) = &policy_check_result { + log::error!("Policy check failed, below is the detail information:\n"); + log::error!("{:x?}\n", e); + } + + let suppl_data = policy_check_result.map_err(|e| match e { + PolicyError::InvalidPolicy => { + log::error!("Invalid rebinding policy.\n"); + CryptoError::TlsVerifyPeerCert(INVALID_MIG_POLICY_ERROR.to_string()) + } + _ => { + log::error!("Rebinding policy unsatisfied.\n"); + CryptoError::TlsVerifyPeerCert(MIG_POLICY_UNSATISFIED_ERROR.to_string()) + } + })?; + + verify_signature_with_tdreport(&cert, suppl_data.as_slice()) + } + fn verify_signature(cert: &Certificate, verified_report: &[u8]) -> CryptoResult<()> { let public_key = cert .tbs_certificate @@ -548,6 +1172,32 @@ mod verify { ecdsa_verify(public_key, &tbs, signature) } + #[cfg(feature = "policy_v2")] + fn verify_signature_with_tdreport(cert: &Certificate, tdreport: &[u8]) -> CryptoResult<()> { + let public_key = cert + .tbs_certificate + .subject_public_key_info + .subject_public_key + .as_bytes() + .ok_or_else(|| { + log::error!("Failed to get public key bytes from certificate.\n"); + CryptoError::ParseCertificate + })?; + let tbs = cert.tbs_certificate.to_der().map_err(|e| { + log::error!("Failed to get tbs_certificate der: {:?}\n", e); + e + })?; + let signature = cert.signature_value.as_bytes().ok_or_else(|| { + log::error!("Failed to get signature bytes from certificate.\n"); + CryptoError::ParseCertificate + })?; + verify_public_key_with_tdreport(tdreport, public_key).map_err(|e| { + log::error!("Public key verification failed: {:?}\n", e); + e + })?; + ecdsa_verify(public_key, &tbs, signature) + } + fn verify_public_key(verified_report: &[u8], public_key: &[u8]) -> CryptoResult<()> { if cfg!(feature = "AzCVMEmu") { // In AzCVMEmu mode, REPORTDATA is constructed differently. @@ -574,6 +1224,37 @@ mod verify { )) } } + + #[cfg(feature = "policy_v2")] + fn verify_public_key_with_tdreport(tdreport: &[u8], public_key: &[u8]) -> CryptoResult<()> { + if cfg!(feature = "AzCVMEmu") { + // In AzCVMEmu mode, REPORTDATA is constructed differently. + // Bypass public key hash check in this development environment. + log::warn!( + "AzCVMEmu mode: Skipping public key verification in TD report. This is NOT secure for production use.\n" + ); + return Ok(()); + } + const PUBLIC_KEY_HASH_SIZE: usize = 48; + + let tdx_report = TdxReport::read_from_bytes(tdreport).ok_or( + CryptoError::TlsVerifyPeerCert(MISMATCH_PUBLIC_KEY.to_string()), + )?; + let report_data = &tdx_report.report_mac.report_data[..PUBLIC_KEY_HASH_SIZE]; + let digest = digest_sha384(public_key).map_err(|e| { + log::error!("Failed to compute SHA384 digest: {:?}\n", e); + e + })?; + + if report_data == digest.as_slice() { + Ok(()) + } else { + log::error!("Public key verification failed in TD report.\n"); + Err(CryptoError::TlsVerifyPeerCert( + MISMATCH_PUBLIC_KEY.to_string(), + )) + } + } } // Only for test to bypass the quote verification @@ -605,6 +1286,20 @@ mod verify { // success for test purpose. Ok(()) } + + pub fn verify_rebinding_old_cert( + cert: &[u8], + pre_session_data: &[u8], + ) -> core::result::Result<(), CryptoError> { + Ok(()) + } + + pub fn verify_rebinding_new_cert( + cert: &[u8], + policy: &[u8], + ) -> core::result::Result<(), CryptoError> { + Ok(()) + } } fn check_migtd_eku(extensions: &Extensions) -> core::result::Result<(), CryptoError> { @@ -626,7 +1321,10 @@ fn check_migtd_eku(extensions: &Extensions) -> core::result::Result<(), CryptoEr Err(CryptoError::ParseCertificate) } -fn find_extension<'a>(extensions: &'a Extensions, id: &ObjectIdentifier) -> Option<&'a [u8]> { +pub(crate) fn find_extension<'a>( + extensions: &'a Extensions, + id: &ObjectIdentifier, +) -> Option<&'a [u8]> { extensions.get().iter().find_map(|extn| { if &extn.extn_id == id { extn.extn_value.map(|v| v.as_bytes()) diff --git a/src/policy/src/lib.rs b/src/policy/src/lib.rs index 51791f3d..a68c08cb 100644 --- a/src/policy/src/lib.rs +++ b/src/policy/src/lib.rs @@ -27,6 +27,7 @@ pub enum PolicyError { InvalidParameter, InvalidPolicy, InvalidEventLog, + InvalidTdReport, PlatformNotFound(String), PlatformNotMatch(String, String), UnqualifiedPlatformInfo, @@ -47,6 +48,7 @@ pub enum PolicyError { CrlEvaluation, HashCalculation, QuoteVerification, + TdReportVerification, QuoteGeneration, GetTdxReport, } diff --git a/src/policy/src/v2/policy.rs b/src/policy/src/v2/policy.rs index 40ca7e30..0928e6cc 100644 --- a/src/policy/src/v2/policy.rs +++ b/src/policy/src/v2/policy.rs @@ -133,9 +133,12 @@ impl PartialEq for ServtdTcbStatus { impl Eq for ServtdTcbStatus {} -/// Contains all required data to be evaluated against a policy +/// Contains all required data to be evaluated against a rebinding policy #[derive(Debug, Clone, Default)] pub struct PolicyEvaluationInfo { + /// The TEE_TCB_SVN of MigTD + pub tee_tcb_svn: Option<[u8; 16]>, + /// The date of the Trusted Computing Base (TCB) in ISO-8601 format, e.g. "2023-06-19T00:00:00Z" pub tcb_date: Option, @@ -846,6 +849,7 @@ mod test { let global = include_str!("../../test/policy_v2/global.json"); let global_policy = serde_json::from_str::(global).unwrap(); let mut value = PolicyEvaluationInfo { + tee_tcb_svn: None, tcb_date: Some("2025-09-01T00:00:00Z".to_string()), tcb_status: Some("UpToDate".to_string()), tcb_evaluation_number: Some(15), diff --git a/src/policy/src/v2/servtd_collateral.rs b/src/policy/src/v2/servtd_collateral.rs index 6434c3f8..e746ac3c 100644 --- a/src/policy/src/v2/servtd_collateral.rs +++ b/src/policy/src/v2/servtd_collateral.rs @@ -161,6 +161,22 @@ pub struct Measurements { } impl Measurements { + pub fn new_from_bytes( + mrtd: &[u8], + rtmr0: &[u8], + rtmr1: &[u8], + rtmr2: Option<&[u8]>, + rtmr3: Option<&[u8]>, + ) -> Self { + Measurements { + mrtd: bytes_to_hex_string(mrtd), + rtmr0: bytes_to_hex_string(rtmr0), + rtmr1: bytes_to_hex_string(rtmr1), + rtmr2: rtmr2.map(|b| bytes_to_hex_string(b)), + rtmr3: rtmr3.map(|b| bytes_to_hex_string(b)), + } + } + fn to_ascii_uppercase(&self) -> Self { Measurements { mrtd: self.mrtd.to_ascii_uppercase(),