github.com/go-asm/go@v1.21.1-0.20240213172139-40c5ead50c48/cmd/compile/test/testdata/gen/constFoldGen.go (about) 1 // Copyright 2016 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 // This program generates a test to verify that the standard arithmetic 6 // operators properly handle constant folding. The test file should be 7 // generated with a known working version of go. 8 // launch with `go run constFoldGen.go` a file called constFold_test.go 9 // will be written into the grandparent directory containing the tests. 10 11 package main 12 13 import ( 14 "bytes" 15 "fmt" 16 "go/format" 17 "log" 18 "os" 19 ) 20 21 type op struct { 22 name, symbol string 23 } 24 type szD struct { 25 name string 26 sn string 27 u []uint64 28 i []int64 29 } 30 31 var szs []szD = []szD{ 32 szD{name: "uint64", sn: "64", u: []uint64{0, 1, 4294967296, 0xffffFFFFffffFFFF}}, 33 szD{name: "int64", sn: "64", i: []int64{-0x8000000000000000, -0x7FFFFFFFFFFFFFFF, 34 -4294967296, -1, 0, 1, 4294967296, 0x7FFFFFFFFFFFFFFE, 0x7FFFFFFFFFFFFFFF}}, 35 36 szD{name: "uint32", sn: "32", u: []uint64{0, 1, 4294967295}}, 37 szD{name: "int32", sn: "32", i: []int64{-0x80000000, -0x7FFFFFFF, -1, 0, 38 1, 0x7FFFFFFF}}, 39 40 szD{name: "uint16", sn: "16", u: []uint64{0, 1, 65535}}, 41 szD{name: "int16", sn: "16", i: []int64{-32768, -32767, -1, 0, 1, 32766, 32767}}, 42 43 szD{name: "uint8", sn: "8", u: []uint64{0, 1, 255}}, 44 szD{name: "int8", sn: "8", i: []int64{-128, -127, -1, 0, 1, 126, 127}}, 45 } 46 47 var ops = []op{ 48 op{"add", "+"}, op{"sub", "-"}, op{"div", "/"}, op{"mul", "*"}, 49 op{"lsh", "<<"}, op{"rsh", ">>"}, op{"mod", "%"}, 50 } 51 52 // compute the result of i op j, cast as type t. 53 func ansU(i, j uint64, t, op string) string { 54 var ans uint64 55 switch op { 56 case "+": 57 ans = i + j 58 case "-": 59 ans = i - j 60 case "*": 61 ans = i * j 62 case "/": 63 if j != 0 { 64 ans = i / j 65 } 66 case "%": 67 if j != 0 { 68 ans = i % j 69 } 70 case "<<": 71 ans = i << j 72 case ">>": 73 ans = i >> j 74 } 75 switch t { 76 case "uint32": 77 ans = uint64(uint32(ans)) 78 case "uint16": 79 ans = uint64(uint16(ans)) 80 case "uint8": 81 ans = uint64(uint8(ans)) 82 } 83 return fmt.Sprintf("%d", ans) 84 } 85 86 // compute the result of i op j, cast as type t. 87 func ansS(i, j int64, t, op string) string { 88 var ans int64 89 switch op { 90 case "+": 91 ans = i + j 92 case "-": 93 ans = i - j 94 case "*": 95 ans = i * j 96 case "/": 97 if j != 0 { 98 ans = i / j 99 } 100 case "%": 101 if j != 0 { 102 ans = i % j 103 } 104 case "<<": 105 ans = i << uint64(j) 106 case ">>": 107 ans = i >> uint64(j) 108 } 109 switch t { 110 case "int32": 111 ans = int64(int32(ans)) 112 case "int16": 113 ans = int64(int16(ans)) 114 case "int8": 115 ans = int64(int8(ans)) 116 } 117 return fmt.Sprintf("%d", ans) 118 } 119 120 func main() { 121 w := new(bytes.Buffer) 122 fmt.Fprintf(w, "// run\n") 123 fmt.Fprintf(w, "// Code generated by gen/constFoldGen.go. DO NOT EDIT.\n\n") 124 fmt.Fprintf(w, "package gc\n") 125 fmt.Fprintf(w, "import \"testing\"\n") 126 127 for _, s := range szs { 128 for _, o := range ops { 129 if o.symbol == "<<" || o.symbol == ">>" { 130 // shifts handled separately below, as they can have 131 // different types on the LHS and RHS. 132 continue 133 } 134 fmt.Fprintf(w, "func TestConstFold%s%s(t *testing.T) {\n", s.name, o.name) 135 fmt.Fprintf(w, "\tvar x, y, r %s\n", s.name) 136 // unsigned test cases 137 for _, c := range s.u { 138 fmt.Fprintf(w, "\tx = %d\n", c) 139 for _, d := range s.u { 140 if d == 0 && (o.symbol == "/" || o.symbol == "%") { 141 continue 142 } 143 fmt.Fprintf(w, "\ty = %d\n", d) 144 fmt.Fprintf(w, "\tr = x %s y\n", o.symbol) 145 want := ansU(c, d, s.name, o.symbol) 146 fmt.Fprintf(w, "\tif r != %s {\n", want) 147 fmt.Fprintf(w, "\t\tt.Errorf(\"%d %%s %d = %%d, want %s\", %q, r)\n", c, d, want, o.symbol) 148 fmt.Fprintf(w, "\t}\n") 149 } 150 } 151 // signed test cases 152 for _, c := range s.i { 153 fmt.Fprintf(w, "\tx = %d\n", c) 154 for _, d := range s.i { 155 if d == 0 && (o.symbol == "/" || o.symbol == "%") { 156 continue 157 } 158 fmt.Fprintf(w, "\ty = %d\n", d) 159 fmt.Fprintf(w, "\tr = x %s y\n", o.symbol) 160 want := ansS(c, d, s.name, o.symbol) 161 fmt.Fprintf(w, "\tif r != %s {\n", want) 162 fmt.Fprintf(w, "\t\tt.Errorf(\"%d %%s %d = %%d, want %s\", %q, r)\n", c, d, want, o.symbol) 163 fmt.Fprintf(w, "\t}\n") 164 } 165 } 166 fmt.Fprintf(w, "}\n") 167 } 168 } 169 170 // Special signed/unsigned cases for shifts 171 for _, ls := range szs { 172 for _, rs := range szs { 173 if rs.name[0] != 'u' { 174 continue 175 } 176 for _, o := range ops { 177 if o.symbol != "<<" && o.symbol != ">>" { 178 continue 179 } 180 fmt.Fprintf(w, "func TestConstFold%s%s%s(t *testing.T) {\n", ls.name, rs.name, o.name) 181 fmt.Fprintf(w, "\tvar x, r %s\n", ls.name) 182 fmt.Fprintf(w, "\tvar y %s\n", rs.name) 183 // unsigned LHS 184 for _, c := range ls.u { 185 fmt.Fprintf(w, "\tx = %d\n", c) 186 for _, d := range rs.u { 187 fmt.Fprintf(w, "\ty = %d\n", d) 188 fmt.Fprintf(w, "\tr = x %s y\n", o.symbol) 189 want := ansU(c, d, ls.name, o.symbol) 190 fmt.Fprintf(w, "\tif r != %s {\n", want) 191 fmt.Fprintf(w, "\t\tt.Errorf(\"%d %%s %d = %%d, want %s\", %q, r)\n", c, d, want, o.symbol) 192 fmt.Fprintf(w, "\t}\n") 193 } 194 } 195 // signed LHS 196 for _, c := range ls.i { 197 fmt.Fprintf(w, "\tx = %d\n", c) 198 for _, d := range rs.u { 199 fmt.Fprintf(w, "\ty = %d\n", d) 200 fmt.Fprintf(w, "\tr = x %s y\n", o.symbol) 201 want := ansS(c, int64(d), ls.name, o.symbol) 202 fmt.Fprintf(w, "\tif r != %s {\n", want) 203 fmt.Fprintf(w, "\t\tt.Errorf(\"%d %%s %d = %%d, want %s\", %q, r)\n", c, d, want, o.symbol) 204 fmt.Fprintf(w, "\t}\n") 205 } 206 } 207 fmt.Fprintf(w, "}\n") 208 } 209 } 210 } 211 212 // Constant folding for comparisons 213 for _, s := range szs { 214 fmt.Fprintf(w, "func TestConstFoldCompare%s(t *testing.T) {\n", s.name) 215 for _, x := range s.i { 216 for _, y := range s.i { 217 fmt.Fprintf(w, "\t{\n") 218 fmt.Fprintf(w, "\t\tvar x %s = %d\n", s.name, x) 219 fmt.Fprintf(w, "\t\tvar y %s = %d\n", s.name, y) 220 if x == y { 221 fmt.Fprintf(w, "\t\tif !(x == y) { t.Errorf(\"!(%%d == %%d)\", x, y) }\n") 222 } else { 223 fmt.Fprintf(w, "\t\tif x == y { t.Errorf(\"%%d == %%d\", x, y) }\n") 224 } 225 if x != y { 226 fmt.Fprintf(w, "\t\tif !(x != y) { t.Errorf(\"!(%%d != %%d)\", x, y) }\n") 227 } else { 228 fmt.Fprintf(w, "\t\tif x != y { t.Errorf(\"%%d != %%d\", x, y) }\n") 229 } 230 if x < y { 231 fmt.Fprintf(w, "\t\tif !(x < y) { t.Errorf(\"!(%%d < %%d)\", x, y) }\n") 232 } else { 233 fmt.Fprintf(w, "\t\tif x < y { t.Errorf(\"%%d < %%d\", x, y) }\n") 234 } 235 if x > y { 236 fmt.Fprintf(w, "\t\tif !(x > y) { t.Errorf(\"!(%%d > %%d)\", x, y) }\n") 237 } else { 238 fmt.Fprintf(w, "\t\tif x > y { t.Errorf(\"%%d > %%d\", x, y) }\n") 239 } 240 if x <= y { 241 fmt.Fprintf(w, "\t\tif !(x <= y) { t.Errorf(\"!(%%d <= %%d)\", x, y) }\n") 242 } else { 243 fmt.Fprintf(w, "\t\tif x <= y { t.Errorf(\"%%d <= %%d\", x, y) }\n") 244 } 245 if x >= y { 246 fmt.Fprintf(w, "\t\tif !(x >= y) { t.Errorf(\"!(%%d >= %%d)\", x, y) }\n") 247 } else { 248 fmt.Fprintf(w, "\t\tif x >= y { t.Errorf(\"%%d >= %%d\", x, y) }\n") 249 } 250 fmt.Fprintf(w, "\t}\n") 251 } 252 } 253 for _, x := range s.u { 254 for _, y := range s.u { 255 fmt.Fprintf(w, "\t{\n") 256 fmt.Fprintf(w, "\t\tvar x %s = %d\n", s.name, x) 257 fmt.Fprintf(w, "\t\tvar y %s = %d\n", s.name, y) 258 if x == y { 259 fmt.Fprintf(w, "\t\tif !(x == y) { t.Errorf(\"!(%%d == %%d)\", x, y) }\n") 260 } else { 261 fmt.Fprintf(w, "\t\tif x == y { t.Errorf(\"%%d == %%d\", x, y) }\n") 262 } 263 if x != y { 264 fmt.Fprintf(w, "\t\tif !(x != y) { t.Errorf(\"!(%%d != %%d)\", x, y) }\n") 265 } else { 266 fmt.Fprintf(w, "\t\tif x != y { t.Errorf(\"%%d != %%d\", x, y) }\n") 267 } 268 if x < y { 269 fmt.Fprintf(w, "\t\tif !(x < y) { t.Errorf(\"!(%%d < %%d)\", x, y) }\n") 270 } else { 271 fmt.Fprintf(w, "\t\tif x < y { t.Errorf(\"%%d < %%d\", x, y) }\n") 272 } 273 if x > y { 274 fmt.Fprintf(w, "\t\tif !(x > y) { t.Errorf(\"!(%%d > %%d)\", x, y) }\n") 275 } else { 276 fmt.Fprintf(w, "\t\tif x > y { t.Errorf(\"%%d > %%d\", x, y) }\n") 277 } 278 if x <= y { 279 fmt.Fprintf(w, "\t\tif !(x <= y) { t.Errorf(\"!(%%d <= %%d)\", x, y) }\n") 280 } else { 281 fmt.Fprintf(w, "\t\tif x <= y { t.Errorf(\"%%d <= %%d\", x, y) }\n") 282 } 283 if x >= y { 284 fmt.Fprintf(w, "\t\tif !(x >= y) { t.Errorf(\"!(%%d >= %%d)\", x, y) }\n") 285 } else { 286 fmt.Fprintf(w, "\t\tif x >= y { t.Errorf(\"%%d >= %%d\", x, y) }\n") 287 } 288 fmt.Fprintf(w, "\t}\n") 289 } 290 } 291 fmt.Fprintf(w, "}\n") 292 } 293 294 // gofmt result 295 b := w.Bytes() 296 src, err := format.Source(b) 297 if err != nil { 298 fmt.Printf("%s\n", b) 299 panic(err) 300 } 301 302 // write to file 303 err = os.WriteFile("../../constFold_test.go", src, 0666) 304 if err != nil { 305 log.Fatalf("can't write output: %v\n", err) 306 } 307 }