Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions charts/model-engine/templates/_helpers.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,7 @@ env:
- name: LAUNCH_SERVICE_TEMPLATE_FOLDER
value: "/workspace/model-engine/model_engine_server/infra/gateways/resources/templates"
{{- $model_cache := default dict .Values.modelCache }}
{{- $gcp_cloud_provider := and .Values.config .Values.config.values .Values.config.values.infra (eq (.Values.config.values.infra.cloud_provider | default "") "gcp") -}}
- name: MODEL_CACHE_ENABLED
value: {{ get $model_cache "enabled" | default false | quote }}
- name: MODEL_CACHE_MOUNT_PATH
Expand Down Expand Up @@ -404,6 +405,14 @@ env:
- name: SERVICEBUS_NAMESPACE
value: {{ .Values.azure.servicebus_namespace }}
{{- end }}
{{- if $gcp_cloud_provider }}
- name: GCP_PROJECT_ID
value: {{ (.Values.gcp).project_id | default "" | quote }}
- name: PUBSUB_TOPIC_PREFIX
value: {{ (.Values.gcp).pubsub_topic_prefix | default "" | quote }}
- name: PUBSUB_SUBSCRIPTION_PREFIX
value: {{ (.Values.gcp).pubsub_subscription_prefix | default "" | quote }}
{{- end }}
{{- if eq .Values.context "circleci" }}
- name: CIRCLECI
value: "true"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@
{{- $tag := .Values.tag }}
{{- $message_broker := .Values.celeryBrokerType }}
{{- $num_shards := .Values.celery_autoscaler.num_shards }}
{{- $gcp_cloud_provider := and .Values.config .Values.config.values .Values.config.values.infra (eq (.Values.config.values.infra.cloud_provider | default "") "gcp") -}}
{{- $broker_name := "redis-elasticache-message-broker-master" }}
{{- if eq $message_broker "sqs" }}
{{ $broker_name = "sqs-message-broker-master" }}
{{- else if eq $message_broker "servicebus" }}
{{ $broker_name = "servicebus-message-broker-master" }}
{{- else if and .Values.config .Values.config.values .Values.config.values.infra (eq (.Values.config.values.infra.cloud_provider | default "") "gcp") }}
{{- else if $gcp_cloud_provider }}
{{ $broker_name = "redis-gcp-memorystore-message-broker-master" }}
{{- end }}
apiVersion: apps/v1
Expand Down Expand Up @@ -89,6 +90,14 @@ spec:
- name: SERVICEBUS_NAMESPACE
value: {{ .Values.azure.servicebus_namespace }}
{{- end }}
{{- if $gcp_cloud_provider }}
- name: GCP_PROJECT_ID
value: {{ (.Values.gcp).project_id | default "" | quote }}
- name: PUBSUB_TOPIC_PREFIX
value: {{ (.Values.gcp).pubsub_topic_prefix | default "" | quote }}
- name: PUBSUB_SUBSCRIPTION_PREFIX
value: {{ (.Values.gcp).pubsub_subscription_prefix | default "" | quote }}
{{- end }}
image: "{{ .Values.image.gatewayRepository }}:{{ $tag }}"
imagePullPolicy: Always
name: main
Expand Down
6 changes: 6 additions & 0 deletions charts/model-engine/values.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,9 @@ utilityImages:

# Additional GPU tolerations for endpoint pods
gpuTolerations: []

# GCP configuration for GCP-based deployments
gcp:
project_id: ""
pubsub_topic_prefix: "launch-endpoint-id-"
pubsub_subscription_prefix: "launch-endpoint-id-"
6 changes: 6 additions & 0 deletions charts/model-engine/values_sample.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -438,3 +438,9 @@ recommendedHardware:
gpu_type: nvidia-hopper-h100
nodes_per_worker: 1
#serviceBuilderQueue:

# GCP configuration for GCP-based deployments
gcp:
project_id: "your-gcp-project"
pubsub_topic_prefix: "launch-endpoint-id-"
pubsub_subscription_prefix: "launch-endpoint-id-"
8 changes: 6 additions & 2 deletions model-engine/model_engine_server/api/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@
from model_engine_server.infra.gateways.resources.asb_queue_endpoint_resource_delegate import (
ASBQueueEndpointResourceDelegate,
)
from model_engine_server.infra.gateways.resources.gcp_pubsub_queue_endpoint_resource_delegate import (
GcpPubSubQueueEndpointResourceDelegate,
)
from model_engine_server.infra.gateways.resources.endpoint_resource_gateway import (
EndpointResourceGateway,
)
Expand Down Expand Up @@ -248,8 +251,9 @@ def _get_external_interfaces(
elif infra_config().cloud_provider == "azure":
queue_delegate = ASBQueueEndpointResourceDelegate()
elif infra_config().cloud_provider == "gcp":
# GCP uses Redis (Memorystore) for Celery, so use Redis-based queue delegate
queue_delegate = RedisQueueEndpointResourceDelegate(redis_client=redis_client)
queue_delegate = GcpPubSubQueueEndpointResourceDelegate(
project_id=infra_config().gcp_project_id,
Comment on lines 253 to +255
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 GCP_PROJECT_ID env var is dead config — delegate will always raise ValueError on GCP clusters

The Helm chart injects GCP_PROJECT_ID from .Values.gcp.project_id, but the delegate is instantiated with project_id=infra_config().gcp_project_id, which is loaded from the infra_service_config YAML ConfigMap. That ConfigMap is rendered from .Values.config.values.infra in service_config_map.yaml — not from .Values.gcp. Since the new gcp.project_id key is in a separate Helm section and gcp_project_id is never injected into config.values.infra, infra_config().gcp_project_id will always be None at runtime. The delegate's own if not project_id: raise ValueError(...) guard will fire on every startup on any GCP cluster that follows the sample values. The same broken path is repeated in k8s_cache.py and start_batch_job_orchestration.py. Either read os.getenv("GCP_PROJECT_ID") directly in the delegate (consistent with how ASBQueueEndpointResourceDelegate reads os.getenv("SERVICEBUS_NAMESPACE")), or add gcp_project_id to config.values.infra in the chart.

Prompt To Fix With AI
This is a comment left during a code review.
Path: model-engine/model_engine_server/api/dependencies.py
Line: 253-255

Comment:
**`GCP_PROJECT_ID` env var is dead config — delegate will always raise `ValueError` on GCP clusters**

The Helm chart injects `GCP_PROJECT_ID` from `.Values.gcp.project_id`, but the delegate is instantiated with `project_id=infra_config().gcp_project_id`, which is loaded from the `infra_service_config` YAML ConfigMap. That ConfigMap is rendered from `.Values.config.values.infra` in `service_config_map.yaml` — not from `.Values.gcp`. Since the new `gcp.project_id` key is in a separate Helm section and `gcp_project_id` is never injected into `config.values.infra`, `infra_config().gcp_project_id` will always be `None` at runtime. The delegate's own `if not project_id: raise ValueError(...)` guard will fire on every startup on any GCP cluster that follows the sample values. The same broken path is repeated in `k8s_cache.py` and `start_batch_job_orchestration.py`. Either read `os.getenv("GCP_PROJECT_ID")` directly in the delegate (consistent with how `ASBQueueEndpointResourceDelegate` reads `os.getenv("SERVICEBUS_NAMESPACE")`), or add `gcp_project_id` to `config.values.infra` in the chart.

How can I resolve this? If you propose a fix, please make it concise.

Fix in Cursor Fix in Claude Code Fix in Codex

)
Comment on lines 253 to +256
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 gcp_project_id is Optional[str]None produces silently-invalid resource paths

infra_config().gcp_project_id is typed Optional[str] and defaults to None if not set in the YAML config. Passing None to GcpPubSubQueueEndpointResourceDelegate(project_id=None) won't raise at construction time; it will silently produce paths like projects/None/topics/launch-endpoint-id-<id>, causing every Pub/Sub API call to fail at runtime with a cryptic NotFound error. An explicit guard (if not infra_config().gcp_project_id: raise ...) or an assertion at construction time would surface the misconfiguration early. The same pattern is repeated in k8s_cache.py and start_batch_job_orchestration.py.

Prompt To Fix With AI
This is a comment left during a code review.
Path: model-engine/model_engine_server/api/dependencies.py
Line: 253-256

Comment:
**`gcp_project_id` is `Optional[str]``None` produces silently-invalid resource paths**

`infra_config().gcp_project_id` is typed `Optional[str]` and defaults to `None` if not set in the YAML config. Passing `None` to `GcpPubSubQueueEndpointResourceDelegate(project_id=None)` won't raise at construction time; it will silently produce paths like `projects/None/topics/launch-endpoint-id-<id>`, causing every Pub/Sub API call to fail at runtime with a cryptic `NotFound` error. An explicit guard (`if not infra_config().gcp_project_id: raise ...`) or an assertion at construction time would surface the misconfiguration early. The same pattern is repeated in `k8s_cache.py` and `start_batch_job_orchestration.py`.

How can I resolve this? If you propose a fix, please make it concise.

Fix in Cursor Fix in Claude Code Fix in Codex

else:
queue_delegate = SQSQueueEndpointResourceDelegate(
sqs_profile=os.getenv("SQS_PROFILE", hmi_config.sqs_profile)
Expand Down
1 change: 1 addition & 0 deletions model-engine/model_engine_server/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class _InfraConfig:
celery_enable_sha256: Optional[bool] = None
docker_registry_type: Optional[str] = None
debug_mode: Optional[bool] = None
gcp_project_id: Optional[str] = None


@dataclass
Expand Down
7 changes: 7 additions & 0 deletions model-engine/model_engine_server/entrypoints/k8s_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
from model_engine_server.infra.gateways.resources.asb_queue_endpoint_resource_delegate import (
ASBQueueEndpointResourceDelegate,
)
from model_engine_server.infra.gateways.resources.gcp_pubsub_queue_endpoint_resource_delegate import (
GcpPubSubQueueEndpointResourceDelegate,
)
from model_engine_server.infra.gateways.resources.endpoint_resource_gateway import (
EndpointResourceGateway,
)
Expand Down Expand Up @@ -119,6 +122,10 @@ async def main(args: Any):
queue_delegate = OnPremQueueEndpointResourceDelegate()
elif infra_config().cloud_provider == "azure":
queue_delegate = ASBQueueEndpointResourceDelegate()
elif infra_config().cloud_provider == "gcp":
queue_delegate = GcpPubSubQueueEndpointResourceDelegate(
project_id=infra_config().gcp_project_id,
)
else:
queue_delegate = SQSQueueEndpointResourceDelegate(
sqs_profile=os.getenv("SQS_PROFILE", hmi_config.sqs_profile)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@
from model_engine_server.infra.gateways.resources.asb_queue_endpoint_resource_delegate import (
ASBQueueEndpointResourceDelegate,
)
from model_engine_server.infra.gateways.resources.gcp_pubsub_queue_endpoint_resource_delegate import (
GcpPubSubQueueEndpointResourceDelegate,
)
from model_engine_server.infra.gateways.resources.fake_queue_endpoint_resource_delegate import (
FakeQueueEndpointResourceDelegate,
)
Expand Down Expand Up @@ -90,6 +93,10 @@ async def run_batch_job(
queue_delegate = OnPremQueueEndpointResourceDelegate()
elif infra_config().cloud_provider == "azure":
queue_delegate = ASBQueueEndpointResourceDelegate()
elif infra_config().cloud_provider == "gcp":
queue_delegate = GcpPubSubQueueEndpointResourceDelegate(
project_id=infra_config().gcp_project_id,
)
else:
queue_delegate = SQSQueueEndpointResourceDelegate(
sqs_profile=os.getenv("SQS_PROFILE", hmi_config.sqs_profile)
Expand All @@ -110,6 +117,9 @@ async def run_batch_job(
if infra_config().cloud_provider == "azure":
inference_task_queue_gateway = servicebus_task_queue_gateway
infra_task_queue_gateway = servicebus_task_queue_gateway
elif infra_config().cloud_provider == "gcp":
inference_task_queue_gateway = redis_task_queue_gateway
infra_task_queue_gateway = redis_task_queue_gateway
elif infra_config().cloud_provider == "onprem" or infra_config().celery_broker_type_redis:
# On-prem uses Redis-based task queues
inference_task_queue_gateway = redis_task_queue_gateway
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
from typing import Any, Dict, Optional

from google.api_core import exceptions as gcp_exceptions
from google.cloud import pubsub_v1
from google.protobuf import field_mask_pb2
from model_engine_server.core.loggers import logger_name, make_logger
from model_engine_server.domain.exceptions import EndpointResourceInfraException
from model_engine_server.infra.gateways.resources.queue_endpoint_resource_delegate import (
QueueEndpointResourceDelegate,
QueueInfo,
)

logger = make_logger(logger_name())

GCP_PUBSUB_MAX_ACK_DEADLINE_SECONDS = 600 # Pub/Sub hard limit


class GcpPubSubQueueEndpointResourceDelegate(QueueEndpointResourceDelegate):
"""
Using GCP Pub/Sub (topic + subscription per endpoint).

topic_prefix and subscription_prefix control the GCP resource name prefix.
The logical queue_name returned to callers always uses the canonical
QueueEndpointResourceDelegate.endpoint_id_to_queue_name format, independent
of these prefixes.
"""

def __init__(
self,
project_id: str,
topic_prefix: str = "launch-endpoint-id-",
subscription_prefix: str = "launch-endpoint-id-",
) -> None:
if not project_id:
raise ValueError(
"GcpPubSubQueueEndpointResourceDelegate requires a non-empty project_id; "
"set infra.gcp_project_id in the service config."
)
self.project_id = project_id
self.topic_prefix = topic_prefix
self.subscription_prefix = subscription_prefix
self._publisher = pubsub_v1.PublisherClient()
self._subscriber = pubsub_v1.SubscriberClient()

def _topic_id(self, endpoint_id: str) -> str:
return f"{self.topic_prefix}{endpoint_id}"

def _subscription_id(self, endpoint_id: str) -> str:
return f"{self.subscription_prefix}{endpoint_id}"

async def create_queue_if_not_exists(
self,
endpoint_id: str,
endpoint_name: str,
endpoint_created_by: str,
endpoint_labels: Dict[str, Any],
queue_message_timeout_seconds: Optional[int] = None,
) -> QueueInfo:
queue_name = QueueEndpointResourceDelegate.endpoint_id_to_queue_name(
endpoint_id
)
topic_path = f"projects/{self.project_id}/topics/{self._topic_id(endpoint_id)}"
subscription_path = f"projects/{self.project_id}/subscriptions/{self._subscription_id(endpoint_id)}"
ack_deadline = min(
queue_message_timeout_seconds or 60, GCP_PUBSUB_MAX_ACK_DEADLINE_SECONDS
)

try:
self._publisher.create_topic(name=topic_path)
except gcp_exceptions.AlreadyExists:
pass

try:
self._subscriber.create_subscription(
name=subscription_path,
topic=topic_path,
ack_deadline_seconds=ack_deadline,
)
except gcp_exceptions.AlreadyExists:
try:
self._subscriber.update_subscription(
subscription=pubsub_v1.types.Subscription(
name=subscription_path,
ack_deadline_seconds=ack_deadline,
),
update_mask=field_mask_pb2.FieldMask(
paths=["ack_deadline_seconds"]
),
)
except gcp_exceptions.GoogleAPIError as e:
logger.warning(
f"Failed to update ack_deadline for Pub/Sub subscription {subscription_path}: {e}"
)

# Pub/Sub has no URL concept analogous to SQS queue URLs
return QueueInfo(queue_name, queue_url=None)

async def delete_queue(self, endpoint_id: str) -> None:
subscription_path = f"projects/{self.project_id}/subscriptions/{self._subscription_id(endpoint_id)}"
topic_path = f"projects/{self.project_id}/topics/{self._topic_id(endpoint_id)}"

try:
self._subscriber.delete_subscription(subscription=subscription_path)
except gcp_exceptions.NotFound:
logger.info(
f"Could not find Pub/Sub subscription {subscription_path} for endpoint {endpoint_id}"
)
except gcp_exceptions.GoogleAPIError as e:
raise EndpointResourceInfraException(
f"Failed to delete Pub/Sub subscription {subscription_path} for endpoint {endpoint_id}: {e}"
) from e

try:
self._publisher.delete_topic(topic=topic_path)
except gcp_exceptions.NotFound:
logger.info(
f"Could not find Pub/Sub topic {topic_path} for endpoint {endpoint_id}"
)
except gcp_exceptions.GoogleAPIError as e:
raise EndpointResourceInfraException(
f"Failed to delete Pub/Sub topic {topic_path} for endpoint {endpoint_id}: {e}"
) from e

async def get_queue_attributes(self, endpoint_id: str) -> Dict[str, Any]:
queue_name = QueueEndpointResourceDelegate.endpoint_id_to_queue_name(
endpoint_id
)
return {
"name": queue_name,
# Pub/Sub does not expose a synchronous undelivered message count;
# real observability requires the Cloud Monitoring API as a separate concern.
"num_undelivered_messages": -1,
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing import Dict, Optional, Tuple

from model_engine_server.common.dtos.resource_manager import CreateOrUpdateResourcesRequest
from model_engine_server.common.dtos.resource_manager import (
CreateOrUpdateResourcesRequest,
)
from model_engine_server.core.loggers import logger_name, make_logger
from model_engine_server.domain.entities import (
ModelEndpointInfraState,
Expand Down Expand Up @@ -28,11 +30,15 @@ class LiveEndpointResourceGateway(EndpointResourceGateway[QueueInfo]):
def __init__(
self,
queue_delegate: QueueEndpointResourceDelegate,
inference_autoscaling_metrics_gateway: Optional[InferenceAutoscalingMetricsGateway],
inference_autoscaling_metrics_gateway: Optional[
InferenceAutoscalingMetricsGateway
],
):
self.k8s_delegate = K8SEndpointResourceDelegate()
self.queue_delegate = queue_delegate
self.inference_autoscaling_metrics_gateway = inference_autoscaling_metrics_gateway
self.inference_autoscaling_metrics_gateway = (
inference_autoscaling_metrics_gateway
)

async def create_queue(
self,
Expand Down Expand Up @@ -79,7 +85,9 @@ async def create_or_update_resources(
sqs_queue_name=queue_name,
sqs_queue_url=queue_url,
)
return EndpointResourceGatewayCreateOrUpdateResourcesResponse(destination=destination)
return EndpointResourceGatewayCreateOrUpdateResourcesResponse(
destination=destination
)

async def get_resources(
self, endpoint_id: str, deployment_name: str, endpoint_type: ModelEndpointType
Expand All @@ -91,16 +99,28 @@ async def get_resources(
)

if endpoint_type == ModelEndpointType.ASYNC:
sqs_attributes = await self.queue_delegate.get_queue_attributes(endpoint_id=endpoint_id)
sqs_attributes = await self.queue_delegate.get_queue_attributes(
endpoint_id=endpoint_id
)
if (
"Attributes" in sqs_attributes
and "ApproximateNumberOfMessages" in sqs_attributes["Attributes"]
):
resources.num_queued_items = int(
sqs_attributes["Attributes"]["ApproximateNumberOfMessages"]
)
elif "active_message_count" in sqs_attributes: # from ASBQueueEndpointResourceDelegate
elif (
"active_message_count" in sqs_attributes
): # from ASBQueueEndpointResourceDelegate
resources.num_queued_items = int(sqs_attributes["active_message_count"])
elif (
"num_undelivered_messages" in sqs_attributes
): # from GcpPubSubQueueEndpointResourceDelegate
# Pub/Sub returns -1 when num_undelivered_messages is not yet wired to Cloud Monitoring.
# Treat -1 as "unknown" and skip; downstream autoscaling expects non-negative counts.
gcp_count = int(sqs_attributes["num_undelivered_messages"])
if gcp_count >= 0:
resources.num_queued_items = gcp_count

return resources

Expand All @@ -125,7 +145,9 @@ async def delete_resources(
sqs_result = False

if self.inference_autoscaling_metrics_gateway is not None:
await self.inference_autoscaling_metrics_gateway.delete_resources(endpoint_id)
await self.inference_autoscaling_metrics_gateway.delete_resources(
endpoint_id
)

return k8s_result and sqs_result

Expand Down
1 change: 1 addition & 0 deletions model-engine/requirements.in
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ azure-storage-blob~=12.19.0
# GCP dependencies
gcloud-aio-storage~=9.6
google-auth~=2.25.0
google-cloud-pubsub>=2.18
google-cloud-artifact-registry~=1.21.0
google-cloud-secret-manager>=2.24.0
google-cloud-storage~=2.14.0
Expand Down
Loading