Skip to content

Commit 4d4da57

Browse files
authored
Fix ML flakes (#38088)
* Fix ML flakes * changed Log CUDA probe failures to warning
1 parent 923ed03 commit 4d4da57

File tree

3 files changed

+19
-5
lines changed

3 files changed

+19
-5
lines changed

sdks/python/apache_beam/ml/inference/pytorch_inference.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,20 @@
5656
Iterable[PredictionResult]]
5757

5858

59+
def _cuda_device_is_usable() -> bool:
60+
"""Returns True only when CUDA can actually allocate tensors."""
61+
if not torch.cuda.is_available():
62+
return False
63+
try:
64+
# Some environments report CUDA available but fail at first real use
65+
# because a driver is missing or inaccessible.
66+
torch.empty(1, device='cuda')
67+
return True
68+
except Exception: # pylint: disable=broad-except
69+
logging.warning("CUDA probe failed", exc_info=True)
70+
return False
71+
72+
5973
def _validate_constructor_args(
6074
state_dict_path, model_class, torch_script_model_path):
6175
message = (
@@ -86,7 +100,7 @@ def _load_model(
86100
model_params: Optional[dict[str, Any]],
87101
torch_script_model_path: Optional[str],
88102
load_model_args: Optional[dict[str, Any]]):
89-
if device == torch.device('cuda') and not torch.cuda.is_available():
103+
if device == torch.device('cuda') and not _cuda_device_is_usable():
90104
logging.warning(
91105
"Model handler specified a 'GPU' device, but GPUs are not available. "
92106
"Switching to CPU.")

sdks/python/apache_beam/ml/rag/ingestion/milvus_search_it_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -204,8 +204,8 @@ def create_client():
204204

205205
self._test_client = retry_with_backoff(
206206
create_client,
207-
max_retries=3,
208-
retry_delay=1.0,
207+
max_retries=5,
208+
retry_delay=2.0,
209209
operation_name="Test Milvus client connection",
210210
exception_types=(MilvusException, ))
211211

sdks/python/apache_beam/ml/rag/test_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -204,8 +204,8 @@ def create_client():
204204

205205
client = retry_with_backoff(
206206
create_client,
207-
max_retries=3,
208-
retry_delay=1.0,
207+
max_retries=5,
208+
retry_delay=2.0,
209209
operation_name="Test Milvus client connection",
210210
exception_types=(MilvusException, ))
211211

0 commit comments

Comments
 (0)