diff --git a/sdk/python/feast/infra/compute_engines/local/nodes.py b/sdk/python/feast/infra/compute_engines/local/nodes.py index db65761a5e2..ae539d2b5b1 100644 --- a/sdk/python/feast/infra/compute_engines/local/nodes.py +++ b/sdk/python/feast/infra/compute_engines/local/nodes.py @@ -374,16 +374,25 @@ def execute(self, context: ExecutionContext) -> ArrowTableValue: for entity in self.feature_view.entity_columns } - rows_to_write = _convert_arrow_to_proto( - input_table, self.feature_view, join_key_to_value_type + batch_size = ( + context.repo_config.materialization_config.online_write_batch_size ) - - online_store.online_write_batch( - config=context.repo_config, - table=self.feature_view, - data=rows_to_write, - progress=lambda x: None, + # Single batch if None (backward compatible), otherwise use configured batch_size + batches = ( + [input_table] + if batch_size is None + else input_table.to_batches(max_chunksize=batch_size) ) + for batch in batches: + rows_to_write = _convert_arrow_to_proto( + batch, self.feature_view, join_key_to_value_type + ) + online_store.online_write_batch( + config=context.repo_config, + table=self.feature_view, + data=rows_to_write, + progress=lambda x: None, + ) if self.feature_view.offline: offline_store = context.offline_store diff --git a/sdk/python/feast/infra/compute_engines/ray/utils.py b/sdk/python/feast/infra/compute_engines/ray/utils.py index 94ebbe2c643..ff8cbba760a 100644 --- a/sdk/python/feast/infra/compute_engines/ray/utils.py +++ b/sdk/python/feast/infra/compute_engines/ray/utils.py @@ -45,19 +45,32 @@ def write_to_online_store( for entity in feature_view.entity_columns } - rows_to_write = _convert_arrow_to_proto( - arrow_table, feature_view, join_key_to_value_type + batch_size = repo_config.materialization_config.online_write_batch_size + # Single batch if None (backward compatible), otherwise use configured batch_size + batches = ( + [arrow_table] + if batch_size is None + else arrow_table.to_batches(max_chunksize=batch_size) ) - if rows_to_write: - online_store.online_write_batch( - config=repo_config, - table=feature_view, - data=rows_to_write, - progress=lambda x: None, + total_rows = 0 + for batch in batches: + rows_to_write = _convert_arrow_to_proto( + batch, feature_view, join_key_to_value_type ) + + if rows_to_write: + online_store.online_write_batch( + config=repo_config, + table=feature_view, + data=rows_to_write, + progress=lambda x: None, + ) + total_rows += len(rows_to_write) + + if total_rows > 0: logger.debug( - f"Successfully wrote {len(rows_to_write)} rows to online store for {feature_view.name}" + f"Successfully wrote {total_rows} rows to online store for {feature_view.name}" ) else: logger.warning(f"No rows to write for {feature_view.name}") diff --git a/sdk/python/feast/infra/compute_engines/spark/utils.py b/sdk/python/feast/infra/compute_engines/spark/utils.py index 4e429f8e075..1234f895464 100644 --- a/sdk/python/feast/infra/compute_engines/spark/utils.py +++ b/sdk/python/feast/infra/compute_engines/spark/utils.py @@ -47,16 +47,24 @@ def map_in_arrow( for entity in feature_view.entity_columns } - rows_to_write = _convert_arrow_to_proto( - table, feature_view, join_key_to_value_type - ) - - online_store.online_write_batch( - config=repo_config, - table=feature_view, - data=rows_to_write, - progress=lambda x: None, + batch_size = repo_config.materialization_config.online_write_batch_size + # Single batch if None (backward compatible), otherwise use configured batch_size + sub_batches = ( + [table] + if batch_size is None + else table.to_batches(max_chunksize=batch_size) ) + for sub_batch in sub_batches: + rows_to_write = _convert_arrow_to_proto( + sub_batch, feature_view, join_key_to_value_type + ) + + online_store.online_write_batch( + config=repo_config, + table=feature_view, + data=rows_to_write, + progress=lambda x: None, + ) if mode == "offline": offline_store.offline_write_batch( config=repo_config, @@ -95,15 +103,23 @@ def map_in_pandas(iterator, serialized_artifacts: SerializedArtifacts): for entity in feature_view.entity_columns } - rows_to_write = _convert_arrow_to_proto( - table, feature_view, join_key_to_value_type - ) - online_store.online_write_batch( - repo_config, - feature_view, - rows_to_write, - lambda x: None, + batch_size = repo_config.materialization_config.online_write_batch_size + # Single batch if None (backward compatible), otherwise use configured batch_size + sub_batches = ( + [table] + if batch_size is None + else table.to_batches(max_chunksize=batch_size) ) + for sub_batch in sub_batches: + rows_to_write = _convert_arrow_to_proto( + sub_batch, feature_view, join_key_to_value_type + ) + online_store.online_write_batch( + repo_config, + feature_view, + rows_to_write, + lambda x: None, + ) yield pd.DataFrame( [pd.Series(range(1, 2))] diff --git a/sdk/python/feast/repo_config.py b/sdk/python/feast/repo_config.py index 4007ba97e39..4e6c3bde68e 100644 --- a/sdk/python/feast/repo_config.py +++ b/sdk/python/feast/repo_config.py @@ -214,6 +214,12 @@ class MaterializationConfig(BaseModel): """ bool: If true, feature retrieval jobs will only pull the latest feature values for each entity. If false, feature retrieval jobs will pull all feature values within the specified time range. """ + online_write_batch_size: Optional[int] = Field(default=None, gt=0) + """ int: Number of rows to write to online store per batch during materialization. + If None (default), all rows are written in a single batch for backward compatibility. + Set to a positive integer (e.g., 10000) to enable batched writes. + Supported compute engines: local, spark, ray. """ + class OpenLineageConfig(FeastBaseModel): """Configuration for OpenLineage integration. diff --git a/sdk/python/tests/unit/infra/compute_engines/local/test_nodes.py b/sdk/python/tests/unit/infra/compute_engines/local/test_nodes.py index 905ea65ae42..edf7d74db1e 100644 --- a/sdk/python/tests/unit/infra/compute_engines/local/test_nodes.py +++ b/sdk/python/tests/unit/infra/compute_engines/local/test_nodes.py @@ -15,6 +15,7 @@ LocalOutputNode, LocalTransformationNode, ) +from feast.repo_config import MaterializationConfig backend = PandasBackend() now = pd.Timestamp.utcnow() @@ -37,9 +38,11 @@ def create_context(node_outputs): # Setup execution context + repo_config = MagicMock() + repo_config.materialization_config = MaterializationConfig() return ExecutionContext( project="test_proj", - repo_config=MagicMock(), + repo_config=repo_config, offline_store=MagicMock(), online_store=MagicMock(), entity_defs=MagicMock(), @@ -214,3 +217,52 @@ def test_local_output_node(): node.inputs[0].name = "source" result = node.execute(context) assert result.num_rows == 4 + + +def test_local_output_node_online_write_default_batch(): + """Test that online_write_batch is called once when batch_size is None (default).""" + # Create a feature view with online=True + feature_view = MagicMock() + feature_view.online = True + feature_view.offline = False + feature_view.entity_columns = [] + + # Create context with default materialization config (batch_size=None) + context = create_context( + node_outputs={"source": ArrowTableValue(pa.Table.from_pandas(sample_df))} + ) + + node = LocalOutputNode("output", feature_view) + node.add_input(MagicMock()) + node.inputs[0].name = "source" + + node.execute(context) + + # Verify online_write_batch was called exactly once (all rows in single batch) + assert context.online_store.online_write_batch.call_count == 1 + + +def test_local_output_node_online_write_batched(): + """Test that online_write_batch is called multiple times when batch_size is configured.""" + # Create a feature view with online=True + feature_view = MagicMock() + feature_view.online = True + feature_view.offline = False + feature_view.entity_columns = [] + + # Create context with batch_size=2 (sample_df has 4 rows, so expect 2 batches) + context = create_context( + node_outputs={"source": ArrowTableValue(pa.Table.from_pandas(sample_df))} + ) + context.repo_config.materialization_config = MaterializationConfig( + online_write_batch_size=2 + ) + + node = LocalOutputNode("output", feature_view) + node.add_input(MagicMock()) + node.inputs[0].name = "source" + + node.execute(context) + + # Verify online_write_batch was called twice (4 rows / batch_size 2 = 2 batches) + assert context.online_store.online_write_batch.call_count == 2