@@ -63,13 +63,19 @@ def test_c_dlpack_exchange_api_cpu(self):
6363 """
6464
6565 mod : Module = tvm_ffi .cpp .load_inline (
66- name = 'mod' , cpp_sources = cpp_source , functions = 'add_one_cpu'
66+ name = 'mod' ,
67+ cpp_sources = cpp_source ,
68+ functions = 'add_one_cpu' ,
69+ keep_module_alive = False ,
6770 )
6871
69- x = paddle .full ((3 ,), 1.0 , dtype = 'float32' ).cpu ()
70- y = paddle .zeros ((3 ,), dtype = 'float32' ).cpu ()
71- mod .add_one_cpu (x , y )
72- np .testing .assert_allclose (y .numpy (), [2.0 , 2.0 , 2.0 ])
72+ def run_check ():
73+ x = paddle .full ((3 ,), 1.0 , dtype = 'float32' ).cpu ()
74+ y = paddle .zeros ((3 ,), dtype = 'float32' ).cpu ()
75+ mod .add_one_cpu (x , y )
76+ np .testing .assert_allclose (y .numpy (), [2.0 , 2.0 , 2.0 ])
77+
78+ run_check ()
7379
7480 def test_c_dlpack_exchange_api_gpu (self ):
7581 if not paddle .is_compiled_with_cuda ():
@@ -116,12 +122,16 @@ def test_c_dlpack_exchange_api_gpu(self):
116122 cpp_sources = cpp_sources ,
117123 cuda_sources = cuda_sources ,
118124 functions = ['add_one_cuda' ],
125+ keep_module_alive = False ,
119126 )
120127
121- x = paddle .full ((3 ,), 1.0 , dtype = 'float32' ).cuda ()
122- y = paddle .zeros ((3 ,), dtype = 'float32' ).cuda ()
123- mod .add_one_cuda (x , y )
124- np .testing .assert_allclose (y .numpy (), [2.0 , 2.0 , 2.0 ])
128+ def run_check ():
129+ x = paddle .full ((3 ,), 1.0 , dtype = 'float32' ).cuda ()
130+ y = paddle .zeros ((3 ,), dtype = 'float32' ).cuda ()
131+ mod .add_one_cuda (x , y )
132+ np .testing .assert_allclose (y .numpy (), [2.0 , 2.0 , 2.0 ])
133+
134+ run_check ()
125135
126136 def test_c_dlpack_exchange_api_alloc_tensor (self ):
127137 cpp_source = r"""
@@ -141,7 +151,10 @@ def test_c_dlpack_exchange_api_alloc_tensor(self):
141151 }
142152 """
143153 mod : Module = tvm_ffi .cpp .load_inline (
144- name = 'mod' , cpp_sources = cpp_source , functions = ['add_one_cpu' ]
154+ name = 'mod' ,
155+ cpp_sources = cpp_source ,
156+ functions = ['add_one_cpu' ],
157+ keep_module_alive = False ,
145158 )
146159
147160 def run_check ():
@@ -200,21 +213,28 @@ def test_data_type_as_input(self):
200213 }
201214 """
202215 mod : Module = tvm_ffi .cpp .load_inline (
203- name = 'mod' , cpp_sources = cpp_source , functions = 'check_dtype'
216+ name = 'mod' ,
217+ cpp_sources = cpp_source ,
218+ functions = 'check_dtype' ,
219+ keep_module_alive = False ,
204220 )
205- for dtype in [
206- paddle .bool ,
207- paddle .uint8 ,
208- paddle .int16 ,
209- paddle .int32 ,
210- paddle .int64 ,
211- paddle .float32 ,
212- paddle .float64 ,
213- paddle .float16 ,
214- paddle .bfloat16 ,
215- ]:
216- x = paddle .zeros ((10 ,), dtype = dtype ).cpu ()
217- mod .check_dtype (x , dtype )
221+
222+ def run_check ():
223+ for dtype in [
224+ paddle .bool ,
225+ paddle .uint8 ,
226+ paddle .int16 ,
227+ paddle .int32 ,
228+ paddle .int64 ,
229+ paddle .float32 ,
230+ paddle .float64 ,
231+ paddle .float16 ,
232+ paddle .bfloat16 ,
233+ ]:
234+ x = paddle .zeros ((10 ,), dtype = dtype ).cpu ()
235+ mod .check_dtype (x , dtype )
236+
237+ run_check ()
218238
219239
220240class TestDLPackDeviceType (unittest .TestCase ):
@@ -260,15 +280,21 @@ def test_dlpack_device_type_as_input(self):
260280 }
261281 """
262282 mod : Module = tvm_ffi .cpp .load_inline (
263- name = 'mod' , cpp_sources = cpp_source , functions = 'check_device'
283+ name = 'mod' ,
284+ cpp_sources = cpp_source ,
285+ functions = 'check_device' ,
286+ keep_module_alive = False ,
264287 )
265288
266- x_cpu = paddle .zeros ((10 ,), dtype = 'float32' ).cpu ()
267- mod .check_device (x_cpu , x_cpu .place )
289+ def run_check ():
290+ x_cpu = paddle .zeros ((10 ,), dtype = 'float32' ).cpu ()
291+ mod .check_device (x_cpu , x_cpu .place )
268292
269- if paddle .is_compiled_with_cuda ():
270- x_gpu = paddle .zeros ((10 ,), dtype = 'float32' ).cuda ()
271- mod .check_device (x_gpu , x_gpu .place )
293+ if paddle .is_compiled_with_cuda ():
294+ x_gpu = paddle .zeros ((10 ,), dtype = 'float32' ).cuda ()
295+ mod .check_device (x_gpu , x_gpu .place )
296+
297+ run_check ()
272298
273299
274300if __name__ == '__main__' :
0 commit comments