github.com/google/syzkaller@v0.0.0-20240517125934-c0f1611a36d6/tools/syz-trace2syz/proggen/call_selector.go (about)

     1  // Copyright 2018 syzkaller project authors. All rights reserved.
     2  // Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
     3  
     4  package proggen
     5  
     6  import (
     7  	"bytes"
     8  	"github.com/google/syzkaller/prog"
     9  	"github.com/google/syzkaller/tools/syz-trace2syz/parser"
    10  	"strconv"
    11  	"unicode"
    12  )
    13  
    14  var discriminatorArgs = map[string][]int{
    15  	"bpf":         {0},
    16  	"fcntl":       {1},
    17  	"ioprio_get":  {0},
    18  	"socket":      {0, 1, 2},
    19  	"socketpair":  {0, 1, 2},
    20  	"ioctl":       {0, 1},
    21  	"getsockopt":  {1, 2},
    22  	"setsockopt":  {1, 2},
    23  	"accept":      {0},
    24  	"accept4":     {0},
    25  	"bind":        {0},
    26  	"connect":     {0},
    27  	"recvfrom":    {0},
    28  	"sendto":      {0},
    29  	"sendmsg":     {0},
    30  	"getsockname": {0},
    31  	"openat":      {1},
    32  }
    33  
    34  var openDiscriminatorArgs = map[string]int{
    35  	"open":         0,
    36  	"openat":       1,
    37  	"syz_open_dev": 0,
    38  }
    39  
    40  type callSelector interface {
    41  	Select(call *parser.Syscall) *prog.Syscall
    42  }
    43  
    44  func newSelectors(target *prog.Target, returnCache returnCache) []callSelector {
    45  	sc := newSelectorCommon(target, returnCache)
    46  	return []callSelector{
    47  		&defaultCallSelector{sc},
    48  		&openCallSelector{sc},
    49  	}
    50  }
    51  
    52  type selectorCommon struct {
    53  	target      *prog.Target
    54  	returnCache returnCache
    55  	callCache   map[string][]*prog.Syscall
    56  }
    57  
    58  func newSelectorCommon(target *prog.Target, returnCache returnCache) *selectorCommon {
    59  	return &selectorCommon{
    60  		target:      target,
    61  		returnCache: returnCache,
    62  		callCache:   make(map[string][]*prog.Syscall),
    63  	}
    64  }
    65  
    66  // matches strace file string with a constant string in openat or syz_open_dev
    67  // if the string in openat or syz_open_dev has a # then this method will
    68  // return the corresponding  id from the strace string
    69  func (cs *selectorCommon) matchFilename(syzFile, straceFile []byte) (bool, int) {
    70  	syzFile = bytes.Trim(syzFile, "\x00")
    71  	straceFile = bytes.Trim(straceFile, "\x00")
    72  	if len(syzFile) != len(straceFile) {
    73  		return false, -1
    74  	}
    75  	var id []byte
    76  	dev := -1
    77  	for i, c := range syzFile {
    78  		x := straceFile[i]
    79  		if c == x {
    80  			continue
    81  		}
    82  		if c != '#' || !unicode.IsDigit(rune(x)) {
    83  			return false, -1
    84  		}
    85  		id = append(id, x)
    86  	}
    87  	if len(id) > 0 {
    88  		dev, _ = strconv.Atoi(string(id))
    89  	}
    90  	return true, dev
    91  }
    92  
    93  // callSet returns all syscalls with the given name.
    94  func (cs *selectorCommon) callSet(callName string) []*prog.Syscall {
    95  	calls, ok := cs.callCache[callName]
    96  	if ok {
    97  		return calls
    98  	}
    99  	for _, call := range cs.target.Syscalls {
   100  		if call.CallName == callName {
   101  			calls = append(calls, call)
   102  		}
   103  	}
   104  	cs.callCache[callName] = calls
   105  	return calls
   106  }
   107  
   108  type openCallSelector struct {
   109  	*selectorCommon
   110  }
   111  
   112  // Select returns the best matching descrimination for this syscall.
   113  func (cs *openCallSelector) Select(call *parser.Syscall) *prog.Syscall {
   114  	if _, ok := openDiscriminatorArgs[call.CallName]; !ok {
   115  		return nil
   116  	}
   117  	for callName := range openDiscriminatorArgs {
   118  		for _, variant := range cs.callSet(callName) {
   119  			match, devID := cs.matchOpen(variant, call)
   120  			if !match {
   121  				continue
   122  			}
   123  			if call.CallName == "open" && callName == "openat" {
   124  				cwd := parser.Constant(cs.target.ConstMap["AT_FDCWD"])
   125  				call.Args = append([]parser.IrType{cwd}, call.Args...)
   126  				return variant
   127  			}
   128  			if match && call.CallName == "open" && callName == "syz_open_dev" {
   129  				if devID < 0 {
   130  					return variant
   131  				}
   132  				args := []parser.IrType{call.Args[0], parser.Constant(uint64(devID))}
   133  				call.Args = append(args, call.Args[1:]...)
   134  				return variant
   135  			}
   136  		}
   137  	}
   138  	return nil
   139  }
   140  
   141  func (cs *openCallSelector) matchOpen(meta *prog.Syscall, call *parser.Syscall) (bool, int) {
   142  	straceFileArg := call.Args[openDiscriminatorArgs[call.CallName]]
   143  	straceBuf := straceFileArg.(*parser.BufferType).Val
   144  	syzFileArg := meta.Args[openDiscriminatorArgs[meta.CallName]].Type
   145  	if _, ok := syzFileArg.(*prog.PtrType); !ok {
   146  		return false, -1
   147  	}
   148  	syzBuf, ok := syzFileArg.(*prog.PtrType).Elem.(*prog.BufferType)
   149  	if !ok {
   150  		return false, -1
   151  	}
   152  	if syzBuf.Kind != prog.BufferString {
   153  		return false, -1
   154  	}
   155  	for _, val := range syzBuf.Values {
   156  		match, devID := cs.matchFilename([]byte(val), []byte(straceBuf))
   157  		if match {
   158  			return match, devID
   159  		}
   160  	}
   161  	return false, -1
   162  }
   163  
   164  type defaultCallSelector struct {
   165  	*selectorCommon
   166  }
   167  
   168  // Select returns the best matching descrimination for this syscall.
   169  func (cs *defaultCallSelector) Select(call *parser.Syscall) *prog.Syscall {
   170  	var match *prog.Syscall
   171  	discriminators := discriminatorArgs[call.CallName]
   172  	if len(discriminators) == 0 {
   173  		return nil
   174  	}
   175  	score := 0
   176  	for _, meta := range cs.callSet(call.CallName) {
   177  		if score1 := cs.matchCall(meta, call, discriminators); score1 > score {
   178  			match, score = meta, score1
   179  		}
   180  	}
   181  	return match
   182  }
   183  
   184  // matchCall returns match score between meta and call.
   185  // Higher score means better match, -1 if they are not matching at all.
   186  func (cs *defaultCallSelector) matchCall(meta *prog.Syscall, call *parser.Syscall, discriminators []int) int {
   187  	score := 0
   188  	for _, i := range discriminators {
   189  		if i >= len(meta.Args) || i >= len(call.Args) {
   190  			return -1
   191  		}
   192  		typ := meta.Args[i].Type
   193  		arg := call.Args[i]
   194  		switch t := typ.(type) {
   195  		case *prog.ConstType:
   196  			// Consts must match precisely.
   197  			constant, ok := arg.(parser.Constant)
   198  			if !ok || constant.Val() != t.Val {
   199  				return -1
   200  			}
   201  			score += 10
   202  		case *prog.FlagsType:
   203  			// Flags may or may not match, but matched flags increase score.
   204  			constant, ok := arg.(parser.Constant)
   205  			if !ok {
   206  				return -1
   207  			}
   208  			val := constant.Val()
   209  			for _, v := range t.Vals {
   210  				if v == val {
   211  					score++
   212  					break
   213  				}
   214  			}
   215  		case *prog.ResourceType:
   216  			// Resources must match one of subtypes,
   217  			// the more precise match, the higher the score.
   218  			retArg := cs.returnCache.get(t, arg)
   219  			if retArg == nil {
   220  				return -1
   221  			}
   222  			matched := false
   223  			for i, kind := range retArg.Type().(*prog.ResourceType).Desc.Kind {
   224  				if kind == t.Desc.Name {
   225  					score += i + 1
   226  					matched = true
   227  					break
   228  				}
   229  			}
   230  			if !matched {
   231  				return -1
   232  			}
   233  		case *prog.PtrType:
   234  			switch r := t.Elem.(type) {
   235  			case *prog.BufferType:
   236  				matched := false
   237  				buffer, ok := arg.(*parser.BufferType)
   238  				if !ok {
   239  					return -1
   240  				}
   241  				if r.Kind != prog.BufferString {
   242  					return -1
   243  				}
   244  				for _, val := range r.Values {
   245  					matched, _ = cs.matchFilename([]byte(val), []byte(buffer.Val))
   246  					if matched {
   247  						score++
   248  						break
   249  					}
   250  				}
   251  				if !matched {
   252  					return -1
   253  				}
   254  			}
   255  		}
   256  	}
   257  	return score
   258  }