github.com/neonyo/sys@v0.0.0-20230720094341-b1ee14be3ce8/unix/internal/mkmerge/mkmerge_test.go (about)

     1  // Copyright 2020 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package main
     6  
     7  import (
     8  	"bytes"
     9  	"fmt"
    10  	"go/parser"
    11  	"go/token"
    12  	"html/template"
    13  	"strings"
    14  	"testing"
    15  )
    16  
    17  func TestImports(t *testing.T) {
    18  	t.Run("importName", func(t *testing.T) {
    19  		cases := []struct {
    20  			src   string
    21  			ident string
    22  		}{
    23  			{`"syscall"`, "syscall"},
    24  			{`. "foobar"`, "."},
    25  			{`"go/ast"`, "ast"},
    26  			{`moo "go/format"`, "moo"},
    27  			{`. "go/token"`, "."},
    28  			{`"golang.org/x/sys/unix"`, "unix"},
    29  			{`nix "golang.org/x/sys/unix"`, "nix"},
    30  			{`_ "golang.org/x/sys/unix"`, "_"},
    31  		}
    32  
    33  		for _, c := range cases {
    34  			pkgSrc := fmt.Sprintf("package main\nimport %s", c.src)
    35  
    36  			f, err := parser.ParseFile(token.NewFileSet(), "", pkgSrc, parser.ImportsOnly)
    37  			if err != nil {
    38  				t.Error(err)
    39  				continue
    40  			}
    41  			if len(f.Imports) != 1 {
    42  				t.Errorf("Got %d imports, expected 1", len(f.Imports))
    43  				continue
    44  			}
    45  
    46  			got, err := importName(f.Imports[0])
    47  			if err != nil {
    48  				t.Fatal(err)
    49  			}
    50  			if got != c.ident {
    51  				t.Errorf("Got %q, expected %q", got, c.ident)
    52  			}
    53  		}
    54  	})
    55  
    56  	t.Run("filterImports", func(t *testing.T) {
    57  		cases := []struct{ before, after string }{
    58  			{`package test
    59  
    60  			import (
    61  				"foo"
    62  				"bar"
    63  			)`,
    64  				"package test\n"},
    65  			{`package test
    66  
    67  			import (
    68  				"foo"
    69  				"bar"
    70  			)
    71  
    72  			func useFoo() { foo.Usage() }`,
    73  				`package test
    74  
    75  import (
    76  	"foo"
    77  )
    78  
    79  func useFoo() { foo.Usage() }
    80  `},
    81  		}
    82  		for _, c := range cases {
    83  			got, err := filterImports([]byte(c.before))
    84  			if err != nil {
    85  				t.Error(err)
    86  			}
    87  
    88  			if string(got) != c.after {
    89  				t.Errorf("Got:\n%s\nExpected:\n%s\n", got, c.after)
    90  			}
    91  		}
    92  	})
    93  }
    94  
    95  func TestMerge(t *testing.T) {
    96  	// Input architecture files
    97  	inTmpl := template.Must(template.New("input").Parse(`
    98  // Package comments
    99  
   100  // build directives for arch{{.}}
   101  
   102  //go:build goos && arch{{.}}
   103  // +build goos,arch{{.}}
   104  
   105  package main
   106  
   107  /*
   108  #include <stdint.h>
   109  #include <stddef.h>
   110  int utimes(uintptr_t, uintptr_t);
   111  int utimensat(int, uintptr_t, uintptr_t, int);
   112  */
   113  import "C"
   114  
   115  // The imports
   116  import (
   117  	"commonDep"
   118  	"uniqueDep{{.}}"
   119  )
   120  
   121  // Vars
   122  var (
   123  	commonVar = commonDep.Use("common")
   124  
   125  	uniqueVar{{.}} = "unique{{.}}"
   126  )
   127  
   128  // Common free standing comment
   129  
   130  // Common comment
   131  const COMMON_INDEPENDENT = 1234
   132  const UNIQUE_INDEPENDENT_{{.}} = "UNIQUE_INDEPENDENT_{{.}}"
   133  
   134  // Group comment
   135  const (
   136  	COMMON_GROUP = "COMMON_GROUP"
   137  	UNIQUE_GROUP_{{.}} = "UNIQUE_GROUP_{{.}}"
   138  )
   139  
   140  // Group2 comment
   141  const (
   142  	UNIQUE_GROUP21_{{.}} = "UNIQUE_GROUP21_{{.}}"
   143  	UNIQUE_GROUP22_{{.}} = "UNIQUE_GROUP22_{{.}}"
   144  )
   145  
   146  // Group3 comment
   147  const (
   148  	sub1Common1 = 11
   149  	sub1Unique2{{.}} = 12
   150  	sub1Common3_LONG = 13
   151  
   152  	sub2Unique1{{.}} = 21
   153  	sub2Common2 = 22
   154  	sub2Common3 = 23
   155  	sub2Unique4{{.}} = 24
   156  )
   157  
   158  type commonInt int
   159  
   160  type uniqueInt{{.}} int
   161  
   162  func commonF() string {
   163  	return commonDep.Use("common")
   164  	}
   165  
   166  func uniqueF() string {
   167  	C.utimes(0, 0)
   168  	return uniqueDep{{.}}.Use("{{.}}")
   169  	}
   170  
   171  // Group4 comment
   172  const (
   173  	sub3Common1 = 31
   174  	sub3Unique2{{.}} = 32
   175  	sub3Unique3{{.}} = 33
   176  	sub3Common4 = 34
   177  
   178  	sub4Common1, sub4Unique2{{.}} = 41, 42
   179  	sub4Unique3{{.}}, sub4Common4 = 43, 44
   180  )
   181  `))
   182  
   183  	// Filtered architecture files
   184  	outTmpl := template.Must(template.New("output").Parse(`// Package comments
   185  
   186  // build directives for arch{{.}}
   187  
   188  //go:build goos && arch{{.}}
   189  // +build goos,arch{{.}}
   190  
   191  package main
   192  
   193  /*
   194  #include <stdint.h>
   195  #include <stddef.h>
   196  int utimes(uintptr_t, uintptr_t);
   197  int utimensat(int, uintptr_t, uintptr_t, int);
   198  */
   199  import "C"
   200  
   201  // The imports
   202  import (
   203  	"commonDep"
   204  	"uniqueDep{{.}}"
   205  )
   206  
   207  // Vars
   208  var (
   209  	commonVar = commonDep.Use("common")
   210  
   211  	uniqueVar{{.}} = "unique{{.}}"
   212  )
   213  
   214  const UNIQUE_INDEPENDENT_{{.}} = "UNIQUE_INDEPENDENT_{{.}}"
   215  
   216  // Group comment
   217  const (
   218  	UNIQUE_GROUP_{{.}} = "UNIQUE_GROUP_{{.}}"
   219  )
   220  
   221  // Group2 comment
   222  const (
   223  	UNIQUE_GROUP21_{{.}} = "UNIQUE_GROUP21_{{.}}"
   224  	UNIQUE_GROUP22_{{.}} = "UNIQUE_GROUP22_{{.}}"
   225  )
   226  
   227  // Group3 comment
   228  const (
   229  	sub1Unique2{{.}} = 12
   230  
   231  	sub2Unique1{{.}} = 21
   232  	sub2Unique4{{.}} = 24
   233  )
   234  
   235  type uniqueInt{{.}} int
   236  
   237  func uniqueF() string {
   238  	C.utimes(0, 0)
   239  	return uniqueDep{{.}}.Use("{{.}}")
   240  }
   241  
   242  // Group4 comment
   243  const (
   244  	sub3Unique2{{.}} = 32
   245  	sub3Unique3{{.}} = 33
   246  
   247  	sub4Common1, sub4Unique2{{.}} = 41, 42
   248  	sub4Unique3{{.}}, sub4Common4 = 43, 44
   249  )
   250  `))
   251  
   252  	const mergedFile = `// Package comments
   253  
   254  package main
   255  
   256  // The imports
   257  import (
   258  	"commonDep"
   259  )
   260  
   261  // Common free standing comment
   262  
   263  // Common comment
   264  const COMMON_INDEPENDENT = 1234
   265  
   266  // Group comment
   267  const (
   268  	COMMON_GROUP = "COMMON_GROUP"
   269  )
   270  
   271  // Group3 comment
   272  const (
   273  	sub1Common1      = 11
   274  	sub1Common3_LONG = 13
   275  
   276  	sub2Common2 = 22
   277  	sub2Common3 = 23
   278  )
   279  
   280  type commonInt int
   281  
   282  func commonF() string {
   283  	return commonDep.Use("common")
   284  }
   285  
   286  // Group4 comment
   287  const (
   288  	sub3Common1 = 31
   289  	sub3Common4 = 34
   290  )
   291  `
   292  
   293  	// Generate source code for different "architectures"
   294  	var inFiles, outFiles []srcFile
   295  	for _, arch := range strings.Fields("A B C D") {
   296  		buf := new(bytes.Buffer)
   297  		err := inTmpl.Execute(buf, arch)
   298  		if err != nil {
   299  			t.Fatal(err)
   300  		}
   301  		inFiles = append(inFiles, srcFile{"file" + arch, buf.Bytes()})
   302  
   303  		buf = new(bytes.Buffer)
   304  		err = outTmpl.Execute(buf, arch)
   305  		if err != nil {
   306  			t.Fatal(err)
   307  		}
   308  		outFiles = append(outFiles, srcFile{"file" + arch, buf.Bytes()})
   309  	}
   310  
   311  	t.Run("getCodeSet", func(t *testing.T) {
   312  		got, err := getCodeSet(inFiles[0].src)
   313  		if err != nil {
   314  			t.Fatal(err)
   315  		}
   316  
   317  		expectedElems := []codeElem{
   318  			{token.COMMENT, "Package comments\n"},
   319  			{token.COMMENT, "build directives for archA\n"},
   320  			{token.COMMENT, "+build goos,archA\n"},
   321  			{token.CONST, `COMMON_INDEPENDENT = 1234`},
   322  			{token.CONST, `UNIQUE_INDEPENDENT_A = "UNIQUE_INDEPENDENT_A"`},
   323  			{token.CONST, `COMMON_GROUP = "COMMON_GROUP"`},
   324  			{token.CONST, `UNIQUE_GROUP_A = "UNIQUE_GROUP_A"`},
   325  			{token.CONST, `UNIQUE_GROUP21_A = "UNIQUE_GROUP21_A"`},
   326  			{token.CONST, `UNIQUE_GROUP22_A = "UNIQUE_GROUP22_A"`},
   327  			{token.CONST, `sub1Common1 = 11`},
   328  			{token.CONST, `sub1Unique2A = 12`},
   329  			{token.CONST, `sub1Common3_LONG = 13`},
   330  			{token.CONST, `sub2Unique1A = 21`},
   331  			{token.CONST, `sub2Common2 = 22`},
   332  			{token.CONST, `sub2Common3 = 23`},
   333  			{token.CONST, `sub2Unique4A = 24`},
   334  			{token.CONST, `sub3Common1 = 31`},
   335  			{token.CONST, `sub3Unique2A = 32`},
   336  			{token.CONST, `sub3Unique3A = 33`},
   337  			{token.CONST, `sub3Common4 = 34`},
   338  			{token.CONST, `sub4Common1, sub4Unique2A = 41, 42`},
   339  			{token.CONST, `sub4Unique3A, sub4Common4 = 43, 44`},
   340  			{token.TYPE, `commonInt int`},
   341  			{token.TYPE, `uniqueIntA int`},
   342  			{token.FUNC, `func commonF() string {
   343  	return commonDep.Use("common")
   344  }`},
   345  			{token.FUNC, `func uniqueF() string {
   346  	C.utimes(0, 0)
   347  	return uniqueDepA.Use("A")
   348  }`},
   349  		}
   350  		expected := newCodeSet()
   351  		for _, d := range expectedElems {
   352  			expected.add(d)
   353  		}
   354  
   355  		if len(got.set) != len(expected.set) {
   356  			t.Errorf("Got %d codeElems, expected %d", len(got.set), len(expected.set))
   357  		}
   358  		for expElem := range expected.set {
   359  			if !got.has(expElem) {
   360  				t.Errorf("Didn't get expected codeElem %#v", expElem)
   361  			}
   362  		}
   363  		for gotElem := range got.set {
   364  			if !expected.has(gotElem) {
   365  				t.Errorf("Got unexpected codeElem %#v", gotElem)
   366  			}
   367  		}
   368  	})
   369  
   370  	t.Run("getCommonSet", func(t *testing.T) {
   371  		got, err := getCommonSet(inFiles)
   372  		if err != nil {
   373  			t.Fatal(err)
   374  		}
   375  
   376  		expected := newCodeSet()
   377  		expected.add(codeElem{token.COMMENT, "Package comments\n"})
   378  		expected.add(codeElem{token.CONST, `COMMON_INDEPENDENT = 1234`})
   379  		expected.add(codeElem{token.CONST, `COMMON_GROUP = "COMMON_GROUP"`})
   380  		expected.add(codeElem{token.CONST, `sub1Common1 = 11`})
   381  		expected.add(codeElem{token.CONST, `sub1Common3_LONG = 13`})
   382  		expected.add(codeElem{token.CONST, `sub2Common2 = 22`})
   383  		expected.add(codeElem{token.CONST, `sub2Common3 = 23`})
   384  		expected.add(codeElem{token.CONST, `sub3Common1 = 31`})
   385  		expected.add(codeElem{token.CONST, `sub3Common4 = 34`})
   386  		expected.add(codeElem{token.TYPE, `commonInt int`})
   387  		expected.add(codeElem{token.FUNC, `func commonF() string {
   388  	return commonDep.Use("common")
   389  }`})
   390  
   391  		if len(got.set) != len(expected.set) {
   392  			t.Errorf("Got %d codeElems, expected %d", len(got.set), len(expected.set))
   393  		}
   394  		for expElem := range expected.set {
   395  			if !got.has(expElem) {
   396  				t.Errorf("Didn't get expected codeElem %#v", expElem)
   397  			}
   398  		}
   399  		for gotElem := range got.set {
   400  			if !expected.has(gotElem) {
   401  				t.Errorf("Got unexpected codeElem %#v", gotElem)
   402  			}
   403  		}
   404  	})
   405  
   406  	t.Run("filter(keepCommon)", func(t *testing.T) {
   407  		commonSet, err := getCommonSet(inFiles)
   408  		if err != nil {
   409  			t.Fatal(err)
   410  		}
   411  
   412  		got, err := filter(inFiles[0].src, commonSet.keepCommon)
   413  		if err != nil {
   414  			t.Fatal(err)
   415  		}
   416  
   417  		expected := []byte(mergedFile)
   418  
   419  		if !bytes.Equal(got, expected) {
   420  			t.Errorf("Got:\n%s\nExpected:\n%s", addLineNr(got), addLineNr(expected))
   421  			diffLines(t, got, expected)
   422  		}
   423  	})
   424  
   425  	t.Run("filter(keepArchSpecific)", func(t *testing.T) {
   426  		commonSet, err := getCommonSet(inFiles)
   427  		if err != nil {
   428  			t.Fatal(err)
   429  		}
   430  
   431  		for i := range inFiles {
   432  			got, err := filter(inFiles[i].src, commonSet.keepArchSpecific)
   433  			if err != nil {
   434  				t.Fatal(err)
   435  			}
   436  
   437  			expected := outFiles[i].src
   438  
   439  			if !bytes.Equal(got, expected) {
   440  				t.Errorf("Got:\n%s\nExpected:\n%s", addLineNr(got), addLineNr(expected))
   441  				diffLines(t, got, expected)
   442  			}
   443  		}
   444  	})
   445  }
   446  
   447  func TestMergedName(t *testing.T) {
   448  	t.Run("getValidGOOS", func(t *testing.T) {
   449  		testcases := []struct {
   450  			filename, goos string
   451  			ok             bool
   452  		}{
   453  			{"zerrors_aix.go", "aix", true},
   454  			{"zerrors_darwin.go", "darwin", true},
   455  			{"zerrors_dragonfly.go", "dragonfly", true},
   456  			{"zerrors_freebsd.go", "freebsd", true},
   457  			{"zerrors_linux.go", "linux", true},
   458  			{"zerrors_netbsd.go", "netbsd", true},
   459  			{"zerrors_openbsd.go", "openbsd", true},
   460  			{"zerrors_solaris.go", "solaris", true},
   461  			{"zerrors_multics.go", "", false},
   462  		}
   463  		for _, tc := range testcases {
   464  			goos, ok := getValidGOOS(tc.filename)
   465  			if goos != tc.goos {
   466  				t.Errorf("got GOOS %q, expected %q", goos, tc.goos)
   467  			}
   468  			if ok != tc.ok {
   469  				t.Errorf("got ok %v, expected %v", ok, tc.ok)
   470  			}
   471  		}
   472  	})
   473  }
   474  
   475  // Helper functions to diff test sources
   476  
   477  func diffLines(t *testing.T, got, expected []byte) {
   478  	t.Helper()
   479  
   480  	gotLines := bytes.Split(got, []byte{'\n'})
   481  	expLines := bytes.Split(expected, []byte{'\n'})
   482  
   483  	i := 0
   484  	for i < len(gotLines) && i < len(expLines) {
   485  		if !bytes.Equal(gotLines[i], expLines[i]) {
   486  			t.Errorf("Line %d: Got:\n%q\nExpected:\n%q", i+1, gotLines[i], expLines[i])
   487  			return
   488  		}
   489  		i++
   490  	}
   491  
   492  	if i < len(gotLines) && i >= len(expLines) {
   493  		t.Errorf("Line %d: got %q, expected EOF", i+1, gotLines[i])
   494  	}
   495  	if i >= len(gotLines) && i < len(expLines) {
   496  		t.Errorf("Line %d: got EOF, expected %q", i+1, gotLines[i])
   497  	}
   498  }
   499  
   500  func addLineNr(src []byte) []byte {
   501  	lines := bytes.Split(src, []byte("\n"))
   502  	for i, line := range lines {
   503  		lines[i] = []byte(fmt.Sprintf("%d: %s", i+1, line))
   504  	}
   505  	return bytes.Join(lines, []byte("\n"))
   506  }