diff --git a/charts/model-engine/templates/_helpers.tpl b/charts/model-engine/templates/_helpers.tpl index acfe93f5..8157eb15 100644 --- a/charts/model-engine/templates/_helpers.tpl +++ b/charts/model-engine/templates/_helpers.tpl @@ -354,6 +354,7 @@ env: - name: LAUNCH_SERVICE_TEMPLATE_FOLDER value: "/workspace/model-engine/model_engine_server/infra/gateways/resources/templates" {{- $model_cache := default dict .Values.modelCache }} + {{- $gcp_cloud_provider := and .Values.config .Values.config.values .Values.config.values.infra (eq (.Values.config.values.infra.cloud_provider | default "") "gcp") -}} - name: MODEL_CACHE_ENABLED value: {{ get $model_cache "enabled" | default false | quote }} - name: MODEL_CACHE_MOUNT_PATH @@ -404,6 +405,14 @@ env: - name: SERVICEBUS_NAMESPACE value: {{ .Values.azure.servicebus_namespace }} {{- end }} + {{- if $gcp_cloud_provider }} + - name: GCP_PROJECT_ID + value: {{ (.Values.gcp).project_id | default "" | quote }} + - name: PUBSUB_TOPIC_PREFIX + value: {{ (.Values.gcp).pubsub_topic_prefix | default "" | quote }} + - name: PUBSUB_SUBSCRIPTION_PREFIX + value: {{ (.Values.gcp).pubsub_subscription_prefix | default "" | quote }} + {{- end }} {{- if eq .Values.context "circleci" }} - name: CIRCLECI value: "true" diff --git a/charts/model-engine/templates/celery_autoscaler_stateful_set.yaml b/charts/model-engine/templates/celery_autoscaler_stateful_set.yaml index 74c8cf44..6449a6b1 100644 --- a/charts/model-engine/templates/celery_autoscaler_stateful_set.yaml +++ b/charts/model-engine/templates/celery_autoscaler_stateful_set.yaml @@ -5,12 +5,13 @@ {{- $tag := .Values.tag }} {{- $message_broker := .Values.celeryBrokerType }} {{- $num_shards := .Values.celery_autoscaler.num_shards }} +{{- $gcp_cloud_provider := and .Values.config .Values.config.values .Values.config.values.infra (eq (.Values.config.values.infra.cloud_provider | default "") "gcp") -}} {{- $broker_name := "redis-elasticache-message-broker-master" }} {{- if eq $message_broker "sqs" }} {{ $broker_name = "sqs-message-broker-master" }} {{- else if eq $message_broker "servicebus" }} {{ $broker_name = "servicebus-message-broker-master" }} -{{- else if and .Values.config .Values.config.values .Values.config.values.infra (eq (.Values.config.values.infra.cloud_provider | default "") "gcp") }} +{{- else if $gcp_cloud_provider }} {{ $broker_name = "redis-gcp-memorystore-message-broker-master" }} {{- end }} apiVersion: apps/v1 @@ -89,6 +90,14 @@ spec: - name: SERVICEBUS_NAMESPACE value: {{ .Values.azure.servicebus_namespace }} {{- end }} + {{- if $gcp_cloud_provider }} + - name: GCP_PROJECT_ID + value: {{ (.Values.gcp).project_id | default "" | quote }} + - name: PUBSUB_TOPIC_PREFIX + value: {{ (.Values.gcp).pubsub_topic_prefix | default "" | quote }} + - name: PUBSUB_SUBSCRIPTION_PREFIX + value: {{ (.Values.gcp).pubsub_subscription_prefix | default "" | quote }} + {{- end }} image: "{{ .Values.image.gatewayRepository }}:{{ $tag }}" imagePullPolicy: Always name: main diff --git a/charts/model-engine/values.yaml b/charts/model-engine/values.yaml index 7509a88f..bae62ce3 100644 --- a/charts/model-engine/values.yaml +++ b/charts/model-engine/values.yaml @@ -94,3 +94,9 @@ utilityImages: # Additional GPU tolerations for endpoint pods gpuTolerations: [] + +# GCP configuration for GCP-based deployments +gcp: + project_id: "" + pubsub_topic_prefix: "launch-endpoint-id-" + pubsub_subscription_prefix: "launch-endpoint-id-" diff --git a/charts/model-engine/values_sample.yaml b/charts/model-engine/values_sample.yaml index b8dc6fb7..dd683a97 100644 --- a/charts/model-engine/values_sample.yaml +++ b/charts/model-engine/values_sample.yaml @@ -438,3 +438,9 @@ recommendedHardware: gpu_type: nvidia-hopper-h100 nodes_per_worker: 1 #serviceBuilderQueue: + +# GCP configuration for GCP-based deployments +gcp: + project_id: "your-gcp-project" + pubsub_topic_prefix: "launch-endpoint-id-" + pubsub_subscription_prefix: "launch-endpoint-id-" diff --git a/model-engine/model_engine_server/api/dependencies.py b/model-engine/model_engine_server/api/dependencies.py index f28425d8..71092932 100644 --- a/model-engine/model_engine_server/api/dependencies.py +++ b/model-engine/model_engine_server/api/dependencies.py @@ -89,6 +89,9 @@ from model_engine_server.infra.gateways.resources.asb_queue_endpoint_resource_delegate import ( ASBQueueEndpointResourceDelegate, ) +from model_engine_server.infra.gateways.resources.gcp_pubsub_queue_endpoint_resource_delegate import ( + GcpPubSubQueueEndpointResourceDelegate, +) from model_engine_server.infra.gateways.resources.endpoint_resource_gateway import ( EndpointResourceGateway, ) @@ -248,8 +251,9 @@ def _get_external_interfaces( elif infra_config().cloud_provider == "azure": queue_delegate = ASBQueueEndpointResourceDelegate() elif infra_config().cloud_provider == "gcp": - # GCP uses Redis (Memorystore) for Celery, so use Redis-based queue delegate - queue_delegate = RedisQueueEndpointResourceDelegate(redis_client=redis_client) + queue_delegate = GcpPubSubQueueEndpointResourceDelegate( + project_id=infra_config().gcp_project_id, + ) else: queue_delegate = SQSQueueEndpointResourceDelegate( sqs_profile=os.getenv("SQS_PROFILE", hmi_config.sqs_profile) diff --git a/model-engine/model_engine_server/core/config.py b/model-engine/model_engine_server/core/config.py index 6886174f..4d837d6e 100644 --- a/model-engine/model_engine_server/core/config.py +++ b/model-engine/model_engine_server/core/config.py @@ -54,6 +54,7 @@ class _InfraConfig: celery_enable_sha256: Optional[bool] = None docker_registry_type: Optional[str] = None debug_mode: Optional[bool] = None + gcp_project_id: Optional[str] = None @dataclass diff --git a/model-engine/model_engine_server/entrypoints/k8s_cache.py b/model-engine/model_engine_server/entrypoints/k8s_cache.py index 36bd8e96..0084ad78 100644 --- a/model-engine/model_engine_server/entrypoints/k8s_cache.py +++ b/model-engine/model_engine_server/entrypoints/k8s_cache.py @@ -21,6 +21,9 @@ from model_engine_server.infra.gateways.resources.asb_queue_endpoint_resource_delegate import ( ASBQueueEndpointResourceDelegate, ) +from model_engine_server.infra.gateways.resources.gcp_pubsub_queue_endpoint_resource_delegate import ( + GcpPubSubQueueEndpointResourceDelegate, +) from model_engine_server.infra.gateways.resources.endpoint_resource_gateway import ( EndpointResourceGateway, ) @@ -119,6 +122,10 @@ async def main(args: Any): queue_delegate = OnPremQueueEndpointResourceDelegate() elif infra_config().cloud_provider == "azure": queue_delegate = ASBQueueEndpointResourceDelegate() + elif infra_config().cloud_provider == "gcp": + queue_delegate = GcpPubSubQueueEndpointResourceDelegate( + project_id=infra_config().gcp_project_id, + ) else: queue_delegate = SQSQueueEndpointResourceDelegate( sqs_profile=os.getenv("SQS_PROFILE", hmi_config.sqs_profile) diff --git a/model-engine/model_engine_server/entrypoints/start_batch_job_orchestration.py b/model-engine/model_engine_server/entrypoints/start_batch_job_orchestration.py index 3d659c4a..46bec9ba 100644 --- a/model-engine/model_engine_server/entrypoints/start_batch_job_orchestration.py +++ b/model-engine/model_engine_server/entrypoints/start_batch_job_orchestration.py @@ -31,6 +31,9 @@ from model_engine_server.infra.gateways.resources.asb_queue_endpoint_resource_delegate import ( ASBQueueEndpointResourceDelegate, ) +from model_engine_server.infra.gateways.resources.gcp_pubsub_queue_endpoint_resource_delegate import ( + GcpPubSubQueueEndpointResourceDelegate, +) from model_engine_server.infra.gateways.resources.fake_queue_endpoint_resource_delegate import ( FakeQueueEndpointResourceDelegate, ) @@ -90,6 +93,10 @@ async def run_batch_job( queue_delegate = OnPremQueueEndpointResourceDelegate() elif infra_config().cloud_provider == "azure": queue_delegate = ASBQueueEndpointResourceDelegate() + elif infra_config().cloud_provider == "gcp": + queue_delegate = GcpPubSubQueueEndpointResourceDelegate( + project_id=infra_config().gcp_project_id, + ) else: queue_delegate = SQSQueueEndpointResourceDelegate( sqs_profile=os.getenv("SQS_PROFILE", hmi_config.sqs_profile) @@ -110,6 +117,9 @@ async def run_batch_job( if infra_config().cloud_provider == "azure": inference_task_queue_gateway = servicebus_task_queue_gateway infra_task_queue_gateway = servicebus_task_queue_gateway + elif infra_config().cloud_provider == "gcp": + inference_task_queue_gateway = redis_task_queue_gateway + infra_task_queue_gateway = redis_task_queue_gateway elif infra_config().cloud_provider == "onprem" or infra_config().celery_broker_type_redis: # On-prem uses Redis-based task queues inference_task_queue_gateway = redis_task_queue_gateway diff --git a/model-engine/model_engine_server/infra/gateways/resources/gcp_pubsub_queue_endpoint_resource_delegate.py b/model-engine/model_engine_server/infra/gateways/resources/gcp_pubsub_queue_endpoint_resource_delegate.py new file mode 100644 index 00000000..0dce7fd5 --- /dev/null +++ b/model-engine/model_engine_server/infra/gateways/resources/gcp_pubsub_queue_endpoint_resource_delegate.py @@ -0,0 +1,133 @@ +from typing import Any, Dict, Optional + +from google.api_core import exceptions as gcp_exceptions +from google.cloud import pubsub_v1 +from google.protobuf import field_mask_pb2 +from model_engine_server.core.loggers import logger_name, make_logger +from model_engine_server.domain.exceptions import EndpointResourceInfraException +from model_engine_server.infra.gateways.resources.queue_endpoint_resource_delegate import ( + QueueEndpointResourceDelegate, + QueueInfo, +) + +logger = make_logger(logger_name()) + +GCP_PUBSUB_MAX_ACK_DEADLINE_SECONDS = 600 # Pub/Sub hard limit + + +class GcpPubSubQueueEndpointResourceDelegate(QueueEndpointResourceDelegate): + """ + Using GCP Pub/Sub (topic + subscription per endpoint). + + topic_prefix and subscription_prefix control the GCP resource name prefix. + The logical queue_name returned to callers always uses the canonical + QueueEndpointResourceDelegate.endpoint_id_to_queue_name format, independent + of these prefixes. + """ + + def __init__( + self, + project_id: str, + topic_prefix: str = "launch-endpoint-id-", + subscription_prefix: str = "launch-endpoint-id-", + ) -> None: + if not project_id: + raise ValueError( + "GcpPubSubQueueEndpointResourceDelegate requires a non-empty project_id; " + "set infra.gcp_project_id in the service config." + ) + self.project_id = project_id + self.topic_prefix = topic_prefix + self.subscription_prefix = subscription_prefix + self._publisher = pubsub_v1.PublisherClient() + self._subscriber = pubsub_v1.SubscriberClient() + + def _topic_id(self, endpoint_id: str) -> str: + return f"{self.topic_prefix}{endpoint_id}" + + def _subscription_id(self, endpoint_id: str) -> str: + return f"{self.subscription_prefix}{endpoint_id}" + + async def create_queue_if_not_exists( + self, + endpoint_id: str, + endpoint_name: str, + endpoint_created_by: str, + endpoint_labels: Dict[str, Any], + queue_message_timeout_seconds: Optional[int] = None, + ) -> QueueInfo: + queue_name = QueueEndpointResourceDelegate.endpoint_id_to_queue_name( + endpoint_id + ) + topic_path = f"projects/{self.project_id}/topics/{self._topic_id(endpoint_id)}" + subscription_path = f"projects/{self.project_id}/subscriptions/{self._subscription_id(endpoint_id)}" + ack_deadline = min( + queue_message_timeout_seconds or 60, GCP_PUBSUB_MAX_ACK_DEADLINE_SECONDS + ) + + try: + self._publisher.create_topic(name=topic_path) + except gcp_exceptions.AlreadyExists: + pass + + try: + self._subscriber.create_subscription( + name=subscription_path, + topic=topic_path, + ack_deadline_seconds=ack_deadline, + ) + except gcp_exceptions.AlreadyExists: + try: + self._subscriber.update_subscription( + subscription=pubsub_v1.types.Subscription( + name=subscription_path, + ack_deadline_seconds=ack_deadline, + ), + update_mask=field_mask_pb2.FieldMask( + paths=["ack_deadline_seconds"] + ), + ) + except gcp_exceptions.GoogleAPIError as e: + logger.warning( + f"Failed to update ack_deadline for Pub/Sub subscription {subscription_path}: {e}" + ) + + # Pub/Sub has no URL concept analogous to SQS queue URLs + return QueueInfo(queue_name, queue_url=None) + + async def delete_queue(self, endpoint_id: str) -> None: + subscription_path = f"projects/{self.project_id}/subscriptions/{self._subscription_id(endpoint_id)}" + topic_path = f"projects/{self.project_id}/topics/{self._topic_id(endpoint_id)}" + + try: + self._subscriber.delete_subscription(subscription=subscription_path) + except gcp_exceptions.NotFound: + logger.info( + f"Could not find Pub/Sub subscription {subscription_path} for endpoint {endpoint_id}" + ) + except gcp_exceptions.GoogleAPIError as e: + raise EndpointResourceInfraException( + f"Failed to delete Pub/Sub subscription {subscription_path} for endpoint {endpoint_id}: {e}" + ) from e + + try: + self._publisher.delete_topic(topic=topic_path) + except gcp_exceptions.NotFound: + logger.info( + f"Could not find Pub/Sub topic {topic_path} for endpoint {endpoint_id}" + ) + except gcp_exceptions.GoogleAPIError as e: + raise EndpointResourceInfraException( + f"Failed to delete Pub/Sub topic {topic_path} for endpoint {endpoint_id}: {e}" + ) from e + + async def get_queue_attributes(self, endpoint_id: str) -> Dict[str, Any]: + queue_name = QueueEndpointResourceDelegate.endpoint_id_to_queue_name( + endpoint_id + ) + return { + "name": queue_name, + # Pub/Sub does not expose a synchronous undelivered message count; + # real observability requires the Cloud Monitoring API as a separate concern. + "num_undelivered_messages": -1, + } diff --git a/model-engine/model_engine_server/infra/gateways/resources/live_endpoint_resource_gateway.py b/model-engine/model_engine_server/infra/gateways/resources/live_endpoint_resource_gateway.py index 3a028b57..11c79704 100644 --- a/model-engine/model_engine_server/infra/gateways/resources/live_endpoint_resource_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/resources/live_endpoint_resource_gateway.py @@ -1,6 +1,8 @@ from typing import Dict, Optional, Tuple -from model_engine_server.common.dtos.resource_manager import CreateOrUpdateResourcesRequest +from model_engine_server.common.dtos.resource_manager import ( + CreateOrUpdateResourcesRequest, +) from model_engine_server.core.loggers import logger_name, make_logger from model_engine_server.domain.entities import ( ModelEndpointInfraState, @@ -28,11 +30,15 @@ class LiveEndpointResourceGateway(EndpointResourceGateway[QueueInfo]): def __init__( self, queue_delegate: QueueEndpointResourceDelegate, - inference_autoscaling_metrics_gateway: Optional[InferenceAutoscalingMetricsGateway], + inference_autoscaling_metrics_gateway: Optional[ + InferenceAutoscalingMetricsGateway + ], ): self.k8s_delegate = K8SEndpointResourceDelegate() self.queue_delegate = queue_delegate - self.inference_autoscaling_metrics_gateway = inference_autoscaling_metrics_gateway + self.inference_autoscaling_metrics_gateway = ( + inference_autoscaling_metrics_gateway + ) async def create_queue( self, @@ -79,7 +85,9 @@ async def create_or_update_resources( sqs_queue_name=queue_name, sqs_queue_url=queue_url, ) - return EndpointResourceGatewayCreateOrUpdateResourcesResponse(destination=destination) + return EndpointResourceGatewayCreateOrUpdateResourcesResponse( + destination=destination + ) async def get_resources( self, endpoint_id: str, deployment_name: str, endpoint_type: ModelEndpointType @@ -91,7 +99,9 @@ async def get_resources( ) if endpoint_type == ModelEndpointType.ASYNC: - sqs_attributes = await self.queue_delegate.get_queue_attributes(endpoint_id=endpoint_id) + sqs_attributes = await self.queue_delegate.get_queue_attributes( + endpoint_id=endpoint_id + ) if ( "Attributes" in sqs_attributes and "ApproximateNumberOfMessages" in sqs_attributes["Attributes"] @@ -99,8 +109,18 @@ async def get_resources( resources.num_queued_items = int( sqs_attributes["Attributes"]["ApproximateNumberOfMessages"] ) - elif "active_message_count" in sqs_attributes: # from ASBQueueEndpointResourceDelegate + elif ( + "active_message_count" in sqs_attributes + ): # from ASBQueueEndpointResourceDelegate resources.num_queued_items = int(sqs_attributes["active_message_count"]) + elif ( + "num_undelivered_messages" in sqs_attributes + ): # from GcpPubSubQueueEndpointResourceDelegate + # Pub/Sub returns -1 when num_undelivered_messages is not yet wired to Cloud Monitoring. + # Treat -1 as "unknown" and skip; downstream autoscaling expects non-negative counts. + gcp_count = int(sqs_attributes["num_undelivered_messages"]) + if gcp_count >= 0: + resources.num_queued_items = gcp_count return resources @@ -125,7 +145,9 @@ async def delete_resources( sqs_result = False if self.inference_autoscaling_metrics_gateway is not None: - await self.inference_autoscaling_metrics_gateway.delete_resources(endpoint_id) + await self.inference_autoscaling_metrics_gateway.delete_resources( + endpoint_id + ) return k8s_result and sqs_result diff --git a/model-engine/requirements.in b/model-engine/requirements.in index 07e849aa..16ed6b92 100644 --- a/model-engine/requirements.in +++ b/model-engine/requirements.in @@ -12,6 +12,7 @@ azure-storage-blob~=12.19.0 # GCP dependencies gcloud-aio-storage~=9.6 google-auth~=2.25.0 +google-cloud-pubsub>=2.18 google-cloud-artifact-registry~=1.21.0 google-cloud-secret-manager>=2.24.0 google-cloud-storage~=2.14.0 diff --git a/model-engine/tests/unit/infra/gateways/resources/test_gcp_pubsub_queue_endpoint_resource_delegate.py b/model-engine/tests/unit/infra/gateways/resources/test_gcp_pubsub_queue_endpoint_resource_delegate.py new file mode 100644 index 00000000..10b6e0d5 --- /dev/null +++ b/model-engine/tests/unit/infra/gateways/resources/test_gcp_pubsub_queue_endpoint_resource_delegate.py @@ -0,0 +1,212 @@ +from unittest.mock import MagicMock, patch + +import pytest +from google.api_core import exceptions as gcp_exceptions +from model_engine_server.domain.exceptions import EndpointResourceInfraException +from model_engine_server.infra.gateways.resources.gcp_pubsub_queue_endpoint_resource_delegate import ( + GcpPubSubQueueEndpointResourceDelegate, +) + +MODULE_PATH = "model_engine_server.infra.gateways.resources.gcp_pubsub_queue_endpoint_resource_delegate" + +ENDPOINT_ID = "test_endpoint_id" +PROJECT_ID = "test-project" +TOPIC_PREFIX = "launch-endpoint-id-" +SUBSCRIPTION_PREFIX = "launch-endpoint-id-" +QUEUE_NAME = f"{TOPIC_PREFIX}{ENDPOINT_ID}" + + +@pytest.fixture +def mock_publisher(): + with patch(f"{MODULE_PATH}.pubsub_v1.PublisherClient") as mock_cls: + yield mock_cls.return_value + + +@pytest.fixture +def mock_subscriber(): + with patch(f"{MODULE_PATH}.pubsub_v1.SubscriberClient") as mock_cls: + yield mock_cls.return_value + + +@pytest.fixture +def delegate(mock_publisher, mock_subscriber): + return GcpPubSubQueueEndpointResourceDelegate(project_id=PROJECT_ID) + + +def test_init_empty_project_id_raises(): + with pytest.raises(ValueError, match="non-empty project_id"): + GcpPubSubQueueEndpointResourceDelegate(project_id="") + + +@pytest.mark.asyncio +async def test_create_queue_if_not_exists_new( + mock_publisher, mock_subscriber, delegate +): + """Both topic and subscription are created when neither exists.""" + result = await delegate.create_queue_if_not_exists( + endpoint_id=ENDPOINT_ID, + endpoint_name="test_endpoint", + endpoint_created_by="test_user", + endpoint_labels={"team": "test"}, + ) + + topic_path = f"projects/{PROJECT_ID}/topics/{TOPIC_PREFIX}{ENDPOINT_ID}" + subscription_path = ( + f"projects/{PROJECT_ID}/subscriptions/{SUBSCRIPTION_PREFIX}{ENDPOINT_ID}" + ) + + mock_publisher.create_topic.assert_called_once_with(name=topic_path) + mock_subscriber.create_subscription.assert_called_once_with( + name=subscription_path, + topic=topic_path, + ack_deadline_seconds=60, # default when timeout is None + ) + assert result.queue_name == QUEUE_NAME + assert result.queue_url is None + + +@pytest.mark.asyncio +async def test_create_queue_if_not_exists_topic_already_exists( + mock_publisher, mock_subscriber, delegate +): + """AlreadyExists on topic creation is silenced; subscription still attempts creation.""" + mock_publisher.create_topic.side_effect = gcp_exceptions.AlreadyExists( + "topic exists" + ) + + result = await delegate.create_queue_if_not_exists( + endpoint_id=ENDPOINT_ID, + endpoint_name="test_endpoint", + endpoint_created_by="test_user", + endpoint_labels={}, + ) + + mock_subscriber.create_subscription.assert_called_once() + assert result.queue_name == QUEUE_NAME + + +@pytest.mark.asyncio +async def test_create_queue_if_not_exists_subscription_already_exists_updates_ack_deadline( + mock_publisher, mock_subscriber, delegate +): + """AlreadyExists on subscription triggers an update_subscription call.""" + mock_subscriber.create_subscription.side_effect = gcp_exceptions.AlreadyExists( + "subscription exists" + ) + + result = await delegate.create_queue_if_not_exists( + endpoint_id=ENDPOINT_ID, + endpoint_name="test_endpoint", + endpoint_created_by="test_user", + endpoint_labels={}, + queue_message_timeout_seconds=120, + ) + + mock_publisher.create_topic.assert_called_once() + mock_subscriber.update_subscription.assert_called_once() + assert result.queue_name == QUEUE_NAME + + +@pytest.mark.asyncio +async def test_create_queue_subscription_already_exists_update_failure_is_warned( + mock_publisher, mock_subscriber, delegate +): + """update_subscription GoogleAPIError is swallowed with a warning (not raised).""" + mock_subscriber.create_subscription.side_effect = gcp_exceptions.AlreadyExists( + "exists" + ) + mock_subscriber.update_subscription.side_effect = gcp_exceptions.GoogleAPIError( + "boom" + ) + + # Should not raise + result = await delegate.create_queue_if_not_exists( + endpoint_id=ENDPOINT_ID, + endpoint_name="test_endpoint", + endpoint_created_by="test_user", + endpoint_labels={}, + ) + assert result.queue_name == QUEUE_NAME + + +@pytest.mark.asyncio +async def test_delete_queue_subscription_not_found_silent( + mock_publisher, mock_subscriber, delegate +): + """NotFound on subscription deletion is silenced; topic deletion still attempts.""" + mock_subscriber.delete_subscription.side_effect = gcp_exceptions.NotFound( + "sub not found" + ) + + await delegate.delete_queue(endpoint_id=ENDPOINT_ID) + + mock_subscriber.delete_subscription.assert_called_once() + mock_publisher.delete_topic.assert_called_once() + + +@pytest.mark.asyncio +async def test_delete_queue_topic_not_found_silent( + mock_publisher, mock_subscriber, delegate +): + """NotFound on topic deletion is silenced.""" + mock_publisher.delete_topic.side_effect = gcp_exceptions.NotFound("topic not found") + + await delegate.delete_queue(endpoint_id=ENDPOINT_ID) + + mock_subscriber.delete_subscription.assert_called_once() + mock_publisher.delete_topic.assert_called_once() + + +@pytest.mark.asyncio +async def test_delete_queue_subscription_api_error_raises( + mock_publisher, mock_subscriber, delegate +): + """Non-NotFound GoogleAPIError on subscription deletion raises EndpointResourceInfraException.""" + mock_subscriber.delete_subscription.side_effect = gcp_exceptions.GoogleAPIError( + "network error" + ) + + with pytest.raises( + EndpointResourceInfraException, match="Failed to delete Pub/Sub subscription" + ): + await delegate.delete_queue(endpoint_id=ENDPOINT_ID) + + +@pytest.mark.asyncio +async def test_delete_queue_topic_api_error_raises( + mock_publisher, mock_subscriber, delegate +): + """Non-NotFound GoogleAPIError on topic deletion raises EndpointResourceInfraException.""" + mock_publisher.delete_topic.side_effect = gcp_exceptions.GoogleAPIError( + "network error" + ) + + with pytest.raises( + EndpointResourceInfraException, match="Failed to delete Pub/Sub topic" + ): + await delegate.delete_queue(endpoint_id=ENDPOINT_ID) + + +@pytest.mark.asyncio +async def test_delete_queue_subscription_deleted_before_topic( + mock_publisher, mock_subscriber, delegate +): + """Subscription must be deleted before topic (Pub/Sub requirement to avoid race).""" + parent = MagicMock() + parent.attach_mock(mock_subscriber.delete_subscription, "sub_del") + parent.attach_mock(mock_publisher.delete_topic, "topic_del") + + await delegate.delete_queue(endpoint_id=ENDPOINT_ID) + + call_order = [c[0] for c in parent.mock_calls] + assert call_order == ["sub_del", "topic_del"] + + +@pytest.mark.asyncio +async def test_get_queue_attributes_returns_expected_shape(delegate): + """get_queue_attributes returns a dict with 'name' and 'num_undelivered_messages'.""" + attrs = await delegate.get_queue_attributes(endpoint_id=ENDPOINT_ID) + + assert attrs["name"] == QUEUE_NAME + assert "num_undelivered_messages" in attrs + assert attrs["num_undelivered_messages"] == -1