Skip to content

Very high memory usage for unnest + group by #20788

@psuszyns

Description

@psuszyns

Describe the bug

unnest of several columns followed by a group by results in an extremely high memory usage. A 341 MB parquet file containing 3 array columns, with 20000 records and 2000 elements in each array column, processed by the query outlined in the 'to reproduce' section results in more 53GB+ of RAM usage.

To Reproduce

"""
Minimal DataFusion example demonstrating memory explosion with:
  row_index + unnest + group_by

This example creates a Parquet file with list columns and shows how the 
unnest + group_by pattern causes unbounded memory growth even when the 
input is processed in streaming mode.

Usage:
    python datafusion_unnest_memory_issue.py generate  # Generate test data
    python datafusion_unnest_memory_issue.py monitor   # Run with memory monitoring
"""

import sys
import os
import psutil
import threading
import pyarrow as pa
import pyarrow.parquet as pq
import numpy as np
import datafusion
from datafusion import col, literal, functions as f
import time
import tracemalloc

# Configuration
NUM_ROWS = 20_000       # Number of records - for 20k records more than 64 GB of memory is recommended
LIST_SIZE = 2000         # Number of elements per list
OUTPUT_DIR = "datafusion_test_data"
PARQUET_FILE = f"{OUTPUT_DIR}/test_lists_{NUM_ROWS}rows_{LIST_SIZE}elements.parquet"


def generate_test_data():
    """
    Generate a Parquet file with multiple list columns.
    
    Schema:
        metadata: string   - some scalar data
        values_a: list<int32>    - list column A
        values_b: list<float32>  - list column B
        values_c: list<float32>  - list column C
    
    Each row has LIST_SIZE elements in each list column.
    """
    print(f"Generating test data: {NUM_ROWS} rows x {LIST_SIZE} list elements")
    print(f"Expected intermediate rows after unnest: {NUM_ROWS * LIST_SIZE}")
    
    os.makedirs(OUTPUT_DIR, exist_ok=True)
    
    # Generate data in chunks
    chunk_size = 10000
    writer = None
    
    for chunk_start in range(0, NUM_ROWS, chunk_size):
        chunk_end = min(chunk_start + chunk_size, NUM_ROWS)
        chunk_rows = chunk_end - chunk_start
        
        # Generate scalar columns
        metadata = [f"row_{i}" for i in range(chunk_start, chunk_end)]
        
        # Generate list columns - each row has LIST_SIZE elements
        # Using different value patterns to make data realistic
        np.random.seed(42 + chunk_start)
        
        values_a = [
            np.random.randint(0, 100, size=LIST_SIZE).tolist()
            for _ in range(chunk_rows)
        ]
        values_b = [
            np.random.uniform(0, 100, size=LIST_SIZE).astype(np.float32).tolist()
            for _ in range(chunk_rows)
        ]
        values_c = [
            np.random.uniform(0, 200, size=LIST_SIZE).astype(np.float32).tolist()
            for _ in range(chunk_rows)
        ]
        
        # Create Arrow arrays
        table = pa.table({
            "metadata": pa.array(metadata, type=pa.string()),
            "values_a": pa.array(values_a, type=pa.list_(pa.int32())),
            "values_b": pa.array(values_b, type=pa.list_(pa.float32())),
            "values_c": pa.array(values_c, type=pa.list_(pa.float32())),
        })
        
        if writer is None:
            writer = pq.ParquetWriter(PARQUET_FILE, table.schema)
        
        writer.write_table(table)
        print(f"  Written rows {chunk_start} - {chunk_end}")
    
    writer.close()
    
    file_size_mb = os.path.getsize(PARQUET_FILE) / (1024 * 1024)
    print(f"\nGenerated: {PARQUET_FILE}")
    print(f"File size: {file_size_mb:.1f} MB")
    print(f"Total rows: {NUM_ROWS}")
    print(f"Elements per list: {LIST_SIZE}")
    print(f"Expected unnested rows: {NUM_ROWS * LIST_SIZE}")


def run_problematic_query():
    """
    Demonstrate the memory issue with unnest + group_by pattern. This pattern requires:
    1. Adding a row index (to group by later)
    2. Unnesting multiple list columns
    3. Grouping by the row index
    
    The problem: Step 3 must buffer ALL unnested rows before it can emit results,
    causing memory to grow with input size x list size.
    """
    ctx = datafusion.SessionContext()
    
    print(f"\nRegistering Parquet file: {PARQUET_FILE}")
    ctx.register_parquet("test_data", PARQUET_FILE)
    
    # Show input stats
    result = ctx.sql("SELECT COUNT(*) as cnt FROM test_data").collect()
    input_rows = result[0]["cnt"][0]
    print(f"Input rows: {input_rows}")
    
    # Run the problematic query
    print("\n--- Unnest + Group By (MEMORY EXPLOSION EXPECTED) ---")
    print("This query unnests list columns and groups by row index.")
    print(f"Expected intermediate rows: {int(input_rows) * LIST_SIZE}")
    print("\nStarting query... (watch memory usage)")
    
    start = time.time()
    tracemalloc.start()
    
    # The problematic pattern using SQL:
    problematic_query = """
    WITH indexed AS (
        SELECT 
            ROW_NUMBER() OVER () as row_idx,
            metadata,
            values_a,
            values_b,
            values_c
        FROM test_data
    ),
    unnested AS (
        SELECT 
            row_idx,
            metadata,
            unnest(values_a) as val_a,
            unnest(values_b) as val_b,
            unnest(values_c) as val_c
        FROM indexed
    ),
    transformed AS (
        SELECT
            row_idx,
            metadata,
            val_a,
            val_b,
            val_c,
            -- Example transformation: create a new column based on val_a, val_b and val_c
            CASE WHEN val_c > 100 THEN val_a * val_b ELSE val_a + val_b END AS val_d
        FROM unnested
    )
    SELECT
        row_idx,
        metadata,
        array_agg(val_a ORDER BY row_idx) AS values_a,
        array_agg(val_b ORDER BY row_idx) AS values_b,
        array_agg(val_c ORDER BY row_idx) AS values_c,
        array_agg(val_d ORDER BY row_idx) AS values_d
    FROM transformed
    GROUP BY row_idx, metadata
    ORDER BY row_idx
    """
    
    try:
        print("\nExecuting query...")
        result = ctx.sql(problematic_query)

        # we want to stream the results to the output file
        output_file = f"{OUTPUT_DIR}/query_result.parquet"
        result.write_parquet(output_file)
        
        current, peak = tracemalloc.get_traced_memory()
        tracemalloc.stop()
        
        print(f"\nQuery completed!")
        print(f"Time: {time.time() - start:.2f}s")
        print(f"Peak memory: {peak / 1024 / 1024:.1f} MB")
                
    except Exception as e:
        current, peak = tracemalloc.get_traced_memory()
        tracemalloc.stop()
        
        print(f"\nQuery failed after {time.time() - start:.2f}s")
        print(f"Peak memory before failure: {peak / 1024 / 1024:.1f} MB")
        print(f"Error: {e}")


def run_with_memory_monitoring():
    """
    Run the query with detailed memory monitoring using psutil.
    """
    
    process = psutil.Process(os.getpid())
    max_memory = [0]
    stop_monitoring = [False]
    
    def monitor_memory():
        while not stop_monitoring[0]:
            mem = process.memory_info().rss / 1024 / 1024
            max_memory[0] = max(max_memory[0], mem)
            time.sleep(0.1)
    
    monitor_thread = threading.Thread(target=monitor_memory)
    monitor_thread.start()
    
    try:
        initial_mem = process.memory_info().rss / 1024 / 1024
        print(f"Initial memory: {initial_mem:.1f} MB")
        
        run_problematic_query()
        
    finally:
        stop_monitoring[0] = True
        monitor_thread.join()
        
        final_mem = process.memory_info().rss / 1024 / 1024
        print(f"\n=== Memory Summary ===")
        print(f"Initial: {initial_mem:.1f} MB")
        print(f"Peak: {max_memory[0]:.1f} MB")
        print(f"Final: {final_mem:.1f} MB")
        print(f"Growth: {max_memory[0] - initial_mem:.1f} MB")


def show_usage():
    print(__doc__)
    print("\nCommands:")
    print("  generate  - Generate test Parquet file")
    print("  run       - Run the problematic query")
    print("  monitor   - Run with memory monitoring (requires psutil)")


if __name__ == "__main__":
    if len(sys.argv) < 2:
        show_usage()
        sys.exit(1)
    
    command = sys.argv[1]
    
    if command == "generate":
        generate_test_data()
    elif command == "run":
        if not os.path.exists(PARQUET_FILE):
            print(f"Test data not found. Run 'python {sys.argv[0]} generate' first.")
            sys.exit(1)
        run_problematic_query()
    elif command == "monitor":
        if not os.path.exists(PARQUET_FILE):
            print(f"Test data not found. Run 'python {sys.argv[0]} generate' first.")
            sys.exit(1)
        run_with_memory_monitoring()
    else:
        show_usage()
        sys.exit(1)

Expected behavior

Executing the above query should be possible with constant memory usage in a streaming fashion.

Additional context

Given example is a toy example but it very closely resembles a real use case when working with bioinformatics vcf format.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions