diff --git a/sdk/python/feast/infra/compute_engines/spark/utils.py b/sdk/python/feast/infra/compute_engines/spark/utils.py index 1234f895464..8c84c9f17a6 100644 --- a/sdk/python/feast/infra/compute_engines/spark/utils.py +++ b/sdk/python/feast/infra/compute_engines/spark/utils.py @@ -1,3 +1,5 @@ +import logging +import os from typing import Dict, Iterable, Literal, Optional import pandas as pd @@ -9,6 +11,102 @@ from feast.infra.common.serde import SerializedArtifacts from feast.utils import _convert_arrow_to_proto, _run_pyarrow_field_mapping +try: + import boto3 + from botocore.client import Config as BotoConfig +except ImportError: + boto3 = None # type: ignore[assignment] + BotoConfig = None # type: ignore[assignment,misc] + +logger = logging.getLogger(__name__) + + +def _ensure_s3a_event_log_dir(spark_config: Dict[str, str]) -> None: + """Pre-create the S3A event log prefix before SparkContext initialisation. + + Spark's EventLogFileWriter.requireLogBaseDirAsDirectory() is called inside + SparkContext.__init__ and crashes if the S3A path doesn't exist yet (S3 has no + real directories, so an empty prefix returns a 404). This function writes a + zero-byte placeholder so the prefix exists before SparkContext is built. + + This is only attempted when: + - spark.eventLog.enabled == "true" + - spark.eventLog.dir starts with "s3a://" + Failures are non-fatal: Spark will surface its own error if the dir is still missing. + """ + if spark_config.get("spark.eventLog.enabled", "false").lower() != "true": + return + event_dir = spark_config.get("spark.eventLog.dir", "") + if not event_dir.startswith("s3a://"): + return + + path = event_dir[len("s3a://") :] + bucket, _, prefix = path.partition("/") + prefix = prefix.rstrip("/") + prefix = (prefix + "/") if prefix else prefix + placeholder_key = prefix + ".keep" + + endpoint = spark_config.get( + "spark.hadoop.fs.s3a.endpoint", + os.environ.get("AWS_ENDPOINT_URL", ""), + ) + access_key = spark_config.get( + "spark.hadoop.fs.s3a.access.key", + os.environ.get("AWS_ACCESS_KEY_ID", ""), + ) + secret_key = spark_config.get( + "spark.hadoop.fs.s3a.secret.key", + os.environ.get("AWS_SECRET_ACCESS_KEY", ""), + ) + session_token = ( + spark_config.get( + "spark.hadoop.fs.s3a.session.token", + os.environ.get("AWS_SESSION_TOKEN", ""), + ) + or None + ) + + try: + if boto3 is None: + raise ImportError("boto3 is not installed") + + addressing_style = ( + "path" + if spark_config.get( + "spark.hadoop.fs.s3a.path.style.access", "false" + ).lower() + == "true" + else "auto" + ) + + s3 = boto3.client( + "s3", + endpoint_url=endpoint if endpoint else None, + aws_access_key_id=access_key or None, + aws_secret_access_key=secret_key or None, + aws_session_token=session_token, + config=BotoConfig( + signature_version="s3v4", + s3={"addressing_style": addressing_style}, + ), + ) + resp = s3.list_objects_v2(Bucket=bucket, Prefix=prefix, MaxKeys=1) + if resp.get("KeyCount", 0) == 0: + s3.put_object(Bucket=bucket, Key=placeholder_key, Body=b"") + logger.debug( + "Created S3A event log dir placeholder: s3a://%s/%s", + bucket, + placeholder_key, + ) + except Exception as exc: + logger.warning( + "Could not pre-create S3A event log dir s3a://%s/%s — " + "SparkContext may fail if the path still doesn't exist: %s", + bucket, + prefix, + exc, + ) + def get_or_create_new_spark_session( spark_config: Optional[Dict[str, str]] = None, @@ -17,6 +115,7 @@ def get_or_create_new_spark_session( if not spark_session: spark_builder = SparkSession.builder if spark_config: + _ensure_s3a_event_log_dir(spark_config) spark_builder = spark_builder.config( conf=SparkConf().setAll([(k, v) for k, v in spark_config.items()]) ) diff --git a/sdk/python/tests/component/spark/test_spark_utils.py b/sdk/python/tests/component/spark/test_spark_utils.py new file mode 100644 index 00000000000..84a680e7a25 --- /dev/null +++ b/sdk/python/tests/component/spark/test_spark_utils.py @@ -0,0 +1,274 @@ +from unittest.mock import MagicMock, patch + +from feast.infra.compute_engines.spark.utils import _ensure_s3a_event_log_dir + +BOTO3_PATH = "feast.infra.compute_engines.spark.utils.boto3" +BOTOCONFIG_PATH = "feast.infra.compute_engines.spark.utils.BotoConfig" + + +def _base_conf(event_log_dir: str) -> dict: + return { + "spark.eventLog.enabled": "true", + "spark.eventLog.dir": event_log_dir, + "spark.hadoop.fs.s3a.endpoint": "http://minio:9000", + } + + +@patch(BOTOCONFIG_PATH, MagicMock()) +@patch(BOTO3_PATH) +def test_ensure_s3a_event_log_dir_creates_placeholder_when_empty(mock_boto3): + """S3A prefix doesn't exist -> placeholder object is written.""" + s3 = MagicMock() + mock_boto3.client.return_value = s3 + s3.list_objects_v2.return_value = {"KeyCount": 0} + + _ensure_s3a_event_log_dir(_base_conf("s3a://my-bucket/spark-events/")) + + s3.list_objects_v2.assert_called_once_with( + Bucket="my-bucket", Prefix="spark-events/", MaxKeys=1 + ) + s3.put_object.assert_called_once_with( + Bucket="my-bucket", Key="spark-events/.keep", Body=b"" + ) + + +@patch(BOTOCONFIG_PATH, MagicMock()) +@patch(BOTO3_PATH) +def test_ensure_s3a_event_log_dir_skips_when_prefix_exists(mock_boto3): + """S3A prefix already has objects -> no placeholder written.""" + s3 = MagicMock() + mock_boto3.client.return_value = s3 + s3.list_objects_v2.return_value = {"KeyCount": 3} + + _ensure_s3a_event_log_dir(_base_conf("s3a://my-bucket/spark-events/")) + + s3.put_object.assert_not_called() + + +@patch(BOTOCONFIG_PATH, MagicMock()) +@patch(BOTO3_PATH) +def test_ensure_s3a_event_log_dir_noop_when_event_log_disabled(mock_boto3): + """spark.eventLog.enabled != true -> boto3 never called.""" + _ensure_s3a_event_log_dir( + {"spark.eventLog.enabled": "false", "spark.eventLog.dir": "s3a://b/p/"} + ) + mock_boto3.client.assert_not_called() + + +@patch(BOTOCONFIG_PATH, MagicMock()) +@patch(BOTO3_PATH) +def test_ensure_s3a_event_log_dir_noop_for_non_s3a_path(mock_boto3): + """Non-S3A paths (hdfs://, file://, etc.) are left untouched.""" + _ensure_s3a_event_log_dir( + {"spark.eventLog.enabled": "true", "spark.eventLog.dir": "hdfs:///spark-logs"} + ) + mock_boto3.client.assert_not_called() + + +@patch(BOTOCONFIG_PATH, MagicMock()) +@patch(BOTO3_PATH) +def test_ensure_s3a_event_log_dir_non_fatal_on_s3_error(mock_boto3): + """boto3 errors are swallowed -> SparkContext will surface its own error.""" + s3 = MagicMock() + mock_boto3.client.return_value = s3 + s3.list_objects_v2.side_effect = Exception("connection refused") + + _ensure_s3a_event_log_dir(_base_conf("s3a://my-bucket/spark-events/")) + + +# --------------------------------------------------------------------------- +# Bucket-root edge cases (s3a://bucket, s3a://bucket/) +# --------------------------------------------------------------------------- + + +@patch(BOTOCONFIG_PATH, MagicMock()) +@patch(BOTO3_PATH) +def test_ensure_s3a_event_log_dir_bucket_root_no_trailing_slash(mock_boto3): + """s3a://bucket (no path) -> .keep at bucket root, not /.keep.""" + s3 = MagicMock() + mock_boto3.client.return_value = s3 + s3.list_objects_v2.return_value = {"KeyCount": 0} + + _ensure_s3a_event_log_dir(_base_conf("s3a://my-bucket")) + + s3.list_objects_v2.assert_called_once_with(Bucket="my-bucket", Prefix="", MaxKeys=1) + s3.put_object.assert_called_once_with(Bucket="my-bucket", Key=".keep", Body=b"") + + +@patch(BOTOCONFIG_PATH, MagicMock()) +@patch(BOTO3_PATH) +def test_ensure_s3a_event_log_dir_bucket_root_trailing_slash(mock_boto3): + """s3a://bucket/ (trailing slash, empty prefix) -> .keep at bucket root.""" + s3 = MagicMock() + mock_boto3.client.return_value = s3 + s3.list_objects_v2.return_value = {"KeyCount": 0} + + _ensure_s3a_event_log_dir(_base_conf("s3a://my-bucket/")) + + s3.list_objects_v2.assert_called_once_with(Bucket="my-bucket", Prefix="", MaxKeys=1) + s3.put_object.assert_called_once_with(Bucket="my-bucket", Key=".keep", Body=b"") + + +# --------------------------------------------------------------------------- +# Credentials from spark config / env var fallback +# --------------------------------------------------------------------------- + + +@patch.dict( + "os.environ", + { + "AWS_ACCESS_KEY_ID": "env-ak", + "AWS_SECRET_ACCESS_KEY": "env-sk", # pragma: allowlist secret + "AWS_SESSION_TOKEN": "env-st", + }, +) +@patch(BOTOCONFIG_PATH, MagicMock()) +@patch(BOTO3_PATH) +def test_ensure_s3a_event_log_dir_uses_spark_config_credentials(mock_boto3): + """Credentials in spark config take precedence over env vars.""" + s3 = MagicMock() + mock_boto3.client.return_value = s3 + s3.list_objects_v2.return_value = {"KeyCount": 1} + + conf = { + **_base_conf("s3a://my-bucket/logs/"), + "spark.hadoop.fs.s3a.access.key": "spark-ak", + "spark.hadoop.fs.s3a.secret.key": "spark-sk", # pragma: allowlist secret + "spark.hadoop.fs.s3a.session.token": "spark-st", + } + _ensure_s3a_event_log_dir(conf) + + mock_boto3.client.assert_called_once() + kw = mock_boto3.client.call_args.kwargs + assert kw["aws_access_key_id"] == "spark-ak" + assert kw["aws_secret_access_key"] == "spark-sk" # pragma: allowlist secret + assert kw["aws_session_token"] == "spark-st" + + +@patch.dict( + "os.environ", + { + "AWS_ACCESS_KEY_ID": "env-ak", + "AWS_SECRET_ACCESS_KEY": "env-sk", # pragma: allowlist secret + "AWS_SESSION_TOKEN": "env-st", + }, +) +@patch(BOTOCONFIG_PATH, MagicMock()) +@patch(BOTO3_PATH) +def test_ensure_s3a_event_log_dir_falls_back_to_env_credentials(mock_boto3): + """Without spark config keys, env vars are used.""" + s3 = MagicMock() + mock_boto3.client.return_value = s3 + s3.list_objects_v2.return_value = {"KeyCount": 1} + + _ensure_s3a_event_log_dir(_base_conf("s3a://my-bucket/logs/")) + + mock_boto3.client.assert_called_once() + kw = mock_boto3.client.call_args.kwargs + assert kw["aws_access_key_id"] == "env-ak" + assert kw["aws_secret_access_key"] == "env-sk" # pragma: allowlist secret + assert kw["aws_session_token"] == "env-st" + + +@patch.dict("os.environ", {}, clear=True) +@patch(BOTOCONFIG_PATH, MagicMock()) +@patch(BOTO3_PATH) +def test_ensure_s3a_event_log_dir_no_credentials_passes_none(mock_boto3): + """No credentials anywhere -> None passed to boto3 (anonymous / instance role).""" + s3 = MagicMock() + mock_boto3.client.return_value = s3 + s3.list_objects_v2.return_value = {"KeyCount": 1} + + conf = { + "spark.eventLog.enabled": "true", + "spark.eventLog.dir": "s3a://my-bucket/logs/", + } + _ensure_s3a_event_log_dir(conf) + + mock_boto3.client.assert_called_once() + kw = mock_boto3.client.call_args.kwargs + assert kw["aws_access_key_id"] is None + assert kw["aws_secret_access_key"] is None + assert kw["aws_session_token"] is None + + +# --------------------------------------------------------------------------- +# Path-style addressing (MinIO / S3-compatible) +# --------------------------------------------------------------------------- + + +@patch(BOTOCONFIG_PATH) +@patch(BOTO3_PATH) +def test_ensure_s3a_event_log_dir_path_style_when_enabled(mock_boto3, mock_config_cls): + """spark.hadoop.fs.s3a.path.style.access=true -> addressing_style='path'.""" + s3 = MagicMock() + mock_boto3.client.return_value = s3 + s3.list_objects_v2.return_value = {"KeyCount": 1} + + conf = { + **_base_conf("s3a://my-bucket/logs/"), + "spark.hadoop.fs.s3a.path.style.access": "true", + } + _ensure_s3a_event_log_dir(conf) + + mock_config_cls.assert_called_once() + config_kwargs = mock_config_cls.call_args + assert config_kwargs.kwargs["s3"] == {"addressing_style": "path"} + + +@patch(BOTOCONFIG_PATH) +@patch(BOTO3_PATH) +def test_ensure_s3a_event_log_dir_virtual_hosted_style_by_default( + mock_boto3, mock_config_cls +): + """No path.style.access config -> addressing_style='auto'.""" + s3 = MagicMock() + mock_boto3.client.return_value = s3 + s3.list_objects_v2.return_value = {"KeyCount": 1} + + _ensure_s3a_event_log_dir(_base_conf("s3a://my-bucket/logs/")) + + mock_config_cls.assert_called_once() + config_kwargs = mock_config_cls.call_args + assert config_kwargs.kwargs["s3"] == {"addressing_style": "auto"} + + +# --------------------------------------------------------------------------- +# Endpoint env var fallback (AWS_ENDPOINT_URL) +# --------------------------------------------------------------------------- + + +@patch.dict("os.environ", {"AWS_ENDPOINT_URL": "http://localhost:9000"}, clear=True) +@patch(BOTOCONFIG_PATH, MagicMock()) +@patch(BOTO3_PATH) +def test_ensure_s3a_event_log_dir_endpoint_from_env(mock_boto3): + """AWS_ENDPOINT_URL env var is used when spark config has no endpoint.""" + s3 = MagicMock() + mock_boto3.client.return_value = s3 + s3.list_objects_v2.return_value = {"KeyCount": 1} + + conf = { + "spark.eventLog.enabled": "true", + "spark.eventLog.dir": "s3a://my-bucket/logs/", + } + _ensure_s3a_event_log_dir(conf) + + mock_boto3.client.assert_called_once() + kw = mock_boto3.client.call_args.kwargs + assert kw["endpoint_url"] == "http://localhost:9000" + + +@patch.dict("os.environ", {"AWS_ENDPOINT_URL": "http://env-endpoint:9000"}, clear=True) +@patch(BOTOCONFIG_PATH, MagicMock()) +@patch(BOTO3_PATH) +def test_ensure_s3a_event_log_dir_spark_endpoint_over_env(mock_boto3): + """spark.hadoop.fs.s3a.endpoint takes precedence over AWS_ENDPOINT_URL.""" + s3 = MagicMock() + mock_boto3.client.return_value = s3 + s3.list_objects_v2.return_value = {"KeyCount": 1} + + _ensure_s3a_event_log_dir(_base_conf("s3a://my-bucket/logs/")) + + mock_boto3.client.assert_called_once() + kw = mock_boto3.client.call_args.kwargs + assert kw["endpoint_url"] == "http://minio:9000"