diff --git a/README.md b/README.md index d1a14ad..fed8efd 100644 --- a/README.md +++ b/README.md @@ -179,6 +179,26 @@ spatialbench-cli --scale-factor 1 --mb-per-file 256 --output-dir sf1-parquet spatialbench-cli --scale-factor 10 --mb-per-file 256 --output-dir sf10-parquet ``` +#### Generate Data Directly to S3 + +You can generate data directly to Amazon S3 or S3-compatible storage by providing an S3 URI as the output directory: + +```bash +# Set AWS credentials +export AWS_ACCESS_KEY_ID="your-access-key" +export AWS_SECRET_ACCESS_KEY="your-secret-key" +export AWS_REGION="us-west-2" # Must match your bucket's region + +# Generate to S3 +spatialbench-cli --scale-factor 10 --mb-per-file 256 --output-dir s3://my-bucket/spatialbench/sf10 + +# For S3-compatible services (MinIO, etc.) +export AWS_ENDPOINT="http://localhost:9000" +spatialbench-cli --scale-factor 1 --output-dir s3://my-bucket/data +``` + +The S3 writer uses streaming multipart upload, buffering data in 32MB chunks before uploading parts. This ensures memory-efficient generation even for large datasets. All output formats (Parquet, CSV, TBL) are supported, and the generated files are byte-for-byte identical to local generation. + #### Custom Spider Configuration You can override these defaults at runtime by passing a YAML file via the `--config` flag: diff --git a/docs/datasets-generators.md b/docs/datasets-generators.md index 5392ebe..c007f1b 100644 --- a/docs/datasets-generators.md +++ b/docs/datasets-generators.md @@ -106,6 +106,14 @@ You can generate the tables for Scale Factor 1 with the following command: spatialbench-cli -s 1 --format=parquet --output-dir sf1-parquet ``` +You can also generate data directly to Amazon S3 by providing an S3 URI: + +``` +spatialbench-cli -s 1 --format=parquet --output-dir s3://my-bucket/sf1-parquet +``` + +See the [Quickstart](quickstart.md#generate-data-directly-to-s3) for details on configuring AWS credentials. + Here are the contents of the `sf1-parquet` directory: * `building.parquet` diff --git a/docs/quickstart.md b/docs/quickstart.md index a4f75bb..f6dea28 100644 --- a/docs/quickstart.md +++ b/docs/quickstart.md @@ -84,6 +84,26 @@ spatialbench-cli --scale-factor 10 --mb-per-file 512 spatialbench-cli --scale-factor 1 --output-dir data/sf1 ``` +### Generate Data Directly to S3 + +You can generate data directly to Amazon S3 or S3-compatible storage by providing an S3 URI as the output directory: + +```shell +# Set AWS credentials +export AWS_ACCESS_KEY_ID="your-access-key" +export AWS_SECRET_ACCESS_KEY="your-secret-key" +export AWS_REGION="us-west-2" # Must match your bucket's region + +# Generate to S3 +spatialbench-cli --scale-factor 10 --mb-per-file 256 --output-dir s3://my-bucket/spatialbench/sf10 + +# For S3-compatible services (MinIO, etc.) +export AWS_ENDPOINT="http://localhost:9000" +spatialbench-cli --scale-factor 1 --output-dir s3://my-bucket/data +``` + +The S3 writer uses streaming multipart upload, buffering data in 32 MB chunks before uploading parts. All standard AWS environment variables are supported, including `AWS_SESSION_TOKEN` for temporary credentials. + ## Configuring Spatial Distributions SpatialBench uses a spatial data generator to generate synthetic points and polygons using realistic spatial distributions. diff --git a/spatialbench-cli/Cargo.toml b/spatialbench-cli/Cargo.toml index 8c1ff93..c98f08a 100644 --- a/spatialbench-cli/Cargo.toml +++ b/spatialbench-cli/Cargo.toml @@ -43,10 +43,11 @@ serde = { version = "1.0.219", features = ["derive"] } anyhow = "1.0.99" serde_yaml = "0.9.33" datafusion = "50.2" -object_store = { version = "0.12.4", features = ["http"] } +object_store = { version = "0.12.4", features = ["http", "aws"] } arrow-array = "56" arrow-schema = "56" url = "2.5.7" +bytes = "1.10.1" [dev-dependencies] assert_cmd = "2.0" diff --git a/spatialbench-cli/src/main.rs b/spatialbench-cli/src/main.rs index 425e5d7..705706c 100644 --- a/spatialbench-cli/src/main.rs +++ b/spatialbench-cli/src/main.rs @@ -28,6 +28,7 @@ mod output_plan; mod parquet; mod plan; mod runner; +mod s3_writer; mod spatial_config_file; mod statistics; mod tbl; @@ -252,8 +253,9 @@ impl Cli { debug!("Logging configured from environment variables"); } - // Create output directory if it doesn't exist and we are not writing to stdout. - if !self.stdout { + // Create output directory if it doesn't exist and we are not writing to stdout + // or to S3 (where local directories are meaningless). + if !self.stdout && !self.output_dir.to_string_lossy().starts_with("s3://") { fs::create_dir_all(&self.output_dir)?; } @@ -386,21 +388,26 @@ impl Cli { } } -impl IntoSize for BufWriter { - fn into_size(self) -> Result { - // we can't get the size of stdout, so just return 0 +impl AsyncFinalize for BufWriter { + async fn finalize(self) -> Result { Ok(0) } } -impl IntoSize for BufWriter { - fn into_size(self) -> Result { +impl AsyncFinalize for BufWriter { + async fn finalize(self) -> Result { let file = self.into_inner()?; let metadata = file.metadata()?; Ok(metadata.len() as usize) } } +impl AsyncFinalize for s3_writer::S3Writer { + async fn finalize(self) -> Result { + self.finish().await + } +} + /// Wrapper around a buffer writer that counts the number of buffers and bytes written struct WriterSink { statistics: WriteStatistics, diff --git a/spatialbench-cli/src/output_plan.rs b/spatialbench-cli/src/output_plan.rs index cdab006..ac60014 100644 --- a/spatialbench-cli/src/output_plan.rs +++ b/spatialbench-cli/src/output_plan.rs @@ -20,21 +20,33 @@ //! * [`OutputPlanGenerator`]: plans the output files to be generated use crate::plan::GenerationPlan; +use crate::s3_writer::{build_s3_client, parse_s3_uri}; use crate::{OutputFormat, Table}; use log::debug; +use object_store::ObjectStore; use parquet::basic::Compression; use std::collections::HashSet; use std::fmt::{Display, Formatter}; use std::io; use std::path::PathBuf; +use std::sync::Arc; /// Where a partition will be output -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone)] pub enum OutputLocation { /// Output to a file File(PathBuf), /// Output to stdout Stdout, + /// Output to S3 with a shared client + S3 { + /// The full S3 URI for this object (e.g. `s3://bucket/path/to/file.parquet`) + uri: String, + /// The object path within the bucket (e.g. `path/to/file.parquet`) + path: String, + /// Shared S3 client for the bucket + client: Arc, + }, } impl Display for OutputLocation { @@ -48,12 +60,13 @@ impl Display for OutputLocation { write!(f, "{}", file.to_string_lossy()) } OutputLocation::Stdout => write!(f, "Stdout"), + OutputLocation::S3 { uri, .. } => write!(f, "{}", uri), } } } /// Describes an output partition (file) that will be generated -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone)] pub struct OutputPlan { /// The table table: Table, @@ -151,6 +164,8 @@ pub struct OutputPlanGenerator { /// Output directories that have been created so far /// (used to avoid creating the same directory multiple times) created_directories: HashSet, + /// Shared S3 client, lazily created on first S3 output location + s3_client: Option>, } impl OutputPlanGenerator { @@ -171,6 +186,7 @@ impl OutputPlanGenerator { output_dir, output_plans: Vec::new(), created_directories: HashSet::new(), + s3_client: None, } } @@ -282,17 +298,48 @@ impl OutputPlanGenerator { OutputFormat::Parquet => "parquet", }; - let mut output_path = self.output_dir.clone(); - if let Some(part) = part { - // If a partition is specified, create a subdirectory for it - output_path.push(table.to_string()); - self.ensure_directory_exists(&output_path)?; - output_path.push(format!("{table}.{part}.{extension}")); + // Check if output_dir is an S3 URI + let output_dir_str = self.output_dir.to_string_lossy(); + if output_dir_str.starts_with("s3://") { + // Handle S3 path + let base_uri = output_dir_str.trim_end_matches('/'); + let s3_uri = if let Some(part) = part { + format!("{base_uri}/{table}/{table}.{part}.{extension}") + } else { + format!("{base_uri}/{table}.{extension}") + }; + + // Lazily build the S3 client on first use, then reuse it + let client = if let Some(ref client) = self.s3_client { + Arc::clone(client) + } else { + let (bucket, _) = parse_s3_uri(&s3_uri)?; + let client = build_s3_client(&bucket)?; + self.s3_client = Some(Arc::clone(&client)); + client + }; + + let (_, path) = parse_s3_uri(&s3_uri)?; + + Ok(OutputLocation::S3 { + uri: s3_uri, + path, + client, + }) } else { - // No partition specified, output to a single file - output_path.push(format!("{table}.{extension}")); + // Handle local filesystem path + let mut output_path = self.output_dir.clone(); + if let Some(part) = part { + // If a partition is specified, create a subdirectory for it + output_path.push(table.to_string()); + self.ensure_directory_exists(&output_path)?; + output_path.push(format!("{table}.{part}.{extension}")); + } else { + // No partition specified, output to a single file + output_path.push(format!("{table}.{extension}")); + } + Ok(OutputLocation::File(output_path)) } - Ok(OutputLocation::File(output_path)) } } diff --git a/spatialbench-cli/src/parquet.rs b/spatialbench-cli/src/parquet.rs index 45ffcbd..fa9d84f 100644 --- a/spatialbench-cli/src/parquet.rs +++ b/spatialbench-cli/src/parquet.rs @@ -33,9 +33,17 @@ use std::io::Write; use std::sync::Arc; use tokio::sync::mpsc::{Receiver, Sender}; -pub trait IntoSize { - /// Convert the object into a size - fn into_size(self) -> Result; +/// Finalize a writer after all Parquet data has been written. +/// +/// This is called from the async context (outside `spawn_blocking`) so +/// that implementations like [`S3Writer`](crate::s3_writer::S3Writer) can +/// `.await` their upload without competing with the tokio runtime for +/// threads — avoiding deadlocks under concurrent plans. +/// +/// For local files and stdout the implementation is trivially synchronous. +pub trait AsyncFinalize: Write + Send + 'static { + /// Finalize the writer and return the total bytes written. + fn finalize(self) -> impl std::future::Future> + Send; } /// Converts a set of RecordBatchIterators into a Parquet file @@ -44,7 +52,7 @@ pub trait IntoSize { /// /// Note the input is an iterator of [`RecordBatchIterator`]; The batches /// produced by each iterator is encoded as its own row group. -pub async fn generate_parquet( +pub async fn generate_parquet( writer: W, iter_iter: I, num_threads: usize, @@ -119,9 +127,8 @@ where row_group_writer.close().unwrap(); statistics.increment_chunks(1); } - let size = writer.into_inner()?.into_size()?; - statistics.increment_bytes(size); - Ok(()) as Result<(), io::Error> + let inner = writer.into_inner()?; + Ok((inner, statistics)) as Result<(W, WriteStatistics), io::Error> }); // now, drive the input stream and send results to the writer task @@ -135,8 +142,14 @@ where // signal the writer task that we are done drop(tx); - // Wait for the writer task to finish - writer_task.await??; + // Wait for the blocking writer task to return the underlying writer + let (inner, mut statistics) = writer_task.await??; + + // Finalize in the async context so S3 uploads can .await without + // competing for tokio runtime threads (prevents deadlock under + // concurrent plans). + let size = inner.finalize().await?; + statistics.increment_bytes(size); Ok(()) } diff --git a/spatialbench-cli/src/plan.rs b/spatialbench-cli/src/plan.rs index 0235439..f0d8c5e 100644 --- a/spatialbench-cli/src/plan.rs +++ b/spatialbench-cli/src/plan.rs @@ -77,6 +77,16 @@ pub struct GenerationPlan { pub const DEFAULT_PARQUET_ROW_GROUP_BYTES: i64 = 128 * 1024 * 1024; +/// Buffer size for Parquet writing (32MB) +/// +/// This buffer size is used for: +/// - Local file writing with BufWriter +/// - S3 multipart upload parts +/// +/// The 32MB size provides good performance and is well above the AWS S3 +/// minimum part size requirement of 5MB for multipart uploads. +pub const PARQUET_BUFFER_SIZE: usize = 32 * 1024 * 1024; + impl GenerationPlan { /// Returns a GenerationPlan number of parts to generate /// @@ -207,7 +217,7 @@ impl GenerationPlan { }) } - /// Return the number of part(ititions) this plan will generate + /// Return the number of part(ition)s this plan will generate pub fn chunk_count(&self) -> usize { self.part_list.clone().count() } diff --git a/spatialbench-cli/src/runner.rs b/spatialbench-cli/src/runner.rs index 5d8a8ac..9d49c7e 100644 --- a/spatialbench-cli/src/runner.rs +++ b/spatialbench-cli/src/runner.rs @@ -21,6 +21,7 @@ use crate::csv::*; use crate::generate::{generate_in_chunks, Source}; use crate::output_plan::{OutputLocation, OutputPlan}; use crate::parquet::generate_parquet; +use crate::s3_writer::S3Writer; use crate::tbl::*; use crate::{OutputFormat, Table, WriterSink}; use log::{debug, info}; @@ -32,6 +33,7 @@ use spatialbench_arrow::{ }; use std::io; use std::io::BufWriter; +use std::sync::Arc; use tokio::task::{JoinError, JoinSet}; /// Runs multiple [`OutputPlan`]s in parallel, managing the number of threads @@ -218,6 +220,12 @@ where })?; Ok(()) } + OutputLocation::S3 { uri, path, client } => { + info!("Writing to S3: {}", uri); + let s3_writer = S3Writer::with_client(Arc::clone(client), path); + let sink = WriterSink::new(s3_writer); + generate_in_chunks(sink, sources, num_threads).await + } } } @@ -228,7 +236,7 @@ where { match plan.output_location() { OutputLocation::Stdout => { - let writer = BufWriter::with_capacity(32 * 1024 * 1024, io::stdout()); // 32MB buffer + let writer = BufWriter::with_capacity(crate::plan::PARQUET_BUFFER_SIZE, io::stdout()); generate_parquet(writer, sources, num_threads, plan.parquet_compression()).await } OutputLocation::File(path) => { @@ -242,7 +250,7 @@ where let file = std::fs::File::create(&temp_path).map_err(|err| { io::Error::other(format!("Failed to create {temp_path:?}: {err}")) })?; - let writer = BufWriter::with_capacity(32 * 1024 * 1024, file); // 32MB buffer + let writer = BufWriter::with_capacity(crate::plan::PARQUET_BUFFER_SIZE, file); generate_parquet(writer, sources, num_threads, plan.parquet_compression()).await?; // rename the temp file to the final path std::fs::rename(&temp_path, path).map_err(|e| { @@ -252,6 +260,11 @@ where })?; Ok(()) } + OutputLocation::S3 { uri, path, client } => { + info!("Writing parquet to S3: {}", uri); + let s3_writer = S3Writer::with_client(Arc::clone(client), path); + generate_parquet(s3_writer, sources, num_threads, plan.parquet_compression()).await + } } } diff --git a/spatialbench-cli/src/s3_writer.rs b/spatialbench-cli/src/s3_writer.rs new file mode 100644 index 0000000..bbc3232 --- /dev/null +++ b/spatialbench-cli/src/s3_writer.rs @@ -0,0 +1,536 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! S3 writer that streams multipart uploads instead of buffering in memory. +//! +//! Data is buffered in 32 MB chunks (matching [`PARQUET_BUFFER_SIZE`]). When a +//! chunk is full it is sent through an [`mpsc`] channel to a background tokio +//! task that uploads it immediately via `MultipartUpload::put_part`. This +//! keeps peak memory usage roughly constant regardless of total file size. +//! +//! [`mpsc`]: tokio::sync::mpsc + +use crate::plan::PARQUET_BUFFER_SIZE; +use bytes::Bytes; +use log::{debug, info}; +use object_store::aws::AmazonS3Builder; +use object_store::path::Path as ObjectPath; +use object_store::ObjectStore; +use std::io::{self, Write}; +use std::sync::Arc; +use tokio::sync::mpsc; +use tokio::sync::oneshot; +use url::Url; + +/// Parse an S3 URI into its (bucket, path) components. +/// +/// The URI should be in the format: `s3://bucket/path/to/object` +pub fn parse_s3_uri(uri: &str) -> Result<(String, String), io::Error> { + let url = Url::parse(uri).map_err(|e| { + io::Error::new( + io::ErrorKind::InvalidInput, + format!("Invalid S3 URI: {}", e), + ) + })?; + + if url.scheme() != "s3" { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + format!("Expected s3:// URI, got: {}", url.scheme()), + )); + } + + let bucket = url + .host_str() + .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "S3 URI missing bucket name"))? + .to_string(); + + let path = url.path().trim_start_matches('/').to_string(); + + Ok((bucket, path)) +} + +/// Build an S3 [`ObjectStore`] client for the given bucket using environment variables. +/// +/// Uses [`AmazonS3Builder::from_env`] which reads all standard AWS environment +/// variables including `AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY`, +/// `AWS_DEFAULT_REGION`, `AWS_REGION`, `AWS_SESSION_TOKEN`, `AWS_ENDPOINT`, etc. +pub fn build_s3_client(bucket: &str) -> Result, io::Error> { + debug!("Building S3 client for bucket: {}", bucket); + let client = AmazonS3Builder::from_env() + .with_bucket_name(bucket) + .build() + .map_err(|e| io::Error::other(format!("Failed to create S3 client: {}", e)))?; + info!("S3 client created successfully for bucket: {}", bucket); + Ok(Arc::new(client)) +} + +/// Message sent from the writer thread to the background upload task. +enum UploadMessage { + /// A completed part ready for upload. + Part(Bytes), + /// All parts have been sent; the upload should be completed. + Finish, +} + +/// A writer that streams data to S3 via multipart upload. +/// +/// Internally, a background tokio task is spawned that starts the multipart +/// upload eagerly and uploads each part as it arrives through a channel. +/// The [`Write`] implementation buffers data in 32 MB chunks and sends +/// completed chunks to the background task via [`mpsc::Sender::blocking_send`] +/// (safe because all callers run inside [`tokio::task::spawn_blocking`]). +/// +/// On [`finish`](S3Writer::finish), any remaining buffered data is sent as the +/// final part, the channel is closed, and we wait for the background task to +/// call `complete()` on the multipart upload. If any part upload fails, the +/// multipart upload is aborted to avoid orphaned uploads accruing S3 storage +/// costs. +/// +/// For small files (< 5 MB total) a simple PUT is used instead of multipart. +pub struct S3Writer { + /// The S3 client (kept for the small-file PUT fallback) + client: Arc, + /// The object path in S3 + path: ObjectPath, + /// Current buffer for accumulating data before sending as a part + buffer: Vec, + /// Total bytes written through [`Write::write`] + total_bytes: usize, + /// Channel to send parts to the background upload task. + /// + /// Set to `None` after the first part is sent (at which point the + /// background task is spawned and this is replaced by `upload_tx`). + /// Before any parts are sent this is `None` and parts accumulate in + /// `pending_parts` for the small-file optimization. + upload_tx: Option>, + /// Receives the final result (total bytes) from the background upload task. + result_rx: Option>>, + /// Parts accumulated before we decide whether to use simple PUT or + /// multipart upload. Once we exceed [`PARQUET_BUFFER_SIZE`] total, we switch + /// to the streaming multipart path. + pending_parts: Vec, + /// Whether the streaming multipart upload has been started. + multipart_started: bool, +} + +impl S3Writer { + /// Create a new S3 writer for the given S3 URI, building a fresh client. + /// + /// Prefer [`S3Writer::with_client`] when writing multiple files to reuse + /// the same client. + /// + /// Authentication is handled through standard AWS environment variables + /// via [`AmazonS3Builder::from_env`]. + pub fn new(uri: &str) -> Result { + let (bucket, path) = parse_s3_uri(uri)?; + let client = build_s3_client(&bucket)?; + Ok(Self::with_client(client, &path)) + } + + /// Create a new S3 writer using an existing [`ObjectStore`] client. + /// + /// This avoids creating a new client per file, which is important when + /// generating many partitioned files. + pub fn with_client(client: Arc, path: &str) -> Self { + debug!("Creating S3 writer for path: {}", path); + Self { + client, + path: ObjectPath::from(path), + buffer: Vec::with_capacity(PARQUET_BUFFER_SIZE), + total_bytes: 0, + upload_tx: None, + result_rx: None, + pending_parts: Vec::new(), + multipart_started: false, + } + } + + /// Start the background multipart upload task, draining any pending parts. + /// + /// This is called lazily when we accumulate enough data to exceed the + /// simple-PUT threshold. From this point on, every completed buffer is + /// sent directly to the background task for immediate upload. + fn start_multipart_upload(&mut self) { + debug_assert!(!self.multipart_started, "multipart upload already started"); + self.multipart_started = true; + + // Channel capacity of 2: one part being uploaded, one buffered and ready. + // This keeps memory bounded while allowing overlap between buffering and + // uploading. + let (tx, rx) = mpsc::channel::(2); + let (result_tx, result_rx) = oneshot::channel(); + + let client = Arc::clone(&self.client); + let path = self.path.clone(); + let pending = std::mem::take(&mut self.pending_parts); + + tokio::spawn(async move { + let result = run_multipart_upload(client, path, pending, rx).await; + // Ignore send error — the receiver may have been dropped if the + // writer was abandoned. + let _ = result_tx.send(result); + }); + + self.upload_tx = Some(tx); + self.result_rx = Some(result_rx); + } + + /// Send a completed buffer chunk to the background upload task. + fn send_part(&mut self, part: Bytes) -> io::Result<()> { + if let Some(tx) = &self.upload_tx { + tx.blocking_send(UploadMessage::Part(part)) + .map_err(|_| io::Error::other("Background upload task terminated unexpectedly"))?; + } + Ok(()) + } + + /// Complete the upload by sending any remaining data and waiting for the + /// background task to finish. + /// + /// For small files (total data < [`PARQUET_BUFFER_SIZE`] and fits in a single + /// part), a simple PUT is used instead of multipart upload. + /// + /// This method must be called from an async context (it is typically called + /// via [`block_on`](tokio::runtime::Handle::block_on) from inside + /// [`spawn_blocking`](tokio::task::spawn_blocking)). + pub async fn finish(mut self) -> Result { + let total = self.total_bytes; + debug!("Completing S3 upload: {} bytes total", total); + + // Flush any remaining buffer data + if !self.buffer.is_empty() { + let remaining = Bytes::from(std::mem::take(&mut self.buffer)); + + if self.multipart_started { + // Send as the last part + if let Some(tx) = &self.upload_tx { + tx.send(UploadMessage::Part(remaining)).await.map_err(|_| { + io::Error::other("Background upload task terminated unexpectedly") + })?; + } + } else { + self.pending_parts.push(remaining); + } + } + + if self.multipart_started { + // Signal the background task that we are done + if let Some(tx) = self.upload_tx.take() { + let _ = tx.send(UploadMessage::Finish).await; + } + // Wait for the background task result + if let Some(rx) = self.result_rx.take() { + rx.await.map_err(|_| { + io::Error::other("Upload task dropped without sending result") + })??; + } + } else { + // Small file path — use a simple PUT + let data: Vec = self + .pending_parts + .into_iter() + .flat_map(|b| b.to_vec()) + .collect(); + + if data.is_empty() { + debug!("No data to upload"); + return Ok(0); + } + + debug!("Using simple PUT for small file: {} bytes", data.len()); + self.client + .put(&self.path, Bytes::from(data).into()) + .await + .map_err(|e| io::Error::other(format!("Failed to upload to S3: {}", e)))?; + } + + info!("Successfully uploaded {} bytes to S3", total); + Ok(total) + } + + /// Get the total bytes written so far + #[allow(dead_code)] // used by zone module in a later commit + pub fn total_bytes(&self) -> usize { + self.total_bytes + } + + /// Get the buffer size (for compatibility) + #[allow(dead_code)] // used by zone module in a later commit + pub fn buffer_size(&self) -> usize { + self.total_bytes + } +} + +/// Background task that runs the multipart upload. +/// +/// Starts the upload, drains any pre-accumulated pending parts, then +/// continuously receives new parts from the channel and uploads them. On +/// any upload error the multipart upload is aborted to avoid orphaned +/// uploads accruing S3 storage costs. +async fn run_multipart_upload( + client: Arc, + path: ObjectPath, + pending_parts: Vec, + mut rx: mpsc::Receiver, +) -> Result<(), io::Error> { + debug!("Starting multipart upload for {:?}", path); + let mut upload = client + .put_multipart(&path) + .await + .map_err(|e| io::Error::other(format!("Failed to start multipart upload: {}", e)))?; + + let mut part_number: usize = 0; + + // Upload any parts that were accumulated before the task started + for part_data in pending_parts { + part_number += 1; + debug!( + "Uploading pending part {} ({} bytes)", + part_number, + part_data.len() + ); + if let Err(e) = upload.put_part(part_data.into()).await { + debug!("Part upload failed, aborting multipart upload"); + let _ = upload.abort().await; + return Err(io::Error::other(format!( + "Failed to upload part {}: {}", + part_number, e + ))); + } + } + + // Receive and upload parts from the channel + while let Some(msg) = rx.recv().await { + match msg { + UploadMessage::Part(part_data) => { + part_number += 1; + debug!("Uploading part {} ({} bytes)", part_number, part_data.len()); + if let Err(e) = upload.put_part(part_data.into()).await { + debug!("Part upload failed, aborting multipart upload"); + let _ = upload.abort().await; + return Err(io::Error::other(format!( + "Failed to upload part {}: {}", + part_number, e + ))); + } + } + UploadMessage::Finish => { + break; + } + } + } + + // Complete the multipart upload + debug!("Completing multipart upload ({} parts)", part_number); + if let Err(e) = upload.complete().await { + debug!("Multipart complete failed, aborting"); + // complete() consumes the upload, so we can't abort here — the upload + // will be cleaned up by S3's lifecycle rules for incomplete uploads. + return Err(io::Error::other(format!( + "Failed to complete multipart upload: {}", + e + ))); + } + + debug!( + "Multipart upload completed successfully ({} parts)", + part_number + ); + Ok(()) +} + +impl Write for S3Writer { + fn write(&mut self, buf: &[u8]) -> io::Result { + self.total_bytes += buf.len(); + self.buffer.extend_from_slice(buf); + + // When buffer reaches our target part size (32MB), send it as a part + if self.buffer.len() >= PARQUET_BUFFER_SIZE { + let part_data = Bytes::from(std::mem::replace( + &mut self.buffer, + Vec::with_capacity(PARQUET_BUFFER_SIZE), + )); + + if self.multipart_started { + // Stream directly to the background upload task + self.send_part(part_data)?; + } else { + // Accumulate until we know whether this will be a small file + self.pending_parts.push(part_data); + + // We now have at least 32MB, which exceeds the 5MB simple PUT + // threshold — switch to streaming multipart upload + self.start_multipart_upload(); + } + } + + Ok(buf.len()) + } + + fn flush(&mut self) -> io::Result<()> { + // No-op: all data will be uploaded in finish() + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use object_store::memory::InMemory; + + // ---- parse_s3_uri tests ---- + + #[test] + fn parse_s3_uri_valid() { + let (bucket, path) = parse_s3_uri("s3://my-bucket/path/to/file.parquet").unwrap(); + assert_eq!(bucket, "my-bucket"); + assert_eq!(path, "path/to/file.parquet"); + } + + #[test] + fn parse_s3_uri_nested_path() { + let (bucket, path) = parse_s3_uri("s3://bucket/a/b/c/d/file.parquet").unwrap(); + assert_eq!(bucket, "bucket"); + assert_eq!(path, "a/b/c/d/file.parquet"); + } + + #[test] + fn parse_s3_uri_no_path() { + let (bucket, path) = parse_s3_uri("s3://bucket").unwrap(); + assert_eq!(bucket, "bucket"); + assert_eq!(path, ""); + } + + #[test] + fn parse_s3_uri_trailing_slash() { + let (bucket, path) = parse_s3_uri("s3://bucket/prefix/").unwrap(); + assert_eq!(bucket, "bucket"); + assert_eq!(path, "prefix/"); + } + + #[test] + fn parse_s3_uri_wrong_scheme() { + let err = parse_s3_uri("https://bucket/path").unwrap_err(); + assert_eq!(err.kind(), io::ErrorKind::InvalidInput); + assert!(err.to_string().contains("Expected s3://")); + } + + #[test] + fn parse_s3_uri_invalid_uri() { + let err = parse_s3_uri("not a uri at all").unwrap_err(); + assert_eq!(err.kind(), io::ErrorKind::InvalidInput); + assert!(err.to_string().contains("Invalid S3 URI")); + } + + // ---- S3Writer tests using InMemory object store ---- + + #[tokio::test] + async fn write_small_file() { + let store = Arc::new(InMemory::new()); + let mut writer = S3Writer::with_client(store.clone(), "output/test.parquet"); + + let data = b"hello world"; + writer.write_all(data).unwrap(); + + let total = writer.finish().await.unwrap(); + assert_eq!(total, data.len()); + + // Verify the data arrived in the store + let result = store + .get(&ObjectPath::from("output/test.parquet")) + .await + .unwrap(); + let stored = result.bytes().await.unwrap(); + assert_eq!(stored.as_ref(), data); + } + + #[tokio::test] + async fn write_empty_file() { + let store = Arc::new(InMemory::new()); + let writer = S3Writer::with_client(store.clone(), "output/empty.parquet"); + + let total = writer.finish().await.unwrap(); + assert_eq!(total, 0); + + // Nothing should be written to the store + let result = store.get(&ObjectPath::from("output/empty.parquet")).await; + assert!(result.is_err()); + } + + #[tokio::test] + async fn write_large_file_triggers_multipart() { + let store = Arc::new(InMemory::new()); + let mut writer = S3Writer::with_client(store.clone(), "output/large.parquet"); + + // Write more than PARQUET_BUFFER_SIZE (32MB) to trigger multipart + let chunk = vec![0xABu8; 1024 * 1024]; // 1MB chunks + let num_chunks = 34; // 34MB total > 32MB threshold + for _ in 0..num_chunks { + writer.write_all(&chunk).unwrap(); + } + + let total = writer.finish().await.unwrap(); + assert_eq!(total, num_chunks * chunk.len()); + + // Verify the data arrived in the store and is correct size + let result = store + .get(&ObjectPath::from("output/large.parquet")) + .await + .unwrap(); + let stored = result.bytes().await.unwrap(); + assert_eq!(stored.len(), num_chunks * chunk.len()); + // Verify all bytes are correct + assert!(stored.iter().all(|&b| b == 0xAB)); + } + + #[tokio::test] + async fn write_multiple_small_writes() { + let store = Arc::new(InMemory::new()); + let mut writer = S3Writer::with_client(store.clone(), "output/multi.parquet"); + + // Simulate many small writes (like a Parquet encoder would produce) + for i in 0u8..100 { + writer.write_all(&[i]).unwrap(); + } + + let total = writer.finish().await.unwrap(); + assert_eq!(total, 100); + + let result = store + .get(&ObjectPath::from("output/multi.parquet")) + .await + .unwrap(); + let stored = result.bytes().await.unwrap(); + let expected: Vec = (0u8..100).collect(); + assert_eq!(stored.as_ref(), expected.as_slice()); + } + + #[tokio::test] + async fn total_bytes_tracks_writes() { + let store = Arc::new(InMemory::new()); + let mut writer = S3Writer::with_client(store, "output/track.parquet"); + + assert_eq!(writer.total_bytes(), 0); + + writer.write_all(&[1, 2, 3]).unwrap(); + assert_eq!(writer.total_bytes(), 3); + + writer.write_all(&[4, 5]).unwrap(); + assert_eq!(writer.total_bytes(), 5); + } +} diff --git a/spatialbench-cli/src/zone/config.rs b/spatialbench-cli/src/zone/config.rs index 9594c8b..a6d28bb 100644 --- a/spatialbench-cli/src/zone/config.rs +++ b/spatialbench-cli/src/zone/config.rs @@ -77,4 +77,91 @@ impl ZoneDfArgs { self.output_dir.join("zone.parquet") } } + + /// Whether the output directory is an S3 URI (starts with `s3://`) + pub fn is_s3(&self) -> bool { + self.output_dir.to_string_lossy().starts_with("s3://") + } + + /// Compute the S3 object key for this zone output. + /// + /// Returns the full S3 URI (e.g. `s3://bucket/prefix/zone.parquet`). + pub fn output_s3_uri(&self) -> String { + let base = self.output_dir.to_string_lossy(); + let base = base.trim_end_matches('/'); + if self.parts.unwrap_or(1) > 1 { + format!("{}/zone/zone.{}.parquet", base, self.part.unwrap_or(1)) + } else { + format!("{}/zone.parquet", base) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn default_args(output_dir: &str) -> ZoneDfArgs { + ZoneDfArgs::new( + 1.0, + PathBuf::from(output_dir), + None, + None, + None, + 128 * 1024 * 1024, + ParquetCompression::ZSTD(Default::default()), + ) + } + + #[test] + fn is_s3_with_s3_uri() { + let args = default_args("s3://my-bucket/output"); + assert!(args.is_s3()); + } + + #[test] + fn is_s3_with_local_path() { + let args = default_args("/tmp/output"); + assert!(!args.is_s3()); + } + + #[test] + fn is_s3_with_relative_path() { + let args = default_args("./output"); + assert!(!args.is_s3()); + } + + #[test] + fn output_s3_uri_single_file() { + let args = default_args("s3://bucket/prefix"); + assert_eq!(args.output_s3_uri(), "s3://bucket/prefix/zone.parquet"); + } + + #[test] + fn output_s3_uri_single_file_trailing_slash() { + let args = default_args("s3://bucket/prefix/"); + assert_eq!(args.output_s3_uri(), "s3://bucket/prefix/zone.parquet"); + } + + #[test] + fn output_s3_uri_with_partitions() { + let mut args = default_args("s3://bucket/prefix"); + args.parts = Some(10); + args.part = Some(3); + assert_eq!( + args.output_s3_uri(), + "s3://bucket/prefix/zone/zone.3.parquet" + ); + } + + #[test] + fn output_s3_uri_partition_defaults_to_part_1() { + let mut args = default_args("s3://bucket/prefix"); + args.parts = Some(5); + // part not set — should default to 1 + assert_eq!( + args.output_s3_uri(), + "s3://bucket/prefix/zone/zone.1.parquet" + ); + } } diff --git a/spatialbench-cli/src/zone/mod.rs b/spatialbench-cli/src/zone/mod.rs index 4071454..727add1 100644 --- a/spatialbench-cli/src/zone/mod.rs +++ b/spatialbench-cli/src/zone/mod.rs @@ -59,7 +59,7 @@ pub async fn generate_zone_parquet_single(args: ZoneDfArgs) -> Result<()> { let batches = df.collect().await?; let writer = ParquetWriter::new(&args, &stats, schema); - writer.write(&batches)?; + writer.write(&batches).await?; Ok(()) } @@ -106,7 +106,7 @@ pub async fn generate_zone_parquet_multi(args: ZoneDfArgs) -> Result<()> { ); let writer = ParquetWriter::new(&part_args, &stats, schema.clone()); - writer.write(&partitioned_batches)?; + writer.write(&partitioned_batches).await?; } Ok(()) diff --git a/spatialbench-cli/src/zone/writer.rs b/spatialbench-cli/src/zone/writer.rs index b22671a..5afa038 100644 --- a/spatialbench-cli/src/zone/writer.rs +++ b/spatialbench-cli/src/zone/writer.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use crate::s3_writer::S3Writer; use anyhow::Result; use arrow_array::RecordBatch; use arrow_schema::SchemaRef; @@ -52,7 +53,16 @@ impl ParquetWriter { } } - pub fn write(&self, batches: &[RecordBatch]) -> Result<()> { + pub async fn write(&self, batches: &[RecordBatch]) -> Result<()> { + if self.args.is_s3() { + self.write_s3(batches).await + } else { + self.write_local(batches) + } + } + + /// Write batches to a local file using a temp-file + rename pattern. + fn write_local(&self, batches: &[RecordBatch]) -> Result<()> { // Create parent directory of output file (handles both zone/ subdirectory and base dir) let parent_dir = self .output_path @@ -108,4 +118,38 @@ impl ParquetWriter { Ok(()) } + + /// Write batches to S3 using [`S3Writer`]. + /// + /// S3 writes are atomic (via multipart upload `complete()`), so no + /// temp-file or rename is needed. + async fn write_s3(&self, batches: &[RecordBatch]) -> Result<()> { + let uri = self.args.output_s3_uri(); + info!("Writing zone parquet to S3: {}", uri); + + let t0 = Instant::now(); + let s3_writer = S3Writer::new(&uri)?; + let mut writer = ArrowWriter::try_new( + s3_writer, + Arc::clone(&self.schema), + Some(self.props.clone()), + )?; + + for batch in batches { + writer.write(batch)?; + } + + let s3_writer = writer.into_inner()?; + let size = s3_writer.finish().await?; + + let duration = t0.elapsed(); + let total_rows: usize = batches.iter().map(|b| b.num_rows()).sum(); + + info!( + "Zone -> {} (part {:?}/{:?}). write={:?}, total_rows={}, bytes={}", + uri, self.args.part, self.args.parts, duration, total_rows, size + ); + + Ok(()) + } }