github.com/google/syzkaller@v0.0.0-20240517125934-c0f1611a36d6/pkg/compiler/compiler.go (about)

     1  // Copyright 2017 syzkaller project authors. All rights reserved.
     2  // Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
     3  
     4  // Package compiler generates sys descriptions of syscalls, types and resources
     5  // from textual descriptions.
     6  package compiler
     7  
     8  import (
     9  	"fmt"
    10  	"path/filepath"
    11  	"strconv"
    12  	"strings"
    13  
    14  	"github.com/google/syzkaller/pkg/ast"
    15  	"github.com/google/syzkaller/prog"
    16  	"github.com/google/syzkaller/sys/targets"
    17  )
    18  
    19  // Overview of compilation process:
    20  // 1. ast.Parse on text file does tokenization and builds AST.
    21  //    This step catches basic syntax errors. AST contains full debug info.
    22  // 2. ExtractConsts as AST returns set of constant identifiers.
    23  //    This step also does verification of include/incdir/define AST nodes.
    24  // 3. User translates constants to values.
    25  // 4. Compile on AST and const values does the rest of the work and returns Prog
    26  //    containing generated prog objects.
    27  // 4.1. assignSyscallNumbers: uses consts to assign syscall numbers.
    28  //      This step also detects unsupported syscalls and discards no longer
    29  //      needed AST nodes (inlcude, define, comments, etc).
    30  // 4.2. patchConsts: patches Int nodes referring to consts with corresponding values.
    31  //      Also detects unsupported syscalls, structs, resources due to missing consts.
    32  // 4.3. check: does extensive semantical checks of AST.
    33  // 4.4. gen: generates prog objects from AST.
    34  
    35  // Prog is description compilation result.
    36  type Prog struct {
    37  	Resources []*prog.ResourceDesc
    38  	Syscalls  []*prog.Syscall
    39  	Types     []prog.Type
    40  	// Set of unsupported syscalls/flags.
    41  	Unsupported map[string]bool
    42  	// Returned if consts was nil.
    43  	fileConsts map[string]*ConstInfo
    44  }
    45  
    46  func createCompiler(desc *ast.Description, target *targets.Target, eh ast.ErrorHandler) *compiler {
    47  	if eh == nil {
    48  		eh = ast.LoggingHandler
    49  	}
    50  	desc.Nodes = append(builtinDescs.Clone().Nodes, desc.Nodes...)
    51  	comp := &compiler{
    52  		desc:           desc,
    53  		target:         target,
    54  		eh:             eh,
    55  		ptrSize:        target.PtrSize,
    56  		unsupported:    make(map[string]bool),
    57  		resources:      make(map[string]*ast.Resource),
    58  		typedefs:       make(map[string]*ast.TypeDef),
    59  		structs:        make(map[string]*ast.Struct),
    60  		intFlags:       make(map[string]*ast.IntFlags),
    61  		strFlags:       make(map[string]*ast.StrFlags),
    62  		used:           make(map[string]bool),
    63  		usedTypedefs:   make(map[string]bool),
    64  		brokenTypedefs: make(map[string]bool),
    65  		structVarlen:   make(map[string]bool),
    66  		structTypes:    make(map[string]prog.Type),
    67  		builtinConsts: map[string]uint64{
    68  			"PTR_SIZE": target.PtrSize,
    69  		},
    70  	}
    71  	return comp
    72  }
    73  
    74  // Compile compiles sys description.
    75  func Compile(desc *ast.Description, consts map[string]uint64, target *targets.Target, eh ast.ErrorHandler) *Prog {
    76  	comp := createCompiler(desc.Clone(), target, eh)
    77  	comp.filterArch()
    78  	comp.typecheck()
    79  	comp.flattenFlags()
    80  	// The subsequent, more complex, checks expect basic validity of the tree,
    81  	// in particular corrent number of type arguments. If there were errors,
    82  	// don't proceed to avoid out-of-bounds references to type arguments.
    83  	if comp.errors != 0 {
    84  		return nil
    85  	}
    86  	if consts == nil {
    87  		fileConsts := comp.extractConsts()
    88  		if comp.errors != 0 {
    89  			return nil
    90  		}
    91  		return &Prog{fileConsts: fileConsts}
    92  	}
    93  	if comp.target.SyscallNumbers {
    94  		comp.assignSyscallNumbers(consts)
    95  	}
    96  	comp.patchConsts(consts)
    97  	comp.check(consts)
    98  	if comp.errors != 0 {
    99  		return nil
   100  	}
   101  	syscalls := comp.genSyscalls()
   102  	comp.layoutTypes(syscalls)
   103  	types := comp.generateTypes(syscalls)
   104  	prg := &Prog{
   105  		Resources:   comp.genResources(),
   106  		Syscalls:    syscalls,
   107  		Types:       types,
   108  		Unsupported: comp.unsupported,
   109  	}
   110  	if comp.errors != 0 {
   111  		return nil
   112  	}
   113  	for _, w := range comp.warnings {
   114  		eh(w.pos, w.msg)
   115  	}
   116  	return prg
   117  }
   118  
   119  type compiler struct {
   120  	desc     *ast.Description
   121  	target   *targets.Target
   122  	eh       ast.ErrorHandler
   123  	errors   int
   124  	warnings []warn
   125  	ptrSize  uint64
   126  
   127  	unsupported    map[string]bool
   128  	resources      map[string]*ast.Resource
   129  	typedefs       map[string]*ast.TypeDef
   130  	structs        map[string]*ast.Struct
   131  	intFlags       map[string]*ast.IntFlags
   132  	strFlags       map[string]*ast.StrFlags
   133  	used           map[string]bool // contains used structs/resources
   134  	usedTypedefs   map[string]bool
   135  	brokenTypedefs map[string]bool
   136  
   137  	structVarlen  map[string]bool
   138  	structTypes   map[string]prog.Type
   139  	builtinConsts map[string]uint64
   140  	fileMeta      map[string]Meta
   141  }
   142  
   143  type warn struct {
   144  	pos ast.Pos
   145  	msg string
   146  }
   147  
   148  func (comp *compiler) error(pos ast.Pos, msg string, args ...interface{}) {
   149  	comp.errors++
   150  	comp.eh(pos, fmt.Sprintf(msg, args...))
   151  }
   152  
   153  func (comp *compiler) warning(pos ast.Pos, msg string, args ...interface{}) {
   154  	comp.warnings = append(comp.warnings, warn{pos, fmt.Sprintf(msg, args...)})
   155  }
   156  
   157  func (comp *compiler) filterArch() {
   158  	if comp.fileMeta == nil {
   159  		comp.fileMeta = comp.fileList()
   160  	}
   161  	comp.desc = comp.desc.Filter(func(n ast.Node) bool {
   162  		pos, typ, name := n.Info()
   163  		meta := comp.fileMeta[filepath.Base(pos.File)]
   164  		if meta.SupportsArch(comp.target.Arch) {
   165  			return true
   166  		}
   167  		switch n.(type) {
   168  		case *ast.Resource, *ast.Struct, *ast.Call, *ast.TypeDef:
   169  			// This is required to keep the unsupported diagnostic working,
   170  			// otherwise sysgen will think that these things are still supported on some arches.
   171  			comp.unsupported[typ+" "+name] = true
   172  		}
   173  		return false
   174  	})
   175  }
   176  
   177  func (comp *compiler) structIsVarlen(name string) bool {
   178  	if varlen, ok := comp.structVarlen[name]; ok {
   179  		return varlen
   180  	}
   181  	s := comp.structs[name]
   182  	if s.IsUnion {
   183  		res := comp.parseIntAttrs(unionAttrs, s, s.Attrs)
   184  		if res[attrVarlen] != 0 {
   185  			comp.structVarlen[name] = true
   186  			return true
   187  		}
   188  	}
   189  	comp.structVarlen[name] = false // to not hang on recursive types
   190  	varlen := false
   191  	for _, fld := range s.Fields {
   192  		hasIfAttr := false
   193  		for _, attr := range fld.Attrs {
   194  			if structFieldAttrs[attr.Ident] == attrIf {
   195  				hasIfAttr = true
   196  			}
   197  		}
   198  		if hasIfAttr || comp.isVarlen(fld.Type) {
   199  			varlen = true
   200  			break
   201  		}
   202  	}
   203  	comp.structVarlen[name] = varlen
   204  	return varlen
   205  }
   206  
   207  func (comp *compiler) parseIntAttrs(descs map[string]*attrDesc, parent ast.Node,
   208  	attrs []*ast.Type) map[*attrDesc]uint64 {
   209  	intAttrs, _ := comp.parseAttrs(descs, parent, attrs)
   210  	return intAttrs
   211  }
   212  
   213  func (comp *compiler) parseAttrs(descs map[string]*attrDesc, parent ast.Node, attrs []*ast.Type) (
   214  	map[*attrDesc]uint64, map[*attrDesc]prog.Expression) {
   215  	_, parentType, parentName := parent.Info()
   216  	resInt := make(map[*attrDesc]uint64)
   217  	resExpr := make(map[*attrDesc]prog.Expression)
   218  	for _, attr := range attrs {
   219  		if unexpected, _, ok := checkTypeKind(attr, kindIdent); !ok {
   220  			comp.error(attr.Pos, "unexpected %v, expect attribute", unexpected)
   221  			return resInt, resExpr
   222  		}
   223  		if len(attr.Colon) != 0 {
   224  			comp.error(attr.Colon[0].Pos, "unexpected ':'")
   225  			return resInt, resExpr
   226  		}
   227  		desc := descs[attr.Ident]
   228  		if desc == nil {
   229  			comp.error(attr.Pos, "unknown %v %v attribute %v", parentType, parentName, attr.Ident)
   230  			return resInt, resExpr
   231  		}
   232  		_, dupInt := resInt[desc]
   233  		_, dupExpr := resExpr[desc]
   234  		if dupInt || dupExpr {
   235  			comp.error(attr.Pos, "duplicate %v %v attribute %v", parentType, parentName, attr.Ident)
   236  			return resInt, resExpr
   237  		}
   238  
   239  		switch desc.Type {
   240  		case flagAttr:
   241  			resInt[desc] = 1
   242  			if len(attr.Args) != 0 {
   243  				comp.error(attr.Pos, "%v attribute has args", attr.Ident)
   244  				return nil, nil
   245  			}
   246  		case intAttr:
   247  			resInt[desc] = comp.parseAttrIntArg(attr)
   248  		case exprAttr:
   249  			resExpr[desc] = comp.parseAttrExprArg(attr)
   250  		default:
   251  			comp.error(attr.Pos, "attribute %v has unknown type", attr.Ident)
   252  			return nil, nil
   253  		}
   254  	}
   255  	return resInt, resExpr
   256  }
   257  
   258  func (comp *compiler) parseAttrExprArg(attr *ast.Type) prog.Expression {
   259  	if len(attr.Args) != 1 {
   260  		comp.error(attr.Pos, "%v attribute is expected to have only one argument", attr.Ident)
   261  		return nil
   262  	}
   263  	arg := attr.Args[0]
   264  	if arg.HasString {
   265  		comp.error(attr.Pos, "%v argument must be an expression", attr.Ident)
   266  		return nil
   267  	}
   268  	return comp.genExpression(arg)
   269  }
   270  
   271  func (comp *compiler) parseAttrIntArg(attr *ast.Type) uint64 {
   272  	if len(attr.Args) != 1 {
   273  		comp.error(attr.Pos, "%v attribute is expected to have 1 argument", attr.Ident)
   274  		return 0
   275  	}
   276  	sz := attr.Args[0]
   277  	if unexpected, _, ok := checkTypeKind(sz, kindInt); !ok {
   278  		comp.error(sz.Pos, "unexpected %v, expect int", unexpected)
   279  		return 0
   280  	}
   281  	if len(sz.Colon) != 0 || len(sz.Args) != 0 {
   282  		comp.error(sz.Pos, "%v attribute has colon or args", attr.Ident)
   283  		return 0
   284  	}
   285  	return sz.Value
   286  }
   287  
   288  func (comp *compiler) getTypeDesc(t *ast.Type) *typeDesc {
   289  	if desc := builtinTypes[t.Ident]; desc != nil {
   290  		return desc
   291  	}
   292  	if comp.resources[t.Ident] != nil {
   293  		return typeResource
   294  	}
   295  	if comp.structs[t.Ident] != nil {
   296  		return typeStruct
   297  	}
   298  	if comp.typedefs[t.Ident] != nil {
   299  		return typeTypedef
   300  	}
   301  	return nil
   302  }
   303  
   304  func (comp *compiler) getArgsBase(t *ast.Type, isArg bool) (*typeDesc, []*ast.Type, prog.IntTypeCommon) {
   305  	desc := comp.getTypeDesc(t)
   306  	if desc == nil {
   307  		panic(fmt.Sprintf("no type desc for %#v", *t))
   308  	}
   309  	args, opt := removeOpt(t)
   310  	com := genCommon(t.Ident, sizeUnassigned, opt != nil)
   311  	base := genIntCommon(com, 0, false)
   312  	if desc.NeedBase {
   313  		base.TypeSize = comp.ptrSize
   314  		if !isArg {
   315  			baseType := args[len(args)-1]
   316  			args = args[:len(args)-1]
   317  			base = typeInt.Gen(comp, baseType, nil, base).(*prog.IntType).IntTypeCommon
   318  		}
   319  	}
   320  	return desc, args, base
   321  }
   322  
   323  func (comp *compiler) derefPointers(t *ast.Type) (*ast.Type, *typeDesc) {
   324  	for {
   325  		desc := comp.getTypeDesc(t)
   326  		if desc != typePtr {
   327  			return t, desc
   328  		}
   329  		t = t.Args[1]
   330  	}
   331  }
   332  
   333  func (comp *compiler) foreachType(n0 ast.Node,
   334  	cb func(*ast.Type, *typeDesc, []*ast.Type, prog.IntTypeCommon)) {
   335  	switch n := n0.(type) {
   336  	case *ast.Call:
   337  		for _, arg := range n.Args {
   338  			comp.foreachSubType(arg.Type, true, cb)
   339  		}
   340  		if n.Ret != nil {
   341  			comp.foreachSubType(n.Ret, true, cb)
   342  		}
   343  	case *ast.Resource:
   344  		comp.foreachSubType(n.Base, false, cb)
   345  	case *ast.Struct:
   346  		for _, f := range n.Fields {
   347  			comp.foreachSubType(f.Type, false, cb)
   348  		}
   349  	case *ast.TypeDef:
   350  		if len(n.Args) == 0 {
   351  			comp.foreachSubType(n.Type, false, cb)
   352  		}
   353  	default:
   354  		panic(fmt.Sprintf("unexpected node %#v", n0))
   355  	}
   356  }
   357  
   358  func (comp *compiler) foreachSubType(t *ast.Type, isArg bool,
   359  	cb func(*ast.Type, *typeDesc, []*ast.Type, prog.IntTypeCommon)) {
   360  	desc, args, base := comp.getArgsBase(t, isArg)
   361  	cb(t, desc, args, base)
   362  	for i, arg := range args {
   363  		if desc.Args[i].Type == typeArgType {
   364  			comp.foreachSubType(arg, desc.Args[i].IsArg, cb)
   365  		}
   366  	}
   367  }
   368  
   369  func removeOpt(t *ast.Type) ([]*ast.Type, *ast.Type) {
   370  	args := t.Args
   371  	if last := len(args) - 1; last >= 0 && args[last].Ident == "opt" {
   372  		return args[:last], args[last]
   373  	}
   374  	return args, nil
   375  }
   376  
   377  func (comp *compiler) parseIntType(name string) (size uint64, bigEndian bool) {
   378  	be := strings.HasSuffix(name, "be")
   379  	if be {
   380  		name = name[:len(name)-len("be")]
   381  	}
   382  	size = comp.ptrSize
   383  	if name != "intptr" {
   384  		size, _ = strconv.ParseUint(name[3:], 10, 64)
   385  		size /= 8
   386  	}
   387  	return size, be
   388  }
   389  
   390  func arrayContains(a []string, v string) bool {
   391  	for _, s := range a {
   392  		if s == v {
   393  			return true
   394  		}
   395  	}
   396  	return false
   397  }
   398  
   399  func (comp *compiler) flattenFlags() {
   400  	comp.flattenIntFlags()
   401  	comp.flattenStrFlags()
   402  
   403  	for _, n := range comp.desc.Nodes {
   404  		switch flags := n.(type) {
   405  		case *ast.IntFlags:
   406  			// It's possible that we don't find the flag in intFlags if it was
   407  			// skipped due to errors (or special name "_") when populating
   408  			// intFlags (see checkNames).
   409  			if f, ok := comp.intFlags[flags.Name.Name]; ok {
   410  				flags.Values = f.Values
   411  			}
   412  		case *ast.StrFlags:
   413  			// Same as for intFlags above.
   414  			if f, ok := comp.strFlags[flags.Name.Name]; ok {
   415  				flags.Values = f.Values
   416  			}
   417  		}
   418  	}
   419  }
   420  
   421  func (comp *compiler) flattenIntFlags() {
   422  	for name, flags := range comp.intFlags {
   423  		if err := recurFlattenFlags[*ast.IntFlags, *ast.Int](comp, name, flags, comp.intFlags,
   424  			map[string]bool{}); err != nil {
   425  			comp.error(flags.Pos, "%v", err)
   426  		}
   427  	}
   428  }
   429  
   430  func (comp *compiler) flattenStrFlags() {
   431  	for name, flags := range comp.strFlags {
   432  		if err := recurFlattenFlags[*ast.StrFlags, *ast.String](comp, name, flags, comp.strFlags,
   433  			map[string]bool{}); err != nil {
   434  			comp.error(flags.Pos, "%v", err)
   435  		}
   436  	}
   437  }
   438  
   439  func recurFlattenFlags[F ast.Flags[V], V ast.FlagValue](comp *compiler, name string, flags F,
   440  	allFlags map[string]F, visitedFlags map[string]bool) error {
   441  	if _, visited := visitedFlags[name]; visited {
   442  		return fmt.Errorf("flags %v used twice or circular dependency on %v", name, name)
   443  	}
   444  	visitedFlags[name] = true
   445  
   446  	var values []V
   447  	for _, flag := range flags.GetValues() {
   448  		if f, isFlags := allFlags[flag.GetName()]; isFlags {
   449  			if err := recurFlattenFlags[F, V](comp, flag.GetName(), f, allFlags, visitedFlags); err != nil {
   450  				return err
   451  			}
   452  			values = append(values, allFlags[flag.GetName()].GetValues()...)
   453  		} else {
   454  			values = append(values, flag)
   455  		}
   456  	}
   457  	if len(values) > 2000 {
   458  		return fmt.Errorf("%v has more than 2000 values", name)
   459  	}
   460  	flags.SetValues(values)
   461  	return nil
   462  }