github.com/go-asm/go@v1.21.1-0.20240213172139-40c5ead50c48/cmd/compile/test/testdata/gen/arithConstGen.go (about)

     1  // Copyright 2016 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // This program generates a test to verify that the standard arithmetic
     6  // operators properly handle const cases. The test file should be
     7  // generated with a known working version of go.
     8  // launch with `go run arithConstGen.go` a file called arithConst.go
     9  // will be written into the parent directory containing the tests
    10  
    11  package main
    12  
    13  import (
    14  	"bytes"
    15  	"fmt"
    16  	"go/format"
    17  	"log"
    18  	"strings"
    19  	"text/template"
    20  )
    21  
    22  type op struct {
    23  	name, symbol string
    24  }
    25  type szD struct {
    26  	name   string
    27  	sn     string
    28  	u      []uint64
    29  	i      []int64
    30  	oponly string
    31  }
    32  
    33  var szs = []szD{
    34  	{name: "uint64", sn: "64", u: []uint64{0, 1, 4294967296, 0x8000000000000000, 0xffffFFFFffffFFFF}},
    35  	{name: "uint64", sn: "64", u: []uint64{3, 5, 7, 9, 10, 11, 13, 19, 21, 25, 27, 37, 41, 45, 73, 81}, oponly: "mul"},
    36  
    37  	{name: "int64", sn: "64", i: []int64{-0x8000000000000000, -0x7FFFFFFFFFFFFFFF,
    38  		-4294967296, -1, 0, 1, 4294967296, 0x7FFFFFFFFFFFFFFE, 0x7FFFFFFFFFFFFFFF}},
    39  	{name: "int64", sn: "64", i: []int64{-9, -5, -3, 3, 5, 7, 9, 10, 11, 13, 19, 21, 25, 27, 37, 41, 45, 73, 81}, oponly: "mul"},
    40  
    41  	{name: "uint32", sn: "32", u: []uint64{0, 1, 4294967295}},
    42  	{name: "uint32", sn: "32", u: []uint64{3, 5, 7, 9, 10, 11, 13, 19, 21, 25, 27, 37, 41, 45, 73, 81}, oponly: "mul"},
    43  
    44  	{name: "int32", sn: "32", i: []int64{-0x80000000, -0x7FFFFFFF, -1, 0,
    45  		1, 0x7FFFFFFF}},
    46  	{name: "int32", sn: "32", i: []int64{-9, -5, -3, 3, 5, 7, 9, 10, 11, 13, 19, 21, 25, 27, 37, 41, 45, 73, 81}, oponly: "mul"},
    47  
    48  	{name: "uint16", sn: "16", u: []uint64{0, 1, 65535}},
    49  	{name: "int16", sn: "16", i: []int64{-32768, -32767, -1, 0, 1, 32766, 32767}},
    50  
    51  	{name: "uint8", sn: "8", u: []uint64{0, 1, 255}},
    52  	{name: "int8", sn: "8", i: []int64{-128, -127, -1, 0, 1, 126, 127}},
    53  }
    54  
    55  var ops = []op{
    56  	{"add", "+"},
    57  	{"sub", "-"},
    58  	{"div", "/"},
    59  	{"mul", "*"},
    60  	{"lsh", "<<"},
    61  	{"rsh", ">>"},
    62  	{"mod", "%"},
    63  	{"and", "&"},
    64  	{"or", "|"},
    65  	{"xor", "^"},
    66  }
    67  
    68  // compute the result of i op j, cast as type t.
    69  func ansU(i, j uint64, t, op string) string {
    70  	var ans uint64
    71  	switch op {
    72  	case "+":
    73  		ans = i + j
    74  	case "-":
    75  		ans = i - j
    76  	case "*":
    77  		ans = i * j
    78  	case "/":
    79  		if j != 0 {
    80  			ans = i / j
    81  		}
    82  	case "%":
    83  		if j != 0 {
    84  			ans = i % j
    85  		}
    86  	case "<<":
    87  		ans = i << j
    88  	case ">>":
    89  		ans = i >> j
    90  	case "&":
    91  		ans = i & j
    92  	case "|":
    93  		ans = i | j
    94  	case "^":
    95  		ans = i ^ j
    96  	}
    97  	switch t {
    98  	case "uint32":
    99  		ans = uint64(uint32(ans))
   100  	case "uint16":
   101  		ans = uint64(uint16(ans))
   102  	case "uint8":
   103  		ans = uint64(uint8(ans))
   104  	}
   105  	return fmt.Sprintf("%d", ans)
   106  }
   107  
   108  // compute the result of i op j, cast as type t.
   109  func ansS(i, j int64, t, op string) string {
   110  	var ans int64
   111  	switch op {
   112  	case "+":
   113  		ans = i + j
   114  	case "-":
   115  		ans = i - j
   116  	case "*":
   117  		ans = i * j
   118  	case "/":
   119  		if j != 0 {
   120  			ans = i / j
   121  		}
   122  	case "%":
   123  		if j != 0 {
   124  			ans = i % j
   125  		}
   126  	case "<<":
   127  		ans = i << uint64(j)
   128  	case ">>":
   129  		ans = i >> uint64(j)
   130  	case "&":
   131  		ans = i & j
   132  	case "|":
   133  		ans = i | j
   134  	case "^":
   135  		ans = i ^ j
   136  	}
   137  	switch t {
   138  	case "int32":
   139  		ans = int64(int32(ans))
   140  	case "int16":
   141  		ans = int64(int16(ans))
   142  	case "int8":
   143  		ans = int64(int8(ans))
   144  	}
   145  	return fmt.Sprintf("%d", ans)
   146  }
   147  
   148  func main() {
   149  	w := new(bytes.Buffer)
   150  	fmt.Fprintf(w, "// Code generated by gen/arithConstGen.go. DO NOT EDIT.\n\n")
   151  	fmt.Fprintf(w, "package main;\n")
   152  	fmt.Fprintf(w, "import \"testing\"\n")
   153  
   154  	fncCnst1 := template.Must(template.New("fnc").Parse(
   155  		`//go:noinline
   156  func {{.Name}}_{{.Type_}}_{{.FNumber}}(a {{.Type_}}) {{.Type_}} { return a {{.Symbol}} {{.Number}} }
   157  `))
   158  	fncCnst2 := template.Must(template.New("fnc").Parse(
   159  		`//go:noinline
   160  func {{.Name}}_{{.FNumber}}_{{.Type_}}(a {{.Type_}}) {{.Type_}} { return {{.Number}} {{.Symbol}} a }
   161  `))
   162  
   163  	type fncData struct {
   164  		Name, Type_, Symbol, FNumber, Number string
   165  	}
   166  
   167  	for _, s := range szs {
   168  		for _, o := range ops {
   169  			if s.oponly != "" && s.oponly != o.name {
   170  				continue
   171  			}
   172  			fd := fncData{o.name, s.name, o.symbol, "", ""}
   173  
   174  			// unsigned test cases
   175  			if len(s.u) > 0 {
   176  				for _, i := range s.u {
   177  					fd.Number = fmt.Sprintf("%d", i)
   178  					fd.FNumber = strings.Replace(fd.Number, "-", "Neg", -1)
   179  
   180  					// avoid division by zero
   181  					if o.name != "mod" && o.name != "div" || i != 0 {
   182  						// introduce uint64 cast for rhs shift operands
   183  						// if they are too large for default uint type
   184  						number := fd.Number
   185  						if (o.name == "lsh" || o.name == "rsh") && uint64(uint32(i)) != i {
   186  							fd.Number = fmt.Sprintf("uint64(%s)", number)
   187  						}
   188  						fncCnst1.Execute(w, fd)
   189  						fd.Number = number
   190  					}
   191  
   192  					fncCnst2.Execute(w, fd)
   193  				}
   194  			}
   195  
   196  			// signed test cases
   197  			if len(s.i) > 0 {
   198  				// don't generate tests for shifts by signed integers
   199  				if o.name == "lsh" || o.name == "rsh" {
   200  					continue
   201  				}
   202  				for _, i := range s.i {
   203  					fd.Number = fmt.Sprintf("%d", i)
   204  					fd.FNumber = strings.Replace(fd.Number, "-", "Neg", -1)
   205  
   206  					// avoid division by zero
   207  					if o.name != "mod" && o.name != "div" || i != 0 {
   208  						fncCnst1.Execute(w, fd)
   209  					}
   210  					fncCnst2.Execute(w, fd)
   211  				}
   212  			}
   213  		}
   214  	}
   215  
   216  	vrf1 := template.Must(template.New("vrf1").Parse(`
   217  		test_{{.Size}}{fn: {{.Name}}_{{.FNumber}}_{{.Type_}}, fnname: "{{.Name}}_{{.FNumber}}_{{.Type_}}", in: {{.Input}}, want: {{.Ans}}},`))
   218  
   219  	vrf2 := template.Must(template.New("vrf2").Parse(`
   220  		test_{{.Size}}{fn: {{.Name}}_{{.Type_}}_{{.FNumber}}, fnname: "{{.Name}}_{{.Type_}}_{{.FNumber}}", in: {{.Input}}, want: {{.Ans}}},`))
   221  
   222  	type cfncData struct {
   223  		Size, Name, Type_, Symbol, FNumber, Number string
   224  		Ans, Input                                 string
   225  	}
   226  	for _, s := range szs {
   227  		fmt.Fprintf(w, `
   228  type test_%[1]s%[2]s struct {
   229  	fn func (%[1]s) %[1]s
   230  	fnname string
   231  	in %[1]s
   232  	want %[1]s
   233  }
   234  `, s.name, s.oponly)
   235  		fmt.Fprintf(w, "var tests_%[1]s%[2]s =[]test_%[1]s {\n\n", s.name, s.oponly)
   236  
   237  		if len(s.u) > 0 {
   238  			for _, o := range ops {
   239  				if s.oponly != "" && s.oponly != o.name {
   240  					continue
   241  				}
   242  				fd := cfncData{s.name, o.name, s.name, o.symbol, "", "", "", ""}
   243  				for _, i := range s.u {
   244  					fd.Number = fmt.Sprintf("%d", i)
   245  					fd.FNumber = strings.Replace(fd.Number, "-", "Neg", -1)
   246  
   247  					// unsigned
   248  					for _, j := range s.u {
   249  
   250  						if o.name != "mod" && o.name != "div" || j != 0 {
   251  							fd.Ans = ansU(i, j, s.name, o.symbol)
   252  							fd.Input = fmt.Sprintf("%d", j)
   253  							if err := vrf1.Execute(w, fd); err != nil {
   254  								panic(err)
   255  							}
   256  						}
   257  
   258  						if o.name != "mod" && o.name != "div" || i != 0 {
   259  							fd.Ans = ansU(j, i, s.name, o.symbol)
   260  							fd.Input = fmt.Sprintf("%d", j)
   261  							if err := vrf2.Execute(w, fd); err != nil {
   262  								panic(err)
   263  							}
   264  						}
   265  
   266  					}
   267  				}
   268  
   269  			}
   270  		}
   271  
   272  		// signed
   273  		if len(s.i) > 0 {
   274  			for _, o := range ops {
   275  				if s.oponly != "" && s.oponly != o.name {
   276  					continue
   277  				}
   278  				// don't generate tests for shifts by signed integers
   279  				if o.name == "lsh" || o.name == "rsh" {
   280  					continue
   281  				}
   282  				fd := cfncData{s.name, o.name, s.name, o.symbol, "", "", "", ""}
   283  				for _, i := range s.i {
   284  					fd.Number = fmt.Sprintf("%d", i)
   285  					fd.FNumber = strings.Replace(fd.Number, "-", "Neg", -1)
   286  					for _, j := range s.i {
   287  						if o.name != "mod" && o.name != "div" || j != 0 {
   288  							fd.Ans = ansS(i, j, s.name, o.symbol)
   289  							fd.Input = fmt.Sprintf("%d", j)
   290  							if err := vrf1.Execute(w, fd); err != nil {
   291  								panic(err)
   292  							}
   293  						}
   294  
   295  						if o.name != "mod" && o.name != "div" || i != 0 {
   296  							fd.Ans = ansS(j, i, s.name, o.symbol)
   297  							fd.Input = fmt.Sprintf("%d", j)
   298  							if err := vrf2.Execute(w, fd); err != nil {
   299  								panic(err)
   300  							}
   301  						}
   302  
   303  					}
   304  				}
   305  
   306  			}
   307  		}
   308  
   309  		fmt.Fprintf(w, "}\n\n")
   310  	}
   311  
   312  	fmt.Fprint(w, `
   313  
   314  // TestArithmeticConst tests results for arithmetic operations against constants.
   315  func TestArithmeticConst(t *testing.T) {
   316  `)
   317  
   318  	for _, s := range szs {
   319  		fmt.Fprintf(w, `for _, test := range tests_%s%s {`, s.name, s.oponly)
   320  		// Use WriteString here to avoid a vet warning about formatting directives.
   321  		w.WriteString(`if got := test.fn(test.in); got != test.want {
   322  			t.Errorf("%s(%d) = %d, want %d\n", test.fnname, test.in, got, test.want)
   323  		}
   324  	}
   325  `)
   326  	}
   327  
   328  	fmt.Fprint(w, `
   329  }
   330  `)
   331  
   332  	// gofmt result
   333  	b := w.Bytes()
   334  	src, err := format.Source(b)
   335  	if err != nil {
   336  		fmt.Printf("%s\n", b)
   337  		panic(err)
   338  	}
   339  
   340  	// write to file
   341  	err = os.WriteFile("../arithConst_test.go", src, 0666)
   342  	if err != nil {
   343  		log.Fatalf("can't write output: %v\n", err)
   344  	}
   345  }