github.com/codysnider/go-ethereum@v1.10.18-0.20220420071915-14f4ae99222a/rlp/rlpgen/gen_test.go (about)

     1  package main
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"go/ast"
     7  	"go/importer"
     8  	"go/parser"
     9  	"go/token"
    10  	"go/types"
    11  	"io/ioutil"
    12  	"os"
    13  	"path/filepath"
    14  	"testing"
    15  )
    16  
    17  // Package RLP is loaded only once and reused for all tests.
    18  var (
    19  	testFset       = token.NewFileSet()
    20  	testImporter   = importer.ForCompiler(testFset, "source", nil).(types.ImporterFrom)
    21  	testPackageRLP *types.Package
    22  )
    23  
    24  func init() {
    25  	cwd, err := os.Getwd()
    26  	if err != nil {
    27  		panic(err)
    28  	}
    29  	testPackageRLP, err = testImporter.ImportFrom(pathOfPackageRLP, cwd, 0)
    30  	if err != nil {
    31  		panic(fmt.Errorf("can't load package RLP: %v", err))
    32  	}
    33  }
    34  
    35  var tests = []string{"uints", "nil", "rawvalue", "optional", "bigint"}
    36  
    37  func TestOutput(t *testing.T) {
    38  	for _, test := range tests {
    39  		test := test
    40  		t.Run(test, func(t *testing.T) {
    41  			inputFile := filepath.Join("testdata", test+".in.txt")
    42  			outputFile := filepath.Join("testdata", test+".out.txt")
    43  			bctx, typ, err := loadTestSource(inputFile, "Test")
    44  			if err != nil {
    45  				t.Fatal("error loading test source:", err)
    46  			}
    47  			output, err := bctx.generate(typ, true, true)
    48  			if err != nil {
    49  				t.Fatal("error in generate:", err)
    50  			}
    51  
    52  			// Set this environment variable to regenerate the test outputs.
    53  			if os.Getenv("WRITE_TEST_FILES") != "" {
    54  				ioutil.WriteFile(outputFile, output, 0644)
    55  			}
    56  
    57  			// Check if output matches.
    58  			wantOutput, err := ioutil.ReadFile(outputFile)
    59  			if err != nil {
    60  				t.Fatal("error loading expected test output:", err)
    61  			}
    62  			if !bytes.Equal(output, wantOutput) {
    63  				t.Fatal("output mismatch:\n", string(output))
    64  			}
    65  		})
    66  	}
    67  }
    68  
    69  func loadTestSource(file string, typeName string) (*buildContext, *types.Named, error) {
    70  	// Load the test input.
    71  	content, err := ioutil.ReadFile(file)
    72  	if err != nil {
    73  		return nil, nil, err
    74  	}
    75  	f, err := parser.ParseFile(testFset, file, content, 0)
    76  	if err != nil {
    77  		return nil, nil, err
    78  	}
    79  	conf := types.Config{Importer: testImporter}
    80  	pkg, err := conf.Check("test", testFset, []*ast.File{f}, nil)
    81  	if err != nil {
    82  		return nil, nil, err
    83  	}
    84  
    85  	// Find the test struct.
    86  	bctx := newBuildContext(testPackageRLP)
    87  	typ, err := lookupStructType(pkg.Scope(), typeName)
    88  	if err != nil {
    89  		return nil, nil, fmt.Errorf("can't find type %s: %v", typeName, err)
    90  	}
    91  	return bctx, typ, nil
    92  }