github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/third_party/mlir/lib/TableGen/Predicate.cpp (about)

     1  //===- Predicate.cpp - Predicate class ------------------------------------===//
     2  //
     3  // Copyright 2019 The MLIR Authors.
     4  //
     5  // Licensed under the Apache License, Version 2.0 (the "License");
     6  // you may not use this file except in compliance with the License.
     7  // You may obtain a copy of the License at
     8  //
     9  //   http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  // Unless required by applicable law or agreed to in writing, software
    12  // distributed under the License is distributed on an "AS IS" BASIS,
    13  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  // See the License for the specific language governing permissions and
    15  // limitations under the License.
    16  // =============================================================================
    17  //
    18  // Wrapper around predicates defined in TableGen.
    19  //
    20  //===----------------------------------------------------------------------===//
    21  
    22  #include "mlir/TableGen/Predicate.h"
    23  #include "llvm/ADT/SetVector.h"
    24  #include "llvm/ADT/SmallPtrSet.h"
    25  #include "llvm/ADT/StringExtras.h"
    26  #include "llvm/Support/FormatVariadic.h"
    27  #include "llvm/TableGen/Error.h"
    28  #include "llvm/TableGen/Record.h"
    29  
    30  using namespace mlir;
    31  
    32  // Construct a Predicate from a record.
    33  tblgen::Pred::Pred(const llvm::Record *record) : def(record) {
    34    assert(def->isSubClassOf("Pred") &&
    35           "must be a subclass of TableGen 'Pred' class");
    36  }
    37  
    38  // Construct a Predicate from an initializer.
    39  tblgen::Pred::Pred(const llvm::Init *init) : def(nullptr) {
    40    if (const auto *defInit = dyn_cast_or_null<llvm::DefInit>(init))
    41      def = defInit->getDef();
    42  }
    43  
    44  std::string tblgen::Pred::getCondition() const {
    45    // Static dispatch to subclasses.
    46    if (def->isSubClassOf("CombinedPred"))
    47      return static_cast<const CombinedPred *>(this)->getConditionImpl();
    48    if (def->isSubClassOf("CPred"))
    49      return static_cast<const CPred *>(this)->getConditionImpl();
    50    llvm_unreachable("Pred::getCondition must be overridden in subclasses");
    51  }
    52  
    53  bool tblgen::Pred::isCombined() const {
    54    return def && def->isSubClassOf("CombinedPred");
    55  }
    56  
    57  ArrayRef<llvm::SMLoc> tblgen::Pred::getLoc() const { return def->getLoc(); }
    58  
    59  tblgen::CPred::CPred(const llvm::Record *record) : Pred(record) {
    60    assert(def->isSubClassOf("CPred") &&
    61           "must be a subclass of Tablegen 'CPred' class");
    62  }
    63  
    64  tblgen::CPred::CPred(const llvm::Init *init) : Pred(init) {
    65    assert((!def || def->isSubClassOf("CPred")) &&
    66           "must be a subclass of Tablegen 'CPred' class");
    67  }
    68  
    69  // Get condition of the C Predicate.
    70  std::string tblgen::CPred::getConditionImpl() const {
    71    assert(!isNull() && "null predicate does not have a condition");
    72    return def->getValueAsString("predExpr");
    73  }
    74  
    75  tblgen::CombinedPred::CombinedPred(const llvm::Record *record) : Pred(record) {
    76    assert(def->isSubClassOf("CombinedPred") &&
    77           "must be a subclass of Tablegen 'CombinedPred' class");
    78  }
    79  
    80  tblgen::CombinedPred::CombinedPred(const llvm::Init *init) : Pred(init) {
    81    assert((!def || def->isSubClassOf("CombinedPred")) &&
    82           "must be a subclass of Tablegen 'CombinedPred' class");
    83  }
    84  
    85  const llvm::Record *tblgen::CombinedPred::getCombinerDef() const {
    86    assert(def->getValue("kind") && "CombinedPred must have a value 'kind'");
    87    return def->getValueAsDef("kind");
    88  }
    89  
    90  const std::vector<llvm::Record *> tblgen::CombinedPred::getChildren() const {
    91    assert(def->getValue("children") &&
    92           "CombinedPred must have a value 'children'");
    93    return def->getValueAsListOfDefs("children");
    94  }
    95  
    96  namespace {
    97  // Kinds of nodes in a logical predicate tree.
    98  enum class PredCombinerKind {
    99    Leaf,
   100    And,
   101    Or,
   102    Not,
   103    SubstLeaves,
   104    Concat,
   105    // Special kinds that are used in simplification.
   106    False,
   107    True
   108  };
   109  
   110  // A node in a logical predicate tree.
   111  struct PredNode {
   112    PredCombinerKind kind;
   113    const tblgen::Pred *predicate;
   114    SmallVector<PredNode *, 4> children;
   115    std::string expr;
   116  
   117    // Prefix and suffix are used by ConcatPred.
   118    std::string prefix;
   119    std::string suffix;
   120  };
   121  } // end anonymous namespace
   122  
   123  // Get a predicate tree node kind based on the kind used in the predicate
   124  // TableGen record.
   125  static PredCombinerKind getPredCombinerKind(const tblgen::Pred &pred) {
   126    if (!pred.isCombined())
   127      return PredCombinerKind::Leaf;
   128  
   129    const auto &combinedPred = static_cast<const tblgen::CombinedPred &>(pred);
   130    return llvm::StringSwitch<PredCombinerKind>(
   131               combinedPred.getCombinerDef()->getName())
   132        .Case("PredCombinerAnd", PredCombinerKind::And)
   133        .Case("PredCombinerOr", PredCombinerKind::Or)
   134        .Case("PredCombinerNot", PredCombinerKind::Not)
   135        .Case("PredCombinerSubstLeaves", PredCombinerKind::SubstLeaves)
   136        .Case("PredCombinerConcat", PredCombinerKind::Concat);
   137  }
   138  
   139  namespace {
   140  // Substitution<pattern, replacement>.
   141  using Subst = std::pair<StringRef, StringRef>;
   142  } // end anonymous namespace
   143  
   144  // Build the predicate tree starting from the top-level predicate, which may
   145  // have children, and perform leaf substitutions inplace.  Note that after
   146  // substitution, nodes are still pointing to the original TableGen record.
   147  // All nodes are created within "allocator".
   148  static PredNode *buildPredicateTree(const tblgen::Pred &root,
   149                                      llvm::BumpPtrAllocator &allocator,
   150                                      ArrayRef<Subst> substitutions) {
   151    auto *rootNode = allocator.Allocate<PredNode>();
   152    new (rootNode) PredNode;
   153    rootNode->kind = getPredCombinerKind(root);
   154    rootNode->predicate = &root;
   155    if (!root.isCombined()) {
   156      rootNode->expr = root.getCondition();
   157      // Apply all parent substitutions from innermost to outermost.
   158      for (const auto &subst : llvm::reverse(substitutions)) {
   159        auto pos = rootNode->expr.find(subst.first);
   160        while (pos != std::string::npos) {
   161          rootNode->expr.replace(pos, subst.first.size(), subst.second);
   162          // Skip the newly inserted substring, which itself may consider the
   163          // pattern to match.
   164          pos += subst.second.size();
   165          // Find the next possible match position.
   166          pos = rootNode->expr.find(subst.first, pos);
   167        }
   168      }
   169      return rootNode;
   170    }
   171  
   172    // If the current combined predicate is a leaf substitution, append it to the
   173    // list before contiuing.
   174    auto allSubstitutions = llvm::to_vector<4>(substitutions);
   175    if (rootNode->kind == PredCombinerKind::SubstLeaves) {
   176      const auto &substPred = static_cast<const tblgen::SubstLeavesPred &>(root);
   177      allSubstitutions.push_back(
   178          {substPred.getPattern(), substPred.getReplacement()});
   179    }
   180    // If the current predicate is a ConcatPred, record the prefix and suffix.
   181    else if (rootNode->kind == PredCombinerKind::Concat) {
   182      const auto &concatPred = static_cast<const tblgen::ConcatPred &>(root);
   183      rootNode->prefix = concatPred.getPrefix();
   184      rootNode->suffix = concatPred.getSuffix();
   185    }
   186  
   187    // Build child subtrees.
   188    auto combined = static_cast<const tblgen::CombinedPred &>(root);
   189    for (const auto *record : combined.getChildren()) {
   190      auto childTree =
   191          buildPredicateTree(tblgen::Pred(record), allocator, allSubstitutions);
   192      rootNode->children.push_back(childTree);
   193    }
   194    return rootNode;
   195  }
   196  
   197  // Simplify a predicate tree rooted at "node" using the predicates that are
   198  // known to be true(false).  For AND(OR) combined predicates, if any of the
   199  // children is known to be false(true), the result is also false(true).
   200  // Furthermore, for AND(OR) combined predicates, children that are known to be
   201  // true(false) don't have to be checked dynamically.
   202  static PredNode *propagateGroundTruth(
   203      PredNode *node, const llvm::SmallPtrSetImpl<tblgen::Pred *> &knownTruePreds,
   204      const llvm::SmallPtrSetImpl<tblgen::Pred *> &knownFalsePreds) {
   205    // If the current predicate is known to be true or false, change the kind of
   206    // the node and return immediately.
   207    if (knownTruePreds.count(node->predicate) != 0) {
   208      node->kind = PredCombinerKind::True;
   209      node->children.clear();
   210      return node;
   211    }
   212    if (knownFalsePreds.count(node->predicate) != 0) {
   213      node->kind = PredCombinerKind::False;
   214      node->children.clear();
   215      return node;
   216    }
   217  
   218    // If the current node is a substitution, stop recursion now.
   219    // The expressions in the leaves below this node were rewritten, but the nodes
   220    // still point to the original predicate records.  While the original
   221    // predicate may be known to be true or false, it is not necessarily the case
   222    // after rewriting.
   223    // TODO(zinenko,jpienaar): we can support ground truth for rewritten
   224    // predicates by either (a) having our own unique'ing of the predicates
   225    // instead of relying on TableGen record pointers or (b) taking ground truth
   226    // values optinally prefixed with a list of substitutions to apply, e.g.
   227    // "predX is true by itself as well as predSubY leaf substitution had been
   228    // applied to it".
   229    if (node->kind == PredCombinerKind::SubstLeaves) {
   230      return node;
   231    }
   232  
   233    // Otherwise, look at child nodes.
   234  
   235    // Move child nodes into some local variable so that they can be optimized
   236    // separately and re-added if necessary.
   237    llvm::SmallVector<PredNode *, 4> children;
   238    std::swap(node->children, children);
   239  
   240    for (auto &child : children) {
   241      // First, simplify the child.  This maintains the predicate as it was.
   242      auto simplifiedChild =
   243          propagateGroundTruth(child, knownTruePreds, knownFalsePreds);
   244  
   245      // Just add the child if we don't know how to simplify the current node.
   246      if (node->kind != PredCombinerKind::And &&
   247          node->kind != PredCombinerKind::Or) {
   248        node->children.push_back(simplifiedChild);
   249        continue;
   250      }
   251  
   252      // Second, based on the type define which known values of child predicates
   253      // immediately collapse this predicate to a known value, and which others
   254      // may be safely ignored.
   255      //   OR(..., True, ...) = True
   256      //   OR(..., False, ...) = OR(..., ...)
   257      //   AND(..., False, ...) = False
   258      //   AND(..., True, ...) = AND(..., ...)
   259      auto collapseKind = node->kind == PredCombinerKind::And
   260                              ? PredCombinerKind::False
   261                              : PredCombinerKind::True;
   262      auto eraseKind = node->kind == PredCombinerKind::And
   263                           ? PredCombinerKind::True
   264                           : PredCombinerKind::False;
   265      const auto &collapseList =
   266          node->kind == PredCombinerKind::And ? knownFalsePreds : knownTruePreds;
   267      const auto &eraseList =
   268          node->kind == PredCombinerKind::And ? knownTruePreds : knownFalsePreds;
   269      if (simplifiedChild->kind == collapseKind ||
   270          collapseList.count(simplifiedChild->predicate) != 0) {
   271        node->kind = collapseKind;
   272        node->children.clear();
   273        return node;
   274      } else if (simplifiedChild->kind == eraseKind ||
   275                 eraseList.count(simplifiedChild->predicate) != 0) {
   276        continue;
   277      }
   278      node->children.push_back(simplifiedChild);
   279    }
   280    return node;
   281  }
   282  
   283  // Combine a list of predicate expressions using a binary combiner.  If a list
   284  // is empty, return "init".
   285  static std::string combineBinary(ArrayRef<std::string> children,
   286                                   std::string combiner, std::string init) {
   287    if (children.empty())
   288      return init;
   289  
   290    auto size = children.size();
   291    if (size == 1)
   292      return children.front();
   293  
   294    std::string str;
   295    llvm::raw_string_ostream os(str);
   296    os << '(' << children.front() << ')';
   297    for (unsigned i = 1; i < size; ++i) {
   298      os << ' ' << combiner << " (" << children[i] << ')';
   299    }
   300    return os.str();
   301  }
   302  
   303  // Prepend negation to the only condition in the predicate expression list.
   304  static std::string combineNot(ArrayRef<std::string> children) {
   305    assert(children.size() == 1 && "expected exactly one child predicate of Neg");
   306    return (Twine("!(") + children.front() + Twine(')')).str();
   307  }
   308  
   309  // Recursively traverse the predicate tree in depth-first post-order and build
   310  // the final expression.
   311  static std::string getCombinedCondition(const PredNode &root) {
   312    // Immediately return for non-combiner predicates that don't have children.
   313    if (root.kind == PredCombinerKind::Leaf)
   314      return root.expr;
   315    if (root.kind == PredCombinerKind::True)
   316      return "true";
   317    if (root.kind == PredCombinerKind::False)
   318      return "false";
   319  
   320    // Recurse into children.
   321    llvm::SmallVector<std::string, 4> childExpressions;
   322    childExpressions.reserve(root.children.size());
   323    for (const auto &child : root.children)
   324      childExpressions.push_back(getCombinedCondition(*child));
   325  
   326    // Combine the expressions based on the predicate node kind.
   327    if (root.kind == PredCombinerKind::And)
   328      return combineBinary(childExpressions, "&&", "true");
   329    if (root.kind == PredCombinerKind::Or)
   330      return combineBinary(childExpressions, "||", "false");
   331    if (root.kind == PredCombinerKind::Not)
   332      return combineNot(childExpressions);
   333    if (root.kind == PredCombinerKind::Concat) {
   334      assert(childExpressions.size() == 1 &&
   335             "ConcatPred should only have one child");
   336      return root.prefix + childExpressions.front() + root.suffix;
   337    }
   338  
   339    // Substitutions were applied before so just ignore them.
   340    if (root.kind == PredCombinerKind::SubstLeaves) {
   341      assert(childExpressions.size() == 1 &&
   342             "substitution predicate must have one child");
   343      return childExpressions[0];
   344    }
   345  
   346    llvm::PrintFatalError(root.predicate->getLoc(), "unsupported predicate kind");
   347  }
   348  
   349  std::string tblgen::CombinedPred::getConditionImpl() const {
   350    llvm::BumpPtrAllocator allocator;
   351    auto predicateTree = buildPredicateTree(*this, allocator, {});
   352    predicateTree = propagateGroundTruth(
   353        predicateTree,
   354        /*knownTruePreds=*/llvm::SmallPtrSet<tblgen::Pred *, 2>(),
   355        /*knownFalsePreds=*/llvm::SmallPtrSet<tblgen::Pred *, 2>());
   356  
   357    return getCombinedCondition(*predicateTree);
   358  }
   359  
   360  StringRef tblgen::SubstLeavesPred::getPattern() const {
   361    return def->getValueAsString("pattern");
   362  }
   363  
   364  StringRef tblgen::SubstLeavesPred::getReplacement() const {
   365    return def->getValueAsString("replacement");
   366  }
   367  
   368  StringRef tblgen::ConcatPred::getPrefix() const {
   369    return def->getValueAsString("prefix");
   370  }
   371  
   372  StringRef tblgen::ConcatPred::getSuffix() const {
   373    return def->getValueAsString("suffix");
   374  }