github.com/alwaysproblem/mlserving-tutorial@v0.0.0-20221124033215-121cfddbfbf4/TorchServing/CustomOp/tests/add_index_test.py (about)

     1  """Test for add_index custom op of pytorch"""
     2  import numpy as np
     3  from numpy.testing import assert_allclose
     4  import torch
     5  
     6  torch.ops.load_library("libadd_index.so")
     7  print(torch.ops.my_ops.add_index)
     8  
     9  a = torch.randint(1, 10, size=(3, 4), dtype=torch.int32)
    10  o = torch.ops.my_ops.add_index(a)
    11  
    12  assert_allclose(
    13      o.numpy().astype(np.int32),
    14      a + torch.Tensor(list(range(12))).reshape((3, 4))
    15  )
    16  
    17  
    18  def compute(x, y, z):
    19    x = torch.ops.my_ops.add_index(x)
    20    x = x.float()
    21    return x.matmul(y) + torch.relu(z)
    22  
    23  
    24  inputs = [
    25      torch.randint(4, 8, size=(8, 8), dtype=torch.int32),
    26      torch.randn(8, 5),
    27      torch.randn(8, 5)
    28  ]
    29  trace = torch.jit.trace(compute, inputs)
    30  print(trace.graph)