cuelang.org/go@v0.10.1/internal/tdtest/update.go (about) 1 // Copyright 2023 CUE Authors 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package tdtest 16 17 import ( 18 "fmt" 19 "go/ast" 20 "go/format" 21 "go/token" 22 "go/types" 23 "os" 24 "reflect" 25 "strconv" 26 "strings" 27 "sync" 28 "testing" 29 30 "golang.org/x/tools/go/ast/astutil" 31 "golang.org/x/tools/go/packages" 32 ) 33 34 // info contains information needed to update files. 35 type info struct { 36 t *testing.T 37 38 tcType reflect.Type 39 40 needsUpdate bool // an updateable field has changed 41 42 table *ast.CompositeLit // the table that is the source of the tests 43 44 testPkg *packages.Package 45 46 calls map[token.Position]*callInfo 47 patches map[ast.Node]ast.Expr 48 } 49 50 type callInfo struct { 51 ast *ast.CallExpr 52 funcName string 53 fieldName string 54 } 55 56 var ( 57 once sync.Once 58 pkgs []*packages.Package 59 pkgsErr error 60 ) 61 62 func initPackages() ([]*packages.Package, error) { 63 once.Do(func() { 64 cfg := &packages.Config{ 65 Mode: packages.NeedFiles | 66 packages.NeedDeps | 67 packages.NeedTypes | 68 packages.NeedTypesInfo | 69 packages.NeedSyntax, 70 Tests: true, 71 } 72 73 pkgs, pkgsErr = packages.Load(cfg, ".") 74 }) 75 return pkgs, pkgsErr 76 } 77 78 func (s *set[T]) getInfo(file string) *info { 79 if s.info != nil { 80 return s.info 81 } 82 info := &info{ 83 t: s.t, 84 tcType: reflect.TypeFor[T](), 85 calls: make(map[token.Position]*callInfo), 86 patches: make(map[ast.Node]ast.Expr), 87 } 88 s.info = info 89 90 t := s.t 91 92 pkgs, pkgsErr = initPackages() 93 if pkgsErr != nil { 94 t.Fatalf("load: %v\n", pkgsErr) 95 } 96 97 // Get package under test. 98 f, pkg := findFileAndPackage(file, pkgs) 99 if f == nil { 100 t.Fatalf("failed to load package for file %s", file) 101 } 102 info.testPkg = pkg 103 104 // TODO: not necessary at the moment, but this is tricky so leaving this in 105 // so as to not to forget how to do it. 106 // 107 // for _, p := range pkg.Types.Imports() { 108 // if p.Path() == "cuelang.org/go/internal/tdtest" { 109 // info.thisPkg = p 110 // } 111 // } 112 // if info.thisPkg == nil { 113 // t.Fatalf("could not find test package") 114 // } 115 116 // Find function declaration of this test. 117 var fn *ast.FuncDecl 118 for _, d := range f.Decls { 119 if fd, ok := d.(*ast.FuncDecl); ok && fd.Name.Name == t.Name() { 120 fn = fd 121 } 122 } 123 if fn == nil { 124 t.Fatalf("could not find test %q in file %q", t.Name(), file) 125 } 126 127 // Find CompositLit table used for the test: 128 // - find call to which CompositLit was passed, 129 a := info.findCalls(fn.Body, "New", "Run") 130 if len(a) != 1 { 131 // TODO: allow more than one. 132 t.Fatalf("only one Run or New function allowed per test") 133 } 134 135 // - analyse second argument of call, 136 call := a[0].ast 137 fset := info.testPkg.Fset 138 ti := info.testPkg.TypesInfo 139 ident, ok := call.Args[1].(*ast.Ident) 140 if !ok { 141 t.Fatalf("%v: arg 2 of %s must be a reference to the table", 142 fset.Position(call.Args[1].Pos()), a[0].funcName) 143 } 144 def := ti.Uses[ident] 145 pos := def.Pos() 146 147 // - locate the CompositeLit in the AST based on position. 148 v0 := findVar(pos, f) 149 if v0 == nil { 150 t.Fatalf("cannot find composite literal in source code") 151 } 152 v, ok := v0.(*ast.CompositeLit) 153 if !ok { 154 // generics should avoid this. 155 t.Fatalf("expected composite literal, found %T", v0) 156 } 157 info.table = v 158 159 // Find and index assertion calls. 160 a = info.findCalls(fn.Body, "Equal") 161 for _, x := range a { 162 info.initFieldRef(x, f) 163 } 164 165 return info 166 } 167 168 // initFieldRef updates c with information about the field referenced 169 // in its corresponding call: 170 // - name of the field 171 // - indexes the field based on filename and line number. 172 func (i *info) initFieldRef(c *callInfo, f *ast.File) { 173 call := c.ast 174 t := i.t 175 info := i.testPkg.TypesInfo 176 fset := i.testPkg.Fset 177 pos := fset.Position(call.Pos()) 178 179 sel, ok := call.Args[1].(*ast.SelectorExpr) 180 s := info.Selections[sel] 181 if !ok || s == nil || s.Kind() != types.FieldVal { 182 t.Fatalf("%v: arg 2 of %s must be a reference to a test case field", 183 fset.Position(call.Args[1].Pos()), c.funcName) 184 } 185 186 obj := s.Obj() 187 c.fieldName = obj.Name() 188 if _, ok := i.tcType.FieldByName(c.fieldName); !ok { 189 t.Fatalf("%v: could not find field %s", 190 fset.Position(obj.Pos()), c.fieldName) 191 } 192 193 pos.Column = 0 194 pos.Offset = 0 195 i.calls[pos] = c 196 } 197 198 // findFileAndPackage locates the ast.File and package within the given slice 199 // of packages, in which the given file is located. 200 func findFileAndPackage(path string, pkgs []*packages.Package) (*ast.File, *packages.Package) { 201 for _, p := range pkgs { 202 for i, gf := range p.GoFiles { 203 if gf == path { 204 return p.Syntax[i], p 205 } 206 } 207 } 208 return nil, nil 209 } 210 211 const typeT = "*cuelang.org/go/internal/tdtest.T" 212 213 // findCalls finds all call expressions within a given block for functions 214 // or methods defined within the tdtest package. 215 func (i *info) findCalls(block *ast.BlockStmt, names ...string) []*callInfo { 216 var a []*callInfo 217 ast.Inspect(block, func(n ast.Node) bool { 218 c, ok := n.(*ast.CallExpr) 219 if !ok { 220 return true 221 } 222 sel, ok := c.Fun.(*ast.SelectorExpr) 223 if !ok { 224 return true 225 } 226 227 // TODO: also test package. It would be better to test the equality 228 // using the information in the types.Info/packages to ensure that 229 // we really got the right function. 230 info := i.testPkg.TypesInfo 231 for _, name := range names { 232 if sel.Sel.Name == name { 233 receiver := info.TypeOf(sel.X).String() 234 if receiver == typeT { 235 // Method. 236 } else if len(c.Args) == 3 { 237 // Run function. 238 fn := c.Args[2].(*ast.FuncLit) 239 if len(fn.Type.Params.List) != 2 { 240 return true 241 } 242 argType := info.TypeOf(fn.Type.Params.List[0].Type).String() 243 if argType != typeT { 244 return true 245 } 246 } else { 247 return true 248 } 249 ci := &callInfo{ 250 funcName: name, 251 ast: c, 252 } 253 a = append(a, ci) 254 return true 255 } 256 } 257 258 return true 259 }) 260 return a 261 } 262 263 func findVar(pos token.Pos, n0 ast.Node) (ret ast.Expr) { 264 ast.Inspect(n0, func(n ast.Node) bool { 265 if n == nil { 266 return true 267 } 268 switch n := n.(type) { 269 case *ast.AssignStmt: 270 for i, v := range n.Lhs { 271 if v.Pos() == pos { 272 ret = n.Rhs[i] 273 } 274 } 275 return false 276 case *ast.ValueSpec: 277 for i, v := range n.Names { 278 if v.Pos() == pos { 279 ret = n.Values[i] 280 } 281 } 282 return false 283 } 284 return true 285 }) 286 return ret 287 } 288 289 func (s *set[TC]) update() { 290 info := s.info 291 292 t := s.t 293 fset := info.testPkg.Fset 294 295 file := fset.Position(info.table.Pos()).Filename 296 var f *ast.File 297 for i, gof := range info.testPkg.GoFiles { 298 if gof == file { 299 f = info.testPkg.Syntax[i] 300 } 301 } 302 if f == nil { 303 t.Fatalf("file %s not in package", file) 304 } 305 306 // TODO: use text-based insertion instead: 307 // - sort insertions and replacements on position in descending order. 308 // - substitute textually. 309 // 310 // We are using Apply because this is supposed to give better handling of 311 // comments. In practice this only works marginally better than not handling 312 // positions at all. Probably a lost cause. 313 astutil.Apply(f, func(c *astutil.Cursor) bool { 314 n := c.Node() 315 316 switch x := info.patches[n]; x.(type) { 317 case nil: 318 case *ast.KeyValueExpr: 319 for { 320 c.InsertAfter(x) 321 x = info.patches[x] 322 if x == nil { 323 break 324 } 325 } 326 default: 327 c.Replace(x) 328 } 329 return true 330 }, nil) 331 332 // TODO: use tmp files? 333 w, err := os.Create(file) 334 if err != nil { 335 t.Fatal(err) 336 } 337 defer w.Close() 338 339 err = format.Node(w, fset, f) 340 if err != nil { 341 t.Fatal(err) 342 } 343 } 344 345 func (t *T) updateField(info *info, ci *callInfo, newValue any) { 346 info.needsUpdate = true 347 348 fset := info.testPkg.Fset 349 350 e, ok := info.table.Elts[t.iter].(*ast.CompositeLit) 351 if !ok { 352 t.Fatalf("not a composite literal") 353 } 354 355 isZero := false 356 var value ast.Expr 357 switch x := reflect.ValueOf(newValue); x.Kind() { 358 default: 359 s := fmt.Sprint(x) 360 x = reflect.ValueOf(s) 361 fallthrough 362 case reflect.String: 363 s := x.String() 364 isZero = s == "" 365 if !strings.ContainsRune(s, '`') && !isZero { 366 s = fmt.Sprintf("`%s`", s) 367 } else { 368 s = strconv.Quote(s) 369 } 370 value = &ast.BasicLit{Kind: token.STRING, Value: s} 371 case reflect.Bool: 372 if b := x.Bool(); b { 373 value = &ast.BasicLit{Kind: token.IDENT, Value: "true"} 374 } else { 375 value = &ast.BasicLit{Kind: token.IDENT, Value: "false"} 376 isZero = true 377 } 378 case reflect.Int, reflect.Int64, reflect.Int32, reflect.Int8: 379 i := x.Int() 380 value = &ast.BasicLit{Kind: token.INT, 381 Value: strconv.FormatInt(i, 10)} 382 isZero = i == 0 383 case reflect.Uint, reflect.Uint64, reflect.Uint32, reflect.Uint8: 384 i := x.Uint() 385 value = &ast.BasicLit{Kind: token.INT, 386 Value: strconv.FormatUint(i, 10)} 387 isZero = i == 0 388 } 389 390 for _, x := range e.Elts { 391 kv, ok := x.(*ast.KeyValueExpr) 392 if !ok { 393 t.Fatalf("%v: elements must be key value pairs", 394 fset.Position(kv.Pos())) 395 } 396 ident, ok := kv.Key.(*ast.Ident) 397 if !ok { 398 t.Fatalf("%v: key must be an identifier", 399 fset.Position(kv.Pos())) 400 } 401 if ident.Name == ci.fieldName { 402 info.patches[kv.Value] = value 403 return 404 } 405 } 406 407 if !isZero { 408 kv := &ast.KeyValueExpr{ 409 Key: &ast.Ident{Name: ci.fieldName}, 410 Value: value, 411 } 412 if len(e.Elts) > 0 { 413 var key ast.Node = e.Elts[len(e.Elts)-1] 414 old := info.patches[key] 415 if old != nil { 416 info.patches[kv] = old 417 } 418 info.patches[key] = kv 419 } else { 420 info.patches[e] = &ast.CompositeLit{Elts: []ast.Expr{kv}} 421 } 422 } 423 }