Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions sdk/python/feast/infra/compute_engines/dag/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,18 @@ def created_timestamp_column(self) -> Optional[str]:
"""
return self._get_mapped_column(self.created_ts_col)

@property
def join_keys_columns(self) -> List[str]:
"""
Get the join keys, mapped through field_mapping to their post-rename
column names. Use this when looking up columns on a DataFrame that has
already had its source columns renamed (e.g. inside DAG nodes that
consume the output of a source-read node).
"""
if not self.field_mapping:
return list(self.join_keys)
return [self.field_mapping.get(key, key) for key in self.join_keys]

def _get_mapped_column(self, column: Optional[str]) -> Optional[str]:
"""
Helper method to get the mapped column name if it exists in field_mapping.
Expand Down
14 changes: 10 additions & 4 deletions sdk/python/feast/infra/compute_engines/local/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,14 +79,18 @@ def execute(self, context: ExecutionContext) -> ArrowTableValue:
for val in input_values:
val.assert_format(DAGFormat.ARROW)

# The upstream source-read node has already renamed columns via
# field_mapping, so use the mapped join keys for joining (see #5942).
join_keys = self.column_info.join_keys_columns

# Convert all upstream ArrowTables to backend DataFrames
joined_df = self.backend.from_arrow(input_values[0].data)
for val in input_values[1:]:
next_df = self.backend.from_arrow(val.data)
joined_df = self.backend.join(
joined_df,
next_df,
on=self.column_info.join_keys,
on=join_keys,
how=self.how,
)

Expand All @@ -105,7 +109,7 @@ def execute(self, context: ExecutionContext) -> ArrowTableValue:
joined_df = self.backend.join(
entity_df,
joined_df,
on=self.column_info.join_keys,
on=join_keys,
how="left",
)

Expand Down Expand Up @@ -193,8 +197,10 @@ def execute(self, context: ExecutionContext) -> ArrowTableValue:

# Extract join_keys, timestamp, and created_ts from context

# Dedup strategy: sort and drop_duplicates
dedup_keys = self.column_info.join_keys
# Dedup strategy: sort and drop_duplicates. Use the mapped join key
# names so we look up the columns that the source-read node has
# already renamed (see issue #5942).
dedup_keys = self.column_info.join_keys_columns
if dedup_keys:
sort_keys = [self.column_info.timestamp_column]
if (
Expand Down
116 changes: 116 additions & 0 deletions sdk/python/tests/unit/infra/compute_engines/local/test_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,122 @@ def test_local_dedup_node():
assert set(df_result["entity_id"]) == {1, 2}


def test_local_dedup_node_with_field_mapping_on_join_key():
"""Regression test for materialization failure when a join key has a field mapping.

The source-read node renames columns via field_mapping (e.g. ``USERID`` -> ``user_id``)
before passing the table to downstream nodes. Without mapping ``column_info.join_keys``
the dedup node would look up the pre-mapping name and raise ``KeyError(['USERID'])``.

See https://github.com/feast-dev/feast/issues/5942.
"""
# Simulate a source-read node output: columns already renamed to the mapped names.
df = pd.DataFrame(
{
"user_id": [1, 1, 2],
"value": [100, 200, 300],
"event_timestamp": [
now - timedelta(seconds=1),
now,
now,
],
}
)

context = create_context(
node_outputs={"source": ArrowTableValue(pa.Table.from_pandas(df))}
)

node = LocalDedupNode(
name="dedup",
backend=backend,
column_info=ColumnInfo(
# The raw join key matches the source column name; field_mapping maps
# it to the user-facing name that the source-read node has already
# renamed the column to.
join_keys=["USERID"],
feature_cols=["value"],
ts_col="EVENT_TIMESTAMP",
created_ts_col=None,
field_mapping={"USERID": "user_id", "EVENT_TIMESTAMP": "event_timestamp"},
),
)
node.add_input(MagicMock())
node.inputs[0].name = "source"

result = node.execute(context)

df_result = result.data.to_pandas()
assert df_result.shape[0] == 2
assert set(df_result["user_id"]) == {1, 2}


def test_local_join_node_with_field_mapping_on_join_key():
"""Regression test for materialization failure when a join key has a field mapping.

The source-read node renames columns via field_mapping (e.g. ``USERID`` -> ``user_id``)
before passing the table to downstream nodes. Without mapping ``column_info.join_keys``
the join node would call ``backend.join(..., on=["USERID"], ...)`` and raise
``KeyError(['USERID'])`` because the columns have already been renamed.

See https://github.com/feast-dev/feast/issues/5942.
"""
# Simulate two source-read node outputs: columns already renamed to the mapped names.
left_df = pd.DataFrame(
{
"user_id": [1, 2],
"value": [10, 20],
"event_timestamp": [now, now],
}
)
right_df = pd.DataFrame(
{
"user_id": [1, 2],
"other_value": [100, 200],
"event_timestamp": [now, now],
}
)

context = create_context(
node_outputs={
"left": ArrowTableValue(pa.Table.from_pandas(left_df)),
"right": ArrowTableValue(pa.Table.from_pandas(right_df)),
}
)
# Bypass the trailing entity_df join — this test exercises the input-table
# join path that consumed the raw (unmapped) join keys before the fix.
context.entity_df = None

join_node = LocalJoinNode(
name="join",
backend=backend,
column_info=ColumnInfo(
# Raw join key matches the source column name; field_mapping maps it
# to the user-facing name that the source-read node has already
# renamed the column to.
join_keys=["USERID"],
feature_cols=["value", "other_value"],
ts_col="EVENT_TIMESTAMP",
created_ts_col=None,
field_mapping={"USERID": "user_id", "EVENT_TIMESTAMP": "event_timestamp"},
),
)
left_input = MagicMock()
left_input.name = "left"
right_input = MagicMock()
right_input.name = "right"
join_node.add_input(left_input)
join_node.add_input(right_input)

result = join_node.execute(context)

df_result = result.data.to_pandas()
assert df_result.shape[0] == 2
assert set(df_result["user_id"]) == {1, 2}
assert "value" in df_result.columns
assert "other_value" in df_result.columns


def test_local_transformation_node():
context = create_context(
node_outputs={"source": ArrowTableValue(pa.Table.from_pandas(sample_df))}
Expand Down
Loading