@@ -63,7 +63,10 @@ def test_c_dlpack_exchange_api_cpu(self):
6363 """
6464
6565 mod : Module = tvm_ffi .cpp .load_inline (
66- name = 'mod' , cpp_sources = cpp_source , functions = 'add_one_cpu'
66+ name = 'mod' ,
67+ cpp_sources = cpp_source ,
68+ functions = 'add_one_cpu' ,
69+ keep_module_alive = False ,
6770 )
6871
6972 x = paddle .full ((3 ,), 1.0 , dtype = 'float32' ).cpu ()
@@ -141,11 +144,23 @@ def test_c_dlpack_exchange_api_alloc_tensor(self):
141144 }
142145 """
143146 mod : Module = tvm_ffi .cpp .load_inline (
144- name = 'mod' , cpp_sources = cpp_source , functions = ['add_one_cpu' ]
147+ name = 'mod' ,
148+ cpp_sources = cpp_source ,
149+ functions = ['add_one_cpu' ],
150+ keep_module_alive = False ,
145151 )
146- x = paddle .full ((3 ,), 1.0 , dtype = 'float32' ).cpu ()
147- y = mod .add_one_cpu (x )
148- np .testing .assert_allclose (y .numpy (), [2.0 , 2.0 , 2.0 ])
152+
153+ def run_check ():
154+ """Must run in a separate function to ensure deletion happens before mod unloads.
155+
156+ When a module returns an object, the object deleter address is part of the
157+ loaded library. We need to keep the module loaded until the object is deleted.
158+ """
159+ x = paddle .full ((3 ,), 1.0 , dtype = 'float32' ).cpu ()
160+ y = mod .add_one_cpu (x )
161+ np .testing .assert_allclose (y .numpy (), [2.0 , 2.0 , 2.0 ])
162+
163+ run_check ()
149164
150165
151166class TestDLPackDataType (unittest .TestCase ):
@@ -191,7 +206,10 @@ def test_data_type_as_input(self):
191206 }
192207 """
193208 mod : Module = tvm_ffi .cpp .load_inline (
194- name = 'mod' , cpp_sources = cpp_source , functions = 'check_dtype'
209+ name = 'mod' ,
210+ cpp_sources = cpp_source ,
211+ functions = 'check_dtype' ,
212+ keep_module_alive = False ,
195213 )
196214 for dtype in [
197215 paddle .bool ,
@@ -251,7 +269,10 @@ def test_dlpack_device_type_as_input(self):
251269 }
252270 """
253271 mod : Module = tvm_ffi .cpp .load_inline (
254- name = 'mod' , cpp_sources = cpp_source , functions = 'check_device'
272+ name = 'mod' ,
273+ cpp_sources = cpp_source ,
274+ functions = 'check_device' ,
275+ keep_module_alive = False ,
255276 )
256277
257278 x_cpu = paddle .zeros ((10 ,), dtype = 'float32' ).cpu ()
0 commit comments