cuelang.org/go@v0.10.1/internal/tdtest/update.go (about)

     1  // Copyright 2023 CUE Authors
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package tdtest
    16  
    17  import (
    18  	"fmt"
    19  	"go/ast"
    20  	"go/format"
    21  	"go/token"
    22  	"go/types"
    23  	"os"
    24  	"reflect"
    25  	"strconv"
    26  	"strings"
    27  	"sync"
    28  	"testing"
    29  
    30  	"golang.org/x/tools/go/ast/astutil"
    31  	"golang.org/x/tools/go/packages"
    32  )
    33  
    34  // info contains information needed to update files.
    35  type info struct {
    36  	t *testing.T
    37  
    38  	tcType reflect.Type
    39  
    40  	needsUpdate bool // an updateable field has changed
    41  
    42  	table *ast.CompositeLit // the table that is the source of the tests
    43  
    44  	testPkg *packages.Package
    45  
    46  	calls   map[token.Position]*callInfo
    47  	patches map[ast.Node]ast.Expr
    48  }
    49  
    50  type callInfo struct {
    51  	ast       *ast.CallExpr
    52  	funcName  string
    53  	fieldName string
    54  }
    55  
    56  var (
    57  	once    sync.Once
    58  	pkgs    []*packages.Package
    59  	pkgsErr error
    60  )
    61  
    62  func initPackages() ([]*packages.Package, error) {
    63  	once.Do(func() {
    64  		cfg := &packages.Config{
    65  			Mode: packages.NeedFiles |
    66  				packages.NeedDeps |
    67  				packages.NeedTypes |
    68  				packages.NeedTypesInfo |
    69  				packages.NeedSyntax,
    70  			Tests: true,
    71  		}
    72  
    73  		pkgs, pkgsErr = packages.Load(cfg, ".")
    74  	})
    75  	return pkgs, pkgsErr
    76  }
    77  
    78  func (s *set[T]) getInfo(file string) *info {
    79  	if s.info != nil {
    80  		return s.info
    81  	}
    82  	info := &info{
    83  		t:       s.t,
    84  		tcType:  reflect.TypeFor[T](),
    85  		calls:   make(map[token.Position]*callInfo),
    86  		patches: make(map[ast.Node]ast.Expr),
    87  	}
    88  	s.info = info
    89  
    90  	t := s.t
    91  
    92  	pkgs, pkgsErr = initPackages()
    93  	if pkgsErr != nil {
    94  		t.Fatalf("load: %v\n", pkgsErr)
    95  	}
    96  
    97  	// Get package under test.
    98  	f, pkg := findFileAndPackage(file, pkgs)
    99  	if f == nil {
   100  		t.Fatalf("failed to load package for file %s", file)
   101  	}
   102  	info.testPkg = pkg
   103  
   104  	// TODO: not necessary at the moment, but this is tricky so leaving this in
   105  	// so as to not to forget how to do it.
   106  	//
   107  	// for _, p := range pkg.Types.Imports() {
   108  	// 	if p.Path() == "cuelang.org/go/internal/tdtest" {
   109  	// 		info.thisPkg = p
   110  	// 	}
   111  	// }
   112  	// if info.thisPkg == nil {
   113  	// 	t.Fatalf("could not find test package")
   114  	// }
   115  
   116  	// Find function declaration of this test.
   117  	var fn *ast.FuncDecl
   118  	for _, d := range f.Decls {
   119  		if fd, ok := d.(*ast.FuncDecl); ok && fd.Name.Name == t.Name() {
   120  			fn = fd
   121  		}
   122  	}
   123  	if fn == nil {
   124  		t.Fatalf("could not find test %q in file %q", t.Name(), file)
   125  	}
   126  
   127  	// Find CompositLit table used for the test:
   128  	// - find call to which CompositLit was passed,
   129  	a := info.findCalls(fn.Body, "New", "Run")
   130  	if len(a) != 1 {
   131  		// TODO: allow more than one.
   132  		t.Fatalf("only one Run or New function allowed per test")
   133  	}
   134  
   135  	// - analyse second argument of call,
   136  	call := a[0].ast
   137  	fset := info.testPkg.Fset
   138  	ti := info.testPkg.TypesInfo
   139  	ident, ok := call.Args[1].(*ast.Ident)
   140  	if !ok {
   141  		t.Fatalf("%v: arg 2 of %s must be a reference to the table",
   142  			fset.Position(call.Args[1].Pos()), a[0].funcName)
   143  	}
   144  	def := ti.Uses[ident]
   145  	pos := def.Pos()
   146  
   147  	// - locate the CompositeLit in the AST based on position.
   148  	v0 := findVar(pos, f)
   149  	if v0 == nil {
   150  		t.Fatalf("cannot find composite literal in source code")
   151  	}
   152  	v, ok := v0.(*ast.CompositeLit)
   153  	if !ok {
   154  		// generics should avoid this.
   155  		t.Fatalf("expected composite literal, found %T", v0)
   156  	}
   157  	info.table = v
   158  
   159  	// Find and index assertion calls.
   160  	a = info.findCalls(fn.Body, "Equal")
   161  	for _, x := range a {
   162  		info.initFieldRef(x, f)
   163  	}
   164  
   165  	return info
   166  }
   167  
   168  // initFieldRef updates c with information about the field referenced
   169  // in its corresponding call:
   170  //   - name of the field
   171  //   - indexes the field based on filename and line number.
   172  func (i *info) initFieldRef(c *callInfo, f *ast.File) {
   173  	call := c.ast
   174  	t := i.t
   175  	info := i.testPkg.TypesInfo
   176  	fset := i.testPkg.Fset
   177  	pos := fset.Position(call.Pos())
   178  
   179  	sel, ok := call.Args[1].(*ast.SelectorExpr)
   180  	s := info.Selections[sel]
   181  	if !ok || s == nil || s.Kind() != types.FieldVal {
   182  		t.Fatalf("%v: arg 2 of %s must be a reference to a test case field",
   183  			fset.Position(call.Args[1].Pos()), c.funcName)
   184  	}
   185  
   186  	obj := s.Obj()
   187  	c.fieldName = obj.Name()
   188  	if _, ok := i.tcType.FieldByName(c.fieldName); !ok {
   189  		t.Fatalf("%v: could not find field %s",
   190  			fset.Position(obj.Pos()), c.fieldName)
   191  	}
   192  
   193  	pos.Column = 0
   194  	pos.Offset = 0
   195  	i.calls[pos] = c
   196  }
   197  
   198  // findFileAndPackage locates the ast.File and package within the given slice
   199  // of packages, in which the given file is located.
   200  func findFileAndPackage(path string, pkgs []*packages.Package) (*ast.File, *packages.Package) {
   201  	for _, p := range pkgs {
   202  		for i, gf := range p.GoFiles {
   203  			if gf == path {
   204  				return p.Syntax[i], p
   205  			}
   206  		}
   207  	}
   208  	return nil, nil
   209  }
   210  
   211  const typeT = "*cuelang.org/go/internal/tdtest.T"
   212  
   213  // findCalls finds all call expressions within a given block for functions
   214  // or methods defined within the tdtest package.
   215  func (i *info) findCalls(block *ast.BlockStmt, names ...string) []*callInfo {
   216  	var a []*callInfo
   217  	ast.Inspect(block, func(n ast.Node) bool {
   218  		c, ok := n.(*ast.CallExpr)
   219  		if !ok {
   220  			return true
   221  		}
   222  		sel, ok := c.Fun.(*ast.SelectorExpr)
   223  		if !ok {
   224  			return true
   225  		}
   226  
   227  		// TODO: also test package. It would be better to test the equality
   228  		// using the information in the types.Info/packages to ensure that
   229  		// we really got the right function.
   230  		info := i.testPkg.TypesInfo
   231  		for _, name := range names {
   232  			if sel.Sel.Name == name {
   233  				receiver := info.TypeOf(sel.X).String()
   234  				if receiver == typeT {
   235  					// Method.
   236  				} else if len(c.Args) == 3 {
   237  					// Run function.
   238  					fn := c.Args[2].(*ast.FuncLit)
   239  					if len(fn.Type.Params.List) != 2 {
   240  						return true
   241  					}
   242  					argType := info.TypeOf(fn.Type.Params.List[0].Type).String()
   243  					if argType != typeT {
   244  						return true
   245  					}
   246  				} else {
   247  					return true
   248  				}
   249  				ci := &callInfo{
   250  					funcName: name,
   251  					ast:      c,
   252  				}
   253  				a = append(a, ci)
   254  				return true
   255  			}
   256  		}
   257  
   258  		return true
   259  	})
   260  	return a
   261  }
   262  
   263  func findVar(pos token.Pos, n0 ast.Node) (ret ast.Expr) {
   264  	ast.Inspect(n0, func(n ast.Node) bool {
   265  		if n == nil {
   266  			return true
   267  		}
   268  		switch n := n.(type) {
   269  		case *ast.AssignStmt:
   270  			for i, v := range n.Lhs {
   271  				if v.Pos() == pos {
   272  					ret = n.Rhs[i]
   273  				}
   274  			}
   275  			return false
   276  		case *ast.ValueSpec:
   277  			for i, v := range n.Names {
   278  				if v.Pos() == pos {
   279  					ret = n.Values[i]
   280  				}
   281  			}
   282  			return false
   283  		}
   284  		return true
   285  	})
   286  	return ret
   287  }
   288  
   289  func (s *set[TC]) update() {
   290  	info := s.info
   291  
   292  	t := s.t
   293  	fset := info.testPkg.Fset
   294  
   295  	file := fset.Position(info.table.Pos()).Filename
   296  	var f *ast.File
   297  	for i, gof := range info.testPkg.GoFiles {
   298  		if gof == file {
   299  			f = info.testPkg.Syntax[i]
   300  		}
   301  	}
   302  	if f == nil {
   303  		t.Fatalf("file %s not in package", file)
   304  	}
   305  
   306  	// TODO: use text-based insertion instead:
   307  	// - sort insertions and replacements on position in descending order.
   308  	// - substitute textually.
   309  	//
   310  	// We are using Apply because this is supposed to give better handling of
   311  	// comments. In practice this only works marginally better than not handling
   312  	// positions at all. Probably a lost cause.
   313  	astutil.Apply(f, func(c *astutil.Cursor) bool {
   314  		n := c.Node()
   315  
   316  		switch x := info.patches[n]; x.(type) {
   317  		case nil:
   318  		case *ast.KeyValueExpr:
   319  			for {
   320  				c.InsertAfter(x)
   321  				x = info.patches[x]
   322  				if x == nil {
   323  					break
   324  				}
   325  			}
   326  		default:
   327  			c.Replace(x)
   328  		}
   329  		return true
   330  	}, nil)
   331  
   332  	// TODO: use tmp files?
   333  	w, err := os.Create(file)
   334  	if err != nil {
   335  		t.Fatal(err)
   336  	}
   337  	defer w.Close()
   338  
   339  	err = format.Node(w, fset, f)
   340  	if err != nil {
   341  		t.Fatal(err)
   342  	}
   343  }
   344  
   345  func (t *T) updateField(info *info, ci *callInfo, newValue any) {
   346  	info.needsUpdate = true
   347  
   348  	fset := info.testPkg.Fset
   349  
   350  	e, ok := info.table.Elts[t.iter].(*ast.CompositeLit)
   351  	if !ok {
   352  		t.Fatalf("not a composite literal")
   353  	}
   354  
   355  	isZero := false
   356  	var value ast.Expr
   357  	switch x := reflect.ValueOf(newValue); x.Kind() {
   358  	default:
   359  		s := fmt.Sprint(x)
   360  		x = reflect.ValueOf(s)
   361  		fallthrough
   362  	case reflect.String:
   363  		s := x.String()
   364  		isZero = s == ""
   365  		if !strings.ContainsRune(s, '`') && !isZero {
   366  			s = fmt.Sprintf("`%s`", s)
   367  		} else {
   368  			s = strconv.Quote(s)
   369  		}
   370  		value = &ast.BasicLit{Kind: token.STRING, Value: s}
   371  	case reflect.Bool:
   372  		if b := x.Bool(); b {
   373  			value = &ast.BasicLit{Kind: token.IDENT, Value: "true"}
   374  		} else {
   375  			value = &ast.BasicLit{Kind: token.IDENT, Value: "false"}
   376  			isZero = true
   377  		}
   378  	case reflect.Int, reflect.Int64, reflect.Int32, reflect.Int8:
   379  		i := x.Int()
   380  		value = &ast.BasicLit{Kind: token.INT,
   381  			Value: strconv.FormatInt(i, 10)}
   382  		isZero = i == 0
   383  	case reflect.Uint, reflect.Uint64, reflect.Uint32, reflect.Uint8:
   384  		i := x.Uint()
   385  		value = &ast.BasicLit{Kind: token.INT,
   386  			Value: strconv.FormatUint(i, 10)}
   387  		isZero = i == 0
   388  	}
   389  
   390  	for _, x := range e.Elts {
   391  		kv, ok := x.(*ast.KeyValueExpr)
   392  		if !ok {
   393  			t.Fatalf("%v: elements must be key value pairs",
   394  				fset.Position(kv.Pos()))
   395  		}
   396  		ident, ok := kv.Key.(*ast.Ident)
   397  		if !ok {
   398  			t.Fatalf("%v: key must be an identifier",
   399  				fset.Position(kv.Pos()))
   400  		}
   401  		if ident.Name == ci.fieldName {
   402  			info.patches[kv.Value] = value
   403  			return
   404  		}
   405  	}
   406  
   407  	if !isZero {
   408  		kv := &ast.KeyValueExpr{
   409  			Key:   &ast.Ident{Name: ci.fieldName},
   410  			Value: value,
   411  		}
   412  		if len(e.Elts) > 0 {
   413  			var key ast.Node = e.Elts[len(e.Elts)-1]
   414  			old := info.patches[key]
   415  			if old != nil {
   416  				info.patches[kv] = old
   417  			}
   418  			info.patches[key] = kv
   419  		} else {
   420  			info.patches[e] = &ast.CompositeLit{Elts: []ast.Expr{kv}}
   421  		}
   422  	}
   423  }