File tree Expand file tree Collapse file tree 3 files changed +19
-5
lines changed
sdks/python/apache_beam/ml Expand file tree Collapse file tree 3 files changed +19
-5
lines changed Original file line number Diff line number Diff line change 5656 Iterable [PredictionResult ]]
5757
5858
59+ def _cuda_device_is_usable () -> bool :
60+ """Returns True only when CUDA can actually allocate tensors."""
61+ if not torch .cuda .is_available ():
62+ return False
63+ try :
64+ # Some environments report CUDA available but fail at first real use
65+ # because a driver is missing or inaccessible.
66+ torch .empty (1 , device = 'cuda' )
67+ return True
68+ except Exception : # pylint: disable=broad-except
69+ logging .warning ("CUDA probe failed" , exc_info = True )
70+ return False
71+
72+
5973def _validate_constructor_args (
6074 state_dict_path , model_class , torch_script_model_path ):
6175 message = (
@@ -86,7 +100,7 @@ def _load_model(
86100 model_params : Optional [dict [str , Any ]],
87101 torch_script_model_path : Optional [str ],
88102 load_model_args : Optional [dict [str , Any ]]):
89- if device == torch .device ('cuda' ) and not torch . cuda . is_available ():
103+ if device == torch .device ('cuda' ) and not _cuda_device_is_usable ():
90104 logging .warning (
91105 "Model handler specified a 'GPU' device, but GPUs are not available. "
92106 "Switching to CPU." )
Original file line number Diff line number Diff line change @@ -204,8 +204,8 @@ def create_client():
204204
205205 self ._test_client = retry_with_backoff (
206206 create_client ,
207- max_retries = 3 ,
208- retry_delay = 1 .0 ,
207+ max_retries = 5 ,
208+ retry_delay = 2 .0 ,
209209 operation_name = "Test Milvus client connection" ,
210210 exception_types = (MilvusException , ))
211211
Original file line number Diff line number Diff line change @@ -204,8 +204,8 @@ def create_client():
204204
205205 client = retry_with_backoff (
206206 create_client ,
207- max_retries = 3 ,
208- retry_delay = 1 .0 ,
207+ max_retries = 5 ,
208+ retry_delay = 2 .0 ,
209209 operation_name = "Test Milvus client connection" ,
210210 exception_types = (MilvusException , ))
211211
You can’t perform that action at this time.
0 commit comments