cuelang.org/go@v0.13.0/internal/tdtest/update.go (about)

     1  // Copyright 2023 CUE Authors
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package tdtest
    16  
    17  import (
    18  	"fmt"
    19  	"go/ast"
    20  	"go/format"
    21  	"go/token"
    22  	"go/types"
    23  	"os"
    24  	"reflect"
    25  	"strconv"
    26  	"strings"
    27  	"sync"
    28  	"testing"
    29  
    30  	"golang.org/x/tools/go/ast/astutil"
    31  	"golang.org/x/tools/go/packages"
    32  )
    33  
    34  // info contains information needed to update files.
    35  type info struct {
    36  	t *testing.T
    37  
    38  	tcType reflect.Type
    39  
    40  	needsUpdate bool // an updateable field has changed
    41  
    42  	table *ast.CompositeLit // the table that is the source of the tests
    43  
    44  	testPkg *packages.Package
    45  
    46  	calls   map[token.Position]*callInfo
    47  	patches map[ast.Node]ast.Expr
    48  }
    49  
    50  type callInfo struct {
    51  	ast       *ast.CallExpr
    52  	funcName  string
    53  	fieldName string
    54  }
    55  
    56  var loadPackages = sync.OnceValues(func() ([]*packages.Package, error) {
    57  	cfg := &packages.Config{
    58  		Mode: packages.NeedFiles |
    59  			packages.NeedDeps |
    60  			packages.NeedTypes |
    61  			packages.NeedTypesInfo |
    62  			packages.NeedSyntax,
    63  		Tests: true,
    64  	}
    65  
    66  	return packages.Load(cfg, ".")
    67  })
    68  
    69  func (s *set[T]) getInfo(file string) *info {
    70  	if s.info != nil {
    71  		return s.info
    72  	}
    73  	info := &info{
    74  		t:       s.t,
    75  		tcType:  reflect.TypeFor[T](),
    76  		calls:   make(map[token.Position]*callInfo),
    77  		patches: make(map[ast.Node]ast.Expr),
    78  	}
    79  	s.info = info
    80  
    81  	t := s.t
    82  
    83  	pkgs, pkgsErr := loadPackages()
    84  	if pkgsErr != nil {
    85  		t.Fatalf("load: %v\n", pkgsErr)
    86  	}
    87  
    88  	// Get package under test.
    89  	f, pkg := findFileAndPackage(file, pkgs)
    90  	if f == nil {
    91  		t.Fatalf("failed to load package for file %s", file)
    92  	}
    93  	info.testPkg = pkg
    94  
    95  	// TODO: not necessary at the moment, but this is tricky so leaving this in
    96  	// so as to not to forget how to do it.
    97  	//
    98  	// for _, p := range pkg.Types.Imports() {
    99  	// 	if p.Path() == "cuelang.org/go/internal/tdtest" {
   100  	// 		info.thisPkg = p
   101  	// 	}
   102  	// }
   103  	// if info.thisPkg == nil {
   104  	// 	t.Fatalf("could not find test package")
   105  	// }
   106  
   107  	// In cuetdtest the function of the test may be a subtest, like TestFoo/v3.
   108  	// We need to strip the last part.
   109  	testName := strings.SplitN(t.Name(), "/", 2)[0]
   110  
   111  	// Find function declaration of this test.
   112  	var fn *ast.FuncDecl
   113  	for _, d := range f.Decls {
   114  		if fd, ok := d.(*ast.FuncDecl); ok && fd.Name.Name == testName {
   115  			fn = fd
   116  		}
   117  	}
   118  	if fn == nil {
   119  		t.Fatalf("could not find test %q in file %q", testName, file)
   120  	}
   121  
   122  	// Find CompositLit table used for the test:
   123  	// - find call to which CompositLit was passed,
   124  	a := info.findCalls(fn.Body, "New", "Run")
   125  	if len(a) != 1 {
   126  		// TODO: allow more than one.
   127  		t.Fatalf("only one Run or New function allowed per test")
   128  	}
   129  
   130  	// - analyse second argument of call,
   131  	call := a[0].ast
   132  	fset := info.testPkg.Fset
   133  	ti := info.testPkg.TypesInfo
   134  	ident, ok := call.Args[1].(*ast.Ident)
   135  	if !ok {
   136  		t.Fatalf("%v: arg 2 of %s must be a reference to the table",
   137  			fset.Position(call.Args[1].Pos()), a[0].funcName)
   138  	}
   139  	def := ti.Uses[ident]
   140  	pos := def.Pos()
   141  
   142  	// - locate the CompositeLit in the AST based on position.
   143  	v0 := findVar(pos, f)
   144  	if v0 == nil {
   145  		t.Fatalf("cannot find composite literal in source code")
   146  	}
   147  	v, ok := v0.(*ast.CompositeLit)
   148  	if !ok {
   149  		// generics should avoid this.
   150  		t.Fatalf("expected composite literal, found %T", v0)
   151  	}
   152  	info.table = v
   153  
   154  	// Find and index assertion calls.
   155  	a = info.findCalls(fn.Body, "Equal")
   156  	for _, x := range a {
   157  		info.initFieldRef(x, f)
   158  	}
   159  
   160  	return info
   161  }
   162  
   163  // initFieldRef updates c with information about the field referenced
   164  // in its corresponding call:
   165  //   - name of the field
   166  //   - indexes the field based on filename and line number.
   167  func (i *info) initFieldRef(c *callInfo, f *ast.File) {
   168  	call := c.ast
   169  	t := i.t
   170  	info := i.testPkg.TypesInfo
   171  	fset := i.testPkg.Fset
   172  	pos := fset.Position(call.Pos())
   173  
   174  	sel, ok := call.Args[1].(*ast.SelectorExpr)
   175  	s := info.Selections[sel]
   176  	if !ok || s == nil || s.Kind() != types.FieldVal {
   177  		t.Fatalf("%v: arg 2 of %s must be a reference to a test case field",
   178  			fset.Position(call.Args[1].Pos()), c.funcName)
   179  	}
   180  
   181  	obj := s.Obj()
   182  	c.fieldName = obj.Name()
   183  	if _, ok := i.tcType.FieldByName(c.fieldName); !ok {
   184  		t.Fatalf("%v: could not find field %s",
   185  			fset.Position(obj.Pos()), c.fieldName)
   186  	}
   187  
   188  	pos.Column = 0
   189  	pos.Offset = 0
   190  	i.calls[pos] = c
   191  }
   192  
   193  // findFileAndPackage locates the ast.File and package within the given slice
   194  // of packages, in which the given file is located.
   195  func findFileAndPackage(path string, pkgs []*packages.Package) (*ast.File, *packages.Package) {
   196  	for _, p := range pkgs {
   197  		for i, gf := range p.GoFiles {
   198  			if gf == path {
   199  				return p.Syntax[i], p
   200  			}
   201  		}
   202  	}
   203  	return nil, nil
   204  }
   205  
   206  func isT(s string) bool {
   207  	// TODO: parametrize this so that tdtest does not have to know of cuetdtest.
   208  	return s == "*cuelang.org/go/internal/tdtest.T" ||
   209  		s == "*cuelang.org/go/internal/cuetdtest.T" ||
   210  		s == "*cuelang.org/go/internal/cuetest.T"
   211  }
   212  
   213  // findCalls finds all call expressions within a given block for functions
   214  // or methods defined within the tdtest package.
   215  func (i *info) findCalls(block *ast.BlockStmt, names ...string) []*callInfo {
   216  	var a []*callInfo
   217  	ast.Inspect(block, func(n ast.Node) bool {
   218  		c, ok := n.(*ast.CallExpr)
   219  		if !ok {
   220  			return true
   221  		}
   222  		sel, ok := c.Fun.(*ast.SelectorExpr)
   223  		if !ok {
   224  			return true
   225  		}
   226  
   227  		// TODO: also test package. It would be better to test the equality
   228  		// using the information in the types.Info/packages to ensure that
   229  		// we really got the right function.
   230  		info := i.testPkg.TypesInfo
   231  		for _, name := range names {
   232  			if sel.Sel.Name == name {
   233  				receiver := info.TypeOf(sel.X).String()
   234  				if isT(receiver) {
   235  					// Method.
   236  				} else if len(c.Args) == 3 {
   237  					// Run function.
   238  					fn := c.Args[2].(*ast.FuncLit)
   239  					if len(fn.Type.Params.List) != 2 {
   240  						return true
   241  					}
   242  					argType := info.TypeOf(fn.Type.Params.List[0].Type).String()
   243  					if !isT(argType) {
   244  						return true
   245  					}
   246  				} else {
   247  					return true
   248  				}
   249  				ci := &callInfo{
   250  					funcName: name,
   251  					ast:      c,
   252  				}
   253  				a = append(a, ci)
   254  				return true
   255  			}
   256  		}
   257  
   258  		return true
   259  	})
   260  	return a
   261  }
   262  
   263  func findVar(pos token.Pos, n0 ast.Node) (ret ast.Expr) {
   264  	ast.Inspect(n0, func(n ast.Node) bool {
   265  		if n == nil {
   266  			return true
   267  		}
   268  		switch n := n.(type) {
   269  		case *ast.AssignStmt:
   270  			for i, v := range n.Lhs {
   271  				if v.Pos() == pos {
   272  					ret = n.Rhs[i]
   273  				}
   274  			}
   275  			return false
   276  		case *ast.ValueSpec:
   277  			for i, v := range n.Names {
   278  				if v.Pos() == pos {
   279  					ret = n.Values[i]
   280  				}
   281  			}
   282  			return false
   283  		}
   284  		return true
   285  	})
   286  	return ret
   287  }
   288  
   289  func (s *set[TC]) update() {
   290  	info := s.info
   291  
   292  	t := s.t
   293  	fset := info.testPkg.Fset
   294  
   295  	file := fset.Position(info.table.Pos()).Filename
   296  	var f *ast.File
   297  	for i, gof := range info.testPkg.GoFiles {
   298  		if gof == file {
   299  			f = info.testPkg.Syntax[i]
   300  		}
   301  	}
   302  	if f == nil {
   303  		t.Fatalf("file %s not in package", file)
   304  	}
   305  
   306  	// TODO: use text-based insertion instead:
   307  	// - sort insertions and replacements on position in descending order.
   308  	// - substitute textually.
   309  	//
   310  	// We are using Apply because this is supposed to give better handling of
   311  	// comments. In practice this only works marginally better than not handling
   312  	// positions at all. Probably a lost cause.
   313  	astutil.Apply(f, func(c *astutil.Cursor) bool {
   314  		n := c.Node()
   315  
   316  		switch x := info.patches[n]; x.(type) {
   317  		case nil:
   318  		case *ast.KeyValueExpr:
   319  			for {
   320  				c.InsertAfter(x)
   321  				x = info.patches[x]
   322  				if x == nil {
   323  					break
   324  				}
   325  			}
   326  		default:
   327  			c.Replace(x)
   328  		}
   329  		return true
   330  	}, nil)
   331  
   332  	// TODO: use tmp files?
   333  	w, err := os.Create(file)
   334  	if err != nil {
   335  		t.Fatal(err)
   336  	}
   337  	defer w.Close()
   338  
   339  	err = format.Node(w, fset, f)
   340  	if err != nil {
   341  		t.Fatal(err)
   342  	}
   343  }
   344  
   345  func (t *T) updateField(info *info, ci *callInfo, newValue any) {
   346  	info.needsUpdate = true
   347  
   348  	fset := info.testPkg.Fset
   349  
   350  	e, ok := info.table.Elts[t.iter].(*ast.CompositeLit)
   351  	if !ok {
   352  		t.Fatalf("not a composite literal")
   353  	}
   354  
   355  	isZero := false
   356  	var value ast.Expr
   357  	switch x := reflect.ValueOf(newValue); x.Kind() {
   358  	default:
   359  		s := fmt.Sprint(x)
   360  		x = reflect.ValueOf(s)
   361  		fallthrough
   362  	case reflect.String:
   363  		s := x.String()
   364  		isZero = s == ""
   365  		if !strings.ContainsRune(s, '`') && !isZero {
   366  			s = fmt.Sprintf("`%s`", s)
   367  		} else {
   368  			s = strconv.Quote(s)
   369  		}
   370  		value = &ast.BasicLit{Kind: token.STRING, Value: s}
   371  	case reflect.Bool:
   372  		if b := x.Bool(); b {
   373  			value = &ast.BasicLit{Kind: token.IDENT, Value: "true"}
   374  		} else {
   375  			value = &ast.BasicLit{Kind: token.IDENT, Value: "false"}
   376  			isZero = true
   377  		}
   378  	case reflect.Int, reflect.Int64, reflect.Int32, reflect.Int8:
   379  		i := x.Int()
   380  		value = &ast.BasicLit{Kind: token.INT,
   381  			Value: strconv.FormatInt(i, 10)}
   382  		isZero = i == 0
   383  	case reflect.Uint, reflect.Uint64, reflect.Uint32, reflect.Uint8:
   384  		i := x.Uint()
   385  		value = &ast.BasicLit{Kind: token.INT,
   386  			Value: strconv.FormatUint(i, 10)}
   387  		isZero = i == 0
   388  	}
   389  
   390  	for _, x := range e.Elts {
   391  		kv, ok := x.(*ast.KeyValueExpr)
   392  		if !ok {
   393  			t.Fatalf("%v: elements must be key value pairs",
   394  				fset.Position(kv.Pos()))
   395  		}
   396  		ident, ok := kv.Key.(*ast.Ident)
   397  		if !ok {
   398  			t.Fatalf("%v: key must be an identifier",
   399  				fset.Position(kv.Pos()))
   400  		}
   401  		if ident.Name == ci.fieldName {
   402  			info.patches[kv.Value] = value
   403  			return
   404  		}
   405  	}
   406  
   407  	if !isZero {
   408  		kv := &ast.KeyValueExpr{
   409  			Key:   &ast.Ident{Name: ci.fieldName},
   410  			Value: value,
   411  		}
   412  		if len(e.Elts) > 0 {
   413  			var key ast.Node = e.Elts[len(e.Elts)-1]
   414  			old := info.patches[key]
   415  			if old != nil {
   416  				info.patches[kv] = old
   417  			}
   418  			info.patches[key] = kv
   419  		} else {
   420  			info.patches[e] = &ast.CompositeLit{Elts: []ast.Expr{kv}}
   421  		}
   422  	}
   423  }