github.com/alwaysproblem/mlserving-tutorial@v0.0.0-20221124033215-121cfddbfbf4/TorchServing/CustomOp/tests/add_index_test.py (about) 1 """Test for add_index custom op of pytorch""" 2 import numpy as np 3 from numpy.testing import assert_allclose 4 import torch 5 6 torch.ops.load_library("libadd_index.so") 7 print(torch.ops.my_ops.add_index) 8 9 a = torch.randint(1, 10, size=(3, 4), dtype=torch.int32) 10 o = torch.ops.my_ops.add_index(a) 11 12 assert_allclose( 13 o.numpy().astype(np.int32), 14 a + torch.Tensor(list(range(12))).reshape((3, 4)) 15 ) 16 17 18 def compute(x, y, z): 19 x = torch.ops.my_ops.add_index(x) 20 x = x.float() 21 return x.matmul(y) + torch.relu(z) 22 23 24 inputs = [ 25 torch.randint(4, 8, size=(8, 8), dtype=torch.int32), 26 torch.randn(8, 5), 27 torch.randn(8, 5) 28 ] 29 trace = torch.jit.trace(compute, inputs) 30 print(trace.graph)