github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/bindings/python/test/test_py2and3.py (about)

     1  # Copyright 2019 The MLIR Authors.
     2  #
     3  # Licensed under the Apache License, Version 2.0 (the "License");
     4  # you may not use this file except in compliance with the License.
     5  # You may obtain a copy of the License at
     6  #
     7  #     http://www.apache.org/licenses/LICENSE-2.0
     8  #
     9  # Unless required by applicable law or agreed to in writing, software
    10  # distributed under the License is distributed on an "AS IS" BASIS,
    11  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  # See the License for the specific language governing permissions and
    13  # limitations under the License.
    14  # ==============================================================================
    15  
    16  # RUN: %p/test_edsc %s | FileCheck %s
    17  """Python2 and 3 test for the MLIR EDSC Python bindings"""
    18  
    19  import google_mlir.bindings.python.pybind as E
    20  import inspect
    21  
    22  # Prints `str` prefixed by the current test function name so we can use it in
    23  # Filecheck label directives.
    24  # This is achieved by inspecting the stack and getting the parent name.
    25  def printWithCurrentFunctionName(str):
    26    print(inspect.stack()[1][3])
    27    print(str)
    28  
    29  class EdscTest:
    30  
    31    def setUp(self):
    32      self.module = E.MLIRModule()
    33      self.boolType = self.module.make_scalar_type("i", 1)
    34      self.i32Type = self.module.make_scalar_type("i", 32)
    35      self.f32Type = self.module.make_scalar_type("f32")
    36      self.indexType = self.module.make_index_type()
    37  
    38    def testBlockArguments(self):
    39      self.setUp()
    40      with self.module.function_context("foo", [], []) as fun:
    41        E.constant_index(42)
    42        with E.BlockContext([self.f32Type, self.f32Type]) as b:
    43          b.arg(0) + b.arg(1)
    44        printWithCurrentFunctionName(str(fun))
    45      # CHECK-LABEL: testBlockArguments
    46      #       CHECK: %{{.*}} = constant 42 : index
    47      #       CHECK: ^bb{{.*}}(%{{.*}}: f32, %{{.*}}: f32):
    48      #       CHECK:   %{{.*}} = addf %{{.*}}, %{{.*}} : f32
    49  
    50    def testBlockContext(self):
    51      self.setUp()
    52      with self.module.function_context("foo", [], []) as fun:
    53        cst = E.constant_index(42)
    54        with E.BlockContext():
    55          cst + cst
    56        printWithCurrentFunctionName(str(fun))
    57      # CHECK-LABEL: testBlockContext
    58      #       CHECK: %{{.*}} = constant 42 : index
    59      #       CHECK: ^bb
    60      #       CHECK: %{{.*}} = "affine.apply"() {map = () -> (84)} : () -> index
    61  
    62    def testBlockContextAppend(self):
    63      self.setUp()
    64      with self.module.function_context("foo", [], []) as fun:
    65        E.constant_index(41)
    66        with E.BlockContext() as b:
    67          blk = b  # save block handle for later
    68          E.constant_index(0)
    69        E.constant_index(42)
    70        with E.BlockContext(E.appendTo(blk)):
    71          E.constant_index(1)
    72        printWithCurrentFunctionName(str(fun))
    73      # CHECK-LABEL: testBlockContextAppend
    74      #       CHECK: %{{.*}} = constant 41 : index
    75      #       CHECK: %{{.*}} = constant 42 : index
    76      #       CHECK: ^bb
    77      #       CHECK: %{{.*}} = constant 0 : index
    78      #       CHECK: %{{.*}} = constant 1 : index
    79  
    80    def testBlockContextStandalone(self):
    81      self.setUp()
    82      with self.module.function_context("foo", [], []) as fun:
    83        blk1 = E.BlockContext()
    84        blk2 = E.BlockContext()
    85        with blk1:
    86          E.constant_index(0)
    87        with blk2:
    88          E.constant_index(56)
    89          E.constant_index(57)
    90        E.constant_index(41)
    91        with blk1:
    92          E.constant_index(1)
    93        E.constant_index(42)
    94        printWithCurrentFunctionName(str(fun))
    95      # CHECK-LABEL: testBlockContextStandalone
    96      #       CHECK: %{{.*}} = constant 41 : index
    97      #       CHECK: %{{.*}} = constant 42 : index
    98      #       CHECK: ^bb
    99      #       CHECK: %{{.*}} = constant 0 : index
   100      #       CHECK: %{{.*}} = constant 1 : index
   101      #       CHECK: ^bb
   102      #       CHECK: %{{.*}} = constant 56 : index
   103      #       CHECK: %{{.*}} = constant 57 : index
   104  
   105    def testBooleanOps(self):
   106      self.setUp()
   107      with self.module.function_context(
   108          "booleans", [self.boolType for _ in range(4)], []) as fun:
   109        i, j, k, l = (fun.arg(x) for x in range(4))
   110        stmt1 = (i < j) & (j >= k)
   111        stmt2 = ~(stmt1 | (k == l))
   112        printWithCurrentFunctionName(str(fun))
   113      # CHECK-LABEL: testBooleanOps
   114      #       CHECK: %{{.*}} = cmpi "slt", %{{.*}}, %{{.*}} : i1
   115      #       CHECK: %{{.*}} = cmpi "sge", %{{.*}}, %{{.*}} : i1
   116      #       CHECK: %{{.*}} = muli %{{.*}}, %{{.*}} : i1
   117      #       CHECK: %{{.*}} = cmpi "eq", %{{.*}}, %{{.*}} : i1
   118      #       CHECK: %{{.*}} = constant 1 : i1
   119      #       CHECK: %{{.*}} = subi %{{.*}}, %{{.*}} : i1
   120      #       CHECK: %{{.*}} = constant 1 : i1
   121      #       CHECK: %{{.*}} = subi %{{.*}}, %{{.*}} : i1
   122      #       CHECK: %{{.*}} = muli %{{.*}}, %{{.*}} : i1
   123      #       CHECK: %{{.*}} = constant 1 : i1
   124      #       CHECK: %{{.*}} = subi %{{.*}}, %{{.*}} : i1
   125      #       CHECK: %{{.*}} = constant 1 : i1
   126      #       CHECK: %{{.*}} = subi %{{.*}}, %{{.*}} : i1
   127  
   128    def testBr(self):
   129      self.setUp()
   130      with self.module.function_context("foo", [], []) as fun:
   131        with E.BlockContext() as b:
   132          blk = b
   133          E.ret()
   134        E.br(blk)
   135        printWithCurrentFunctionName(str(fun))
   136      # CHECK-LABEL: testBr
   137      #       CHECK:   br ^bb
   138      #       CHECK: ^bb
   139      #       CHECK:   return
   140  
   141    def testBrArgs(self):
   142      self.setUp()
   143      with self.module.function_context("foo", [], []) as fun:
   144        # Create an infinite loop.
   145        with E.BlockContext([self.indexType, self.indexType]) as b:
   146          E.br(b, [b.arg(1), b.arg(0)])
   147        E.br(b, [E.constant_index(0), E.constant_index(1)])
   148        printWithCurrentFunctionName(str(fun))
   149      # CHECK-LABEL: testBrArgs
   150      #       CHECK:   %{{.*}} = constant 0 : index
   151      #       CHECK:   %{{.*}} = constant 1 : index
   152      #       CHECK:   br ^bb{{.*}}(%{{.*}}, %{{.*}} : index, index)
   153      #       CHECK: ^bb{{.*}}(%{{.*}}: index, %{{.*}}: index):
   154      #       CHECK:   br ^bb{{.*}}(%{{.*}}, %{{.*}} : index, index)
   155  
   156    def testBrDeclaration(self):
   157      self.setUp()
   158      with self.module.function_context("foo", [], []) as fun:
   159        blk = E.BlockContext()
   160        E.br(blk.handle())
   161        with blk:
   162          E.ret()
   163        printWithCurrentFunctionName(str(fun))
   164      # CHECK-LABEL: testBrDeclaration
   165      #       CHECK:   br ^bb
   166      #       CHECK: ^bb
   167      #       CHECK:   return
   168  
   169    def testCallOp(self):
   170      self.setUp()
   171      callee = self.module.declare_function("sqrtf", [self.f32Type],
   172                                            [self.f32Type])
   173      with self.module.function_context("call", [self.f32Type], []) as fun:
   174        funCst = E.constant_function(callee)
   175        funCst([fun.arg(0)]) + E.constant_float(42., self.f32Type)
   176        printWithCurrentFunctionName(str(self.module))
   177      # CHECK-LABEL: testCallOp
   178      #       CHECK: func @sqrtf(f32) -> f32
   179      #       CHECK:   %{{.*}} = constant @sqrtf : (f32) -> f32
   180      #       CHECK:   %{{.*}} = call_indirect %{{.*}}(%{{.*}}) : (f32) -> f32
   181  
   182    def testCondBr(self):
   183      self.setUp()
   184      with self.module.function_context("foo", [self.boolType], []) as fun:
   185        with E.BlockContext() as blk1:
   186          E.ret([])
   187        with E.BlockContext([self.indexType]) as blk2:
   188          E.ret([])
   189        cst = E.constant_index(0)
   190        E.cond_br(fun.arg(0), blk1, [], blk2, [cst])
   191        printWithCurrentFunctionName(str(fun))
   192      # CHECK-LABEL: testCondBr
   193      #       CHECK:   cond_br %{{.*}}, ^bb{{.*}}, ^bb{{.*}}(%{{.*}} : index)
   194  
   195    def testConstants(self):
   196      self.setUp()
   197      with self.module.function_context("constants", [], []) as fun:
   198        E.constant_float(1.23, self.module.make_scalar_type("bf16"))
   199        E.constant_float(1.23, self.module.make_scalar_type("f16"))
   200        E.constant_float(1.23, self.module.make_scalar_type("f32"))
   201        E.constant_float(1.23, self.module.make_scalar_type("f64"))
   202        E.constant_int(1, 1)
   203        E.constant_int(123, 8)
   204        E.constant_int(123, 16)
   205        E.constant_int(123, 32)
   206        E.constant_int(123, 64)
   207        E.constant_index(123)
   208        E.constant_function(fun)
   209        printWithCurrentFunctionName(str(fun))
   210      # CHECK-LABEL: testConstants
   211      #       CHECK:  constant 1.230000e+00 : bf16
   212      #       CHECK:  constant 1.230470e+00 : f16
   213      #       CHECK:  constant 1.230000e+00 : f32
   214      #       CHECK:  constant 1.230000e+00 : f64
   215      #       CHECK:  constant 1 : i1
   216      #       CHECK:  constant 123 : i8
   217      #       CHECK:  constant 123 : i16
   218      #       CHECK:  constant 123 : i32
   219      #       CHECK:  constant 123 : index
   220      #       CHECK:  constant @constants : () -> ()
   221  
   222    def testCustom(self):
   223      self.setUp()
   224      with self.module.function_context("custom", [self.indexType, self.f32Type],
   225                                        []) as fun:
   226        E.op("foo", [fun.arg(0)], [self.f32Type]) + fun.arg(1)
   227        printWithCurrentFunctionName(str(fun))
   228      # CHECK-LABEL: testCustom
   229      #       CHECK: %{{.*}} = "foo"(%{{.*}}) : (index) -> f32
   230      #       CHECK:  %{{.*}} = addf %{{.*}}, %{{.*}} : f32
   231  
   232    # Create 'addi' using the generic Op interface.  We need an operation known
   233    # to the execution engine so that the engine can compile it.
   234    def testCustomOpCompilation(self):
   235      self.setUp()
   236      with self.module.function_context("adder", [self.i32Type], []) as f:
   237        c1 = E.op(
   238            "std.constant", [], [self.i32Type],
   239            value=self.module.integerAttr(self.i32Type, 42))
   240        E.op("std.addi", [c1, f.arg(0)], [self.i32Type])
   241        E.ret([])
   242      self.module.compile()
   243      printWithCurrentFunctionName(str(self.module.get_engine_address() == 0))
   244      # CHECK-LABEL: testCustomOpCompilation
   245      #       CHECK: False
   246  
   247    def testDivisions(self):
   248      self.setUp()
   249      with self.module.function_context(
   250          "division", [self.indexType, self.i32Type, self.i32Type], []) as fun:
   251        # indices only support floor division
   252        fun.arg(0) // E.constant_index(42)
   253        # regular values only support regular division
   254        fun.arg(1) / fun.arg(2)
   255        printWithCurrentFunctionName(str(self.module))
   256      # CHECK-LABEL: testDivisions
   257      #       CHECK:  floordiv 42
   258      #       CHECK:  divis %{{.*}}, %{{.*}} : i32
   259  
   260    def testFunctionArgs(self):
   261      self.setUp()
   262      with self.module.function_context("foo", [self.f32Type, self.f32Type],
   263                                        [self.indexType]) as fun:
   264        pass
   265        printWithCurrentFunctionName(str(fun))
   266      # CHECK-LABEL: testFunctionArgs
   267      #       CHECK: func @foo(%{{.*}}: f32, %{{.*}}: f32) -> index
   268  
   269    def testFunctionContext(self):
   270      self.setUp()
   271      with self.module.function_context("foo", [], []):
   272        pass
   273        printWithCurrentFunctionName(self.module.get_function("foo"))
   274      # CHECK-LABEL: testFunctionContext
   275      #       CHECK: func @foo() {
   276  
   277    def testFunctionDeclaration(self):
   278      self.setUp()
   279      boolAttr = self.module.boolAttr(True)
   280      t = self.module.make_memref_type(self.f32Type, [10])
   281      t_llvm_noalias = t({"llvm.noalias": boolAttr})
   282      t_readonly = t({"readonly": boolAttr})
   283      f = self.module.declare_function("foo", [t, t_llvm_noalias, t_readonly], [])
   284      printWithCurrentFunctionName(str(self.module))
   285      # CHECK-LABEL: testFunctionDeclaration
   286      #       CHECK: func @foo(memref<10xf32>, memref<10xf32> {llvm.noalias = true}, memref<10xf32> {readonly = true})
   287  
   288    def testFunctionMultiple(self):
   289      self.setUp()
   290      with self.module.function_context("foo", [], []):
   291        pass
   292      with self.module.function_context("foo", [], []):
   293        E.constant_index(0)
   294      printWithCurrentFunctionName(str(self.module))
   295      # CHECK-LABEL: testFunctionMultiple
   296      #       CHECK: func @foo()
   297      #       CHECK: func @foo_0()
   298      #       CHECK: %{{.*}} = constant 0 : index
   299  
   300    def testIndexedValue(self):
   301      self.setUp()
   302      memrefType = self.module.make_memref_type(self.f32Type, [10, 42])
   303      with self.module.function_context("indexed", [memrefType],
   304                                        [memrefType]) as fun:
   305        A = E.IndexedValue(fun.arg(0))
   306        cst = E.constant_float(1., self.f32Type)
   307        with E.LoopNestContext(
   308            [E.constant_index(0), E.constant_index(0)],
   309            [E.constant_index(10), E.constant_index(42)], [1, 1]) as (i, j):
   310          A.store([i, j], A.load([i, j]) + cst)
   311        E.ret([fun.arg(0)])
   312        printWithCurrentFunctionName(str(fun))
   313      # CHECK-LABEL: testIndexedValue
   314      #       CHECK: "affine.for"()
   315      #       CHECK: "affine.for"()
   316      #       CHECK: "affine.load"
   317      #  CHECK-SAME: memref<10x42xf32>
   318      #       CHECK:  %{{.*}} = addf %{{.*}}, %{{.*}} : f32
   319      #       CHECK:  "affine.store"
   320      #  CHECK-SAME:  memref<10x42xf32>
   321      #       CHECK: {lower_bound = () -> (0), step = 1 : index, upper_bound = () -> (42)}
   322      #       CHECK: {lower_bound = () -> (0), step = 1 : index, upper_bound = () -> (10)}
   323  
   324    def testLoopContext(self):
   325      self.setUp()
   326      with self.module.function_context("foo", [], []) as fun:
   327        lhs = E.constant_index(0)
   328        rhs = E.constant_index(42)
   329        with E.LoopContext(lhs, rhs, 1) as i:
   330          lhs + rhs + i
   331          with E.LoopContext(rhs, rhs + rhs, 2) as j:
   332            x = i + j
   333        printWithCurrentFunctionName(str(fun))
   334      # CHECK-LABEL: testLoopContext
   335      #       CHECK: "affine.for"() (
   336      #       CHECK:   ^bb{{.*}}(%{{.*}}: index):
   337      #       CHECK: "affine.for"(%{{.*}}, %{{.*}}) (
   338      #       CHECK: ^bb{{.*}}(%{{.*}}: index):
   339      #       CHECK: "affine.apply"(%{{.*}}, %{{.*}}) {map = (d0, d1) -> (d0 + d1)} : (index, index) -> index
   340      #       CHECK: {lower_bound = (d0) -> (d0), step = 2 : index, upper_bound = (d0) -> (d0)} : (index, index) -> ()
   341      #       CHECK: {lower_bound = () -> (0), step = 1 : index, upper_bound = () -> (42)}
   342  
   343    def testLoopNestContext(self):
   344      self.setUp()
   345      with self.module.function_context("foo", [], []) as fun:
   346        lbs = [E.constant_index(i) for i in range(4)]
   347        ubs = [E.constant_index(10 * i + 5) for i in range(4)]
   348        with E.LoopNestContext(lbs, ubs, [1, 3, 5, 7]) as (i, j, k, l):
   349          i + j + k + l
   350      printWithCurrentFunctionName(str(fun))
   351      # CHECK-LABEL: testLoopNestContext
   352      #       CHECK: "affine.for"() (
   353      #       CHECK: ^bb{{.*}}(%{{.*}}: index):
   354      #       CHECK: "affine.for"() (
   355      #       CHECK: ^bb{{.*}}(%{{.*}}: index):
   356      #       CHECK: "affine.for"() (
   357      #       CHECK: ^bb{{.*}}(%{{.*}}: index):
   358      #       CHECK: "affine.for"() (
   359      #       CHECK: ^bb{{.*}}(%{{.*}}: index):
   360      #       CHECK: %{{.*}} = "affine.apply"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) {map = (d0, d1, d2, d3) -> (d0 + d1 + d2 + d3)} : (index, index, index, index) -> index
   361  
   362    def testMLIRBooleanCompilation(self):
   363      self.setUp()
   364      m = self.module.make_memref_type(self.boolType, [10])  # i1 tensor
   365      with self.module.function_context("mkbooltensor", [m, m], []) as f:
   366        input = E.IndexedValue(f.arg(0))
   367        output = E.IndexedValue(f.arg(1))
   368        zero = E.constant_index(0)
   369        ten = E.constant_index(10)
   370        with E.LoopNestContext([zero] * 3, [ten] * 3, [1] * 3) as (i, j, k):
   371          b1 = (i < j) & (j < k)
   372          b2 = ~b1
   373          b3 = b2 | (k < j)
   374          output.store([i], input.load([i]) & b3)
   375        E.ret([])
   376      self.module.compile()
   377      printWithCurrentFunctionName(str(self.module.get_engine_address() == 0))
   378      # CHECK-LABEL: testMLIRBooleanCompilation
   379      #       CHECK: False
   380  
   381    def testMLIRFunctionCreation(self):
   382      self.setUp()
   383      module = E.MLIRModule()
   384      t = module.make_scalar_type("f32")
   385      m = module.make_memref_type(t, [3, 4, -1, 5])
   386      printWithCurrentFunctionName(str(t))
   387      print(str(m))
   388      print(str(module.make_function("copy", [m, m], [])))
   389      print(str(module.make_function("sqrtf", [t], [t])))
   390      # CHECK-LABEL: testMLIRFunctionCreation
   391      #       CHECK:  f32
   392      #       CHECK:  memref<3x4x?x5xf32>
   393      #       CHECK: func @copy(%{{.*}}: memref<3x4x?x5xf32>, %{{.*}}: memref<3x4x?x5xf32>) {
   394      #       CHECK:  func @sqrtf(%{{.*}}: f32) -> f32
   395  
   396    def testMLIRScalarTypes(self):
   397      self.setUp()
   398      module = E.MLIRModule()
   399      printWithCurrentFunctionName(str(module.make_scalar_type("bf16")))
   400      print(str(module.make_scalar_type("f16")))
   401      print(str(module.make_scalar_type("f32")))
   402      print(str(module.make_scalar_type("f64")))
   403      print(str(module.make_scalar_type("i", 1)))
   404      print(str(module.make_scalar_type("i", 8)))
   405      print(str(module.make_scalar_type("i", 32)))
   406      print(str(module.make_scalar_type("i", 123)))
   407      print(str(module.make_scalar_type("index")))
   408      # CHECK-LABEL: testMLIRScalarTypes
   409      #       CHECK:  bf16
   410      #       CHECK:  f16
   411      #       CHECK:  f32
   412      #       CHECK:  f64
   413      #       CHECK:  i1
   414      #       CHECK:  i8
   415      #       CHECK:  i32
   416      #       CHECK:  i123
   417      #       CHECK:  index
   418  
   419    def testMatrixMultiply(self):
   420      self.setUp()
   421      memrefType = self.module.make_memref_type(self.f32Type, [32, 32])
   422      with self.module.function_context(
   423          "matmul", [memrefType, memrefType, memrefType], []) as fun:
   424        A = E.IndexedValue(fun.arg(0))
   425        B = E.IndexedValue(fun.arg(1))
   426        C = E.IndexedValue(fun.arg(2))
   427        c0 = E.constant_index(0)
   428        c32 = E.constant_index(32)
   429        with E.LoopNestContext([c0, c0, c0], [c32, c32, c32], [1, 1, 1]) as (i, j,
   430                                                                             k):
   431          C.store([i, j], A.load([i, k]) * B.load([k, j]))
   432        E.ret([])
   433        printWithCurrentFunctionName(str(fun))
   434      # CHECK-LABEL: testMatrixMultiply
   435      #       CHECK: "affine.for"()
   436      #       CHECK: "affine.for"()
   437      #       CHECK: "affine.for"()
   438      #   CHECK-DAG:  %{{.*}} = "affine.load"
   439      #   CHECK-DAG:  %{{.*}} = "affine.load"
   440      #       CHECK:  %{{.*}} = mulf %{{.*}}, %{{.*}} : f32
   441      #       CHECK:  "affine.store"
   442      #  CHECK-SAME:  memref<32x32xf32>
   443      #       CHECK: {lower_bound = () -> (0), step = 1 : index, upper_bound = () -> (32)} : () -> ()
   444      #       CHECK: {lower_bound = () -> (0), step = 1 : index, upper_bound = () -> (32)} : () -> ()
   445      #       CHECK: {lower_bound = () -> (0), step = 1 : index, upper_bound = () -> (32)} : () -> ()
   446  
   447    def testRet(self):
   448      self.setUp()
   449      with self.module.function_context("foo", [],
   450                                        [self.indexType, self.indexType]) as fun:
   451        c42 = E.constant_index(42)
   452        c0 = E.constant_index(0)
   453        E.ret([c42, c0])
   454        printWithCurrentFunctionName(str(fun))
   455      # CHECK-LABEL: testRet
   456      #       CHECK:    %{{.*}} = constant 42 : index
   457      #       CHECK:    %{{.*}} = constant 0 : index
   458      #       CHECK:    return %{{.*}}, %{{.*}} : index, index
   459  
   460    def testSelectOp(self):
   461      self.setUp()
   462      with self.module.function_context("foo", [self.boolType],
   463                                        [self.i32Type]) as fun:
   464        a = E.constant_int(42, 32)
   465        b = E.constant_int(0, 32)
   466        E.ret([E.select(fun.arg(0), a, b)])
   467        printWithCurrentFunctionName(str(fun))
   468      # CHECK-LABEL: testSelectOp
   469      #       CHECK:  %{{.*}} = select %{{.*}}, %{{.*}}, %{{.*}} : i32
   470  
   471  
   472  # Until python 3.6 this cannot be used because the order in the dict is not the
   473  # order of method declaration.
   474  def runTests():
   475    def isTest(attr):
   476      return inspect.ismethod(attr) and "EdscTest.setUp " not in str(attr)
   477  
   478    edscTest = EdscTest()
   479    tests = sorted(filter(isTest,
   480                          (getattr(edscTest, attr) for attr in dir(edscTest))),
   481                   key = lambda x : str(x))
   482    for test in tests:
   483      test()
   484  
   485  if __name__ == '__main__':
   486    runTests()