github.com/google/syzkaller@v0.0.0-20240517125934-c0f1611a36d6/tools/syz-check/check.go (about)

     1  // Copyright 2019 syzkaller project authors. All rights reserved.
     2  // Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
     3  
     4  // syz-check does best-effort static correctness checking of the syscall descriptions in sys/os/*.txt.
     5  // Use:
     6  //
     7  //	$ go install ./tools/syz-check
     8  //	$ syz-check -obj-amd64 /linux_amd64/vmlinux -obj-arm64 /linux_arm64/vmlinux \
     9  //		-obj-386 /linux_386/vmlinux -obj-arm /linux_arm/vmlinux
    10  //
    11  // The vmlinux files should include debug info, enable all relevant configs (since we parse dwarf),
    12  // and be compiled with -gdwarf-3 -fno-eliminate-unused-debug-types -fno-eliminate-unused-debug-symbols flags.
    13  // -gdwarf-3 is required because version 4 changes the way bitfields are encoded and Go before 1.18
    14  // does not support then new encoding and at least earlier versions mis-handle it, see:
    15  // https://go-review.googlesource.com/c/go/+/328709/comments/edf0619d_daec236f
    16  //
    17  // Use the following configs for kernels (x86_64 config for i386 as well):
    18  // upstream-apparmor-kasan.config, upstream-arm-full.config, upstream-arm64-full.config
    19  //
    20  // You may check only one arch as well (but then don't commit changes to warn files):
    21  //
    22  //	$ syz-check -obj-amd64 /linux_amd64/vmlinux
    23  //
    24  // You may also disable dwarf or netlink checks with the corresponding flags.
    25  // E.g. -dwarf=0 greatly speeds up checking if you are only interested in netlink warnings
    26  // (but then again don't commit changes).
    27  //
    28  // The results are produced in sys/os/*.warn files.
    29  // On implementation level syz-check parses vmlinux dwarf, extracts struct descriptions
    30  // and compares them with what we have (size, fields, alignment, etc). Netlink checking extracts policy symbols
    31  // from the object files and parses them.
    32  package main
    33  
    34  import (
    35  	"bytes"
    36  	"debug/dwarf"
    37  	"debug/elf"
    38  	"flag"
    39  	"fmt"
    40  	"os"
    41  	"path/filepath"
    42  	"runtime"
    43  	"sort"
    44  	"strings"
    45  	"unsafe"
    46  
    47  	"github.com/google/syzkaller/pkg/ast"
    48  	"github.com/google/syzkaller/pkg/compiler"
    49  	"github.com/google/syzkaller/pkg/osutil"
    50  	"github.com/google/syzkaller/pkg/symbolizer"
    51  	"github.com/google/syzkaller/pkg/tool"
    52  	"github.com/google/syzkaller/prog"
    53  	"github.com/google/syzkaller/sys/targets"
    54  )
    55  
    56  func main() {
    57  	var (
    58  		flagOS      = flag.String("os", runtime.GOOS, "OS")
    59  		flagDWARF   = flag.Bool("dwarf", true, "do checking based on DWARF")
    60  		flagNetlink = flag.Bool("netlink", true, "do checking of netlink policies")
    61  	)
    62  	arches := make(map[string]*string)
    63  	for arch := range targets.List[targets.Linux] {
    64  		arches[arch] = flag.String("obj-"+arch, "", arch+" kernel object file")
    65  	}
    66  	defer tool.Init()()
    67  	var warnings []Warn
    68  	for arch, obj := range arches {
    69  		if *obj == "" {
    70  			delete(arches, arch)
    71  			continue
    72  		}
    73  		warnings1, err := check(*flagOS, arch, *obj, *flagDWARF, *flagNetlink)
    74  		if err != nil {
    75  			tool.Fail(err)
    76  		}
    77  		warnings = append(warnings, warnings1...)
    78  		runtime.GC()
    79  	}
    80  	if len(arches) == 0 {
    81  		fmt.Fprintf(os.Stderr, "specify at least one -obj-arch flag\n")
    82  		flag.PrintDefaults()
    83  		os.Exit(1)
    84  	}
    85  	if err := writeWarnings(*flagOS, len(arches), warnings); err != nil {
    86  		fmt.Fprintln(os.Stderr, err)
    87  		os.Exit(1)
    88  	}
    89  }
    90  
    91  func check(OS, arch, obj string, dwarf, netlink bool) ([]Warn, error) {
    92  	var warnings []Warn
    93  	if obj == "" {
    94  		return nil, fmt.Errorf("no object file in -obj-%v flag", arch)
    95  	}
    96  	structTypes, locs, warnings1, err := parseDescriptions(OS, arch)
    97  	if err != nil {
    98  		return nil, err
    99  	}
   100  	warnings = append(warnings, warnings1...)
   101  	if dwarf {
   102  		structs, err := parseKernelObject(obj)
   103  		if err != nil {
   104  			return nil, err
   105  		}
   106  		warnings2, err := checkImpl(structs, structTypes, locs)
   107  		if err != nil {
   108  			return nil, err
   109  		}
   110  		warnings = append(warnings, warnings2...)
   111  	}
   112  	if netlink {
   113  		warnings3, err := checkNetlink(OS, arch, obj, structTypes, locs)
   114  		if err != nil {
   115  			return nil, err
   116  		}
   117  		warnings = append(warnings, warnings3...)
   118  	}
   119  	for i := range warnings {
   120  		warnings[i].arch = arch
   121  	}
   122  	return warnings, nil
   123  }
   124  
   125  const (
   126  	WarnCompiler           = "compiler"
   127  	WarnNoSuchStruct       = "no-such-struct"
   128  	WarnBadStructSize      = "bad-struct-size"
   129  	WarnBadFieldNumber     = "bad-field-number"
   130  	WarnBadFieldSize       = "bad-field-size"
   131  	WarnBadFieldOffset     = "bad-field-offset"
   132  	WarnBadBitfield        = "bad-bitfield"
   133  	WarnNoNetlinkPolicy    = "no-such-netlink-policy"
   134  	WarnNetlinkBadSize     = "bad-kernel-netlink-policy-size"
   135  	WarnNetlinkBadAttrType = "bad-netlink-attr-type"
   136  	WarnNetlinkBadAttr     = "bad-netlink-attr"
   137  )
   138  
   139  type Warn struct {
   140  	pos  ast.Pos
   141  	arch string
   142  	typ  string
   143  	msg  string
   144  }
   145  
   146  func writeWarnings(OS string, narches int, warnings []Warn) error {
   147  	allFiles, err := filepath.Glob(filepath.Join("sys", OS, "*.warn"))
   148  	if err != nil {
   149  		return err
   150  	}
   151  	toRemove := make(map[string]bool)
   152  	for _, file := range allFiles {
   153  		toRemove[file] = true
   154  	}
   155  	byFile := make(map[string][]Warn)
   156  	for _, warn := range warnings {
   157  		byFile[warn.pos.File] = append(byFile[warn.pos.File], warn)
   158  	}
   159  	for file, warns := range byFile {
   160  		sort.Slice(warns, func(i, j int) bool {
   161  			w1, w2 := warns[i], warns[j]
   162  			if w1.pos.Line != w2.pos.Line {
   163  				return w1.pos.Line < w2.pos.Line
   164  			}
   165  			if w1.typ != w2.typ {
   166  				return w1.typ < w2.typ
   167  			}
   168  			if w1.msg != w2.msg {
   169  				return w1.msg < w2.msg
   170  			}
   171  			return w1.arch < w2.arch
   172  		})
   173  		buf := new(bytes.Buffer)
   174  		for i := 0; i < len(warns); i++ {
   175  			warn := warns[i]
   176  			arch := warn.arch
   177  			arches := []string{warn.arch}
   178  			for i < len(warns)-1 && warn.msg == warns[i+1].msg {
   179  				if arch != warns[i+1].arch {
   180  					arch = warns[i+1].arch
   181  					arches = append(arches, arch)
   182  				}
   183  				i++
   184  			}
   185  			archStr := ""
   186  			// We do netlink checking only on amd64, so don't add arch.
   187  			if len(arches) < narches && !strings.Contains(warn.typ, "netlink") {
   188  				archStr = fmt.Sprintf(" [%v]", strings.Join(arches, ","))
   189  			}
   190  			fmt.Fprintf(buf, "%v: %v%v\n", warn.typ, warn.msg, archStr)
   191  		}
   192  		warnFile := file + ".warn"
   193  		if err := osutil.WriteFile(warnFile, buf.Bytes()); err != nil {
   194  			return err
   195  		}
   196  		delete(toRemove, warnFile)
   197  	}
   198  	for file := range toRemove {
   199  		os.Remove(file)
   200  	}
   201  	return nil
   202  }
   203  
   204  func checkImpl(structs map[string]*dwarf.StructType, structTypes []prog.Type,
   205  	locs map[string]*ast.Struct) ([]Warn, error) {
   206  	var warnings []Warn
   207  	for _, typ := range structTypes {
   208  		name := typ.TemplateName()
   209  		astStruct := locs[name]
   210  		if astStruct == nil {
   211  			continue
   212  		}
   213  		if _, ok := isNetlinkPolicy(typ); ok {
   214  			continue // netlink policies are not structs even if we describe them as structs
   215  		}
   216  		// In some cases we split a single struct into multiple ones
   217  		// (more precise description), so try to match our foo$bar with kernel foo.
   218  		kernelStruct := structs[name]
   219  		if delim := strings.LastIndexByte(name, '$'); kernelStruct == nil && delim != -1 {
   220  			kernelStruct = structs[name[:delim]]
   221  		}
   222  		warns, err := checkStruct(typ, astStruct, kernelStruct)
   223  		if err != nil {
   224  			return nil, err
   225  		}
   226  		warnings = append(warnings, warns...)
   227  	}
   228  	return warnings, nil
   229  }
   230  
   231  func checkStruct(typ prog.Type, astStruct *ast.Struct, str *dwarf.StructType) ([]Warn, error) {
   232  	var warnings []Warn
   233  	warn := func(pos ast.Pos, typ, msg string, args ...interface{}) {
   234  		warnings = append(warnings, Warn{pos: pos, typ: typ, msg: fmt.Sprintf(msg, args...)})
   235  	}
   236  	name := typ.TemplateName()
   237  	if str == nil {
   238  		// Varlen structs are frequently not described in kernel (not possible in C).
   239  		if !typ.Varlen() {
   240  			warn(astStruct.Pos, WarnNoSuchStruct, "%v", name)
   241  		}
   242  		return warnings, nil
   243  	}
   244  	if !typ.Varlen() && typ.Size() != uint64(str.ByteSize) {
   245  		warn(astStruct.Pos, WarnBadStructSize, "%v: syz=%v kernel=%v", name, typ.Size(), str.ByteSize)
   246  	}
   247  	// TODO: handle unions, currently we should report some false errors.
   248  	if _, ok := typ.(*prog.UnionType); ok || str.Kind == "union" {
   249  		return warnings, nil
   250  	}
   251  	// Ignore structs with out_overlay attribute.
   252  	// They are never described in the kernel as a simple struct.
   253  	// We could only match and check fields based on some common conventions,
   254  	// but since we have very few of them it's unclear what are these conventions
   255  	// and implementing something complex will have low RoI.
   256  	if typ.(*prog.StructType).OverlayField != 0 {
   257  		return warnings, nil
   258  	}
   259  	// TODO: we could also check enums (elements match corresponding flags in syzkaller).
   260  	// TODO: we could also check values of literal constants (dwarf should have that, right?).
   261  	// TODO: handle nested structs/unions, e.g.:
   262  	// struct foo {
   263  	//	union {
   264  	//		...
   265  	//	} bar;
   266  	// };
   267  	// should be matched with:
   268  	// foo_bar [
   269  	//	...
   270  	// ]
   271  	// TODO: consider making guesses about semantic types of fields,
   272  	// e.g. if a name contains filedes/uid/pid/gid that may be the corresponding resource.
   273  	ai := 0
   274  	offset := uint64(0)
   275  	for _, field := range typ.(*prog.StructType).Fields {
   276  		if field.Type.Varlen() {
   277  			ai = len(str.Field)
   278  			break
   279  		}
   280  		if prog.IsPad(field.Type) {
   281  			offset += field.Type.Size()
   282  			continue
   283  		}
   284  		if ai < len(str.Field) {
   285  			fld := str.Field[ai]
   286  			pos := astStruct.Fields[ai].Pos
   287  			desc := fmt.Sprintf("%v.%v", name, field.Name)
   288  			if field.Name != fld.Name {
   289  				desc += "/" + fld.Name
   290  			}
   291  			if field.Type.UnitSize() != uint64(fld.Type.Size()) {
   292  				warn(pos, WarnBadFieldSize, "%v: syz=%v kernel=%v",
   293  					desc, field.Type.UnitSize(), fld.Type.Size())
   294  			}
   295  			byteOffset := offset - field.Type.UnitOffset()
   296  			if byteOffset != uint64(fld.ByteOffset) {
   297  				warn(pos, WarnBadFieldOffset, "%v: syz=%v kernel=%v",
   298  					desc, byteOffset, fld.ByteOffset)
   299  			}
   300  			// How would you define bitfield offset?
   301  			// Offset of the beginning of the field from the beginning of the memory location, right?
   302  			// No, DWARF defines it as offset of the end of the field from the end of the memory location.
   303  			bitOffset := fld.Type.Size()*8 - fld.BitOffset - fld.BitSize
   304  			if fld.BitSize == 0 {
   305  				// And to make things even more interesting this calculation
   306  				// does not work for normal variables.
   307  				bitOffset = 0
   308  			}
   309  			if field.Type.BitfieldLength() != uint64(fld.BitSize) ||
   310  				field.Type.BitfieldOffset() != uint64(bitOffset) {
   311  				warn(pos, WarnBadBitfield, "%v: size/offset: syz=%v/%v kernel=%v/%v",
   312  					desc, field.Type.BitfieldLength(), field.Type.BitfieldOffset(),
   313  					fld.BitSize, bitOffset)
   314  			}
   315  		}
   316  		ai++
   317  		offset += field.Size()
   318  	}
   319  	if ai != len(str.Field) {
   320  		warn(astStruct.Pos, WarnBadFieldNumber, "%v: syz=%v kernel=%v", name, ai, len(str.Field))
   321  	}
   322  	return warnings, nil
   323  }
   324  
   325  func parseDescriptions(OS, arch string) ([]prog.Type, map[string]*ast.Struct, []Warn, error) {
   326  	errorBuf := new(bytes.Buffer)
   327  	var warnings []Warn
   328  	eh := func(pos ast.Pos, msg string) {
   329  		warnings = append(warnings, Warn{pos: pos, typ: WarnCompiler, msg: msg})
   330  		fmt.Fprintf(errorBuf, "%v: %v\n", pos, msg)
   331  	}
   332  	top := ast.ParseGlob(filepath.Join("sys", OS, "*.txt"), eh)
   333  	if top == nil {
   334  		return nil, nil, nil, fmt.Errorf("failed to parse txt files:\n%s", errorBuf.Bytes())
   335  	}
   336  	consts := compiler.DeserializeConstFile(filepath.Join("sys", OS, "*.const"), eh).Arch(arch)
   337  	if consts == nil {
   338  		return nil, nil, nil, fmt.Errorf("failed to parse const files:\n%s", errorBuf.Bytes())
   339  	}
   340  	prg := compiler.Compile(top, consts, targets.Get(OS, arch), eh)
   341  	if prg == nil {
   342  		return nil, nil, nil, fmt.Errorf("failed to compile descriptions:\n%s", errorBuf.Bytes())
   343  	}
   344  	prog.RestoreLinks(prg.Syscalls, prg.Resources, prg.Types)
   345  	locs := make(map[string]*ast.Struct)
   346  	for _, decl := range top.Nodes {
   347  		switch n := decl.(type) {
   348  		case *ast.Struct:
   349  			locs[n.Name.Name] = n
   350  		case *ast.TypeDef:
   351  			if n.Struct != nil {
   352  				locs[n.Name.Name] = n.Struct
   353  			}
   354  		}
   355  	}
   356  	var structs []prog.Type
   357  	for _, typ := range prg.Types {
   358  		switch typ.(type) {
   359  		case *prog.StructType, *prog.UnionType:
   360  			structs = append(structs, typ)
   361  		}
   362  	}
   363  	return structs, locs, warnings, nil
   364  }
   365  
   366  // Overall idea of netlink checking.
   367  // Currnetly we check netlink policies for common detectable mistakes.
   368  // First, we detect what looks like a netlink policy in our descriptions
   369  // (these are structs/unions only with nlattr/nlnext/nlnetw fields).
   370  // Then we find corresponding symbols (offset/size) in vmlinux using nm.
   371  // Then we read elf headers and locate where these symbols are in the rodata section.
   372  // Then read in the symbol data, which is an array of nla_policy structs.
   373  // These structs allow to easily figure out type/size of attributes.
   374  // Finally we compare our descriptions with the kernel policy description.
   375  func checkNetlink(OS, arch, obj string, structTypes []prog.Type,
   376  	locs map[string]*ast.Struct) ([]Warn, error) {
   377  	if arch != targets.AMD64 {
   378  		// Netlink policies are arch-independent (?),
   379  		// so no need to check all arches.
   380  		// Also our definition of nlaPolicy below is 64-bit specific.
   381  		return nil, nil
   382  	}
   383  	ef, err := elf.Open(obj)
   384  	if err != nil {
   385  		return nil, err
   386  	}
   387  	rodata := ef.Section(".rodata")
   388  	if rodata == nil {
   389  		return nil, fmt.Errorf("object file %v does not contain .rodata section", obj)
   390  	}
   391  	symb := symbolizer.NewSymbolizer(targets.Get(OS, arch))
   392  	symbols, err := symb.ReadRodataSymbols(obj)
   393  	if err != nil {
   394  		return nil, err
   395  	}
   396  	var warnings []Warn
   397  	structMap := make(map[string]prog.Type)
   398  	for _, typ := range structTypes {
   399  		structMap[typ.Name()] = typ
   400  	}
   401  	checkedAttrs := make(map[string]*checkAttr)
   402  	for _, typ := range structTypes {
   403  		warnings1, err := checkNetlinkStruct(locs, symbols, rodata, structMap, checkedAttrs, typ)
   404  		if err != nil {
   405  			return nil, err
   406  		}
   407  		warnings = append(warnings, warnings1...)
   408  	}
   409  	warnings = append(warnings, checkMissingAttrs(checkedAttrs)...)
   410  	return warnings, nil
   411  }
   412  
   413  func checkNetlinkStruct(locs map[string]*ast.Struct, symbols map[string][]symbolizer.Symbol, rodata *elf.Section,
   414  	structMap map[string]prog.Type, checkedAttrs map[string]*checkAttr, typ prog.Type) ([]Warn, error) {
   415  	name := typ.TemplateName()
   416  	astStruct := locs[name]
   417  	if astStruct == nil {
   418  		return nil, nil
   419  	}
   420  	fields, ok := isNetlinkPolicy(typ)
   421  	if !ok {
   422  		return nil, nil
   423  	}
   424  	// In some cases we split a single policy into multiple ones (more precise description),
   425  	// so try to match our foo$bar with kernel foo as well.
   426  	kernelName, ss := name, symbols[name]
   427  	if delim := strings.LastIndexByte(name, '$'); len(ss) == 0 && delim != -1 {
   428  		kernelName = name[:delim]
   429  		ss = symbols[kernelName]
   430  	}
   431  	if len(ss) == 0 {
   432  		return []Warn{{pos: astStruct.Pos, typ: WarnNoNetlinkPolicy, msg: name}}, nil
   433  	}
   434  	var warnings []Warn
   435  	var warnings1 *[]Warn
   436  	var policy1 []nlaPolicy
   437  	var attrs1 map[int]bool
   438  	// We may have several symbols with the same name (they frequently have internal linking),
   439  	// in such case we choose the one that produces fewer warnings.
   440  	for _, symb := range ss {
   441  		if symb.Size == 0 || symb.Size%int(unsafe.Sizeof(nlaPolicy{})) != 0 {
   442  			warnings = append(warnings, Warn{pos: astStruct.Pos, typ: WarnNetlinkBadSize,
   443  				msg: fmt.Sprintf("%v (%v), size %v", kernelName, name, ss[0].Size)})
   444  			continue
   445  		}
   446  		binary := make([]byte, symb.Size)
   447  		addr := symb.Addr - rodata.Addr
   448  		if _, err := rodata.ReadAt(binary, int64(addr)); err != nil {
   449  			return nil, fmt.Errorf("failed to read policy %v (%v) at %v: %w",
   450  				kernelName, name, symb.Addr, err)
   451  		}
   452  		policy := (*[1e6]nlaPolicy)(unsafe.Pointer(&binary[0]))[:symb.Size/int(unsafe.Sizeof(nlaPolicy{}))]
   453  		warnings2, attrs2, err := checkNetlinkPolicy(structMap, typ, fields, astStruct, policy)
   454  		if err != nil {
   455  			return nil, err
   456  		}
   457  		if warnings1 == nil || len(*warnings1) > len(warnings2) {
   458  			warnings1 = &warnings2
   459  			policy1 = policy
   460  			attrs1 = attrs2
   461  		}
   462  	}
   463  	if warnings1 != nil {
   464  		warnings = append(warnings, *warnings1...)
   465  		ca := checkedAttrs[kernelName]
   466  		if ca == nil {
   467  			ca = &checkAttr{
   468  				pos:    astStruct.Pos,
   469  				name:   name,
   470  				policy: policy1,
   471  				attrs:  make(map[int]bool),
   472  			}
   473  			checkedAttrs[kernelName] = ca
   474  		}
   475  		for attr := range attrs1 {
   476  			ca.attrs[attr] = true
   477  		}
   478  	}
   479  	return warnings, nil
   480  }
   481  
   482  type checkAttr struct {
   483  	pos    ast.Pos
   484  	name   string
   485  	policy []nlaPolicy
   486  	attrs  map[int]bool
   487  }
   488  
   489  func checkMissingAttrs(checkedAttrs map[string]*checkAttr) []Warn {
   490  	// Missing attribute checking is a bit tricky because we may split a single
   491  	// kernel policy into several policies for better precision.
   492  	// They have different names, but map to the same kernel policy.
   493  	// We want to report a missing attribute iff it's missing in all copies of the policy.
   494  	var warnings []Warn
   495  	for _, ca := range checkedAttrs {
   496  		var missing []int
   497  		for i, pol := range ca.policy {
   498  			// Ignore attributes that are not described in the policy
   499  			// (some of them are unused at all, however there are cases where
   500  			// they are not described but used as inputs, and these are actually
   501  			// the worst ones).
   502  			if !ca.attrs[i] && (pol.typ != NLA_UNSPEC && pol.typ != NLA_REJECT || pol.len != 0) {
   503  				missing = append(missing, i)
   504  			}
   505  		}
   506  		// If we miss too many, there is probably something else going on.
   507  		if len(missing) != 0 && len(missing) <= 5 {
   508  			warnings = append(warnings, Warn{
   509  				pos: ca.pos,
   510  				typ: WarnNetlinkBadAttr,
   511  				msg: fmt.Sprintf("%v: missing attributes: %v", ca.name, missing),
   512  			})
   513  		}
   514  	}
   515  	return warnings
   516  }
   517  
   518  func isNetlinkPolicy(typ prog.Type) ([]prog.Field, bool) {
   519  	var fields []prog.Field
   520  	switch t := typ.(type) {
   521  	case *prog.StructType:
   522  		fields = t.Fields
   523  	case *prog.UnionType:
   524  		fields = t.Fields
   525  	default:
   526  		return nil, false
   527  	}
   528  	haveAttr := false
   529  	for _, fld := range fields {
   530  		field := fld.Type
   531  		if prog.IsPad(field) {
   532  			continue
   533  		}
   534  		if isNlattr(field) {
   535  			haveAttr = true
   536  			continue
   537  		}
   538  		if arr, ok := field.(*prog.ArrayType); ok {
   539  			field = arr.Elem
   540  		}
   541  		if _, ok := isNetlinkPolicy(field); ok {
   542  			continue
   543  		}
   544  		return nil, false
   545  	}
   546  	return fields, haveAttr
   547  }
   548  
   549  const (
   550  	nlattrT  = "nlattr_t"
   551  	nlattrTT = "nlattr_tt"
   552  )
   553  
   554  func isNlattr(typ prog.Type) bool {
   555  	name := typ.TemplateName()
   556  	return name == nlattrT || name == nlattrTT
   557  }
   558  
   559  func checkNetlinkPolicy(structMap map[string]prog.Type, typ prog.Type, fields []prog.Field,
   560  	astStruct *ast.Struct, policy []nlaPolicy) ([]Warn, map[int]bool, error) {
   561  	var warnings []Warn
   562  	warn := func(pos ast.Pos, typ, msg string, args ...interface{}) {
   563  		warnings = append(warnings, Warn{pos: pos, typ: typ, msg: fmt.Sprintf(msg, args...)})
   564  	}
   565  	checked := make(map[int]bool)
   566  	ai := 0
   567  	for _, field := range fields {
   568  		if prog.IsPad(field.Type) {
   569  			continue
   570  		}
   571  		fld := astStruct.Fields[ai]
   572  		ai++
   573  		if !isNlattr(field.Type) {
   574  			continue
   575  		}
   576  		ft := field.Type.(*prog.StructType)
   577  		attr := int(ft.Fields[1].Type.(*prog.ConstType).Val)
   578  		if attr >= len(policy) {
   579  			warn(fld.Pos, WarnNetlinkBadAttrType, "%v.%v: type %v, kernel policy size %v",
   580  				typ.TemplateName(), field.Name, attr, len(policy))
   581  			continue
   582  		}
   583  		if checked[attr] {
   584  			warn(fld.Pos, WarnNetlinkBadAttr, "%v.%v: duplicate attribute",
   585  				typ.TemplateName(), field.Name)
   586  		}
   587  		checked[attr] = true
   588  		w := checkNetlinkAttr(ft, policy[attr])
   589  		if w != "" {
   590  			warn(fld.Pos, WarnNetlinkBadAttr, "%v.%v: %v",
   591  				typ.TemplateName(), field.Name, w)
   592  		}
   593  	}
   594  	return warnings, checked, nil
   595  }
   596  
   597  func checkNetlinkAttr(typ *prog.StructType, policy nlaPolicy) string {
   598  	payload := typ.Fields[2].Type
   599  	if typ.TemplateName() == nlattrTT {
   600  		payload = typ.Fields[4].Type
   601  	}
   602  	if warn := checkAttrType(typ, payload, policy); warn != "" {
   603  		return warn
   604  	}
   605  	size, minSize, maxSize := attrSize(policy)
   606  	payloadSize := minTypeSize(payload)
   607  	if size != -1 && size != payloadSize {
   608  		return fmt.Sprintf("bad size %v, expect %v", payloadSize, size)
   609  	}
   610  	if minSize != -1 && minSize > payloadSize {
   611  		return fmt.Sprintf("bad size %v, expect min %v", payloadSize, minSize)
   612  	}
   613  	if maxSize != -1 && maxSize < payloadSize {
   614  		return fmt.Sprintf("bad size %v, expect max %v", payloadSize, maxSize)
   615  	}
   616  
   617  	valMin, valMax, haveVal := typeMinMaxValue(payload)
   618  	if haveVal {
   619  		if policy.validation == NLA_VALIDATE_RANGE || policy.validation == NLA_VALIDATE_MIN {
   620  			if int64(valMin) < int64(policy.minVal) {
   621  				// This is a common case that occurs several times: limit on min value of 1.
   622  				// Not worth fixing (at least not in initial batch), it just crosses out a
   623  				// single value of 0, which we shuold test anyway.
   624  				if !(policy.validation == NLA_VALIDATE_MIN && policy.minVal == 1) {
   625  					return fmt.Sprintf("bad min value %v, expect %v",
   626  						int64(valMin), policy.minVal)
   627  				}
   628  			}
   629  		}
   630  		if policy.validation == NLA_VALIDATE_RANGE || policy.validation == NLA_VALIDATE_MAX {
   631  			if int64(valMax) > int64(policy.maxVal) {
   632  				return fmt.Sprintf("bad max value %v, expect %v",
   633  					int64(valMax), policy.maxVal)
   634  			}
   635  		}
   636  	}
   637  	return ""
   638  }
   639  
   640  func minTypeSize(t prog.Type) int {
   641  	if !t.Varlen() {
   642  		return int(t.Size())
   643  	}
   644  	switch typ := t.(type) {
   645  	case *prog.StructType:
   646  		if typ.OverlayField != 0 {
   647  			// Overlayed structs are not supported here
   648  			// (and should not be used in netlink).
   649  			// Make this always produce a warning.
   650  			return -1
   651  		}
   652  		// Some struct args has trailing arrays, but are only checked for min size.
   653  		// Try to get some estimation for min size of this struct.
   654  		size := 0
   655  		for _, field := range typ.Fields {
   656  			if !field.Varlen() {
   657  				size += int(field.Size())
   658  			}
   659  		}
   660  		return size
   661  	case *prog.ArrayType:
   662  		if typ.Kind == prog.ArrayRangeLen && !typ.Elem.Varlen() {
   663  			return int(typ.RangeBegin * typ.Elem.Size())
   664  		}
   665  	case *prog.UnionType:
   666  		size := 0
   667  		for _, field := range typ.Fields {
   668  			if size1 := minTypeSize(field.Type); size1 != -1 && size > size1 || size == 0 {
   669  				size = size1
   670  			}
   671  		}
   672  		return size
   673  	}
   674  	return -1
   675  }
   676  
   677  func checkAttrType(typ *prog.StructType, payload prog.Type, policy nlaPolicy) string {
   678  	switch policy.typ {
   679  	case NLA_STRING, NLA_NUL_STRING:
   680  		if _, ok := payload.(*prog.BufferType); !ok {
   681  			return "expect string"
   682  		}
   683  	case NLA_NESTED:
   684  		if typ.TemplateName() != nlattrTT || typ.Fields[3].Type.(*prog.ConstType).Val != 1 {
   685  			return "should be nlnest"
   686  		}
   687  	case NLA_BITFIELD32:
   688  		if typ.TemplateName() != nlattrT || payload.TemplateName() != "nla_bitfield32" {
   689  			return "should be nlattr[nla_bitfield32]"
   690  		}
   691  	case NLA_NESTED_ARRAY:
   692  		if _, ok := payload.(*prog.ArrayType); !ok {
   693  			return "expect array"
   694  		}
   695  	case NLA_REJECT:
   696  		return "NLA_REJECT attribute will always be rejected"
   697  	}
   698  	return ""
   699  }
   700  
   701  func attrSize(policy nlaPolicy) (int, int, int) {
   702  	switch policy.typ {
   703  	case NLA_UNSPEC:
   704  		if policy.len != 0 {
   705  			return -1, int(policy.len), -1
   706  		}
   707  	case NLA_MIN_LEN:
   708  		return -1, int(policy.len), -1
   709  	case NLA_EXACT_LEN, NLA_EXACT_LEN_WARN:
   710  		return int(policy.len), -1, -1
   711  	case NLA_U8, NLA_S8:
   712  		return 1, -1, -1
   713  	case NLA_U16, NLA_S16:
   714  		return 2, -1, -1
   715  	case NLA_U32, NLA_S32:
   716  		return 4, -1, -1
   717  	case NLA_U64, NLA_S64, NLA_MSECS:
   718  		return 8, -1, -1
   719  	case NLA_FLAG:
   720  		return 0, -1, -1
   721  	case NLA_BINARY:
   722  		if policy.len != 0 {
   723  			return -1, -1, int(policy.len)
   724  		}
   725  	}
   726  	return -1, -1, -1
   727  }
   728  
   729  func typeMinMaxValue(payload prog.Type) (min, max uint64, ok bool) {
   730  	switch typ := payload.(type) {
   731  	case *prog.ConstType:
   732  		return typ.Val, typ.Val, true
   733  	case *prog.IntType:
   734  		if typ.Kind == prog.IntRange {
   735  			return typ.RangeBegin, typ.RangeEnd, true
   736  		}
   737  		return 0, ^uint64(0), true
   738  	case *prog.FlagsType:
   739  		min, max := ^uint64(0), uint64(0)
   740  		for _, v := range typ.Vals {
   741  			if min > v {
   742  				min = v
   743  			}
   744  			if max < v {
   745  				max = v
   746  			}
   747  		}
   748  		return min, max, true
   749  	}
   750  	return 0, 0, false
   751  }
   752  
   753  type nlaPolicy struct {
   754  	typ        uint8
   755  	validation uint8
   756  	len        uint16
   757  	_          uint32
   758  	minVal     int16
   759  	maxVal     int16
   760  	_          int32
   761  }
   762  
   763  // nolint
   764  const (
   765  	NLA_UNSPEC = iota
   766  	NLA_U8
   767  	NLA_U16
   768  	NLA_U32
   769  	NLA_U64
   770  	NLA_STRING
   771  	NLA_FLAG
   772  	NLA_MSECS
   773  	NLA_NESTED
   774  	NLA_NESTED_ARRAY
   775  	NLA_NUL_STRING
   776  	NLA_BINARY
   777  	NLA_S8
   778  	NLA_S16
   779  	NLA_S32
   780  	NLA_S64
   781  	NLA_BITFIELD32
   782  	NLA_REJECT
   783  	NLA_EXACT_LEN
   784  	NLA_EXACT_LEN_WARN
   785  	NLA_MIN_LEN
   786  )
   787  
   788  // nolint
   789  const (
   790  	_ = iota
   791  	NLA_VALIDATE_RANGE
   792  	NLA_VALIDATE_MIN
   793  	NLA_VALIDATE_MAX
   794  )