github.com/google/syzkaller@v0.0.0-20240517125934-c0f1611a36d6/prog/prog_test.go (about)

     1  // Copyright 2015 syzkaller project authors. All rights reserved.
     2  // Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
     3  
     4  package prog
     5  
     6  import (
     7  	"bytes"
     8  	"fmt"
     9  	"math/rand"
    10  	"strings"
    11  	"testing"
    12  
    13  	"github.com/google/syzkaller/pkg/testutil"
    14  )
    15  
    16  func TestGeneration(t *testing.T) {
    17  	target, rs, iters := initTest(t)
    18  	ct := target.DefaultChoiceTable()
    19  	for i := 0; i < iters; i++ {
    20  		target.Generate(rs, 20, ct)
    21  	}
    22  }
    23  
    24  func TestDefault(t *testing.T) {
    25  	target, _, _ := initTest(t)
    26  	ForeachType(target.Syscalls, func(typ Type, ctx *TypeCtx) {
    27  		arg := typ.DefaultArg(ctx.Dir)
    28  		if !isDefault(arg) {
    29  			t.Errorf("default arg is not default: %s\ntype: %#v\narg: %#v",
    30  				typ, typ, arg)
    31  		}
    32  	})
    33  }
    34  
    35  func TestDefaultCallArgs(t *testing.T) {
    36  	testEachTarget(t, func(t *testing.T, target *Target) {
    37  		for _, meta := range target.SyscallMap {
    38  			if meta.Attrs.Disabled {
    39  				continue
    40  			}
    41  			// Ensure that we can restore all arguments of all calls.
    42  			prog := fmt.Sprintf("%v()", meta.Name)
    43  			p, err := target.Deserialize([]byte(prog), NonStrict)
    44  			if err != nil {
    45  				t.Fatalf("failed to restore default args in prog %q: %v", prog, err)
    46  			}
    47  			if len(p.Calls) != 1 || p.Calls[0].Meta.Name != meta.Name {
    48  				t.Fatalf("restored bad program from prog %q: %q", prog, p.Serialize())
    49  			}
    50  			s0 := string(p.Serialize())
    51  			p.sanitizeFix()
    52  			s1 := string(p.Serialize())
    53  			if s0 != s1 {
    54  				t.Fatalf("non-sanitized program or non-idempotent sanitize\nwas: %v\ngot: %v", s0, s1)
    55  			}
    56  		}
    57  	})
    58  }
    59  
    60  func testSerialize(t *testing.T, verbose bool) {
    61  	target, rs, iters := initTest(t)
    62  	ct := target.DefaultChoiceTable()
    63  	for i := 0; i < iters; i++ {
    64  		p := target.Generate(rs, 10, ct)
    65  		var data []byte
    66  		mode := NonStrict
    67  		if verbose {
    68  			data = p.SerializeVerbose()
    69  			mode = Strict
    70  		} else {
    71  			data = p.Serialize()
    72  		}
    73  		p1, err := target.Deserialize(data, mode)
    74  		if err != nil {
    75  			t.Fatalf("failed to deserialize program: %v\n%s", err, data)
    76  		}
    77  		if p1 == nil {
    78  			t.Fatalf("deserialized nil program:\n%s", data)
    79  		}
    80  		var data1 []byte
    81  		if verbose {
    82  			data1 = p1.SerializeVerbose()
    83  		} else {
    84  			data1 = p1.Serialize()
    85  		}
    86  		if len(p.Calls) != len(p1.Calls) {
    87  			t.Fatalf("different number of calls")
    88  		}
    89  		if !bytes.Equal(data, data1) {
    90  			t.Fatalf("program changed after serialize/deserialize\noriginal:\n%s\n\nnew:\n%s", data, data1)
    91  		}
    92  	}
    93  }
    94  
    95  func TestSerialize(t *testing.T) {
    96  	testSerialize(t, false)
    97  }
    98  
    99  func TestSerializeVerbose(t *testing.T) {
   100  	testSerialize(t, true)
   101  }
   102  
   103  func TestVmaType(t *testing.T) {
   104  	target, rs, iters := initRandomTargetTest(t, "test", "64")
   105  	ct := target.DefaultChoiceTable()
   106  	meta := target.SyscallMap["test$vma0"]
   107  	r := newRand(target, rs)
   108  	pageSize := target.PageSize
   109  	for i := 0; i < iters; i++ {
   110  		s := newState(target, ct, nil)
   111  		calls := r.generateParticularCall(s, meta)
   112  		c := calls[len(calls)-1]
   113  		if c.Meta.Name != "test$vma0" {
   114  			t.Fatalf("generated wrong call %v", c.Meta.Name)
   115  		}
   116  		if len(c.Args) != 6 {
   117  			t.Fatalf("generated wrong number of args %v", len(c.Args))
   118  		}
   119  		check := func(v, l Arg, min, max uint64) {
   120  			va, ok := v.(*PointerArg)
   121  			if !ok {
   122  				t.Fatalf("vma has bad type: %v", v)
   123  			}
   124  			la, ok := l.(*ConstArg)
   125  			if !ok {
   126  				t.Fatalf("len has bad type: %v", l)
   127  			}
   128  			if va.VmaSize < min || va.VmaSize > max {
   129  				t.Fatalf("vma has bad size: %v, want [%v-%v]",
   130  					va.VmaSize, min, max)
   131  			}
   132  			if la.Val < min || la.Val > max {
   133  				t.Fatalf("len has bad value: %v, want [%v-%v]",
   134  					la.Val, min, max)
   135  			}
   136  		}
   137  		check(c.Args[0], c.Args[1], 1*pageSize, 1e5*pageSize)
   138  		check(c.Args[2], c.Args[3], 5*pageSize, 5*pageSize)
   139  		check(c.Args[4], c.Args[5], 7*pageSize, 9*pageSize)
   140  	}
   141  }
   142  
   143  // TestCrossTarget ensures that a program serialized for one arch can be
   144  // deserialized for another arch. This happens when managers exchange
   145  // programs via hub.
   146  func TestCrossTarget(t *testing.T) {
   147  	if testutil.RaceEnabled {
   148  		t.Skip("skipping in race mode, too slow")
   149  	}
   150  	t.Parallel()
   151  	const OS = "linux"
   152  	var archs []string
   153  	for _, target := range AllTargets() {
   154  		if target.OS == OS {
   155  			archs = append(archs, target.Arch)
   156  		}
   157  	}
   158  	for _, arch := range archs {
   159  		target, err := GetTarget(OS, arch)
   160  		if err != nil {
   161  			t.Fatal(err)
   162  		}
   163  		var crossTargets []*Target
   164  		for _, crossArch := range archs {
   165  			if crossArch == arch {
   166  				continue
   167  			}
   168  			crossTarget, err := GetTarget(OS, crossArch)
   169  			if err != nil {
   170  				t.Fatal(err)
   171  			}
   172  			crossTargets = append(crossTargets, crossTarget)
   173  		}
   174  		t.Run(fmt.Sprintf("%v/%v", OS, arch), func(t *testing.T) {
   175  			t.Parallel()
   176  			testCrossTarget(t, target, crossTargets)
   177  		})
   178  	}
   179  }
   180  
   181  func testCrossTarget(t *testing.T, target *Target, crossTargets []*Target) {
   182  	ct := target.DefaultChoiceTable()
   183  	rs := testutil.RandSource(t)
   184  	iters := 100
   185  	if testing.Short() {
   186  		iters /= 10
   187  	}
   188  	for i := 0; i < iters; i++ {
   189  		p := target.Generate(rs, 20, ct)
   190  		testCrossArchProg(t, p, crossTargets)
   191  		p, err := target.Deserialize(p.Serialize(), NonStrict)
   192  		if err != nil {
   193  			t.Fatal(err)
   194  		}
   195  		testCrossArchProg(t, p, crossTargets)
   196  		p.Mutate(rs, 20, ct, nil, nil)
   197  		testCrossArchProg(t, p, crossTargets)
   198  		p, _ = Minimize(p, -1, false, func(*Prog, int) bool {
   199  			return rs.Int63()%2 == 0
   200  		})
   201  		testCrossArchProg(t, p, crossTargets)
   202  	}
   203  }
   204  
   205  func testCrossArchProg(t *testing.T, p *Prog, crossTargets []*Target) {
   206  	serialized := p.Serialize()
   207  	for _, crossTarget := range crossTargets {
   208  		_, err := crossTarget.Deserialize(serialized, NonStrict)
   209  		if err == nil || strings.Contains(err.Error(), "unknown syscall") {
   210  			continue
   211  		}
   212  		t.Fatalf("failed to deserialize for %v/%v: %v\n%s",
   213  			crossTarget.OS, crossTarget.Arch, err, serialized)
   214  	}
   215  }
   216  
   217  func TestSpecialStructs(t *testing.T) {
   218  	testEachTargetRandom(t, func(t *testing.T, target *Target, rs rand.Source, iters int) {
   219  		_ = target.GenerateAllSyzProg(rs)
   220  		ct := target.DefaultChoiceTable()
   221  		for special, gen := range target.SpecialTypes {
   222  			t.Run(special, func(t *testing.T) {
   223  				var typ Type
   224  				for i := 0; i < len(target.Syscalls) && typ == nil; i++ {
   225  					ForeachCallType(target.Syscalls[i], func(t Type, ctx *TypeCtx) {
   226  						if ctx.Dir == DirOut {
   227  							return
   228  						}
   229  						if s, ok := t.(*StructType); ok && s.Name() == special {
   230  							typ = s
   231  						}
   232  						if s, ok := t.(*UnionType); ok && s.Name() == special {
   233  							typ = s
   234  						}
   235  					})
   236  				}
   237  				if typ == nil {
   238  					t.Fatal("can't find struct description")
   239  				}
   240  				g := &Gen{newRand(target, rs), newState(target, ct, nil)}
   241  				for i := 0; i < iters/len(target.SpecialTypes); i++ {
   242  					var arg Arg
   243  					for i := 0; i < 2; i++ {
   244  						arg, _ = gen(g, typ, DirInOut, arg)
   245  						if arg.Dir() != DirInOut {
   246  							t.Fatalf("got wrong arg dir %v", arg.Dir())
   247  						}
   248  					}
   249  				}
   250  			})
   251  		}
   252  	})
   253  }
   254  
   255  func TestEscapingPaths(t *testing.T) {
   256  	paths := map[string]bool{
   257  		"/":                      true,
   258  		"/\x00":                  true,
   259  		"/file/..":               true,
   260  		"/file/../..":            true,
   261  		"./..":                   true,
   262  		"..":                     true,
   263  		"file/../../file":        true,
   264  		"../file":                true,
   265  		"./file/../../file/file": true,
   266  		"":                       false,
   267  		".":                      false,
   268  		"file":                   false,
   269  		"./file":                 false,
   270  		"./file/..":              false,
   271  	}
   272  	for path, want := range paths {
   273  		got := escapingFilename(path)
   274  		if got != want {
   275  			t.Errorf("path %q: got %v, want %v", path, got, want)
   276  		}
   277  	}
   278  }
   279  
   280  func TestFallbackSignal(t *testing.T) {
   281  	type desc struct {
   282  		prog string
   283  		info []CallInfo
   284  	}
   285  	tests := []desc{
   286  		// Test restored errno values and that non-executed syscalls don't get fallback signal.
   287  		{
   288  			`
   289  fallback$0()
   290  fallback$0()
   291  fallback$0()
   292  `,
   293  			[]CallInfo{
   294  				{
   295  					Flags:  CallExecuted,
   296  					Errno:  0,
   297  					Signal: make([]uint32, 1),
   298  				},
   299  				{
   300  					Flags:  CallExecuted,
   301  					Errno:  42,
   302  					Signal: make([]uint32, 1),
   303  				},
   304  				{},
   305  			},
   306  		},
   307  		// Test different cases of argument-dependent signal and that unsuccessful calls don't get it.
   308  		{
   309  			`
   310  r0 = fallback$0()
   311  fallback$1(r0)
   312  fallback$1(r0)
   313  fallback$1(0xffffffffffffffff)
   314  fallback$1(0x0)
   315  fallback$1(0x0)
   316  `,
   317  			[]CallInfo{
   318  				{
   319  					Flags:  CallExecuted,
   320  					Errno:  0,
   321  					Signal: make([]uint32, 1),
   322  				},
   323  				{
   324  					Flags:  CallExecuted,
   325  					Errno:  1,
   326  					Signal: make([]uint32, 1),
   327  				},
   328  				{
   329  					Flags:  CallExecuted,
   330  					Errno:  0,
   331  					Signal: make([]uint32, 2),
   332  				},
   333  				{
   334  					Flags:  CallExecuted,
   335  					Errno:  0,
   336  					Signal: make([]uint32, 1),
   337  				},
   338  				{
   339  					Flags:  CallExecuted,
   340  					Errno:  0,
   341  					Signal: make([]uint32, 2),
   342  				},
   343  				{
   344  					Flags:  CallExecuted,
   345  					Errno:  2,
   346  					Signal: make([]uint32, 1),
   347  				},
   348  			},
   349  		},
   350  		// Test that calls get no signal after a successful seccomp.
   351  		{
   352  			`
   353  fallback$0()
   354  fallback$0()
   355  breaks_returns()
   356  fallback$0()
   357  breaks_returns()
   358  fallback$0()
   359  fallback$0()
   360  `,
   361  			[]CallInfo{
   362  				{
   363  					Flags:  CallExecuted,
   364  					Errno:  0,
   365  					Signal: make([]uint32, 1),
   366  				},
   367  				{
   368  					Flags:  CallExecuted,
   369  					Errno:  0,
   370  					Signal: make([]uint32, 1),
   371  				},
   372  				{
   373  					Flags:  CallExecuted,
   374  					Errno:  1,
   375  					Signal: make([]uint32, 1),
   376  				},
   377  				{
   378  					Flags: CallExecuted,
   379  					Errno: 0,
   380  				},
   381  				{
   382  					Flags: CallExecuted,
   383  					Errno: 0,
   384  				},
   385  				{
   386  					Flags: CallExecuted,
   387  				},
   388  				{
   389  					Flags: CallExecuted,
   390  				},
   391  			},
   392  		},
   393  		{
   394  			`
   395  fallback$0()
   396  breaks_returns()
   397  fallback$0()
   398  breaks_returns()
   399  fallback$0()
   400  `,
   401  			[]CallInfo{
   402  				{
   403  					Flags:  CallExecuted,
   404  					Errno:  0,
   405  					Signal: make([]uint32, 1),
   406  				},
   407  				{
   408  					Flags:  CallExecuted,
   409  					Errno:  1,
   410  					Signal: make([]uint32, 1),
   411  				},
   412  				{
   413  					Flags: CallExecuted,
   414  					Errno: 0,
   415  				},
   416  				{
   417  					Flags: CallExecuted,
   418  					Errno: 0,
   419  				},
   420  				{
   421  					Flags: CallExecuted,
   422  				},
   423  			},
   424  		},
   425  	}
   426  	target, err := GetTarget("test", "64")
   427  	if err != nil {
   428  		t.Fatal(err)
   429  	}
   430  	for i, test := range tests {
   431  		t.Run(fmt.Sprint(i), func(t *testing.T) {
   432  			p, err := target.Deserialize([]byte(test.prog), Strict)
   433  			if err != nil {
   434  				t.Fatal(err)
   435  			}
   436  			if len(p.Calls) != len(test.info) {
   437  				t.Fatalf("call=%v info=%v", len(p.Calls), len(test.info))
   438  			}
   439  			wantSignal := make([]int, len(test.info))
   440  			for i := range test.info {
   441  				wantSignal[i] = len(test.info[i].Signal)
   442  				test.info[i].Signal = nil
   443  			}
   444  			p.FallbackSignal(test.info)
   445  			for i := range test.info {
   446  				if len(test.info[i].Signal) != wantSignal[i] {
   447  					t.Errorf("call %v: signal=%v want=%v", i, len(test.info[i].Signal), wantSignal[i])
   448  				}
   449  				for _, sig := range test.info[i].Signal {
   450  					call, errno := DecodeFallbackSignal(sig)
   451  					if call != p.Calls[i].Meta.ID {
   452  						t.Errorf("call %v: sig=%x id=%v want=%v", i, sig, call, p.Calls[i].Meta.ID)
   453  					}
   454  					if errno != test.info[i].Errno {
   455  						t.Errorf("call %v: sig=%x errno=%v want=%v", i, sig, errno, test.info[i].Errno)
   456  					}
   457  				}
   458  			}
   459  		})
   460  	}
   461  }
   462  
   463  func TestSanitizeRandom(t *testing.T) {
   464  	testEachTargetRandom(t, func(t *testing.T, target *Target, rs rand.Source, iters int) {
   465  		ct := target.DefaultChoiceTable()
   466  		for i := 0; i < iters; i++ {
   467  			p := target.Generate(rs, 10, ct)
   468  			s0 := string(p.Serialize())
   469  			p.sanitizeFix()
   470  			s1 := string(p.Serialize())
   471  			if s0 != s1 {
   472  				t.Fatalf("non-sanitized program or non-idempotent sanitize\nwas: %v\ngot: %v", s0, s1)
   473  			}
   474  		}
   475  	})
   476  }