Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 17 additions & 8 deletions sdk/python/feast/infra/compute_engines/local/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,16 +374,25 @@ def execute(self, context: ExecutionContext) -> ArrowTableValue:
for entity in self.feature_view.entity_columns
}

rows_to_write = _convert_arrow_to_proto(
input_table, self.feature_view, join_key_to_value_type
batch_size = (
context.repo_config.materialization_config.online_write_batch_size
)

online_store.online_write_batch(
config=context.repo_config,
table=self.feature_view,
data=rows_to_write,
progress=lambda x: None,
# Single batch if None (backward compatible), otherwise use configured batch_size
batches = (
[input_table]
if batch_size is None
else input_table.to_batches(max_chunksize=batch_size)
)
for batch in batches:
rows_to_write = _convert_arrow_to_proto(
batch, self.feature_view, join_key_to_value_type
)
online_store.online_write_batch(
config=context.repo_config,
table=self.feature_view,
data=rows_to_write,
progress=lambda x: None,
)
Comment thread
cutoutsy marked this conversation as resolved.

if self.feature_view.offline:
offline_store = context.offline_store
Expand Down
31 changes: 22 additions & 9 deletions sdk/python/feast/infra/compute_engines/ray/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,19 +45,32 @@ def write_to_online_store(
for entity in feature_view.entity_columns
}

rows_to_write = _convert_arrow_to_proto(
arrow_table, feature_view, join_key_to_value_type
batch_size = repo_config.materialization_config.online_write_batch_size
# Single batch if None (backward compatible), otherwise use configured batch_size
batches = (
[arrow_table]
if batch_size is None
else arrow_table.to_batches(max_chunksize=batch_size)
)

if rows_to_write:
online_store.online_write_batch(
config=repo_config,
table=feature_view,
data=rows_to_write,
progress=lambda x: None,
total_rows = 0
for batch in batches:
rows_to_write = _convert_arrow_to_proto(
batch, feature_view, join_key_to_value_type
)

if rows_to_write:
online_store.online_write_batch(
config=repo_config,
table=feature_view,
data=rows_to_write,
progress=lambda x: None,
)
total_rows += len(rows_to_write)

if total_rows > 0:
logger.debug(
f"Successfully wrote {len(rows_to_write)} rows to online store for {feature_view.name}"
f"Successfully wrote {total_rows} rows to online store for {feature_view.name}"
)
else:
logger.warning(f"No rows to write for {feature_view.name}")
Expand Down
50 changes: 33 additions & 17 deletions sdk/python/feast/infra/compute_engines/spark/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,16 +47,24 @@ def map_in_arrow(
for entity in feature_view.entity_columns
}

rows_to_write = _convert_arrow_to_proto(
table, feature_view, join_key_to_value_type
)

online_store.online_write_batch(
config=repo_config,
table=feature_view,
data=rows_to_write,
progress=lambda x: None,
batch_size = repo_config.materialization_config.online_write_batch_size
# Single batch if None (backward compatible), otherwise use configured batch_size
sub_batches = (
[table]
if batch_size is None
else table.to_batches(max_chunksize=batch_size)
)
for sub_batch in sub_batches:
rows_to_write = _convert_arrow_to_proto(
sub_batch, feature_view, join_key_to_value_type
)

online_store.online_write_batch(
config=repo_config,
table=feature_view,
data=rows_to_write,
progress=lambda x: None,
)
if mode == "offline":
offline_store.offline_write_batch(
config=repo_config,
Expand Down Expand Up @@ -95,15 +103,23 @@ def map_in_pandas(iterator, serialized_artifacts: SerializedArtifacts):
for entity in feature_view.entity_columns
}

rows_to_write = _convert_arrow_to_proto(
table, feature_view, join_key_to_value_type
)
online_store.online_write_batch(
repo_config,
feature_view,
rows_to_write,
lambda x: None,
batch_size = repo_config.materialization_config.online_write_batch_size
# Single batch if None (backward compatible), otherwise use configured batch_size
sub_batches = (
[table]
if batch_size is None
else table.to_batches(max_chunksize=batch_size)
)
for sub_batch in sub_batches:
rows_to_write = _convert_arrow_to_proto(
sub_batch, feature_view, join_key_to_value_type
)
online_store.online_write_batch(
repo_config,
feature_view,
rows_to_write,
lambda x: None,
)

yield pd.DataFrame(
[pd.Series(range(1, 2))]
Expand Down
6 changes: 6 additions & 0 deletions sdk/python/feast/repo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,12 @@ class MaterializationConfig(BaseModel):
""" bool: If true, feature retrieval jobs will only pull the latest feature values for each entity.
If false, feature retrieval jobs will pull all feature values within the specified time range. """

online_write_batch_size: Optional[int] = Field(default=None, gt=0)
""" int: Number of rows to write to online store per batch during materialization.
If None (default), all rows are written in a single batch for backward compatibility.
Set to a positive integer (e.g., 10000) to enable batched writes.
Supported compute engines: local, spark, ray. """


class OpenLineageConfig(FeastBaseModel):
"""Configuration for OpenLineage integration.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
LocalOutputNode,
LocalTransformationNode,
)
from feast.repo_config import MaterializationConfig

backend = PandasBackend()
now = pd.Timestamp.utcnow()
Expand All @@ -37,9 +38,11 @@

def create_context(node_outputs):
# Setup execution context
repo_config = MagicMock()
repo_config.materialization_config = MaterializationConfig()
return ExecutionContext(
project="test_proj",
repo_config=MagicMock(),
repo_config=repo_config,
offline_store=MagicMock(),
online_store=MagicMock(),
entity_defs=MagicMock(),
Expand Down Expand Up @@ -214,3 +217,52 @@ def test_local_output_node():
node.inputs[0].name = "source"
result = node.execute(context)
assert result.num_rows == 4


def test_local_output_node_online_write_default_batch():
"""Test that online_write_batch is called once when batch_size is None (default)."""
# Create a feature view with online=True
feature_view = MagicMock()
feature_view.online = True
feature_view.offline = False
feature_view.entity_columns = []

# Create context with default materialization config (batch_size=None)
context = create_context(
node_outputs={"source": ArrowTableValue(pa.Table.from_pandas(sample_df))}
)

node = LocalOutputNode("output", feature_view)
node.add_input(MagicMock())
node.inputs[0].name = "source"

node.execute(context)

# Verify online_write_batch was called exactly once (all rows in single batch)
assert context.online_store.online_write_batch.call_count == 1


def test_local_output_node_online_write_batched():
"""Test that online_write_batch is called multiple times when batch_size is configured."""
# Create a feature view with online=True
feature_view = MagicMock()
feature_view.online = True
feature_view.offline = False
feature_view.entity_columns = []

# Create context with batch_size=2 (sample_df has 4 rows, so expect 2 batches)
context = create_context(
node_outputs={"source": ArrowTableValue(pa.Table.from_pandas(sample_df))}
)
context.repo_config.materialization_config = MaterializationConfig(
online_write_batch_size=2
)

node = LocalOutputNode("output", feature_view)
node.add_input(MagicMock())
node.inputs[0].name = "source"

node.execute(context)

# Verify online_write_batch was called twice (4 rows / batch_size 2 = 2 batches)
assert context.online_store.online_write_batch.call_count == 2
Loading