github.com/google/syzkaller@v0.0.0-20251211124644-a066d2bc4b02/pkg/compiler/compiler.go (about)

     1  // Copyright 2017 syzkaller project authors. All rights reserved.
     2  // Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
     3  
     4  // Package compiler generates sys descriptions of syscalls, types and resources
     5  // from textual descriptions.
     6  package compiler
     7  
     8  import (
     9  	"fmt"
    10  	"strconv"
    11  	"strings"
    12  
    13  	"github.com/google/syzkaller/pkg/ast"
    14  	"github.com/google/syzkaller/prog"
    15  	"github.com/google/syzkaller/sys/targets"
    16  )
    17  
    18  // Overview of compilation process:
    19  // 1. ast.Parse on text file does tokenization and builds AST.
    20  //    This step catches basic syntax errors. AST contains full debug info.
    21  // 2. ExtractConsts as AST returns set of constant identifiers.
    22  //    This step also does verification of include/incdir/define AST nodes.
    23  // 3. User translates constants to values.
    24  // 4. Compile on AST and const values does the rest of the work and returns Prog
    25  //    containing generated prog objects.
    26  // 4.1. assignSyscallNumbers: uses consts to assign syscall numbers.
    27  //      This step also detects unsupported syscalls and discards no longer
    28  //      needed AST nodes (inlcude, define, comments, etc).
    29  // 4.2. patchConsts: patches Int nodes referring to consts with corresponding values.
    30  //      Also detects unsupported syscalls, structs, resources due to missing consts.
    31  // 4.3. check: does extensive semantical checks of AST.
    32  // 4.4. gen: generates prog objects from AST.
    33  
    34  // Prog is description compilation result.
    35  type Prog struct {
    36  	Resources []*prog.ResourceDesc
    37  	Syscalls  []*prog.Syscall
    38  	Types     []prog.Type
    39  	// Set of unsupported syscalls/flags.
    40  	Unsupported map[string]bool
    41  	// Returned if consts was nil.
    42  	fileConsts map[string]*ConstInfo
    43  }
    44  
    45  func createCompiler(desc *ast.Description, target *targets.Target, eh ast.ErrorHandler) *compiler {
    46  	if eh == nil {
    47  		eh = ast.LoggingHandler
    48  	}
    49  	desc.Nodes = append(builtinDescs.Clone().Nodes, desc.Nodes...)
    50  	comp := &compiler{
    51  		desc:           desc,
    52  		target:         target,
    53  		eh:             eh,
    54  		ptrSize:        target.PtrSize,
    55  		unsupported:    make(map[string]bool),
    56  		resources:      make(map[string]*ast.Resource),
    57  		typedefs:       make(map[string]*ast.TypeDef),
    58  		structs:        make(map[string]*ast.Struct),
    59  		intFlags:       make(map[string]*ast.IntFlags),
    60  		strFlags:       make(map[string]*ast.StrFlags),
    61  		used:           make(map[string]bool),
    62  		usedTypedefs:   make(map[string]bool),
    63  		brokenTypedefs: make(map[string]bool),
    64  		structVarlen:   make(map[string]bool),
    65  		structTypes:    make(map[string]prog.Type),
    66  		structFiles:    make(map[*ast.Struct]map[string]ast.Pos),
    67  		recursiveQuery: make(map[ast.Node]bool),
    68  		builtinConsts: map[string]uint64{
    69  			"PTR_SIZE": target.PtrSize,
    70  		},
    71  	}
    72  	return comp
    73  }
    74  
    75  // Compile compiles sys description.
    76  func Compile(desc *ast.Description, consts map[string]uint64, target *targets.Target, eh ast.ErrorHandler) *Prog {
    77  	comp := createCompiler(desc.Clone(), target, eh)
    78  	comp.filterArch()
    79  	comp.typecheck()
    80  	comp.flattenFlags()
    81  	// The subsequent, more complex, checks expect basic validity of the tree,
    82  	// in particular corrent number of type arguments. If there were errors,
    83  	// don't proceed to avoid out-of-bounds references to type arguments.
    84  	if comp.errors != 0 {
    85  		return nil
    86  	}
    87  	if consts == nil {
    88  		fileConsts := comp.extractConsts()
    89  		if comp.errors != 0 {
    90  			return nil
    91  		}
    92  		return &Prog{fileConsts: fileConsts}
    93  	}
    94  	if comp.target.SyscallNumbers {
    95  		comp.assignSyscallNumbers(consts)
    96  	}
    97  	comp.patchConsts(consts)
    98  	comp.check(consts)
    99  	if comp.errors != 0 {
   100  		return nil
   101  	}
   102  	syscalls := comp.genSyscalls()
   103  	comp.layoutTypes(syscalls)
   104  	types := comp.generateTypes(syscalls)
   105  	prg := &Prog{
   106  		Resources:   comp.genResources(),
   107  		Syscalls:    syscalls,
   108  		Types:       types,
   109  		Unsupported: comp.unsupported,
   110  	}
   111  	if comp.errors != 0 {
   112  		return nil
   113  	}
   114  	for _, w := range comp.warnings {
   115  		eh(w.pos, w.msg)
   116  	}
   117  	return prg
   118  }
   119  
   120  type compiler struct {
   121  	desc     *ast.Description
   122  	target   *targets.Target
   123  	eh       ast.ErrorHandler
   124  	errors   int
   125  	warnings []warn
   126  	ptrSize  uint64
   127  
   128  	unsupported    map[string]bool
   129  	resources      map[string]*ast.Resource
   130  	typedefs       map[string]*ast.TypeDef
   131  	structs        map[string]*ast.Struct
   132  	intFlags       map[string]*ast.IntFlags
   133  	strFlags       map[string]*ast.StrFlags
   134  	used           map[string]bool // contains used structs/resources
   135  	usedTypedefs   map[string]bool
   136  	brokenTypedefs map[string]bool
   137  
   138  	structVarlen   map[string]bool
   139  	structTypes    map[string]prog.Type
   140  	structFiles    map[*ast.Struct]map[string]ast.Pos
   141  	builtinConsts  map[string]uint64
   142  	fileMetas      map[string]Meta
   143  	recursiveQuery map[ast.Node]bool
   144  }
   145  
   146  type warn struct {
   147  	pos ast.Pos
   148  	msg string
   149  }
   150  
   151  func (comp *compiler) error(pos ast.Pos, msg string, args ...interface{}) {
   152  	comp.errors++
   153  	comp.eh(pos, fmt.Sprintf(msg, args...))
   154  }
   155  
   156  func (comp *compiler) warning(pos ast.Pos, msg string, args ...interface{}) {
   157  	comp.warnings = append(comp.warnings, warn{pos, fmt.Sprintf(msg, args...)})
   158  }
   159  
   160  func (comp *compiler) filterArch() {
   161  	comp.desc = comp.desc.Filter(func(n ast.Node) bool {
   162  		pos, typ, name := n.Info()
   163  		if comp.fileMeta(pos).SupportsArch(comp.target.Arch) {
   164  			return true
   165  		}
   166  		switch n.(type) {
   167  		case *ast.Resource, *ast.Struct, *ast.Call, *ast.TypeDef:
   168  			// This is required to keep the unsupported diagnostic working,
   169  			// otherwise sysgen will think that these things are still supported on some arches.
   170  			comp.unsupported[typ+" "+name] = true
   171  		}
   172  		return false
   173  	})
   174  }
   175  
   176  func (comp *compiler) structIsVarlen(name string) bool {
   177  	if varlen, ok := comp.structVarlen[name]; ok {
   178  		return varlen
   179  	}
   180  	s := comp.structs[name]
   181  	if s.IsUnion {
   182  		res := comp.parseIntAttrs(unionAttrs, s, s.Attrs)
   183  		if res[attrVarlen] != 0 {
   184  			comp.structVarlen[name] = true
   185  			return true
   186  		}
   187  	}
   188  	comp.structVarlen[name] = false // to not hang on recursive types
   189  	varlen := false
   190  	for _, fld := range s.Fields {
   191  		hasIfAttr := false
   192  		for _, attr := range fld.Attrs {
   193  			if structFieldAttrs[attr.Ident] == attrIf {
   194  				hasIfAttr = true
   195  			}
   196  		}
   197  		if hasIfAttr || comp.isVarlen(fld.Type) {
   198  			varlen = true
   199  			break
   200  		}
   201  	}
   202  	comp.structVarlen[name] = varlen
   203  	return varlen
   204  }
   205  
   206  func (comp *compiler) parseIntAttrs(descs map[string]*attrDesc, parent ast.Node,
   207  	attrs []*ast.Type) map[*attrDesc]uint64 {
   208  	intAttrs, _, _ := comp.parseAttrs(descs, parent, attrs)
   209  	return intAttrs
   210  }
   211  
   212  func (comp *compiler) parseAttrs(descs map[string]*attrDesc, parent ast.Node, attrs []*ast.Type) (
   213  	map[*attrDesc]uint64, map[*attrDesc]prog.Expression, map[*attrDesc]string) {
   214  	_, parentType, parentName := parent.Info()
   215  	resInt := make(map[*attrDesc]uint64)
   216  	resExpr := make(map[*attrDesc]prog.Expression)
   217  	resString := make(map[*attrDesc]string)
   218  	for _, attr := range attrs {
   219  		if unexpected, _, ok := checkTypeKind(attr, kindIdent); !ok {
   220  			comp.error(attr.Pos, "unexpected %v, expect attribute", unexpected)
   221  			return resInt, resExpr, resString
   222  		}
   223  		if len(attr.Colon) != 0 {
   224  			comp.error(attr.Colon[0].Pos, "unexpected ':'")
   225  			return resInt, resExpr, resString
   226  		}
   227  		desc := descs[attr.Ident]
   228  		if desc == nil {
   229  			comp.error(attr.Pos, "unknown %v %v attribute %v", parentType, parentName, attr.Ident)
   230  			return resInt, resExpr, resString
   231  		}
   232  		_, dupInt := resInt[desc]
   233  		_, dupExpr := resExpr[desc]
   234  		_, dupString := resString[desc]
   235  		if dupInt || dupExpr || dupString {
   236  			comp.error(attr.Pos, "duplicate %v %v attribute %v", parentType, parentName, attr.Ident)
   237  			return resInt, resExpr, resString
   238  		}
   239  
   240  		switch desc.Type {
   241  		case flagAttr:
   242  			resInt[desc] = 1
   243  			if len(attr.Args) != 0 {
   244  				comp.error(attr.Pos, "%v attribute has args", attr.Ident)
   245  				return nil, nil, nil
   246  			}
   247  		case intAttr:
   248  			resInt[desc] = comp.parseAttrIntArg(attr)
   249  		case exprAttr:
   250  			resExpr[desc] = comp.parseAttrExprArg(attr)
   251  		case stringAttr:
   252  			resString[desc] = comp.parseAttrStringArg(attr)
   253  		default:
   254  			comp.error(attr.Pos, "attribute %v has unknown type", attr.Ident)
   255  			return nil, nil, nil
   256  		}
   257  	}
   258  	return resInt, resExpr, resString
   259  }
   260  
   261  func (comp *compiler) parseAttrExprArg(attr *ast.Type) prog.Expression {
   262  	if len(attr.Args) != 1 {
   263  		comp.error(attr.Pos, "%v attribute is expected to have only one argument", attr.Ident)
   264  		return nil
   265  	}
   266  	arg := attr.Args[0]
   267  	if arg.HasString {
   268  		comp.error(attr.Pos, "%v argument must be an expression", attr.Ident)
   269  		return nil
   270  	}
   271  	return comp.genExpression(arg)
   272  }
   273  
   274  func (comp *compiler) parseAttrIntArg(attr *ast.Type) uint64 {
   275  	if len(attr.Args) != 1 {
   276  		comp.error(attr.Pos, "%v attribute is expected to have 1 argument", attr.Ident)
   277  		return 0
   278  	}
   279  	sz := attr.Args[0]
   280  	if unexpected, _, ok := checkTypeKind(sz, kindInt); !ok {
   281  		comp.error(sz.Pos, "unexpected %v, expect int", unexpected)
   282  		return 0
   283  	}
   284  	if len(sz.Colon) != 0 || len(sz.Args) != 0 {
   285  		comp.error(sz.Pos, "%v attribute has colon or args", attr.Ident)
   286  		return 0
   287  	}
   288  	return sz.Value
   289  }
   290  
   291  func (comp *compiler) parseAttrStringArg(attr *ast.Type) string {
   292  	if len(attr.Args) != 1 {
   293  		comp.error(attr.Pos, "%v attribute is expected to have 1 argument", attr.Ident)
   294  		return ""
   295  	}
   296  	arg := attr.Args[0]
   297  	if !arg.HasString {
   298  		comp.error(attr.Pos, "%v argument must be a string", attr.Ident)
   299  		return ""
   300  	}
   301  	return arg.String
   302  }
   303  
   304  func (comp *compiler) getTypeDesc(t *ast.Type) *typeDesc {
   305  	if desc := builtinTypes[t.Ident]; desc != nil {
   306  		return desc
   307  	}
   308  	if comp.resources[t.Ident] != nil {
   309  		return typeResource
   310  	}
   311  	if comp.structs[t.Ident] != nil {
   312  		return typeStruct
   313  	}
   314  	if comp.typedefs[t.Ident] != nil {
   315  		return typeTypedef
   316  	}
   317  	return nil
   318  }
   319  
   320  func (comp *compiler) getArgsBase(t *ast.Type, isArg bool) (*typeDesc, []*ast.Type, prog.IntTypeCommon) {
   321  	desc := comp.getTypeDesc(t)
   322  	if desc == nil {
   323  		panic(fmt.Sprintf("no type desc for %#v", *t))
   324  	}
   325  	args, opt := removeOpt(t)
   326  	com := genCommon(t.Ident, sizeUnassigned, opt != nil)
   327  	base := genIntCommon(com, 0, false)
   328  	if desc.NeedBase {
   329  		base.TypeSize = comp.ptrSize
   330  		if !isArg {
   331  			baseType := args[len(args)-1]
   332  			args = args[:len(args)-1]
   333  			base = typeInt.Gen(comp, baseType, nil, base).(*prog.IntType).IntTypeCommon
   334  		}
   335  	}
   336  	return desc, args, base
   337  }
   338  
   339  func (comp *compiler) derefPointers(t *ast.Type) (*ast.Type, *typeDesc) {
   340  	for {
   341  		desc := comp.getTypeDesc(t)
   342  		if desc != typePtr {
   343  			return t, desc
   344  		}
   345  		t = t.Args[1]
   346  	}
   347  }
   348  
   349  func (comp *compiler) foreachType(n0 ast.Node,
   350  	cb func(*ast.Type, *typeDesc, []*ast.Type, prog.IntTypeCommon)) {
   351  	switch n := n0.(type) {
   352  	case *ast.Call:
   353  		for _, arg := range n.Args {
   354  			comp.foreachSubType(arg.Type, true, cb)
   355  		}
   356  		if n.Ret != nil {
   357  			comp.foreachSubType(n.Ret, true, cb)
   358  		}
   359  	case *ast.Resource:
   360  		comp.foreachSubType(n.Base, false, cb)
   361  	case *ast.Struct:
   362  		for _, f := range n.Fields {
   363  			comp.foreachSubType(f.Type, false, cb)
   364  		}
   365  	case *ast.TypeDef:
   366  		if len(n.Args) == 0 {
   367  			comp.foreachSubType(n.Type, false, cb)
   368  		}
   369  	default:
   370  		panic(fmt.Sprintf("unexpected node %#v", n0))
   371  	}
   372  }
   373  
   374  func (comp *compiler) foreachSubType(t *ast.Type, isArg bool,
   375  	cb func(*ast.Type, *typeDesc, []*ast.Type, prog.IntTypeCommon)) {
   376  	desc, args, base := comp.getArgsBase(t, isArg)
   377  	cb(t, desc, args, base)
   378  	for i, arg := range args {
   379  		if desc.Args[i].Type == typeArgType {
   380  			comp.foreachSubType(arg, desc.Args[i].IsArg, cb)
   381  		}
   382  	}
   383  }
   384  
   385  func removeOpt(t *ast.Type) ([]*ast.Type, *ast.Type) {
   386  	args := t.Args
   387  	if last := len(args) - 1; last >= 0 && args[last].Ident == "opt" {
   388  		return args[:last], args[last]
   389  	}
   390  	return args, nil
   391  }
   392  
   393  func (comp *compiler) parseIntType(name string) (size uint64, bigEndian bool) {
   394  	be := strings.HasSuffix(name, "be")
   395  	if be {
   396  		name = name[:len(name)-len("be")]
   397  	}
   398  	size = comp.ptrSize
   399  	if name != "intptr" {
   400  		size, _ = strconv.ParseUint(name[3:], 10, 64)
   401  		size /= 8
   402  	}
   403  	return size, be
   404  }
   405  
   406  func arrayContains(a []string, v string) bool {
   407  	for _, s := range a {
   408  		if s == v {
   409  			return true
   410  		}
   411  	}
   412  	return false
   413  }
   414  
   415  func (comp *compiler) flattenFlags() {
   416  	comp.flattenIntFlags()
   417  	comp.flattenStrFlags()
   418  
   419  	for _, n := range comp.desc.Nodes {
   420  		switch flags := n.(type) {
   421  		case *ast.IntFlags:
   422  			// It's possible that we don't find the flag in intFlags if it was
   423  			// skipped due to errors (or special name "_") when populating
   424  			// intFlags (see checkNames).
   425  			if f, ok := comp.intFlags[flags.Name.Name]; ok {
   426  				flags.Values = f.Values
   427  			}
   428  		case *ast.StrFlags:
   429  			// Same as for intFlags above.
   430  			if f, ok := comp.strFlags[flags.Name.Name]; ok {
   431  				flags.Values = f.Values
   432  			}
   433  		}
   434  	}
   435  }
   436  
   437  func (comp *compiler) flattenIntFlags() {
   438  	for name, flags := range comp.intFlags {
   439  		if err := recurFlattenFlags[*ast.IntFlags, *ast.Int](comp, name, flags, comp.intFlags,
   440  			map[string]bool{}); err != nil {
   441  			comp.error(flags.Pos, "%v", err)
   442  		}
   443  	}
   444  }
   445  
   446  func (comp *compiler) flattenStrFlags() {
   447  	for name, flags := range comp.strFlags {
   448  		if err := recurFlattenFlags[*ast.StrFlags, *ast.String](comp, name, flags, comp.strFlags,
   449  			map[string]bool{}); err != nil {
   450  			comp.error(flags.Pos, "%v", err)
   451  		}
   452  	}
   453  }
   454  
   455  func recurFlattenFlags[F ast.Flags[V], V ast.FlagValue](comp *compiler, name string, flags F,
   456  	allFlags map[string]F, visitedFlags map[string]bool) error {
   457  	if _, visited := visitedFlags[name]; visited {
   458  		return fmt.Errorf("flags %v used twice or circular dependency on %v", name, name)
   459  	}
   460  	visitedFlags[name] = true
   461  
   462  	var values []V
   463  	for _, flag := range flags.GetValues() {
   464  		if f, isFlags := allFlags[flag.GetName()]; isFlags {
   465  			if err := recurFlattenFlags[F, V](comp, flag.GetName(), f, allFlags, visitedFlags); err != nil {
   466  				return err
   467  			}
   468  			values = append(values, allFlags[flag.GetName()].GetValues()...)
   469  		} else {
   470  			values = append(values, flag)
   471  		}
   472  	}
   473  	if len(values) > 100000 {
   474  		return fmt.Errorf("%v has more than 100000 values %v", name, len(values))
   475  	}
   476  	flags.SetValues(values)
   477  	return nil
   478  }