Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion protos/feast/core/DataSource.proto
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ message DataSource {
reserved 6 to 10;

// Type of Data Source.
// Next available id: 12
// Next available id: 13
enum SourceType {
INVALID = 0;
BATCH_FILE = 1;
Expand Down Expand Up @@ -231,6 +231,9 @@ message DataSource {

// Date Format of date partition column (e.g. %Y-%m-%d)
string date_partition_column_format = 5;

// Table Format (e.g. iceberg, delta, etc)
string table_format = 6;
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO, create TableFormat proto, consolidate with FileFormat proto

}

// Defines configuration for custom third-party data sources.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import logging
import traceback
import warnings
Expand All @@ -14,17 +15,17 @@
)
from feast.repo_config import RepoConfig
from feast.saved_dataset import SavedDatasetStorage
from feast.table_format import TableFormat, table_format_from_dict
from feast.type_map import spark_to_feast_value_type
from feast.value_type import ValueType

logger = logging.getLogger(__name__)


class SparkSourceFormat(Enum):
class SparkFileSourceFormat(Enum):
csv = "csv"
json = "json"
parquet = "parquet"
delta = "delta"
avro = "avro"


Expand All @@ -42,6 +43,7 @@ def __init__(
query: Optional[str] = None,
path: Optional[str] = None,
file_format: Optional[str] = None,
table_format: Optional[TableFormat] = None,
created_timestamp_column: Optional[str] = None,
field_mapping: Optional[Dict[str, str]] = None,
description: Optional[str] = "",
Expand All @@ -58,7 +60,9 @@ def __init__(
table: The name of a Spark table.
query: The query to be executed in Spark.
path: The path to file data.
file_format: The format of the file data.
file_format: The underlying file format (parquet, avro, csv, json).

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not consolidate now?

table_format: The table metadata format (iceberg, delta, hudi, etc.).
Optional and separate from file_format.
created_timestamp_column: Timestamp column indicating when the row
was created, used for deduplicating rows.
field_mapping: A dictionary mapping of column names in this data
Expand All @@ -70,7 +74,7 @@ def __init__(
timestamp_field: Event timestamp field used for point-in-time joins of
feature values.
date_partition_column: The column to partition the data on for faster
retrieval. This is useful for large tables and will limit the number ofi
retrieval. This is useful for large tables and will limit the number of
"""
# If no name, use the table as the default name.
if name is None and table is None:
Expand Down Expand Up @@ -102,6 +106,7 @@ def __init__(
path=path,
file_format=file_format,
date_partition_column_format=date_partition_column_format,
table_format=table_format,
)

@property
Expand Down Expand Up @@ -132,6 +137,13 @@ def file_format(self):
"""
return self.spark_options.file_format

@property
def table_format(self):
"""
Returns the table format of this feature data source.
"""
return self.spark_options.table_format

@property
def date_partition_column_format(self):
"""
Expand All @@ -151,6 +163,7 @@ def from_proto(data_source: DataSourceProto) -> Any:
query=spark_options.query,
path=spark_options.path,
file_format=spark_options.file_format,
table_format=spark_options.table_format,
date_partition_column_format=spark_options.date_partition_column_format,
date_partition_column=data_source.date_partition_column,
timestamp_field=data_source.timestamp_field,
Expand Down Expand Up @@ -219,7 +232,7 @@ def get_table_query_string(self) -> str:
if spark_session is None:
raise AssertionError("Could not find an active spark session.")
try:
df = spark_session.read.format(self.file_format).load(self.path)
df = self._load_dataframe_from_path(spark_session)
except Exception:
logger.exception(
"Spark read of file source failed.\n" + traceback.format_exc()
Expand All @@ -230,6 +243,24 @@ def get_table_query_string(self) -> str:

return f"`{tmp_table_name}`"

def _load_dataframe_from_path(self, spark_session):
"""Load DataFrame from path, considering both file format and table format."""

if self.table_format is None:
# No table format specified, use standard file reading with file_format
return spark_session.read.format(self.file_format).load(self.path)

# Build reader with table format and options
reader = spark_session.read.format(self.table_format.format_type.value)

# Add table format specific options
for key, value in self.table_format.properties.items():
reader = reader.option(key, value)

# For catalog-based table formats like Iceberg, the path is actually a table name
# For file-based formats, it's still a file path
return reader.load(self.path)

def __eq__(self, other):
base_eq = super().__eq__(other)
if not base_eq:
Expand All @@ -245,7 +276,7 @@ def __hash__(self):


class SparkOptions:
allowed_formats = [format.value for format in SparkSourceFormat]
allowed_formats = [format.value for format in SparkFileSourceFormat]

def __init__(
self,
Expand All @@ -254,6 +285,7 @@ def __init__(
path: Optional[str],
file_format: Optional[str],
date_partition_column_format: Optional[str] = "%Y-%m-%d",
table_format: Optional[TableFormat] = None,
):
# Check that only one of the ways to load a spark dataframe can be used. We have
# to treat empty string and null the same due to proto (de)serialization.
Expand All @@ -262,11 +294,14 @@ def __init__(
"Exactly one of params(table, query, path) must be specified."
)
if path:
if not file_format:
# If table_format is specified, file_format is optional (table format determines the reader)
# If no table_format, file_format is required for basic file reading
if not table_format and not file_format:
raise ValueError(
"If 'path' is specified, then 'file_format' is required."
"If 'path' is specified without 'table_format', then 'file_format' is required."
)
if file_format not in self.allowed_formats:
# Only validate file_format if it's provided (it's optional with table_format)
if file_format and file_format not in self.allowed_formats:
raise ValueError(
f"'file_format' should be one of {self.allowed_formats}"
)
Expand All @@ -276,6 +311,7 @@ def __init__(
self._path = path
self._file_format = file_format
self._date_partition_column_format = date_partition_column_format
self._table_format = table_format

@property
def table(self):
Expand Down Expand Up @@ -317,6 +353,14 @@ def date_partition_column_format(self):
def date_partition_column_format(self, date_partition_column_format):
self._date_partition_column_format = date_partition_column_format

@property
def table_format(self):
return self._table_format

@table_format.setter
def table_format(self, table_format):
self._table_format = table_format

@classmethod
def from_proto(cls, spark_options_proto: DataSourceProto.SparkOptions):
"""
Expand All @@ -326,12 +370,20 @@ def from_proto(cls, spark_options_proto: DataSourceProto.SparkOptions):
Returns:
Returns a SparkOptions object based on the spark_options protobuf
"""
# Parse table_format if present
table_format = None
if spark_options_proto.table_format:
table_format = table_format_from_dict(
json.loads(spark_options_proto.table_format)
)

spark_options = cls(
table=spark_options_proto.table,
query=spark_options_proto.query,
path=spark_options_proto.path,
file_format=spark_options_proto.file_format,
date_partition_column_format=spark_options_proto.date_partition_column_format,
table_format=table_format,
)

return spark_options
Expand All @@ -348,6 +400,9 @@ def to_proto(self) -> DataSourceProto.SparkOptions:
path=self.path,
file_format=self.file_format,
date_partition_column_format=self.date_partition_column_format,
table_format=json.dumps(self.table_format.to_dict())
if self.table_format
else "",
)

return spark_options_proto
Expand All @@ -364,12 +419,14 @@ def __init__(
query: Optional[str] = None,
path: Optional[str] = None,
file_format: Optional[str] = None,
table_format: Optional[TableFormat] = None,
):
self.spark_options = SparkOptions(
table=table,
query=query,
path=path,
file_format=file_format,
table_format=table_format,
)

@staticmethod
Expand All @@ -380,6 +437,7 @@ def from_proto(storage_proto: SavedDatasetStorageProto) -> SavedDatasetStorage:
query=spark_options.query,
path=spark_options.path,
file_format=spark_options.file_format,
table_format=spark_options.table_format,
)

def to_proto(self) -> SavedDatasetStorageProto:
Expand All @@ -391,4 +449,5 @@ def to_data_source(self) -> DataSource:
query=self.spark_options.query,
path=self.spark_options.path,
file_format=self.spark_options.file_format,
table_format=self.spark_options.table_format,
)
Loading
Loading