github.com/google/syzkaller@v0.0.0-20251211124644-a066d2bc4b02/pkg/csource/syscall_generation_test.go (about)

     1  // Copyright 2025 syzkaller project authors. All rights reserved.
     2  // Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
     3  package csource
     4  
     5  import (
     6  	"bufio"
     7  	"flag"
     8  	"fmt"
     9  	"os"
    10  	"path"
    11  	"strings"
    12  	"testing"
    13  
    14  	"github.com/google/go-cmp/cmp"
    15  	"github.com/google/syzkaller/prog"
    16  	"github.com/google/syzkaller/sys/targets"
    17  	"github.com/stretchr/testify/assert"
    18  )
    19  
    20  var flagUpdate = flag.Bool("update", false, "update test files accordingly to current results")
    21  
    22  type testData struct {
    23  	filepath string
    24  	// The input syscall description, e.g. bind$netlink(r0, &(0x7f0000514ff4)={0x10, 0x0, 0x0, 0x2ffffffff}, 0xc).
    25  	input string
    26  	calls []annotatedCall
    27  }
    28  
    29  type annotatedCall struct {
    30  	comment string
    31  	syscall string
    32  }
    33  
    34  func TestGenerateSyscalls(t *testing.T) {
    35  	flag.Parse()
    36  
    37  	testCases, err := readTestCases("./testdata")
    38  	assert.NoError(t, err)
    39  
    40  	target, err := prog.GetTarget(targets.Linux, targets.AMD64)
    41  	if err != nil {
    42  		t.Fatal(err)
    43  	}
    44  
    45  	for _, tc := range testCases {
    46  		newData, equal := testGenerationImpl(t, tc, target)
    47  		if *flagUpdate && !equal {
    48  			t.Logf("writing updated contents to %s", tc.filepath)
    49  			err = os.WriteFile(tc.filepath, []byte(newData), 0640)
    50  			assert.NoError(t, err)
    51  		}
    52  	}
    53  }
    54  
    55  func readTestCases(dir string) ([]testData, error) {
    56  	var testCases []testData
    57  
    58  	testFiles, err := os.ReadDir(dir)
    59  	if err != nil {
    60  		return nil, err
    61  	}
    62  
    63  	for _, testFile := range testFiles {
    64  		if testFile.IsDir() {
    65  			continue
    66  		}
    67  
    68  		testCase, err := readTestData(path.Join(dir, testFile.Name()))
    69  		if err != nil {
    70  			return nil, err
    71  		}
    72  		testCases = append(testCases, testCase)
    73  	}
    74  
    75  	return testCases, nil
    76  }
    77  
    78  func readTestData(filepath string) (testData, error) {
    79  	var td testData
    80  	td.filepath = filepath
    81  
    82  	file, err := os.Open(filepath)
    83  	if err != nil {
    84  		return testData{}, err
    85  	}
    86  
    87  	scanner := bufio.NewScanner(file)
    88  
    89  	var inputBuilder strings.Builder
    90  	for scanner.Scan() {
    91  		line := scanner.Text()
    92  		if line == "" {
    93  			break
    94  		}
    95  		inputBuilder.WriteString(line + "\n")
    96  	}
    97  	td.input = inputBuilder.String()
    98  
    99  	var commentBuilder strings.Builder
   100  	for scanner.Scan() {
   101  		line := scanner.Text()
   102  		if strings.HasPrefix(line, commentPrefix) {
   103  			if commentBuilder.Len() > 0 {
   104  				commentBuilder.WriteString("\n")
   105  			}
   106  			commentBuilder.WriteString(line)
   107  		} else {
   108  			td.calls = append(td.calls, annotatedCall{
   109  				comment: commentBuilder.String(),
   110  				syscall: line,
   111  			})
   112  			commentBuilder.Reset()
   113  		}
   114  	}
   115  
   116  	if err := scanner.Err(); err != nil {
   117  		return testData{}, err
   118  	}
   119  
   120  	if commentBuilder.Len() != 0 {
   121  		return testData{}, fmt.Errorf("expected a syscall expression but got EOF")
   122  	}
   123  	return td, nil
   124  }
   125  
   126  // Returns the generated content, and whether or not they were equal.
   127  func testGenerationImpl(t *testing.T, test testData, target *prog.Target) (string, bool) {
   128  	p, err := target.Deserialize([]byte(test.input), prog.Strict)
   129  	if err != nil {
   130  		t.Fatal(err)
   131  	}
   132  
   133  	// Generate the actual comments.
   134  	var actualComments []string
   135  	for _, call := range p.Calls {
   136  		comment := generateComment(call)
   137  		// Formatted comments make comparison easier.
   138  		formatted, err := Format([]byte(comment))
   139  		if err != nil {
   140  			t.Fatal(err)
   141  		}
   142  		actualComments = append(actualComments, string(formatted))
   143  	}
   144  
   145  	// Minimal options as we are just testing syscall output.
   146  	opts := Options{
   147  		Slowdown: 1,
   148  	}
   149  	ctx := &context{
   150  		p:         p,
   151  		opts:      opts,
   152  		target:    p.Target,
   153  		sysTarget: targets.Get(p.Target.OS, p.Target.Arch),
   154  		calls:     make(map[string]uint64),
   155  	}
   156  
   157  	// Partially replicate the flow from csource.go.
   158  	exec, err := p.SerializeForExec()
   159  	if err != nil {
   160  		t.Fatal(err)
   161  	}
   162  	decoded, err := ctx.target.DeserializeExec(exec, nil)
   163  	if err != nil {
   164  		t.Fatal(err)
   165  	}
   166  	var actualSyscalls []string
   167  	for _, execCall := range decoded.Calls {
   168  		actualSyscalls = append(actualSyscalls, ctx.fmtCallBody(execCall))
   169  	}
   170  
   171  	if len(actualSyscalls) != len(test.calls) || len(actualSyscalls) != len(actualComments) {
   172  		t.Fatal("Generated inconsistent syscalls or comments.")
   173  	}
   174  
   175  	areEqual := true
   176  	for i := range actualSyscalls {
   177  		if diffSyscalls := cmp.Diff(actualSyscalls[i], test.calls[i].syscall); diffSyscalls != "" {
   178  			fmt.Print(diffSyscalls)
   179  			t.Fail()
   180  			areEqual = false
   181  		}
   182  		if diffComments := cmp.Diff(actualComments[i], test.calls[i].comment); diffComments != "" {
   183  			fmt.Print(diffComments)
   184  			t.Fail()
   185  			areEqual = false
   186  		}
   187  	}
   188  
   189  	var outputBuilder strings.Builder
   190  	outputBuilder.WriteString(test.input + "\n")
   191  	for i := range actualSyscalls {
   192  		outputBuilder.WriteString(actualComments[i] + "\n")
   193  		outputBuilder.WriteString(actualSyscalls[i])
   194  		// Avoid trailing newline.
   195  		if i != len(test.calls)-1 {
   196  			outputBuilder.WriteString("\n")
   197  		}
   198  	}
   199  
   200  	return outputBuilder.String(), areEqual
   201  }