github.com/google/syzkaller@v0.0.0-20240517125934-c0f1611a36d6/prog/images_test.go (about)

     1  // Copyright 2022 syzkaller project authors. All rights reserved.
     2  // Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
     3  
     4  package prog_test
     5  
     6  import (
     7  	"flag"
     8  	"fmt"
     9  	"io"
    10  	"os"
    11  	"path/filepath"
    12  	"reflect"
    13  	"sort"
    14  	"strings"
    15  	"testing"
    16  
    17  	"github.com/google/go-cmp/cmp"
    18  	"github.com/google/syzkaller/pkg/osutil"
    19  	. "github.com/google/syzkaller/prog"
    20  	"github.com/google/syzkaller/sys/targets"
    21  )
    22  
    23  var flagUpdate = flag.Bool("update", false, "update test files accordingly to current results")
    24  
    25  func TestForEachAsset(t *testing.T) {
    26  	target, err := GetTarget(targets.Linux, targets.AMD64)
    27  	if err != nil {
    28  		t.Fatal(err)
    29  	}
    30  	files, err := filepath.Glob(filepath.Join("testdata", "fs_images", "*.in"))
    31  	if err != nil {
    32  		t.Fatalf("directory read failed: %v", err)
    33  	}
    34  	allOutFiles, err := filepath.Glob(filepath.Join("testdata", "fs_images", "*.out*"))
    35  	if err != nil {
    36  		t.Fatalf("directory read failed: %v", err)
    37  	}
    38  	testedOutFiles := []string{}
    39  	for _, file := range files {
    40  		sourceProg, err := os.ReadFile(file)
    41  		if err != nil {
    42  			t.Fatal(err)
    43  		}
    44  		p, err := target.Deserialize(sourceProg, NonStrict)
    45  		if err != nil {
    46  			t.Fatalf("failed to deserialize %s: %s", file, err)
    47  		}
    48  		base := strings.TrimSuffix(file, ".in")
    49  		p.ForEachAsset(func(name string, typ AssetType, r io.Reader) {
    50  			if typ != MountInRepro {
    51  				t.Fatalf("unknown asset type %v", typ)
    52  			}
    53  			testResult, err := io.ReadAll(r)
    54  			if err != nil {
    55  				t.Fatal(err)
    56  			}
    57  			outFilePath := fmt.Sprintf("%v.out_%v", base, name)
    58  			if *flagUpdate {
    59  				if err := osutil.WriteFile(outFilePath, testResult); err != nil {
    60  					t.Fatal(err)
    61  				}
    62  			}
    63  			if !osutil.IsExist(outFilePath) {
    64  				t.Fatalf("asset %v does not exist", outFilePath)
    65  			}
    66  			testedOutFiles = append(testedOutFiles, outFilePath)
    67  			outFile, err := os.ReadFile(outFilePath)
    68  			if err != nil {
    69  				t.Fatal(err)
    70  			}
    71  			if !reflect.DeepEqual(testResult, outFile) {
    72  				t.Fatalf("output not equal:\nWant: %x\nGot: %x", outFile, testResult)
    73  			}
    74  		})
    75  	}
    76  	sort.Strings(testedOutFiles)
    77  	sort.Strings(allOutFiles)
    78  	if diff := cmp.Diff(allOutFiles, testedOutFiles); diff != "" {
    79  		t.Fatalf("not all output files used: %v", diff)
    80  	}
    81  }