Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import threading
from unittest.mock import MagicMock, patch

from feast.infra.offline_stores.contrib.trino_offline_store.trino_queries import (
Query,
)


def test_query_init_in_main_thread_registers_signals():
"""signal.signal() should work fine in main thread."""

# Should not raise any exception in main thread
cursor = MagicMock()
with patch("signal.signal") as mock_signal:
query = Query(query_text="SELECT 1", cursor=cursor)
assert query.query_text == "SELECT 1"
# Expected signal.signal to be called twice for SIGINT and SIGTERM
Comment on lines +15 to +17
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Suggestion] Test should verify signal handler registration details

The test only checks call_count but should also verify that the correct signals and handler function are being registered for more thorough validation.

Current code:

+        # Expected signal.signal to be called twice for SIGINT and SIGTERM
+        assert mock_signal.call_count == 2

Suggested:

Suggested change
query = Query(query_text="SELECT 1", cursor=cursor)
assert query.query_text == "SELECT 1"
# Expected signal.signal to be called twice for SIGINT and SIGTERM
# Verify signal handlers are registered correctly
assert mock_signal.call_count == 2
mock_signal.assert_any_call(signal.SIGINT, query.cancel)
mock_signal.assert_any_call(signal.SIGTERM, query.cancel)

assert mock_signal.call_count == 2


def test_query_init_in_worker_thread_does_not_raise():
"""Regression test: signal.signal() fails in non-main threads."""
# signal.signal() raises ValueError when called outside the main thread.
# This test verifies the fix guards against that by running Query.__init__
# in a worker thread and ensuring no exception is raised.

errors = []
cursor = MagicMock()

def create_query():
try:
query = Query(query_text="SELECT 1", cursor=cursor)
assert query.query_text == "SELECT 1"
except ValueError as e:
errors.append(e)

thread = threading.Thread(target=create_query)
thread.start()
thread.join()

assert not errors, f"Unexpected ValueError in worker thread: {errors[0]}"
Comment on lines +40 to +41
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Suggestion] Improve error message handling

The error message assumes errors[0] exists, but this could cause an IndexError if errors is empty. Consider a more defensive approach.

Current code:

+    assert not errors, f"Unexpected ValueError in worker thread: {errors[0]}"

Suggested:

Suggested change
assert not errors, f"Unexpected ValueError in worker thread: {errors[0]}"
assert not errors, f"Unexpected ValueError in worker thread: {errors[0] if errors else 'No errors captured'}"

Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import signal
import threading
from dataclasses import dataclass
from enum import Enum
from typing import Any, Dict, List, Optional
Expand Down Expand Up @@ -92,8 +93,9 @@ def __init__(self, query_text: str, cursor: Cursor):
self.status = QueryStatus.PENDING
self._cursor = cursor

signal.signal(signal.SIGINT, self.cancel)
signal.signal(signal.SIGTERM, self.cancel)
if threading.current_thread() is threading.main_thread():
signal.signal(signal.SIGINT, self.cancel)
signal.signal(signal.SIGTERM, self.cancel)

def execute(self) -> Results:
try:
Expand Down
Loading