diff --git a/sdk/python/tests/unit/test_on_demand_python_transformation.py b/sdk/python/tests/unit/test_on_demand_python_transformation.py index 9a09037d422..6a0f777b283 100644 --- a/sdk/python/tests/unit/test_on_demand_python_transformation.py +++ b/sdk/python/tests/unit/test_on_demand_python_transformation.py @@ -45,202 +45,201 @@ class TestOnDemandPythonTransformation(unittest.TestCase): def setUp(self): - with tempfile.TemporaryDirectory() as data_dir: - self.store = FeatureStore( - config=RepoConfig( - project="test_on_demand_python_transformation", - registry=os.path.join(data_dir, "registry.db"), - provider="local", - entity_key_serialization_version=3, - online_store=SqliteOnlineStoreConfig( - path=os.path.join(data_dir, "online.db") - ), - ) + self.data_dir = tempfile.mkdtemp() + data_dir = self.data_dir + self.store = FeatureStore( + config=RepoConfig( + project="test_on_demand_python_transformation", + registry=os.path.join(data_dir, "registry.db"), + provider="local", + entity_key_serialization_version=3, + online_store=SqliteOnlineStoreConfig( + path=os.path.join(data_dir, "online.db") + ), ) + ) - # Generate test data. - end_date = datetime.now().replace(microsecond=0, second=0, minute=0) - start_date = end_date - timedelta(days=15) + # Generate test data. + end_date = datetime.now().replace(microsecond=0, second=0, minute=0) + start_date = end_date - timedelta(days=15) - driver_entities = [1001, 1002, 1003, 1004, 1005] - driver_df = create_driver_hourly_stats_df( - driver_entities, start_date, end_date - ) - driver_stats_path = os.path.join(data_dir, "driver_stats.parquet") - driver_df.to_parquet( - path=driver_stats_path, allow_truncated_timestamps=True - ) + driver_entities = [1001, 1002, 1003, 1004, 1005] + driver_df = create_driver_hourly_stats_df(driver_entities, start_date, end_date) + driver_stats_path = os.path.join(data_dir, "driver_stats.parquet") + driver_df.to_parquet(path=driver_stats_path, allow_truncated_timestamps=True) - driver = Entity( - name="driver", join_keys=["driver_id"], value_type=ValueType.INT64 - ) + driver = Entity( + name="driver", join_keys=["driver_id"], value_type=ValueType.INT64 + ) - driver_stats_source = FileSource( - name="driver_hourly_stats_source", - path=driver_stats_path, - timestamp_field="event_timestamp", - created_timestamp_column="created", - ) - input_request_source = RequestSource( - name="counter_source", - schema=[ - Field(name="counter", dtype=Int64), - Field(name="input_datetime", dtype=UnixTimestamp), - ], - ) + driver_stats_source = FileSource( + name="driver_hourly_stats_source", + path=driver_stats_path, + timestamp_field="event_timestamp", + created_timestamp_column="created", + ) + input_request_source = RequestSource( + name="counter_source", + schema=[ + Field(name="counter", dtype=Int64), + Field(name="input_datetime", dtype=UnixTimestamp), + ], + ) - driver_stats_fv = FeatureView( - name="driver_hourly_stats", - entities=[driver], - ttl=timedelta(days=0), - schema=[ - Field(name="conv_rate", dtype=Float32), - Field(name="acc_rate", dtype=Float32), - Field(name="avg_daily_trips", dtype=Int64), - ], - online=True, - source=driver_stats_source, - ) + driver_stats_fv = FeatureView( + name="driver_hourly_stats", + entities=[driver], + ttl=timedelta(days=0), + schema=[ + Field(name="conv_rate", dtype=Float32), + Field(name="acc_rate", dtype=Float32), + Field(name="avg_daily_trips", dtype=Int64), + ], + online=True, + source=driver_stats_source, + ) - driver_stats_entity_less_fv = FeatureView( - name="driver_hourly_stats_no_entity", - entities=[], - ttl=timedelta(days=0), - schema=[ - Field(name="conv_rate", dtype=Float32), - Field(name="acc_rate", dtype=Float32), - Field(name="avg_daily_trips", dtype=Int64), - ], - online=True, - source=driver_stats_source, - ) + driver_stats_entity_less_fv = FeatureView( + name="driver_hourly_stats_no_entity", + entities=[], + ttl=timedelta(days=0), + schema=[ + Field(name="conv_rate", dtype=Float32), + Field(name="acc_rate", dtype=Float32), + Field(name="avg_daily_trips", dtype=Int64), + ], + online=True, + source=driver_stats_source, + ) - @on_demand_feature_view( - sources=[driver_stats_fv], - schema=[Field(name="conv_rate_plus_acc_pandas", dtype=Float64)], - mode="pandas", - ) - def pandas_view(inputs: pd.DataFrame) -> pd.DataFrame: - df = pd.DataFrame() - df["conv_rate_plus_acc_pandas"] = ( - inputs["conv_rate"] + inputs["acc_rate"] - ) - return df + @on_demand_feature_view( + sources=[driver_stats_fv], + schema=[Field(name="conv_rate_plus_acc_pandas", dtype=Float64)], + mode="pandas", + ) + def pandas_view(inputs: pd.DataFrame) -> pd.DataFrame: + df = pd.DataFrame() + df["conv_rate_plus_acc_pandas"] = inputs["conv_rate"] + inputs["acc_rate"] + return df - @on_demand_feature_view( - sources=[driver_stats_fv[["conv_rate", "acc_rate"]]], - schema=[Field(name="conv_rate_plus_acc_python", dtype=Float64)], - mode="python", - ) - def python_view(inputs: dict[str, Any]) -> dict[str, Any]: - output: dict[str, Any] = { - "conv_rate_plus_acc_python": conv_rate + acc_rate + @on_demand_feature_view( + sources=[driver_stats_fv[["conv_rate", "acc_rate"]]], + schema=[Field(name="conv_rate_plus_acc_python", dtype=Float64)], + mode="python", + ) + def python_view(inputs: dict[str, Any]) -> dict[str, Any]: + output: dict[str, Any] = { + "conv_rate_plus_acc_python": conv_rate + acc_rate + for conv_rate, acc_rate in zip(inputs["conv_rate"], inputs["acc_rate"]) + } + return output + + @on_demand_feature_view( + sources=[driver_stats_fv[["conv_rate", "acc_rate"]]], + schema=[ + Field(name="conv_rate_plus_val1_python", dtype=Float64), + Field(name="conv_rate_plus_val2_python", dtype=Float64), + ], + mode="python", + ) + def python_demo_view(inputs: dict[str, Any]) -> dict[str, Any]: + output: dict[str, Any] = { + "conv_rate_plus_val1_python": [ + conv_rate + acc_rate for conv_rate, acc_rate in zip( inputs["conv_rate"], inputs["acc_rate"] ) - } - return output - - @on_demand_feature_view( - sources=[driver_stats_fv[["conv_rate", "acc_rate"]]], - schema=[ - Field(name="conv_rate_plus_val1_python", dtype=Float64), - Field(name="conv_rate_plus_val2_python", dtype=Float64), ], - mode="python", - ) - def python_demo_view(inputs: dict[str, Any]) -> dict[str, Any]: - output: dict[str, Any] = { - "conv_rate_plus_val1_python": [ - conv_rate + acc_rate - for conv_rate, acc_rate in zip( - inputs["conv_rate"], inputs["acc_rate"] - ) - ], - "conv_rate_plus_val2_python": [ - conv_rate + acc_rate - for conv_rate, acc_rate in zip( - inputs["conv_rate"], inputs["acc_rate"] - ) - ], - } - return output - - @on_demand_feature_view( - sources=[driver_stats_fv[["conv_rate", "acc_rate"]]], - schema=[ - Field(name="conv_rate_plus_acc_python_singleton", dtype=Float64), - Field( - name="conv_rate_plus_acc_python_singleton_array", - dtype=Array(Float64), - ), + "conv_rate_plus_val2_python": [ + conv_rate + acc_rate + for conv_rate, acc_rate in zip( + inputs["conv_rate"], inputs["acc_rate"] + ) ], - mode="python", - singleton=True, + } + return output + + @on_demand_feature_view( + sources=[driver_stats_fv[["conv_rate", "acc_rate"]]], + schema=[ + Field(name="conv_rate_plus_acc_python_singleton", dtype=Float64), + Field( + name="conv_rate_plus_acc_python_singleton_array", + dtype=Array(Float64), + ), + ], + mode="python", + singleton=True, + ) + def python_singleton_view(inputs: dict[str, Any]) -> dict[str, Any]: + output: dict[str, Any] = dict(conv_rate_plus_acc_python=float("-inf")) + output["conv_rate_plus_acc_python_singleton"] = ( + inputs["conv_rate"] + inputs["acc_rate"] ) - def python_singleton_view(inputs: dict[str, Any]) -> dict[str, Any]: - output: dict[str, Any] = dict(conv_rate_plus_acc_python=float("-inf")) - output["conv_rate_plus_acc_python_singleton"] = ( - inputs["conv_rate"] + inputs["acc_rate"] - ) - output["conv_rate_plus_acc_python_singleton_array"] = [0.1, 0.2, 0.3] - return output + output["conv_rate_plus_acc_python_singleton_array"] = [0.1, 0.2, 0.3] + return output - @on_demand_feature_view( - sources=[ - driver_stats_fv[["conv_rate", "acc_rate"]], - input_request_source, - ], - schema=[ - Field(name="conv_rate_plus_acc", dtype=Float64), - Field(name="current_datetime", dtype=UnixTimestamp), - Field(name="counter", dtype=Int64), - Field(name="input_datetime", dtype=UnixTimestamp), + @on_demand_feature_view( + sources=[ + driver_stats_fv[["conv_rate", "acc_rate"]], + input_request_source, + ], + schema=[ + Field(name="conv_rate_plus_acc", dtype=Float64), + Field(name="current_datetime", dtype=UnixTimestamp), + Field(name="counter", dtype=Int64), + Field(name="input_datetime", dtype=UnixTimestamp), + ], + mode="python", + write_to_online_store=True, + ) + def python_stored_writes_feature_view( + inputs: dict[str, Any], + ) -> dict[str, Any]: + output: dict[str, Any] = { + "conv_rate_plus_acc": [ + conv_rate + acc_rate + for conv_rate, acc_rate in zip( + inputs["conv_rate"], inputs["acc_rate"] + ) ], - mode="python", - write_to_online_store=True, - ) - def python_stored_writes_feature_view( - inputs: dict[str, Any], - ) -> dict[str, Any]: - output: dict[str, Any] = { - "conv_rate_plus_acc": [ - conv_rate + acc_rate - for conv_rate, acc_rate in zip( - inputs["conv_rate"], inputs["acc_rate"] - ) - ], - "current_datetime": [datetime.now() for _ in inputs["conv_rate"]], - "counter": [c + 1 for c in inputs["counter"]], - "input_datetime": [d for d in inputs["input_datetime"]], - } - return output + "current_datetime": [datetime.now() for _ in inputs["conv_rate"]], + "counter": [c + 1 for c in inputs["counter"]], + "input_datetime": [d for d in inputs["input_datetime"]], + } + return output - self.store.apply( - [ - driver, - driver_stats_source, - driver_stats_fv, - pandas_view, - python_view, - python_singleton_view, - python_demo_view, - driver_stats_entity_less_fv, - python_stored_writes_feature_view, - ] - ) - self.store.write_to_online_store( - feature_view_name="driver_hourly_stats", df=driver_df - ) - assert driver_stats_fv.entity_columns == [ - Field(name=driver.join_key, dtype=from_value_type(driver.value_type)) + self.store.apply( + [ + driver, + driver_stats_source, + driver_stats_fv, + pandas_view, + python_view, + python_singleton_view, + python_demo_view, + driver_stats_entity_less_fv, + python_stored_writes_feature_view, ] - assert driver_stats_entity_less_fv.entity_columns == [DUMMY_ENTITY_FIELD] + ) + self.store.write_to_online_store( + feature_view_name="driver_hourly_stats", df=driver_df + ) + assert driver_stats_fv.entity_columns == [ + Field(name=driver.join_key, dtype=from_value_type(driver.value_type)) + ] + assert driver_stats_entity_less_fv.entity_columns == [DUMMY_ENTITY_FIELD] - assert len(self.store.list_all_feature_views()) == 7 - assert len(self.store.list_feature_views()) == 2 - assert len(self.store.list_on_demand_feature_views()) == 5 - assert len(self.store.list_stream_feature_views()) == 0 + assert len(self.store.list_all_feature_views()) == 7 + assert len(self.store.list_feature_views()) == 2 + assert len(self.store.list_on_demand_feature_views()) == 5 + assert len(self.store.list_stream_feature_views()) == 0 + + def tearDown(self): + import shutil + + if hasattr(self, "data_dir"): + shutil.rmtree(self.data_dir, ignore_errors=True) def test_setup(self): pass