Skip to content

Commit 184cc2f

Browse files
committed
[TVM FFI] localize load_inline lifetime in tests
1 parent e3d76e4 commit 184cc2f

1 file changed

Lines changed: 56 additions & 30 deletions

File tree

test/legacy_test/test_tvm_ffi.py

Lines changed: 56 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -63,13 +63,19 @@ def test_c_dlpack_exchange_api_cpu(self):
6363
"""
6464

6565
mod: Module = tvm_ffi.cpp.load_inline(
66-
name='mod', cpp_sources=cpp_source, functions='add_one_cpu'
66+
name='mod',
67+
cpp_sources=cpp_source,
68+
functions='add_one_cpu',
69+
keep_module_alive=False,
6770
)
6871

69-
x = paddle.full((3,), 1.0, dtype='float32').cpu()
70-
y = paddle.zeros((3,), dtype='float32').cpu()
71-
mod.add_one_cpu(x, y)
72-
np.testing.assert_allclose(y.numpy(), [2.0, 2.0, 2.0])
72+
def run_check():
73+
x = paddle.full((3,), 1.0, dtype='float32').cpu()
74+
y = paddle.zeros((3,), dtype='float32').cpu()
75+
mod.add_one_cpu(x, y)
76+
np.testing.assert_allclose(y.numpy(), [2.0, 2.0, 2.0])
77+
78+
run_check()
7379

7480
def test_c_dlpack_exchange_api_gpu(self):
7581
if not paddle.is_compiled_with_cuda():
@@ -116,12 +122,16 @@ def test_c_dlpack_exchange_api_gpu(self):
116122
cpp_sources=cpp_sources,
117123
cuda_sources=cuda_sources,
118124
functions=['add_one_cuda'],
125+
keep_module_alive=False,
119126
)
120127

121-
x = paddle.full((3,), 1.0, dtype='float32').cuda()
122-
y = paddle.zeros((3,), dtype='float32').cuda()
123-
mod.add_one_cuda(x, y)
124-
np.testing.assert_allclose(y.numpy(), [2.0, 2.0, 2.0])
128+
def run_check():
129+
x = paddle.full((3,), 1.0, dtype='float32').cuda()
130+
y = paddle.zeros((3,), dtype='float32').cuda()
131+
mod.add_one_cuda(x, y)
132+
np.testing.assert_allclose(y.numpy(), [2.0, 2.0, 2.0])
133+
134+
run_check()
125135

126136
def test_c_dlpack_exchange_api_alloc_tensor(self):
127137
cpp_source = r"""
@@ -141,7 +151,10 @@ def test_c_dlpack_exchange_api_alloc_tensor(self):
141151
}
142152
"""
143153
mod: Module = tvm_ffi.cpp.load_inline(
144-
name='mod', cpp_sources=cpp_source, functions=['add_one_cpu']
154+
name='mod',
155+
cpp_sources=cpp_source,
156+
functions=['add_one_cpu'],
157+
keep_module_alive=False,
145158
)
146159

147160
def run_check():
@@ -200,21 +213,28 @@ def test_data_type_as_input(self):
200213
}
201214
"""
202215
mod: Module = tvm_ffi.cpp.load_inline(
203-
name='mod', cpp_sources=cpp_source, functions='check_dtype'
216+
name='mod',
217+
cpp_sources=cpp_source,
218+
functions='check_dtype',
219+
keep_module_alive=False,
204220
)
205-
for dtype in [
206-
paddle.bool,
207-
paddle.uint8,
208-
paddle.int16,
209-
paddle.int32,
210-
paddle.int64,
211-
paddle.float32,
212-
paddle.float64,
213-
paddle.float16,
214-
paddle.bfloat16,
215-
]:
216-
x = paddle.zeros((10,), dtype=dtype).cpu()
217-
mod.check_dtype(x, dtype)
221+
222+
def run_check():
223+
for dtype in [
224+
paddle.bool,
225+
paddle.uint8,
226+
paddle.int16,
227+
paddle.int32,
228+
paddle.int64,
229+
paddle.float32,
230+
paddle.float64,
231+
paddle.float16,
232+
paddle.bfloat16,
233+
]:
234+
x = paddle.zeros((10,), dtype=dtype).cpu()
235+
mod.check_dtype(x, dtype)
236+
237+
run_check()
218238

219239

220240
class TestDLPackDeviceType(unittest.TestCase):
@@ -260,15 +280,21 @@ def test_dlpack_device_type_as_input(self):
260280
}
261281
"""
262282
mod: Module = tvm_ffi.cpp.load_inline(
263-
name='mod', cpp_sources=cpp_source, functions='check_device'
283+
name='mod',
284+
cpp_sources=cpp_source,
285+
functions='check_device',
286+
keep_module_alive=False,
264287
)
265288

266-
x_cpu = paddle.zeros((10,), dtype='float32').cpu()
267-
mod.check_device(x_cpu, x_cpu.place)
289+
def run_check():
290+
x_cpu = paddle.zeros((10,), dtype='float32').cpu()
291+
mod.check_device(x_cpu, x_cpu.place)
268292

269-
if paddle.is_compiled_with_cuda():
270-
x_gpu = paddle.zeros((10,), dtype='float32').cuda()
271-
mod.check_device(x_gpu, x_gpu.place)
293+
if paddle.is_compiled_with_cuda():
294+
x_gpu = paddle.zeros((10,), dtype='float32').cuda()
295+
mod.check_device(x_gpu, x_gpu.place)
296+
297+
run_check()
272298

273299

274300
if __name__ == '__main__':

0 commit comments

Comments
 (0)