golang.org/x/sys@v0.20.1-0.20240517151509-673e0f94c16d/unix/internal/mkmerge/mkmerge_test.go (about)

     1  // Copyright 2020 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package main
     6  
     7  import (
     8  	"bytes"
     9  	"fmt"
    10  	"go/parser"
    11  	"go/token"
    12  	"html/template"
    13  	"strings"
    14  	"testing"
    15  )
    16  
    17  func TestImports(t *testing.T) {
    18  	t.Run("importName", func(t *testing.T) {
    19  		cases := []struct {
    20  			src   string
    21  			ident string
    22  		}{
    23  			{`"syscall"`, "syscall"},
    24  			{`. "foobar"`, "."},
    25  			{`"go/ast"`, "ast"},
    26  			{`moo "go/format"`, "moo"},
    27  			{`. "go/token"`, "."},
    28  			{`"golang.org/x/sys/unix"`, "unix"},
    29  			{`nix "golang.org/x/sys/unix"`, "nix"},
    30  			{`_ "golang.org/x/sys/unix"`, "_"},
    31  		}
    32  
    33  		for _, c := range cases {
    34  			pkgSrc := fmt.Sprintf("package main\nimport %s", c.src)
    35  
    36  			f, err := parser.ParseFile(token.NewFileSet(), "", pkgSrc, parser.ImportsOnly)
    37  			if err != nil {
    38  				t.Error(err)
    39  				continue
    40  			}
    41  			if len(f.Imports) != 1 {
    42  				t.Errorf("Got %d imports, expected 1", len(f.Imports))
    43  				continue
    44  			}
    45  
    46  			got, err := importName(f.Imports[0])
    47  			if err != nil {
    48  				t.Fatal(err)
    49  			}
    50  			if got != c.ident {
    51  				t.Errorf("Got %q, expected %q", got, c.ident)
    52  			}
    53  		}
    54  	})
    55  
    56  	t.Run("filterImports", func(t *testing.T) {
    57  		cases := []struct{ before, after string }{
    58  			{`package test
    59  
    60  			import (
    61  				"foo"
    62  				"bar"
    63  			)`,
    64  				"package test\n"},
    65  			{`package test
    66  
    67  			import (
    68  				"foo"
    69  				"bar"
    70  			)
    71  
    72  			func useFoo() { foo.Usage() }`,
    73  				`package test
    74  
    75  import (
    76  	"foo"
    77  )
    78  
    79  func useFoo() { foo.Usage() }
    80  `},
    81  		}
    82  		for _, c := range cases {
    83  			got, err := filterImports([]byte(c.before))
    84  			if err != nil {
    85  				t.Error(err)
    86  			}
    87  
    88  			if string(got) != c.after {
    89  				t.Errorf("Got:\n%s\nExpected:\n%s\n", got, c.after)
    90  			}
    91  		}
    92  	})
    93  }
    94  
    95  func TestMerge(t *testing.T) {
    96  	// Input architecture files
    97  	inTmpl := template.Must(template.New("input").Parse(`
    98  // Package comments
    99  
   100  // build directives for arch{{.}}
   101  
   102  //go:build goos && arch{{.}}
   103  
   104  package main
   105  
   106  /*
   107  #include <stdint.h>
   108  #include <stddef.h>
   109  int utimes(uintptr_t, uintptr_t);
   110  int utimensat(int, uintptr_t, uintptr_t, int);
   111  */
   112  import "C"
   113  
   114  // The imports
   115  import (
   116  	"commonDep"
   117  	"uniqueDep{{.}}"
   118  )
   119  
   120  // Vars
   121  var (
   122  	commonVar = commonDep.Use("common")
   123  
   124  	uniqueVar{{.}} = "unique{{.}}"
   125  )
   126  
   127  // Common free standing comment
   128  
   129  // Common comment
   130  const COMMON_INDEPENDENT = 1234
   131  const UNIQUE_INDEPENDENT_{{.}} = "UNIQUE_INDEPENDENT_{{.}}"
   132  
   133  // Group comment
   134  const (
   135  	COMMON_GROUP = "COMMON_GROUP"
   136  	UNIQUE_GROUP_{{.}} = "UNIQUE_GROUP_{{.}}"
   137  )
   138  
   139  // Group2 comment
   140  const (
   141  	UNIQUE_GROUP21_{{.}} = "UNIQUE_GROUP21_{{.}}"
   142  	UNIQUE_GROUP22_{{.}} = "UNIQUE_GROUP22_{{.}}"
   143  )
   144  
   145  // Group3 comment
   146  const (
   147  	sub1Common1 = 11
   148  	sub1Unique2{{.}} = 12
   149  	sub1Common3_LONG = 13
   150  
   151  	sub2Unique1{{.}} = 21
   152  	sub2Common2 = 22
   153  	sub2Common3 = 23
   154  	sub2Unique4{{.}} = 24
   155  )
   156  
   157  type commonInt int
   158  
   159  type uniqueInt{{.}} int
   160  
   161  func commonF() string {
   162  	return commonDep.Use("common")
   163  	}
   164  
   165  func uniqueF() string {
   166  	C.utimes(0, 0)
   167  	return uniqueDep{{.}}.Use("{{.}}")
   168  	}
   169  
   170  // Group4 comment
   171  const (
   172  	sub3Common1 = 31
   173  	sub3Unique2{{.}} = 32
   174  	sub3Unique3{{.}} = 33
   175  	sub3Common4 = 34
   176  
   177  	sub4Common1, sub4Unique2{{.}} = 41, 42
   178  	sub4Unique3{{.}}, sub4Common4 = 43, 44
   179  )
   180  `))
   181  
   182  	// Filtered architecture files
   183  	outTmpl := template.Must(template.New("output").Parse(`// Package comments
   184  
   185  // build directives for arch{{.}}
   186  
   187  //go:build goos && arch{{.}}
   188  
   189  package main
   190  
   191  /*
   192  #include <stdint.h>
   193  #include <stddef.h>
   194  int utimes(uintptr_t, uintptr_t);
   195  int utimensat(int, uintptr_t, uintptr_t, int);
   196  */
   197  import "C"
   198  
   199  // The imports
   200  import (
   201  	"commonDep"
   202  	"uniqueDep{{.}}"
   203  )
   204  
   205  // Vars
   206  var (
   207  	commonVar = commonDep.Use("common")
   208  
   209  	uniqueVar{{.}} = "unique{{.}}"
   210  )
   211  
   212  const UNIQUE_INDEPENDENT_{{.}} = "UNIQUE_INDEPENDENT_{{.}}"
   213  
   214  // Group comment
   215  const (
   216  	UNIQUE_GROUP_{{.}} = "UNIQUE_GROUP_{{.}}"
   217  )
   218  
   219  // Group2 comment
   220  const (
   221  	UNIQUE_GROUP21_{{.}} = "UNIQUE_GROUP21_{{.}}"
   222  	UNIQUE_GROUP22_{{.}} = "UNIQUE_GROUP22_{{.}}"
   223  )
   224  
   225  // Group3 comment
   226  const (
   227  	sub1Unique2{{.}} = 12
   228  
   229  	sub2Unique1{{.}} = 21
   230  	sub2Unique4{{.}} = 24
   231  )
   232  
   233  type uniqueInt{{.}} int
   234  
   235  func uniqueF() string {
   236  	C.utimes(0, 0)
   237  	return uniqueDep{{.}}.Use("{{.}}")
   238  }
   239  
   240  // Group4 comment
   241  const (
   242  	sub3Unique2{{.}} = 32
   243  	sub3Unique3{{.}} = 33
   244  
   245  	sub4Common1, sub4Unique2{{.}} = 41, 42
   246  	sub4Unique3{{.}}, sub4Common4 = 43, 44
   247  )
   248  `))
   249  
   250  	const mergedFile = `// Package comments
   251  
   252  package main
   253  
   254  // The imports
   255  import (
   256  	"commonDep"
   257  )
   258  
   259  // Common free standing comment
   260  
   261  // Common comment
   262  const COMMON_INDEPENDENT = 1234
   263  
   264  // Group comment
   265  const (
   266  	COMMON_GROUP = "COMMON_GROUP"
   267  )
   268  
   269  // Group3 comment
   270  const (
   271  	sub1Common1      = 11
   272  	sub1Common3_LONG = 13
   273  
   274  	sub2Common2 = 22
   275  	sub2Common3 = 23
   276  )
   277  
   278  type commonInt int
   279  
   280  func commonF() string {
   281  	return commonDep.Use("common")
   282  }
   283  
   284  // Group4 comment
   285  const (
   286  	sub3Common1 = 31
   287  	sub3Common4 = 34
   288  )
   289  `
   290  
   291  	// Generate source code for different "architectures"
   292  	var inFiles, outFiles []srcFile
   293  	for _, arch := range strings.Fields("A B C D") {
   294  		buf := new(bytes.Buffer)
   295  		err := inTmpl.Execute(buf, arch)
   296  		if err != nil {
   297  			t.Fatal(err)
   298  		}
   299  		inFiles = append(inFiles, srcFile{"file" + arch, buf.Bytes()})
   300  
   301  		buf = new(bytes.Buffer)
   302  		err = outTmpl.Execute(buf, arch)
   303  		if err != nil {
   304  			t.Fatal(err)
   305  		}
   306  		outFiles = append(outFiles, srcFile{"file" + arch, buf.Bytes()})
   307  	}
   308  
   309  	t.Run("getCodeSet", func(t *testing.T) {
   310  		got, err := getCodeSet(inFiles[0].src)
   311  		if err != nil {
   312  			t.Fatal(err)
   313  		}
   314  
   315  		expectedElems := []codeElem{
   316  			{token.COMMENT, "Package comments\n"},
   317  			{token.COMMENT, "build directives for archA\n"},
   318  			{token.COMMENT, "go:build goos && archA\n"},
   319  			{token.CONST, `COMMON_INDEPENDENT = 1234`},
   320  			{token.CONST, `UNIQUE_INDEPENDENT_A = "UNIQUE_INDEPENDENT_A"`},
   321  			{token.CONST, `COMMON_GROUP = "COMMON_GROUP"`},
   322  			{token.CONST, `UNIQUE_GROUP_A = "UNIQUE_GROUP_A"`},
   323  			{token.CONST, `UNIQUE_GROUP21_A = "UNIQUE_GROUP21_A"`},
   324  			{token.CONST, `UNIQUE_GROUP22_A = "UNIQUE_GROUP22_A"`},
   325  			{token.CONST, `sub1Common1 = 11`},
   326  			{token.CONST, `sub1Unique2A = 12`},
   327  			{token.CONST, `sub1Common3_LONG = 13`},
   328  			{token.CONST, `sub2Unique1A = 21`},
   329  			{token.CONST, `sub2Common2 = 22`},
   330  			{token.CONST, `sub2Common3 = 23`},
   331  			{token.CONST, `sub2Unique4A = 24`},
   332  			{token.CONST, `sub3Common1 = 31`},
   333  			{token.CONST, `sub3Unique2A = 32`},
   334  			{token.CONST, `sub3Unique3A = 33`},
   335  			{token.CONST, `sub3Common4 = 34`},
   336  			{token.CONST, `sub4Common1, sub4Unique2A = 41, 42`},
   337  			{token.CONST, `sub4Unique3A, sub4Common4 = 43, 44`},
   338  			{token.TYPE, `commonInt int`},
   339  			{token.TYPE, `uniqueIntA int`},
   340  			{token.FUNC, `func commonF() string {
   341  	return commonDep.Use("common")
   342  }`},
   343  			{token.FUNC, `func uniqueF() string {
   344  	C.utimes(0, 0)
   345  	return uniqueDepA.Use("A")
   346  }`},
   347  		}
   348  		expected := newCodeSet()
   349  		for _, d := range expectedElems {
   350  			expected.add(d)
   351  		}
   352  
   353  		if len(got.set) != len(expected.set) {
   354  			t.Errorf("Got %d codeElems, expected %d", len(got.set), len(expected.set))
   355  		}
   356  		for expElem := range expected.set {
   357  			if !got.has(expElem) {
   358  				t.Errorf("Didn't get expected codeElem %#v", expElem)
   359  			}
   360  		}
   361  		for gotElem := range got.set {
   362  			if !expected.has(gotElem) {
   363  				t.Errorf("Got unexpected codeElem %#v", gotElem)
   364  			}
   365  		}
   366  	})
   367  
   368  	t.Run("getCommonSet", func(t *testing.T) {
   369  		got, err := getCommonSet(inFiles)
   370  		if err != nil {
   371  			t.Fatal(err)
   372  		}
   373  
   374  		expected := newCodeSet()
   375  		expected.add(codeElem{token.COMMENT, "Package comments\n"})
   376  		expected.add(codeElem{token.CONST, `COMMON_INDEPENDENT = 1234`})
   377  		expected.add(codeElem{token.CONST, `COMMON_GROUP = "COMMON_GROUP"`})
   378  		expected.add(codeElem{token.CONST, `sub1Common1 = 11`})
   379  		expected.add(codeElem{token.CONST, `sub1Common3_LONG = 13`})
   380  		expected.add(codeElem{token.CONST, `sub2Common2 = 22`})
   381  		expected.add(codeElem{token.CONST, `sub2Common3 = 23`})
   382  		expected.add(codeElem{token.CONST, `sub3Common1 = 31`})
   383  		expected.add(codeElem{token.CONST, `sub3Common4 = 34`})
   384  		expected.add(codeElem{token.TYPE, `commonInt int`})
   385  		expected.add(codeElem{token.FUNC, `func commonF() string {
   386  	return commonDep.Use("common")
   387  }`})
   388  
   389  		if len(got.set) != len(expected.set) {
   390  			t.Errorf("Got %d codeElems, expected %d", len(got.set), len(expected.set))
   391  		}
   392  		for expElem := range expected.set {
   393  			if !got.has(expElem) {
   394  				t.Errorf("Didn't get expected codeElem %#v", expElem)
   395  			}
   396  		}
   397  		for gotElem := range got.set {
   398  			if !expected.has(gotElem) {
   399  				t.Errorf("Got unexpected codeElem %#v", gotElem)
   400  			}
   401  		}
   402  	})
   403  
   404  	t.Run("filter(keepCommon)", func(t *testing.T) {
   405  		commonSet, err := getCommonSet(inFiles)
   406  		if err != nil {
   407  			t.Fatal(err)
   408  		}
   409  
   410  		got, err := filter(inFiles[0].src, commonSet.keepCommon)
   411  		if err != nil {
   412  			t.Fatal(err)
   413  		}
   414  
   415  		expected := []byte(mergedFile)
   416  
   417  		if !bytes.Equal(got, expected) {
   418  			t.Errorf("Got:\n%s\nExpected:\n%s", addLineNr(got), addLineNr(expected))
   419  			diffLines(t, got, expected)
   420  		}
   421  	})
   422  
   423  	t.Run("filter(keepArchSpecific)", func(t *testing.T) {
   424  		commonSet, err := getCommonSet(inFiles)
   425  		if err != nil {
   426  			t.Fatal(err)
   427  		}
   428  
   429  		for i := range inFiles {
   430  			got, err := filter(inFiles[i].src, commonSet.keepArchSpecific)
   431  			if err != nil {
   432  				t.Fatal(err)
   433  			}
   434  
   435  			expected := outFiles[i].src
   436  
   437  			if !bytes.Equal(got, expected) {
   438  				t.Errorf("Got:\n%s\nExpected:\n%s", addLineNr(got), addLineNr(expected))
   439  				diffLines(t, got, expected)
   440  			}
   441  		}
   442  	})
   443  }
   444  
   445  func TestMergedName(t *testing.T) {
   446  	t.Run("getValidGOOS", func(t *testing.T) {
   447  		testcases := []struct {
   448  			filename, goos string
   449  			ok             bool
   450  		}{
   451  			{"zerrors_aix.go", "aix", true},
   452  			{"zerrors_darwin.go", "darwin", true},
   453  			{"zerrors_dragonfly.go", "dragonfly", true},
   454  			{"zerrors_freebsd.go", "freebsd", true},
   455  			{"zerrors_linux.go", "linux", true},
   456  			{"zerrors_netbsd.go", "netbsd", true},
   457  			{"zerrors_openbsd.go", "openbsd", true},
   458  			{"zerrors_solaris.go", "solaris", true},
   459  			{"zerrors_multics.go", "", false},
   460  		}
   461  		for _, tc := range testcases {
   462  			goos, ok := getValidGOOS(tc.filename)
   463  			if goos != tc.goos {
   464  				t.Errorf("got GOOS %q, expected %q", goos, tc.goos)
   465  			}
   466  			if ok != tc.ok {
   467  				t.Errorf("got ok %v, expected %v", ok, tc.ok)
   468  			}
   469  		}
   470  	})
   471  }
   472  
   473  // Helper functions to diff test sources
   474  
   475  func diffLines(t *testing.T, got, expected []byte) {
   476  	t.Helper()
   477  
   478  	gotLines := bytes.Split(got, []byte{'\n'})
   479  	expLines := bytes.Split(expected, []byte{'\n'})
   480  
   481  	i := 0
   482  	for i < len(gotLines) && i < len(expLines) {
   483  		if !bytes.Equal(gotLines[i], expLines[i]) {
   484  			t.Errorf("Line %d: Got:\n%q\nExpected:\n%q", i+1, gotLines[i], expLines[i])
   485  			return
   486  		}
   487  		i++
   488  	}
   489  
   490  	if i < len(gotLines) && i >= len(expLines) {
   491  		t.Errorf("Line %d: got %q, expected EOF", i+1, gotLines[i])
   492  	}
   493  	if i >= len(gotLines) && i < len(expLines) {
   494  		t.Errorf("Line %d: got EOF, expected %q", i+1, gotLines[i])
   495  	}
   496  }
   497  
   498  func addLineNr(src []byte) []byte {
   499  	lines := bytes.Split(src, []byte("\n"))
   500  	for i, line := range lines {
   501  		lines[i] = []byte(fmt.Sprintf("%d: %s", i+1, line))
   502  	}
   503  	return bytes.Join(lines, []byte("\n"))
   504  }