github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/tools/mlir-tblgen/ReferenceImplGen.cpp (about)

     1  //===- ReferenceImplGen.cpp - MLIR reference implementation generator -----===//
     2  //
     3  // Copyright 2019 The MLIR Authors.
     4  //
     5  // Licensed under the Apache License, Version 2.0 (the "License");
     6  // you may not use this file except in compliance with the License.
     7  // You may obtain a copy of the License at
     8  //
     9  //   http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  // Unless required by applicable law or agreed to in writing, software
    12  // distributed under the License is distributed on an "AS IS" BASIS,
    13  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  // See the License for the specific language governing permissions and
    15  // limitations under the License.
    16  // =============================================================================
    17  //
    18  // ReferenceImplGen uses the description of operations to generate reference
    19  // implementations for the ops.
    20  //
    21  //===----------------------------------------------------------------------===//
    22  
    23  #include "mlir/TableGen/GenInfo.h"
    24  #include "mlir/TableGen/Operator.h"
    25  #include "llvm/ADT/StringExtras.h"
    26  #include "llvm/Support/FormatVariadic.h"
    27  #include "llvm/Support/Signals.h"
    28  #include "llvm/TableGen/Error.h"
    29  #include "llvm/TableGen/Record.h"
    30  #include "llvm/TableGen/TableGenBackend.h"
    31  
    32  using namespace llvm;
    33  using namespace mlir;
    34  
    35  using mlir::tblgen::Operator;
    36  
    37  static void emitReferenceImplementations(const RecordKeeper &recordKeeper,
    38                                           raw_ostream &os) {
    39    emitSourceFileHeader("Reference implementation file", os);
    40    const auto &defs = recordKeeper.getAllDerivedDefinitions("Op");
    41  
    42    os << "void printRefImplementation(StringRef opName, mlir::FuncOp *f) {\n"
    43       << "  using namespace ::mlir::edsc;\n"
    44       << "if (false) {}";
    45    for (auto *def : defs) {
    46      Operator op(def);
    47      auto referenceImplGenerator = def->getValueInit("referenceImplementation");
    48      if (!referenceImplGenerator)
    49        continue;
    50      os << " else if (opName == \"" << op.getOperationName() << "\") {\n"
    51         << "  edsc::ScopedContext scope(f);\n";
    52  
    53      for (auto en : llvm::enumerate(op.getOperands())) {
    54        os.indent(2) << formatv("ValueHandle arg_{0}(f->getArgument({1})); "
    55                                "(void)arg_{0};\n",
    56                                en.value().name, en.index());
    57        // TODO(jpienaar): this is generally incorrect, not all args are memref
    58        // in the general case.
    59        os.indent(2) << formatv("MemRefView view_{0}(f->getArgument({1})); "
    60                                "(void)view_{0};\n",
    61                                en.value().name, en.index());
    62      }
    63      unsigned numOperands = op.getNumOperands();
    64      unsigned numResults = op.getNumResults();
    65      for (unsigned idx = 0; idx < numResults; ++idx) {
    66        os.indent(2) << formatv("ValueHandle arg_{0}(f->getArgument({1})); "
    67                                "(void)arg_{0};\n",
    68                                op.getResult(idx).name, numOperands + idx);
    69        // TODO(jpienaar): this is generally incorrect, not all args are memref
    70        // in the general case.
    71        os.indent(2) << formatv("MemRefView view_{0}(f->getArgument({1})); "
    72                                "(void)view_{0};\n",
    73                                op.getResult(idx).name, numOperands + idx);
    74      }
    75  
    76      // Print the EDSC.
    77      os << referenceImplGenerator->getAsUnquotedString() << "\n";
    78      os.indent(2) << "f->print(llvm::outs());\n\n";
    79      os << "}";
    80    }
    81    os << " else {\n";
    82    os.indent(2) << "f->emitError(\"no reference impl. for \" + opName);\n";
    83    os.indent(2) << "return;\n";
    84    os << "}\n";
    85    os << "}\n";
    86  }
    87  
    88  static mlir::GenRegistration
    89      genRegister("gen-reference-implementations",
    90                  "Generate reference implemenations",
    91                  [](const RecordKeeper &records, raw_ostream &os) {
    92                    emitReferenceImplementations(records, os);
    93                    return false;
    94                  });