github.com/google/syzkaller@v0.0.0-20251211124644-a066d2bc4b02/tools/clang/declextract/declextract.cpp (about)

     1  // Copyright 2024 syzkaller project authors. All rights reserved.
     2  // Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
     3  
     4  #include "json.h"
     5  #include "output.h"
     6  
     7  #include "clang/AST/APValue.h"
     8  #include "clang/AST/ASTContext.h"
     9  #include "clang/AST/Attr.h"
    10  #include "clang/AST/Attrs.inc"
    11  #include "clang/AST/Decl.h"
    12  #include "clang/AST/DeclarationName.h"
    13  #include "clang/AST/Expr.h"
    14  #include "clang/AST/PrettyPrinter.h"
    15  #include "clang/AST/RecursiveASTVisitor.h"
    16  #include "clang/AST/Stmt.h"
    17  #include "clang/AST/Type.h"
    18  #include "clang/ASTMatchers/ASTMatchFinder.h"
    19  #include "clang/ASTMatchers/ASTMatchers.h"
    20  #include "clang/Basic/CharInfo.h"
    21  #include "clang/Basic/LLVM.h"
    22  #include "clang/Basic/SourceManager.h"
    23  #include "clang/Basic/TypeTraits.h"
    24  #include "clang/Frontend/CompilerInstance.h"
    25  #include "clang/Tooling/CommonOptionsParser.h"
    26  #include "clang/Tooling/Tooling.h"
    27  #include "llvm/ADT/StringRef.h"
    28  #include "llvm/Support/Casting.h"
    29  #include "llvm/Support/CommandLine.h"
    30  #include "llvm/Support/ErrorHandling.h"
    31  
    32  #include <algorithm>
    33  #include <cstddef>
    34  #include <cstdint>
    35  #include <filesystem>
    36  #include <stack>
    37  #include <string>
    38  #include <string_view>
    39  #include <tuple>
    40  #include <unordered_map>
    41  #include <vector>
    42  
    43  #include <sys/ioctl.h>
    44  
    45  using namespace clang;
    46  using namespace clang::ast_matchers;
    47  
    48  // MacroDef/MacroMap hold information about macros defined in the file.
    49  struct MacroDef {
    50    std::string Value;       // value as written in the source
    51    SourceRange SourceRange; // soruce range of the value
    52  };
    53  using MacroMap = std::unordered_map<std::string, MacroDef>;
    54  
    55  // ConstDesc describes a macro or an enum value.
    56  struct ConstDesc {
    57    std::string Name;
    58    std::string Value;
    59    SourceRange SourceRange;
    60    int64_t IntValue;
    61  };
    62  
    63  class Extractor : public MatchFinder, public tooling::SourceFileCallbacks {
    64  public:
    65    Extractor() {
    66      match(&Extractor::matchFunctionDef, functionDecl(isDefinition()).bind("function"));
    67  
    68      match(&Extractor::matchSyscall,
    69            functionDecl(isExpandedFromMacro("SYSCALL_DEFINEx"), matchesName("__do_sys_.*")).bind("syscall"));
    70  
    71      match(&Extractor::matchIouring,
    72            translationUnitDecl(forEachDescendant(
    73                varDecl(hasType(constantArrayType(hasElementType(hasDeclaration(recordDecl(hasName("io_issue_def")))))),
    74                        isDefinition())
    75                    .bind("io_issue_defs"))));
    76  
    77      match(&Extractor::matchNetlinkPolicy,
    78            translationUnitDecl(forEachDescendant(
    79                varDecl(hasType(constantArrayType(hasElementType(hasDeclaration(recordDecl(hasName("nla_policy")))))),
    80                        isDefinition())
    81                    .bind("netlink_policy"))));
    82  
    83      match(&Extractor::matchNetlinkFamily, varDecl(hasType(recordDecl(hasName("genl_family")).bind("genl_family")),
    84                                                    has(initListExpr().bind("genl_family_init"))));
    85  
    86      match(&Extractor::matchFileOps,
    87            varDecl(forEachDescendant(initListExpr(hasType(recordDecl(hasName("file_operations")))).bind("init")))
    88                .bind("var"));
    89    }
    90  
    91    void print() const { Output.print(); }
    92  
    93  private:
    94    friend struct FunctionAnalyzer;
    95    using MatchFunc = void (Extractor::*)();
    96    // Thunk that redirects MatchCallback::run method to one of the methods of the Extractor class.
    97    struct MatchCallbackThunk : MatchFinder::MatchCallback {
    98      Extractor& Ex;
    99      MatchFunc Action;
   100      MatchCallbackThunk(Extractor& Ex, MatchFunc Action) : Ex(Ex), Action(Action) {}
   101      void run(const MatchFinder::MatchResult& Result) override { Ex.run(Result, Action); }
   102    };
   103    std::vector<std::unique_ptr<MatchCallbackThunk>> Matchers;
   104  
   105    // These set to point to the Result of the current match (to avoid passing them through all methods).
   106    const BoundNodes* Nodes = nullptr;
   107    ASTContext* Context = nullptr;
   108    SourceManager* SourceManager = nullptr;
   109  
   110    Output Output;
   111    MacroMap Macros;
   112    std::unordered_map<std::string, bool> EnumDedup;
   113    std::unordered_map<std::string, bool> StructDedup;
   114    std::unordered_map<std::string, int> FileOpsDedup;
   115  
   116    void matchFunctionDef();
   117    void matchSyscall();
   118    void matchIouring();
   119    void matchNetlinkPolicy();
   120    void matchNetlinkFamily();
   121    void matchFileOps();
   122    bool handleBeginSource(CompilerInstance& CI) override;
   123    template <typename M> void match(MatchFunc Action, const M& Matcher);
   124    void run(const MatchFinder::MatchResult& Result, MatchFunc Action);
   125    template <typename T> const T* getResult(StringRef ID) const;
   126    FieldType extractRecord(QualType QT, const RecordType* Typ, const std::string& BackupName);
   127    std::string extractEnum(QualType QT, const EnumDecl* Decl);
   128    void emitConst(const std::string& Name, int64_t Val, SourceLocation Loc);
   129    std::string getFuncName(const Expr* Expr);
   130    std::string getDeclName(const Expr* Expr);
   131    const ValueDecl* getValueDecl(const Expr* Expr);
   132    std::string getDeclFileID(const Decl* Decl);
   133    std::string getUniqueDeclName(const NamedDecl* Decl);
   134    std::vector<std::pair<int, std::string>> extractDesignatedInitConsts(const VarDecl& ArrayDecl);
   135    FieldType genType(QualType Typ, const std::string& BackupName = "");
   136    std::unordered_map<std::string, unsigned> structFieldIndexes(const RecordDecl* Decl);
   137    template <typename T = int64_t> T evaluate(const Expr* E);
   138    template <typename T, typename Node, typename Condition>
   139    std::vector<const T*> findAllMatches(const Node* Expr, const Condition& Cond);
   140    template <typename T, typename Node, typename Condition>
   141    const T* findFirstMatch(const Node* Expr, const Condition& Cond);
   142    std::optional<QualType> getSizeofType(const Expr* E);
   143    int sizeofType(const Type* T);
   144    int alignofType(const Type* T);
   145    void extractIoctl(const Expr* Cmd, const ConstDesc& Const);
   146    std::optional<ConstDesc> isMacroOrEnum(const Expr* E);
   147    ConstDesc constDesc(const Expr* E, const std::string& Str, const std::string& Value, const SourceRange& SourceRange);
   148  };
   149  
   150  // PPCallbacksTracker records all macro definitions (name/value/source location).
   151  class PPCallbacksTracker : public PPCallbacks {
   152  public:
   153    PPCallbacksTracker(Preprocessor& PP, MacroMap& Macros) : SM(PP.getSourceManager()), Macros(Macros) {}
   154  
   155  private:
   156    SourceManager& SM;
   157    MacroMap& Macros;
   158  
   159    void MacroDefined(const Token& MacroName, const MacroDirective* MD) override {
   160      const char* NameBegin = SM.getCharacterData(MacroName.getLocation());
   161      const char* NameEnd = SM.getCharacterData(MacroName.getEndLoc());
   162      std::string Name(NameBegin, NameEnd - NameBegin);
   163      const char* ValBegin = SM.getCharacterData(MD->getMacroInfo()->getDefinitionLoc());
   164      const char* ValEnd = SM.getCharacterData(MD->getMacroInfo()->getDefinitionEndLoc()) + 1;
   165      // Definition includes the macro name, remove it.
   166      ValBegin += std::min<size_t>(Name.size(), ValEnd - ValBegin);
   167      // Trim whitespace from both ends.
   168      while (ValBegin < ValEnd && isspace(*ValBegin))
   169        ValBegin++;
   170      while (ValBegin < ValEnd && isspace(*(ValEnd - 1)))
   171        ValEnd--;
   172      std::string Value(ValBegin, ValEnd - ValBegin);
   173      Macros[Name] = MacroDef{
   174          .Value = Value,
   175          .SourceRange = SourceRange(MD->getMacroInfo()->getDefinitionLoc(), MD->getMacroInfo()->getDefinitionEndLoc()),
   176      };
   177    }
   178  };
   179  
   180  const Expr* removeCasts(const Expr* E) {
   181    for (;;) {
   182      if (auto* P = dyn_cast<ParenExpr>(E))
   183        E = P->getSubExpr();
   184      else if (auto* C = dyn_cast<CastExpr>(E))
   185        E = C->getSubExpr();
   186      else
   187        break;
   188    }
   189    return E;
   190  }
   191  
   192  bool Extractor::handleBeginSource(CompilerInstance& CI) {
   193    Preprocessor& PP = CI.getPreprocessor();
   194    PP.addPPCallbacks(std::make_unique<PPCallbacksTracker>(PP, Macros));
   195    return true;
   196  }
   197  
   198  template <typename M> void Extractor::match(MatchFunc Action, const M& Matcher) {
   199    Matchers.emplace_back(new MatchCallbackThunk(*this, Action));
   200    addMatcher(Matcher, Matchers.back().get());
   201  }
   202  
   203  void Extractor::run(const MatchFinder::MatchResult& Result, MatchFunc Action) {
   204    Nodes = &Result.Nodes;
   205    Context = Result.Context;
   206    SourceManager = Result.SourceManager;
   207    (this->*Action)();
   208  }
   209  
   210  template <typename T> const T* Extractor::getResult(StringRef ID) const { return Nodes->getNodeAs<T>(ID); }
   211  
   212  std::string TypeName(QualType QT) {
   213    std::string Name = QT.getAsString();
   214    auto Attr = Name.find(" __attribute__");
   215    if (Attr != std::string::npos)
   216      Name = Name.substr(0, Attr);
   217    return Name;
   218  }
   219  
   220  // Top function that converts any clang type QT to our output type.
   221  FieldType Extractor::genType(QualType QT, const std::string& BackupName) {
   222    const Type* T = QT.IgnoreParens().getUnqualifiedType().getDesugaredType(*Context).getTypePtr();
   223    if (llvm::isa<BuiltinType>(T)) {
   224      return IntType{.ByteSize = sizeofType(T), .Name = TypeName(QT), .Base = QualType(T, 0).getAsString()};
   225    }
   226    if (auto* Typ = llvm::dyn_cast<EnumType>(T)) {
   227      return IntType{.ByteSize = sizeofType(T), .Enum = extractEnum(QT, Typ->getDecl())};
   228    }
   229    if (llvm::isa<FunctionProtoType>(T)) {
   230      return PtrType{.Elem = TodoType(), .IsConst = true};
   231    }
   232    if (auto* Typ = llvm::dyn_cast<IncompleteArrayType>(T)) {
   233      return ArrType{.Elem = genType(Typ->getElementType(), BackupName)};
   234    }
   235    if (auto* Typ = llvm::dyn_cast<RecordType>(T)) {
   236      return extractRecord(QT, Typ, BackupName);
   237    }
   238    if (auto* Typ = llvm::dyn_cast<ConstantArrayType>(T)) {
   239      // TODO: the size may be a macro that is different for each arch, e.g.:
   240      //   long foo[FOOSIZE/sizeof(long)];
   241      int Size = Typ->getSize().getZExtValue();
   242      return ArrType{
   243          .Elem = genType(Typ->getElementType(), BackupName),
   244          .MinSize = Size,
   245          .MaxSize = Size,
   246          .Align = alignofType(Typ),
   247          .IsConstSize = true,
   248      };
   249    }
   250    if (auto* Typ = llvm::dyn_cast<PointerType>(T)) {
   251      FieldType Elem;
   252      const QualType& Pointee = Typ->getPointeeType();
   253      if (Pointee->isAnyCharacterType())
   254        Elem = BufferType{.IsString = true};
   255      else if (Pointee->isVoidType())
   256        Elem = ArrType{.Elem = TodoType()};
   257      else
   258        Elem = genType(Pointee, BackupName); // note: it may be an array as well
   259      return PtrType{
   260          .Elem = std::move(Elem),
   261          .IsConst = Pointee.isConstQualified(),
   262      };
   263    }
   264    QT.dump();
   265    llvm::report_fatal_error("unhandled type");
   266  }
   267  
   268  FieldType Extractor::extractRecord(QualType QT, const RecordType* Typ, const std::string& BackupName) {
   269    auto* Decl = Typ->getDecl()->getDefinition();
   270    if (!Decl)
   271      return TodoType(); // definition is in a different TU
   272    std::string Name = Decl->getDeclName().getAsString();
   273    // If it's a typedef of anon struct, we want to use the typedef name:
   274    //   typedef struct {...} foo_t;
   275    if (Name.empty() && QT->isTypedefNameType())
   276      Name = QualType(Typ, 0).getAsString();
   277    // If no other names, fallback to the parent-struct-based name.
   278    if (Name.empty()) {
   279      assert(!BackupName.empty());
   280      // The BackupName is supposed to be unique.
   281      assert(!StructDedup[BackupName]);
   282      Name = BackupName;
   283    }
   284    if (Name.find("struct ") == 0)
   285      Name = Name.substr(strlen("struct "));
   286    if (StructDedup[Name])
   287      return Name;
   288    StructDedup[Name] = true;
   289    std::vector<Field> Fields;
   290    for (const FieldDecl* F : Decl->fields()) {
   291      std::string FieldName = F->getNameAsString();
   292      std::string BackupFieldName = Name + "_" + FieldName;
   293      bool IsAnonymous = false;
   294      if (FieldName.empty()) {
   295        BackupFieldName = Name + "_" + std::to_string(F->getFieldIndex());
   296        FieldName = BackupFieldName;
   297        IsAnonymous = true;
   298      }
   299      FieldType FieldType = genType(F->getType(), BackupFieldName);
   300      int BitWidth = F->isBitField() ? F->getBitWidthValue() : 0;
   301      int CountedBy = F->getType()->isCountAttributedType()
   302                          ? llvm::dyn_cast<FieldDecl>(
   303                                F->getType()->getAs<CountAttributedType>()->getCountExpr()->getReferencedDeclOfCallee())
   304                                ->getFieldIndex()
   305                          : -1;
   306      Fields.push_back(Field{
   307          .Name = FieldName,
   308          .IsAnonymous = IsAnonymous,
   309          .BitWidth = BitWidth,
   310          .CountedBy = CountedBy,
   311          .Type = std::move(FieldType),
   312      });
   313    }
   314    int AlignAttr = 0;
   315    bool Packed = false;
   316    if (Decl->isStruct() && Decl->hasAttrs()) {
   317      for (const auto& A : Decl->getAttrs()) {
   318        if (auto* Attr = llvm::dyn_cast<AlignedAttr>(A))
   319          AlignAttr = Attr->getAlignment(*Context) / 8;
   320        else if (llvm::isa<PackedAttr>(A))
   321          Packed = true;
   322      }
   323    }
   324    Output.emit(Struct{
   325        .Name = Name,
   326        .ByteSize = sizeofType(Typ),
   327        .Align = alignofType(Typ),
   328        .IsUnion = Decl->isUnion(),
   329        .IsPacked = Packed,
   330        .AlignAttr = AlignAttr,
   331        .Fields = std::move(Fields),
   332    });
   333    return Name;
   334  }
   335  
   336  std::string Extractor::extractEnum(QualType QT, const EnumDecl* Decl) {
   337    std::string Name = Decl->getNameAsString();
   338    if (Name.empty()) {
   339      // This is an unnamed enum declared with a typedef:
   340      //   typedef enum {...} enum_name;
   341      auto Typedef = dyn_cast<TypedefType>(QT.getTypePtr());
   342      if (Typedef)
   343        Name = Typedef->getDecl()->getNameAsString();
   344      if (Name.empty()) {
   345        QT.dump();
   346        llvm::report_fatal_error("enum with empty name");
   347      }
   348    }
   349    if (EnumDedup[Name])
   350      return Name;
   351    EnumDedup[Name] = true;
   352    std::vector<std::string> Values;
   353    for (const auto* Enumerator : Decl->enumerators()) {
   354      const std::string& Name = Enumerator->getNameAsString();
   355      emitConst(Name, Enumerator->getInitVal().getExtValue(), Decl->getBeginLoc());
   356      Values.push_back(Name);
   357    }
   358    Output.emit(Enum{
   359        .Name = Name,
   360        .Values = Values,
   361    });
   362    return Name;
   363  }
   364  
   365  void Extractor::emitConst(const std::string& Name, int64_t Val, SourceLocation Loc) {
   366    Output.emit(ConstInfo{
   367        .Name = Name,
   368        .Filename = std::filesystem::relative(SourceManager->getFilename(Loc).str()),
   369        .Value = Val,
   370    });
   371  }
   372  
   373  // Returns base part of the source file containing the canonical declaration.
   374  // If the passed declaration is also a definition, then it will look for a preceeding declaration.
   375  // This is used to generate unique names for static definitions that may have duplicate names
   376  // across different TUs. We assume that the base part of the source file is enough
   377  // to make them unique.
   378  std::string Extractor::getDeclFileID(const Decl* Decl) {
   379    std::string file =
   380        std::filesystem::path(SourceManager->getFilename(Decl->getCanonicalDecl()->getSourceRange().getBegin()).str())
   381            .filename()
   382            .stem()
   383            .string();
   384    std::replace(file.begin(), file.end(), '-', '_');
   385    return file;
   386  }
   387  
   388  std::optional<ConstDesc> Extractor::isMacroOrEnum(const Expr* E) {
   389    if (!E)
   390      return {};
   391    if (auto* Enum = removeCasts(E)->getEnumConstantDecl())
   392      return constDesc(E, Enum->getNameAsString(), "", Enum->getSourceRange());
   393    auto Range = Lexer::getAsCharRange(E->getSourceRange(), *SourceManager, Context->getLangOpts());
   394    const std::string& Str = Lexer::getSourceText(Range, *SourceManager, Context->getLangOpts()).str();
   395    auto MacroDef = Macros.find(Str);
   396    if (MacroDef == Macros.end())
   397      return {};
   398    return constDesc(E, Str, MacroDef->second.Value, MacroDef->second.SourceRange);
   399  }
   400  
   401  ConstDesc Extractor::constDesc(const Expr* E, const std::string& Str, const std::string& Value,
   402                                 const SourceRange& SourceRange) {
   403    int64_t Val = evaluate(E);
   404    emitConst(Str, Val, SourceRange.getBegin());
   405    return ConstDesc{
   406        .Name = Str,
   407        .Value = Value,
   408        .SourceRange = SourceRange,
   409        .IntValue = Val,
   410    };
   411  }
   412  
   413  template <typename Node> void matchHelper(MatchFinder& Finder, ASTContext* Context, const Node* Expr) {
   414    Finder.match(*Expr, *Context);
   415  }
   416  
   417  void matchHelper(MatchFinder& Finder, ASTContext* Context, const ASTContext* Expr) {
   418    assert(Context == Expr);
   419    Finder.matchAST(*Context);
   420  }
   421  
   422  // Returns all matches of Cond named "res" in Expr and returns them casted to T.
   423  // Expr can point to Context for a global match.
   424  template <typename T, typename Node, typename Condition>
   425  std::vector<const T*> Extractor::findAllMatches(const Node* Expr, const Condition& Cond) {
   426    if (!Expr)
   427      return {};
   428    struct Matcher : MatchFinder::MatchCallback {
   429      std::vector<const T*> Matches;
   430      void run(const MatchFinder::MatchResult& Result) override {
   431        if (const T* M = Result.Nodes.getNodeAs<T>("res"))
   432          Matches.push_back(M);
   433      }
   434    };
   435    MatchFinder Finder;
   436    Matcher Matcher;
   437    Finder.addMatcher(Cond, &Matcher);
   438    matchHelper(Finder, Context, Expr);
   439    return std::move(Matcher.Matches);
   440  }
   441  
   442  // Returns the first match of Cond named "res" in Expr and returns it casted to T.
   443  // If no match is found, returns nullptr.
   444  template <typename T, typename Node, typename Condition>
   445  const T* Extractor::findFirstMatch(const Node* Expr, const Condition& Cond) {
   446    const auto& Matches = findAllMatches<T>(Expr, Cond);
   447    return Matches.empty() ? nullptr : Matches[0];
   448  }
   449  
   450  // Extracts the first function reference from the expression.
   451  // TODO: try to extract the actual function reference the expression will be evaluated to
   452  // (the first one is not necessarily the right one).
   453  std::string Extractor::getFuncName(const Expr* Expr) {
   454    auto* Decl =
   455        findFirstMatch<DeclRefExpr>(Expr, stmt(forEachDescendant(declRefExpr(hasType(functionType())).bind("res"))));
   456    return Decl ? Decl->getDecl()->getNameAsString() : "";
   457  }
   458  
   459  // If expression refers to some identifier, returns the identifier name.
   460  // Otherwise returns an empty string.
   461  // For example, if the expression is `function_name`, returns "function_name" string.
   462  std::string Extractor::getDeclName(const Expr* Expr) {
   463    // The expression can be complex and include casts and e.g. InitListExpr,
   464    // to remove all of these we match the first/any DeclRefExpr.
   465    auto* Decl = getValueDecl(Expr);
   466    return Decl ? Decl->getNameAsString() : "";
   467  }
   468  
   469  // Returns the first ValueDecl in the expression.
   470  const ValueDecl* Extractor::getValueDecl(const Expr* Expr) {
   471    // The expression can be complex and include casts and e.g. InitListExpr,
   472    // to remove all of these we match the first/any DeclRefExpr.
   473    auto* Decl = findFirstMatch<DeclRefExpr>(Expr, stmt(forEachDescendant(declRefExpr().bind("res"))));
   474    return Decl ? Decl->getDecl() : nullptr;
   475  }
   476  
   477  // Recursively finds first sizeof in the expression and return the type passed to sizeof.
   478  std::optional<QualType> Extractor::getSizeofType(const Expr* E) {
   479    auto* Res = findFirstMatch<UnaryExprOrTypeTraitExpr>(
   480        E, stmt(forEachDescendant(unaryExprOrTypeTraitExpr(ofKind(UETT_SizeOf)).bind("res"))));
   481    if (!Res)
   482      return {};
   483    if (Res->isArgumentType())
   484      return Res->getArgumentType();
   485    return Res->getArgumentExpr()->getType();
   486  }
   487  
   488  // Returns map of field name -> field index.
   489  std::unordered_map<std::string, unsigned> Extractor::structFieldIndexes(const RecordDecl* Decl) {
   490    // TODO: this is wrong for structs that contain unions and anonymous sub-structs (e.g. genl_split_ops).
   491    // To handle these we would need to look at InitListExpr::getInitializedFieldInUnion, and recurse
   492    // into anonymous structs.
   493    std::unordered_map<std::string, unsigned> Indexes;
   494    for (const auto& F : Decl->fields())
   495      Indexes[F->getNameAsString()] = F->getFieldIndex();
   496    return Indexes;
   497  }
   498  
   499  // Extracts enum info from array variable designated initialization.
   500  // For example, for the following code:
   501  //
   502  //	enum Foo {
   503  //		FooA = 11,
   504  //		FooB = 42,
   505  //	};
   506  //
   507  //	struct Bar bars[] = {
   508  //		[FooA] = {...},
   509  //		[FooB] = {...},
   510  //	};
   511  //
   512  // it returns the following vector: {{11, "FooA"}, {42, "FooB"}}.
   513  std::vector<std::pair<int, std::string>> Extractor::extractDesignatedInitConsts(const VarDecl& ArrayDecl) {
   514    const auto& Matches = findAllMatches<ConstantExpr>(
   515        &ArrayDecl,
   516        decl(forEachDescendant(designatedInitExpr(optionally(has(constantExpr(has(declRefExpr())).bind("res")))))));
   517    std::vector<std::pair<int, std::string>> Inits;
   518    for (auto* Match : Matches) {
   519      const int64_t Val = *Match->getAPValueResult().getInt().getRawData();
   520      const auto& Name = Match->getEnumConstantDecl()->getNameAsString();
   521      const auto& Loc = Match->getEnumConstantDecl()->getBeginLoc();
   522      emitConst(Name, Val, Loc);
   523      Inits.emplace_back(Val, Name);
   524    }
   525    return Inits;
   526  }
   527  
   528  int Extractor::sizeofType(const Type* T) { return static_cast<int>(Context->getTypeInfo(T).Width) / 8; }
   529  int Extractor::alignofType(const Type* T) { return static_cast<int>(Context->getTypeInfo(T).Align) / 8; }
   530  
   531  template <typename T> T Extractor::evaluate(const Expr* E) {
   532    Expr::EvalResult Res;
   533    E->EvaluateAsConstantExpr(Res, *Context);
   534    // TODO: it's unclear what to do if it's not Int (in some cases we see None here).
   535    if (Res.Val.getKind() != APValue::Int)
   536      return 0;
   537    auto val = Res.Val.getInt();
   538    if (val.isSigned())
   539      return val.sextOrTrunc(64).getSExtValue();
   540    return val.zextOrTrunc(64).getZExtValue();
   541  }
   542  
   543  void Extractor::matchNetlinkPolicy() {
   544    const auto* PolicyArray = getResult<VarDecl>("netlink_policy");
   545    const auto* Init = llvm::dyn_cast_if_present<InitListExpr>(PolicyArray->getInit());
   546    if (!Init)
   547      return;
   548    const auto& InitConsts = extractDesignatedInitConsts(*PolicyArray);
   549    auto Fields = structFieldIndexes(Init->getInit(0)->getType()->getAsRecordDecl());
   550    std::vector<NetlinkAttr> Attrs;
   551    for (const auto& [I, Name] : InitConsts) {
   552      const auto* AttrInit = llvm::dyn_cast<InitListExpr>(Init->getInit(I));
   553      const std::string& AttrKind = getDeclName(AttrInit->getInit(Fields["type"]));
   554      if (AttrKind == "NLA_REJECT")
   555        continue;
   556      auto* LenExpr = AttrInit->getInit(Fields["len"]);
   557      int MaxSize = 0;
   558      std::string NestedPolicy;
   559      std::unique_ptr<FieldType> Elem;
   560      if (AttrKind == "NLA_NESTED" || AttrKind == "NLA_NESTED_ARRAY") {
   561        if (const auto* NestedDecl = getValueDecl(AttrInit->getInit(2)))
   562          NestedPolicy = getUniqueDeclName(NestedDecl);
   563      } else {
   564        MaxSize = evaluate<int>(LenExpr);
   565        if (auto SizeofType = getSizeofType(LenExpr))
   566          Elem = std::make_unique<FieldType>(genType(*SizeofType));
   567      }
   568      Attrs.push_back(NetlinkAttr{
   569          .Name = Name,
   570          .Kind = AttrKind,
   571          .MaxSize = MaxSize,
   572          .NestedPolicy = NestedPolicy,
   573          .Elem = std::move(Elem),
   574      });
   575    }
   576    Output.emit(NetlinkPolicy{
   577        .Name = getUniqueDeclName(PolicyArray),
   578        .Attrs = std::move(Attrs),
   579    });
   580  }
   581  
   582  void Extractor::matchNetlinkFamily() {
   583    const auto* FamilyInit = getResult<InitListExpr>("genl_family_init");
   584    auto Fields = structFieldIndexes(getResult<RecordDecl>("genl_family"));
   585    const std::string& FamilyName = llvm::dyn_cast<StringLiteral>(FamilyInit->getInit(Fields["name"]))->getString().str();
   586    std::string DefaultPolicy;
   587    if (const auto* PolicyDecl = FamilyInit->getInit(Fields["policy"])->getAsBuiltinConstantDeclRef(*Context))
   588      DefaultPolicy = getUniqueDeclName(PolicyDecl);
   589    std::vector<NetlinkOp> Ops;
   590    for (const auto& OpsName : {"ops", "small_ops", "split_ops"}) {
   591      const auto* OpsDecl =
   592          llvm::dyn_cast_if_present<VarDecl>(FamilyInit->getInit(Fields[OpsName])->getAsBuiltinConstantDeclRef(*Context));
   593      const auto NumOps = FamilyInit->getInit(Fields[std::string("n_") + OpsName])->getIntegerConstantExpr(*Context);
   594      // The ops variable may be defined in another TU.
   595      // TODO: extract variables from another TUs.
   596      if (!OpsDecl || !OpsDecl->getInit() || !NumOps)
   597        continue;
   598      const auto* OpsInit = llvm::dyn_cast<InitListExpr>(OpsDecl->getInit());
   599      auto OpsFields = structFieldIndexes(OpsInit->getInit(0)->getType()->getAsRecordDecl());
   600      for (int I = 0; I < *NumOps; I++) {
   601        const auto* OpInit = llvm::dyn_cast<InitListExpr>(OpsInit->getInit(I));
   602        const auto* CmdInit = OpInit->getInit(OpsFields["cmd"])->getEnumConstantDecl();
   603        if (!CmdInit)
   604          continue;
   605        const std::string& OpName = CmdInit->getNameAsString();
   606        emitConst(OpName, CmdInit->getInitVal().getExtValue(), CmdInit->getBeginLoc());
   607        std::string Policy;
   608        if (OpsFields.count("policy") != 0) {
   609          if (const auto* PolicyDecl = OpInit->getInit(OpsFields["policy"])->getAsBuiltinConstantDeclRef(*Context))
   610            Policy = getUniqueDeclName(PolicyDecl);
   611        }
   612        if (Policy.empty())
   613          Policy = DefaultPolicy;
   614        std::string Func = getFuncName(OpInit->getInit(OpsFields["doit"]));
   615        if (Func.empty())
   616          Func = getFuncName(OpInit->getInit(OpsFields["dumpit"]));
   617        int Flags = evaluate(OpInit->getInit(OpsFields["flags"]));
   618        const char* Access = AccessUser;
   619        constexpr int GENL_ADMIN_PERM = 0x01;
   620        constexpr int GENL_UNS_ADMIN_PERM = 0x10;
   621        if (Flags & GENL_ADMIN_PERM)
   622          Access = AccessAdmin;
   623        else if (Flags & GENL_UNS_ADMIN_PERM)
   624          Access = AccessNsAdmin;
   625        Ops.push_back(NetlinkOp{
   626            .Name = OpName,
   627            .Func = Func,
   628            .Access = Access,
   629            .Policy = Policy,
   630        });
   631      }
   632    }
   633    Output.emit(NetlinkFamily{
   634        .Name = FamilyName,
   635        .Ops = std::move(Ops),
   636    });
   637  }
   638  
   639  std::string Extractor::getUniqueDeclName(const NamedDecl* Decl) {
   640    return Decl->getNameAsString() + "_" + getDeclFileID(Decl);
   641  }
   642  
   643  bool isInterestingCall(const CallExpr* Call) {
   644    auto* CalleeDecl = Call->getDirectCallee();
   645    // We don't handle indirect calls yet.
   646    if (!CalleeDecl)
   647      return false;
   648    // Builtins are not interesting and won't have a body.
   649    if (CalleeDecl->getBuiltinID() != Builtin::ID::NotBuiltin)
   650      return false;
   651    const std::string& Callee = CalleeDecl->getNameAsString();
   652    // There are too many of these and they should only be called at runtime in broken builds.
   653    if (Callee.rfind("__compiletime_assert", 0) == 0 || Callee == "____wrong_branch_error" ||
   654        Callee == "__bad_size_call_parameter")
   655      return false;
   656    return true;
   657  }
   658  
   659  struct FunctionAnalyzer : RecursiveASTVisitor<FunctionAnalyzer> {
   660    FunctionAnalyzer(Extractor* Extractor, const FunctionDecl* Func)
   661        : Extractor(Extractor), CurrentFunc(Func->getNameAsString()), Context(Extractor->Context),
   662          SourceManager(Extractor->SourceManager) {
   663      // The global function scope.
   664      Scopes.push_back(FunctionScope{.Arg = -1});
   665      Current = &Scopes[0];
   666      TraverseStmt(Func->getBody());
   667    }
   668  
   669    bool VisitBinaryOperator(const BinaryOperator* B) {
   670      if (B->isAssignmentOp())
   671        noteFact(getTypingEntity(B->getRHS()), getTypingEntity(B->getLHS()));
   672      return true;
   673    }
   674  
   675    bool VisitVarDecl(const VarDecl* D) {
   676      if (D->getStorageDuration() == SD_Automatic)
   677        noteFact(getTypingEntity(D->getInit()), getDeclTypingEntity(D));
   678      return true;
   679    }
   680  
   681    bool VisitReturnStmt(const ReturnStmt* Ret) {
   682      noteFact(getTypingEntity(Ret->getRetValue()), EntityReturn{.Func = CurrentFunc});
   683      return true;
   684    }
   685  
   686    bool VisitCallExpr(const CallExpr* Call) {
   687      if (isInterestingCall(Call)) {
   688        const std::string& Callee = Call->getDirectCallee()->getNameAsString();
   689        Current->Calls.push_back(Callee);
   690        for (unsigned AI = 0; AI < Call->getNumArgs(); AI++) {
   691          noteFact(getTypingEntity(Call->getArg(AI)), EntityArgument{
   692                                                          .Func = Callee,
   693                                                          .Arg = AI,
   694                                                      });
   695        }
   696      }
   697      return true;
   698    }
   699  
   700    bool VisitSwitchStmt(const SwitchStmt* S) {
   701      // We are only interested in switches on the function arguments
   702      // with cases that mention defines from uapi headers.
   703      // This covers ioctl/fcntl/prctl/ptrace/etc.
   704      bool IsInteresting = false;
   705      auto Param = getTypingEntity(S->getCond());
   706      if (Current == &Scopes[0] && Param && Param->Argument) {
   707        for (auto* C = S->getSwitchCaseList(); C; C = C->getNextSwitchCase()) {
   708          auto* Case = dyn_cast<CaseStmt>(C);
   709          if (!Case)
   710            continue;
   711          auto LMacro = Extractor->isMacroOrEnum(Case->getLHS());
   712          auto RMacro = Extractor->isMacroOrEnum(Case->getRHS());
   713          if (LMacro || RMacro) {
   714            IsInteresting = true;
   715            break;
   716          }
   717        }
   718      }
   719  
   720      SwitchStack.push({S, IsInteresting, IsInteresting ? static_cast<int>(Param->Argument->Arg) : -1});
   721      return true;
   722    }
   723  
   724    bool VisitSwitchCase(const SwitchCase* C) {
   725      if (!SwitchStack.top().IsInteresting)
   726        return true;
   727      // If there are several cases with the same "body", we want to create new scope
   728      // only for the first one:
   729      //   case FOO:
   730      //   case BAR:
   731      //     ... some code ...
   732      if (!C->getNextSwitchCase() || C->getNextSwitchCase()->getSubStmt() != C) {
   733        int Line = SourceManager->getExpansionLineNumber(C->getBeginLoc());
   734        if (Current != &Scopes[0])
   735          Current->EndLine = Line;
   736        Scopes.push_back(FunctionScope{
   737            .Arg = SwitchStack.top().Arg,
   738            .StartLine = Line,
   739        });
   740        Current = &Scopes.back();
   741      }
   742      // Otherwise it's a default case, for which we don't add any values.
   743      if (auto* Case = dyn_cast<CaseStmt>(C)) {
   744        int64_t LVal = Extractor->evaluate(Case->getLHS());
   745        auto LMacro = Extractor->isMacroOrEnum(Case->getLHS());
   746        if (LMacro) {
   747          Current->Values.push_back(LMacro->Name);
   748          Extractor->extractIoctl(Case->getLHS(), *LMacro);
   749        } else {
   750          Current->Values.push_back(std::to_string(LVal));
   751        }
   752        if (Case->caseStmtIsGNURange()) {
   753          // GNU range is:
   754          //   case FOO ... BAR:
   755          // Add all values in the range.
   756          int64_t RVal = Extractor->evaluate(Case->getRHS());
   757          auto RMacro = Extractor->isMacroOrEnum(Case->getRHS());
   758          for (int64_t V = LVal + 1; V <= RVal - (RMacro ? 1 : 0); V++)
   759            Current->Values.push_back(std::to_string(V));
   760          if (RMacro)
   761            Current->Values.push_back(RMacro->Name);
   762        }
   763      }
   764      return true;
   765    }
   766  
   767    bool dataTraverseStmtPost(const Stmt* S) {
   768      if (SwitchStack.empty())
   769        return true;
   770      auto Top = SwitchStack.top();
   771      if (Top.S != S)
   772        return true;
   773      if (Top.IsInteresting) {
   774        if (Current != &Scopes[0])
   775          Current->EndLine = SourceManager->getExpansionLineNumber(S->getEndLoc());
   776        Current = &Scopes[0];
   777      }
   778      SwitchStack.pop();
   779      return true;
   780    }
   781  
   782    void noteFact(std::optional<TypingEntity>&& Src, std::optional<TypingEntity>&& Dst) {
   783      if (Src && Dst)
   784        Current->Facts.push_back({std::move(*Src), std::move(*Dst)});
   785    }
   786  
   787    std::optional<TypingEntity> getTypingEntity(const Expr* E);
   788    std::optional<TypingEntity> getDeclTypingEntity(const Decl* Decl);
   789  
   790    struct SwitchDesc {
   791      const SwitchStmt* S;
   792      bool IsInteresting;
   793      int Arg;
   794    };
   795  
   796    Extractor* Extractor;
   797    std::string CurrentFunc;
   798    ASTContext* Context;
   799    SourceManager* SourceManager;
   800    std::vector<FunctionScope> Scopes;
   801    FunctionScope* Current = nullptr;
   802    std::unordered_map<const VarDecl*, int> LocalVars;
   803    std::unordered_map<std::string, int> LocalSeq;
   804    std::stack<SwitchDesc> SwitchStack;
   805  };
   806  
   807  void Extractor::matchFunctionDef() {
   808    const auto* Func = getResult<FunctionDecl>("function");
   809    if (!Func->getBody())
   810      return;
   811    auto Range = Func->getSourceRange();
   812    const std::string& SourceFile =
   813        std::filesystem::relative(SourceManager->getFilename(SourceManager->getExpansionLoc(Range.getBegin())).str());
   814    const int StartLine = SourceManager->getExpansionLineNumber(Range.getBegin());
   815    const int EndLine = SourceManager->getExpansionLineNumber(Range.getEnd());
   816    FunctionAnalyzer Analyzer(this, Func);
   817    Output.emit(Function{
   818        .Name = Func->getNameAsString(),
   819        .File = SourceFile,
   820        .StartLine = StartLine,
   821        .EndLine = EndLine,
   822        .IsStatic = Func->isStatic(),
   823        .Scopes = std::move(Analyzer.Scopes),
   824    });
   825  }
   826  
   827  std::optional<TypingEntity> FunctionAnalyzer::getTypingEntity(const Expr* E) {
   828    if (!E)
   829      return {};
   830    E = removeCasts(E);
   831    if (auto* DeclRef = dyn_cast<DeclRefExpr>(E)) {
   832      return getDeclTypingEntity(DeclRef->getDecl());
   833    } else if (auto* Member = dyn_cast<MemberExpr>(E)) {
   834      const Type* StructType =
   835          Member->getBase()->getType().IgnoreParens().getUnqualifiedType().getDesugaredType(*Context).getTypePtr();
   836      if (auto* T = dyn_cast<PointerType>(StructType))
   837        StructType = T->getPointeeType().IgnoreParens().getUnqualifiedType().getDesugaredType(*Context).getTypePtr();
   838      auto* StructDecl = dyn_cast<RecordType>(StructType)->getDecl();
   839      std::string StructName = StructDecl->getNameAsString();
   840      if (StructName.empty()) {
   841        // The struct may be anonymous, but we need some name.
   842        // Ideally we generate the same name we generate in struct definitions, then it will be possible
   843        // to match them between each other. However, it does not seem to be easy. We can use DeclContext::getParent
   844        // to get declaration of the enclosing struct, but we will also need to figure out the field index
   845        // and handle all corner cases. For now we just use the following quick hack: hash declaration file:line.
   846        // Note: the hash must be stable across different machines (for test golden files), so we take just
   847        // the last part of the file name.
   848        const std::string& SourceFile =
   849            std::filesystem::path(
   850                SourceManager->getFilename(SourceManager->getExpansionLoc(StructDecl->getBeginLoc())).str())
   851                .filename()
   852                .string();
   853        int Line = SourceManager->getExpansionLineNumber(StructDecl->getBeginLoc());
   854        StructName = std::to_string(std::hash<std::string>()(SourceFile) + std::hash<int>()(Line));
   855      }
   856      return EntityField{
   857          .Struct = StructName,
   858          .Field = Member->getMemberDecl()->getNameAsString(),
   859      };
   860    } else if (auto* Unary = dyn_cast<UnaryOperator>(E)) {
   861      if (Unary->getOpcode() == UnaryOperatorKind::UO_AddrOf) {
   862        if (auto* DeclRef = dyn_cast<DeclRefExpr>(removeCasts(Unary->getSubExpr()))) {
   863          if (auto* Var = dyn_cast<VarDecl>(DeclRef->getDecl())) {
   864            if (Var->hasGlobalStorage()) {
   865              return EntityGlobalAddr{
   866                  .Name = Extractor->getUniqueDeclName(Var),
   867              };
   868            }
   869          }
   870        }
   871      }
   872    } else if (auto* Call = dyn_cast<CallExpr>(E)) {
   873      if (isInterestingCall(Call)) {
   874        return EntityReturn{
   875            .Func = Call->getDirectCallee()->getNameAsString(),
   876        };
   877      }
   878    }
   879    return {};
   880  }
   881  
   882  std::optional<TypingEntity> FunctionAnalyzer::getDeclTypingEntity(const Decl* Decl) {
   883    if (auto* Parm = dyn_cast<ParmVarDecl>(Decl)) {
   884      return EntityArgument{
   885          .Func = CurrentFunc,
   886          .Arg = Parm->getFunctionScopeIndex(),
   887      };
   888    } else if (auto* Var = dyn_cast<VarDecl>(Decl)) {
   889      if (Var->hasLocalStorage()) {
   890        std::string VarName = Var->getNameAsString();
   891        // Theoretically there can be several local vars with the same name.
   892        // Give them unique suffixes if that's the case.
   893        if (LocalVars.count(Var) == 0)
   894          LocalVars[Var] = LocalSeq[VarName]++;
   895        if (int Seq = LocalVars[Var])
   896          VarName += std::to_string(Seq);
   897        return EntityLocal{
   898            .Name = VarName,
   899        };
   900      }
   901    }
   902    return {};
   903  }
   904  
   905  void Extractor::matchSyscall() {
   906    const auto* Func = getResult<FunctionDecl>("syscall");
   907    std::vector<Field> Args;
   908    for (const auto& Param : Func->parameters()) {
   909      Args.push_back(Field{
   910          .Name = Param->getNameAsString(),
   911          .Type = genType(Param->getType()),
   912      });
   913    }
   914    Output.emit(Syscall{
   915        .Func = Func->getNameAsString(),
   916        .Args = std::move(Args),
   917    });
   918  }
   919  
   920  void Extractor::matchIouring() {
   921    const auto* IssueDefs = getResult<VarDecl>("io_issue_defs");
   922    const auto& InitConsts = extractDesignatedInitConsts(*IssueDefs);
   923    const auto* InitList = llvm::dyn_cast<InitListExpr>(IssueDefs->getInit());
   924    auto Fields = structFieldIndexes(InitList->getInit(0)->getType()->getAsRecordDecl());
   925    for (const auto& [I, Name] : InitConsts) {
   926      const auto& Init = llvm::dyn_cast<InitListExpr>(InitList->getInit(I));
   927      std::string Prep = getFuncName(Init->getInit(Fields["prep"]));
   928      if (Prep == "io_eopnotsupp_prep")
   929        continue;
   930      Output.emit(IouringOp{
   931          .Name = Name,
   932          .Func = getFuncName(Init->getInit(Fields["issue"])),
   933      });
   934    }
   935  }
   936  
   937  void Extractor::matchFileOps() {
   938    const auto* Fops = getResult<InitListExpr>("init");
   939    if (Fops->getNumInits() == 0 || isa<DesignatedInitExpr>(Fops->getInit(0))) {
   940      // Some code constructs produce init list with DesignatedInitExpr.
   941      // Unclear why, but it won't be handled by the following code, and is not necessary to handle.
   942      return;
   943    }
   944    const auto* Var = getResult<VarDecl>("var");
   945    std::string VarName = getUniqueDeclName(Var);
   946    int NameSeq = FileOpsDedup[VarName]++;
   947    if (NameSeq)
   948      VarName += std::to_string(NameSeq);
   949    auto Fields = structFieldIndexes(Fops->getType()->getAsRecordDecl());
   950    std::string Open = getFuncName(Fops->getInit(Fields["open"]));
   951    std::string Ioctl = getFuncName(Fops->getInit(Fields["unlocked_ioctl"]));
   952    std::string Read = getFuncName(Fops->getInit(Fields["read"]));
   953    if (Read.empty())
   954      Read = getFuncName(Fops->getInit(Fields["read_iter"]));
   955    std::string Write = getFuncName(Fops->getInit(Fields["write"]));
   956    if (Write.empty())
   957      Write = getFuncName(Fops->getInit(Fields["write_iter"]));
   958    std::string Mmap = getFuncName(Fops->getInit(Fields["mmap"]));
   959    if (Mmap.empty())
   960      Mmap = getFuncName(Fops->getInit(Fields["get_unmapped_area"]));
   961    Output.emit(FileOps{
   962        .Name = VarName,
   963        .Open = std::move(Open),
   964        .Read = std::move(Read),
   965        .Write = std::move(Write),
   966        .Mmap = std::move(Mmap),
   967        .Ioctl = std::move(Ioctl),
   968    });
   969  }
   970  
   971  void Extractor::extractIoctl(const Expr* Cmd, const ConstDesc& Const) {
   972    // This is old style ioctl defined directly via a number.
   973    // We can't infer anything about it.
   974    if (Const.Value.find("_IO") != 0)
   975      return;
   976    FieldType Type;
   977    auto Dir = _IOC_DIR(Const.IntValue);
   978    if (Dir == _IOC_NONE) {
   979      Type = IntType{.ByteSize = 1, .IsConst = true};
   980    } else if (std::optional<QualType> Arg = getSizeofType(Cmd)) {
   981      Type = PtrType{
   982          .Elem = genType(*Arg),
   983          .IsConst = Dir == _IOC_READ,
   984      };
   985    } else {
   986      // It is an ioctl, but we failed to get the arg type.
   987      // Let the Go part figure out a good arg type.
   988      return;
   989    }
   990    Output.emit(Ioctl{
   991        .Name = Const.Name,
   992        .Type = std::move(Type),
   993    });
   994  }
   995  
   996  int main(int argc, const char** argv) {
   997    llvm::cl::OptionCategory Options("syz-declextract options");
   998    auto OptionsParser = tooling::CommonOptionsParser::create(argc, argv, Options);
   999    if (!OptionsParser) {
  1000      llvm::errs() << OptionsParser.takeError();
  1001      return 1;
  1002    }
  1003    Extractor Ex;
  1004    tooling::ClangTool Tool(OptionsParser->getCompilations(), OptionsParser->getSourcePathList());
  1005    if (Tool.run(tooling::newFrontendActionFactory(&Ex, &Ex).get()))
  1006      return 1;
  1007    Ex.print();
  1008    return 0;
  1009  }