golang.org/x/tools@v0.21.1-0.20240520172518-788d39e776b1/go/ssa/subst.go (about)

     1  // Copyright 2022 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package ssa
     6  
     7  import (
     8  	"go/types"
     9  
    10  	"golang.org/x/tools/internal/aliases"
    11  )
    12  
    13  // Type substituter for a fixed set of replacement types.
    14  //
    15  // A nil *subster is an valid, empty substitution map. It always acts as
    16  // the identity function. This allows for treating parameterized and
    17  // non-parameterized functions identically while compiling to ssa.
    18  //
    19  // Not concurrency-safe.
    20  type subster struct {
    21  	replacements map[*types.TypeParam]types.Type // values should contain no type params
    22  	cache        map[types.Type]types.Type       // cache of subst results
    23  	ctxt         *types.Context                  // cache for instantiation
    24  	scope        *types.Scope                    // *types.Named declared within this scope can be substituted (optional)
    25  	debug        bool                            // perform extra debugging checks
    26  	// TODO(taking): consider adding Pos
    27  	// TODO(zpavlinovic): replacements can contain type params
    28  	// when generating instances inside of a generic function body.
    29  }
    30  
    31  // Returns a subster that replaces tparams[i] with targs[i]. Uses ctxt as a cache.
    32  // targs should not contain any types in tparams.
    33  // scope is the (optional) lexical block of the generic function for which we are substituting.
    34  func makeSubster(ctxt *types.Context, scope *types.Scope, tparams *types.TypeParamList, targs []types.Type, debug bool) *subster {
    35  	assert(tparams.Len() == len(targs), "makeSubster argument count must match")
    36  
    37  	subst := &subster{
    38  		replacements: make(map[*types.TypeParam]types.Type, tparams.Len()),
    39  		cache:        make(map[types.Type]types.Type),
    40  		ctxt:         ctxt,
    41  		scope:        scope,
    42  		debug:        debug,
    43  	}
    44  	for i := 0; i < tparams.Len(); i++ {
    45  		subst.replacements[tparams.At(i)] = targs[i]
    46  	}
    47  	if subst.debug {
    48  		subst.wellFormed()
    49  	}
    50  	return subst
    51  }
    52  
    53  // wellFormed asserts that subst was properly initialized.
    54  func (subst *subster) wellFormed() {
    55  	if subst == nil {
    56  		return
    57  	}
    58  	// Check that all of the type params do not appear in the arguments.
    59  	s := make(map[types.Type]bool, len(subst.replacements))
    60  	for tparam := range subst.replacements {
    61  		s[tparam] = true
    62  	}
    63  	for _, r := range subst.replacements {
    64  		if reaches(r, s) {
    65  			panic(subst)
    66  		}
    67  	}
    68  }
    69  
    70  // typ returns the type of t with the type parameter tparams[i] substituted
    71  // for the type targs[i] where subst was created using tparams and targs.
    72  func (subst *subster) typ(t types.Type) (res types.Type) {
    73  	if subst == nil {
    74  		return t // A nil subst is type preserving.
    75  	}
    76  	if r, ok := subst.cache[t]; ok {
    77  		return r
    78  	}
    79  	defer func() {
    80  		subst.cache[t] = res
    81  	}()
    82  
    83  	switch t := t.(type) {
    84  	case *types.TypeParam:
    85  		r := subst.replacements[t]
    86  		assert(r != nil, "type param without replacement encountered")
    87  		return r
    88  
    89  	case *types.Basic:
    90  		return t
    91  
    92  	case *types.Array:
    93  		if r := subst.typ(t.Elem()); r != t.Elem() {
    94  			return types.NewArray(r, t.Len())
    95  		}
    96  		return t
    97  
    98  	case *types.Slice:
    99  		if r := subst.typ(t.Elem()); r != t.Elem() {
   100  			return types.NewSlice(r)
   101  		}
   102  		return t
   103  
   104  	case *types.Pointer:
   105  		if r := subst.typ(t.Elem()); r != t.Elem() {
   106  			return types.NewPointer(r)
   107  		}
   108  		return t
   109  
   110  	case *types.Tuple:
   111  		return subst.tuple(t)
   112  
   113  	case *types.Struct:
   114  		return subst.struct_(t)
   115  
   116  	case *types.Map:
   117  		key := subst.typ(t.Key())
   118  		elem := subst.typ(t.Elem())
   119  		if key != t.Key() || elem != t.Elem() {
   120  			return types.NewMap(key, elem)
   121  		}
   122  		return t
   123  
   124  	case *types.Chan:
   125  		if elem := subst.typ(t.Elem()); elem != t.Elem() {
   126  			return types.NewChan(t.Dir(), elem)
   127  		}
   128  		return t
   129  
   130  	case *types.Signature:
   131  		return subst.signature(t)
   132  
   133  	case *types.Union:
   134  		return subst.union(t)
   135  
   136  	case *types.Interface:
   137  		return subst.interface_(t)
   138  
   139  	case *aliases.Alias:
   140  		return subst.alias(t)
   141  
   142  	case *types.Named:
   143  		return subst.named(t)
   144  
   145  	case *opaqueType:
   146  		return t // opaque types are never substituted
   147  
   148  	default:
   149  		panic("unreachable")
   150  	}
   151  }
   152  
   153  // types returns the result of {subst.typ(ts[i])}.
   154  func (subst *subster) types(ts []types.Type) []types.Type {
   155  	res := make([]types.Type, len(ts))
   156  	for i := range ts {
   157  		res[i] = subst.typ(ts[i])
   158  	}
   159  	return res
   160  }
   161  
   162  func (subst *subster) tuple(t *types.Tuple) *types.Tuple {
   163  	if t != nil {
   164  		if vars := subst.varlist(t); vars != nil {
   165  			return types.NewTuple(vars...)
   166  		}
   167  	}
   168  	return t
   169  }
   170  
   171  type varlist interface {
   172  	At(i int) *types.Var
   173  	Len() int
   174  }
   175  
   176  // fieldlist is an adapter for structs for the varlist interface.
   177  type fieldlist struct {
   178  	str *types.Struct
   179  }
   180  
   181  func (fl fieldlist) At(i int) *types.Var { return fl.str.Field(i) }
   182  func (fl fieldlist) Len() int            { return fl.str.NumFields() }
   183  
   184  func (subst *subster) struct_(t *types.Struct) *types.Struct {
   185  	if t != nil {
   186  		if fields := subst.varlist(fieldlist{t}); fields != nil {
   187  			tags := make([]string, t.NumFields())
   188  			for i, n := 0, t.NumFields(); i < n; i++ {
   189  				tags[i] = t.Tag(i)
   190  			}
   191  			return types.NewStruct(fields, tags)
   192  		}
   193  	}
   194  	return t
   195  }
   196  
   197  // varlist reutrns subst(in[i]) or return nils if subst(v[i]) == v[i] for all i.
   198  func (subst *subster) varlist(in varlist) []*types.Var {
   199  	var out []*types.Var // nil => no updates
   200  	for i, n := 0, in.Len(); i < n; i++ {
   201  		v := in.At(i)
   202  		w := subst.var_(v)
   203  		if v != w && out == nil {
   204  			out = make([]*types.Var, n)
   205  			for j := 0; j < i; j++ {
   206  				out[j] = in.At(j)
   207  			}
   208  		}
   209  		if out != nil {
   210  			out[i] = w
   211  		}
   212  	}
   213  	return out
   214  }
   215  
   216  func (subst *subster) var_(v *types.Var) *types.Var {
   217  	if v != nil {
   218  		if typ := subst.typ(v.Type()); typ != v.Type() {
   219  			if v.IsField() {
   220  				return types.NewField(v.Pos(), v.Pkg(), v.Name(), typ, v.Embedded())
   221  			}
   222  			return types.NewVar(v.Pos(), v.Pkg(), v.Name(), typ)
   223  		}
   224  	}
   225  	return v
   226  }
   227  
   228  func (subst *subster) union(u *types.Union) *types.Union {
   229  	var out []*types.Term // nil => no updates
   230  
   231  	for i, n := 0, u.Len(); i < n; i++ {
   232  		t := u.Term(i)
   233  		r := subst.typ(t.Type())
   234  		if r != t.Type() && out == nil {
   235  			out = make([]*types.Term, n)
   236  			for j := 0; j < i; j++ {
   237  				out[j] = u.Term(j)
   238  			}
   239  		}
   240  		if out != nil {
   241  			out[i] = types.NewTerm(t.Tilde(), r)
   242  		}
   243  	}
   244  
   245  	if out != nil {
   246  		return types.NewUnion(out)
   247  	}
   248  	return u
   249  }
   250  
   251  func (subst *subster) interface_(iface *types.Interface) *types.Interface {
   252  	if iface == nil {
   253  		return nil
   254  	}
   255  
   256  	// methods for the interface. Initially nil if there is no known change needed.
   257  	// Signatures for the method where recv is nil. NewInterfaceType fills in the receivers.
   258  	var methods []*types.Func
   259  	initMethods := func(n int) { // copy first n explicit methods
   260  		methods = make([]*types.Func, iface.NumExplicitMethods())
   261  		for i := 0; i < n; i++ {
   262  			f := iface.ExplicitMethod(i)
   263  			norecv := changeRecv(f.Type().(*types.Signature), nil)
   264  			methods[i] = types.NewFunc(f.Pos(), f.Pkg(), f.Name(), norecv)
   265  		}
   266  	}
   267  	for i := 0; i < iface.NumExplicitMethods(); i++ {
   268  		f := iface.ExplicitMethod(i)
   269  		// On interfaces, we need to cycle break on anonymous interface types
   270  		// being in a cycle with their signatures being in cycles with their receivers
   271  		// that do not go through a Named.
   272  		norecv := changeRecv(f.Type().(*types.Signature), nil)
   273  		sig := subst.typ(norecv)
   274  		if sig != norecv && methods == nil {
   275  			initMethods(i)
   276  		}
   277  		if methods != nil {
   278  			methods[i] = types.NewFunc(f.Pos(), f.Pkg(), f.Name(), sig.(*types.Signature))
   279  		}
   280  	}
   281  
   282  	var embeds []types.Type
   283  	initEmbeds := func(n int) { // copy first n embedded types
   284  		embeds = make([]types.Type, iface.NumEmbeddeds())
   285  		for i := 0; i < n; i++ {
   286  			embeds[i] = iface.EmbeddedType(i)
   287  		}
   288  	}
   289  	for i := 0; i < iface.NumEmbeddeds(); i++ {
   290  		e := iface.EmbeddedType(i)
   291  		r := subst.typ(e)
   292  		if e != r && embeds == nil {
   293  			initEmbeds(i)
   294  		}
   295  		if embeds != nil {
   296  			embeds[i] = r
   297  		}
   298  	}
   299  
   300  	if methods == nil && embeds == nil {
   301  		return iface
   302  	}
   303  	if methods == nil {
   304  		initMethods(iface.NumExplicitMethods())
   305  	}
   306  	if embeds == nil {
   307  		initEmbeds(iface.NumEmbeddeds())
   308  	}
   309  	return types.NewInterfaceType(methods, embeds).Complete()
   310  }
   311  
   312  func (subst *subster) alias(t *aliases.Alias) types.Type {
   313  	// TODO(go.dev/issues/46477): support TypeParameters once these are available from go/types.
   314  	u := aliases.Unalias(t)
   315  	if s := subst.typ(u); s != u {
   316  		// If there is any change, do not create a new alias.
   317  		return s
   318  	}
   319  	// If there is no change, t did not reach any type parameter.
   320  	// Keep the Alias.
   321  	return t
   322  }
   323  
   324  func (subst *subster) named(t *types.Named) types.Type {
   325  	// A named type may be:
   326  	// (1) ordinary named type (non-local scope, no type parameters, no type arguments),
   327  	// (2) locally scoped type,
   328  	// (3) generic (type parameters but no type arguments), or
   329  	// (4) instantiated (type parameters and type arguments).
   330  	tparams := t.TypeParams()
   331  	if tparams.Len() == 0 {
   332  		if subst.scope != nil && !subst.scope.Contains(t.Obj().Pos()) {
   333  			// Outside the current function scope?
   334  			return t // case (1) ordinary
   335  		}
   336  
   337  		// case (2) locally scoped type.
   338  		// Create a new named type to represent this instantiation.
   339  		// We assume that local types of distinct instantiations of a
   340  		// generic function are distinct, even if they don't refer to
   341  		// type parameters, but the spec is unclear; see golang/go#58573.
   342  		//
   343  		// Subtle: We short circuit substitution and use a newly created type in
   344  		// subst, i.e. cache[t]=n, to pre-emptively replace t with n in recursive
   345  		// types during traversal. This both breaks infinite cycles and allows for
   346  		// constructing types with the replacement applied in subst.typ(under).
   347  		//
   348  		// Example:
   349  		// func foo[T any]() {
   350  		//   type linkedlist struct {
   351  		//     next *linkedlist
   352  		//     val T
   353  		//   }
   354  		// }
   355  		//
   356  		// When the field `next *linkedlist` is visited during subst.typ(under),
   357  		// we want the substituted type for the field `next` to be `*n`.
   358  		n := types.NewNamed(t.Obj(), nil, nil)
   359  		subst.cache[t] = n
   360  		subst.cache[n] = n
   361  		n.SetUnderlying(subst.typ(t.Underlying()))
   362  		return n
   363  	}
   364  	targs := t.TypeArgs()
   365  
   366  	// insts are arguments to instantiate using.
   367  	insts := make([]types.Type, tparams.Len())
   368  
   369  	// case (3) generic ==> targs.Len() == 0
   370  	// Instantiating a generic with no type arguments should be unreachable.
   371  	// Please report a bug if you encounter this.
   372  	assert(targs.Len() != 0, "substition into a generic Named type is currently unsupported")
   373  
   374  	// case (4) instantiated.
   375  	// Substitute into the type arguments and instantiate the replacements/
   376  	// Example:
   377  	//    type N[A any] func() A
   378  	//    func Foo[T](g N[T]) {}
   379  	//  To instantiate Foo[string], one goes through {T->string}. To get the type of g
   380  	//  one subsitutes T with string in {N with typeargs == {T} and typeparams == {A} }
   381  	//  to get {N with TypeArgs == {string} and typeparams == {A} }.
   382  	assert(targs.Len() == tparams.Len(), "typeargs.Len() must match typeparams.Len() if present")
   383  	for i, n := 0, targs.Len(); i < n; i++ {
   384  		inst := subst.typ(targs.At(i)) // TODO(generic): Check with rfindley for mutual recursion
   385  		insts[i] = inst
   386  	}
   387  	r, err := types.Instantiate(subst.ctxt, t.Origin(), insts, false)
   388  	assert(err == nil, "failed to Instantiate Named type")
   389  	return r
   390  }
   391  
   392  func (subst *subster) signature(t *types.Signature) types.Type {
   393  	tparams := t.TypeParams()
   394  
   395  	// We are choosing not to support tparams.Len() > 0 until a need has been observed in practice.
   396  	//
   397  	// There are some known usages for types.Types coming from types.{Eval,CheckExpr}.
   398  	// To support tparams.Len() > 0, we just need to do the following [psuedocode]:
   399  	//   targs := {subst.replacements[tparams[i]]]}; Instantiate(ctxt, t, targs, false)
   400  
   401  	assert(tparams.Len() == 0, "Substituting types.Signatures with generic functions are currently unsupported.")
   402  
   403  	// Either:
   404  	// (1)non-generic function.
   405  	//    no type params to substitute
   406  	// (2)generic method and recv needs to be substituted.
   407  
   408  	// Receivers can be either:
   409  	// named
   410  	// pointer to named
   411  	// interface
   412  	// nil
   413  	// interface is the problematic case. We need to cycle break there!
   414  	recv := subst.var_(t.Recv())
   415  	params := subst.tuple(t.Params())
   416  	results := subst.tuple(t.Results())
   417  	if recv != t.Recv() || params != t.Params() || results != t.Results() {
   418  		return types.NewSignatureType(recv, nil, nil, params, results, t.Variadic())
   419  	}
   420  	return t
   421  }
   422  
   423  // reaches returns true if a type t reaches any type t' s.t. c[t'] == true.
   424  // It updates c to cache results.
   425  //
   426  // reaches is currently only part of the wellFormed debug logic, and
   427  // in practice c is initially only type parameters. It is not currently
   428  // relied on in production.
   429  func reaches(t types.Type, c map[types.Type]bool) (res bool) {
   430  	if c, ok := c[t]; ok {
   431  		return c
   432  	}
   433  
   434  	// c is populated with temporary false entries as types are visited.
   435  	// This avoids repeat visits and break cycles.
   436  	c[t] = false
   437  	defer func() {
   438  		c[t] = res
   439  	}()
   440  
   441  	switch t := t.(type) {
   442  	case *types.TypeParam, *types.Basic:
   443  		return false
   444  	case *types.Array:
   445  		return reaches(t.Elem(), c)
   446  	case *types.Slice:
   447  		return reaches(t.Elem(), c)
   448  	case *types.Pointer:
   449  		return reaches(t.Elem(), c)
   450  	case *types.Tuple:
   451  		for i := 0; i < t.Len(); i++ {
   452  			if reaches(t.At(i).Type(), c) {
   453  				return true
   454  			}
   455  		}
   456  	case *types.Struct:
   457  		for i := 0; i < t.NumFields(); i++ {
   458  			if reaches(t.Field(i).Type(), c) {
   459  				return true
   460  			}
   461  		}
   462  	case *types.Map:
   463  		return reaches(t.Key(), c) || reaches(t.Elem(), c)
   464  	case *types.Chan:
   465  		return reaches(t.Elem(), c)
   466  	case *types.Signature:
   467  		if t.Recv() != nil && reaches(t.Recv().Type(), c) {
   468  			return true
   469  		}
   470  		return reaches(t.Params(), c) || reaches(t.Results(), c)
   471  	case *types.Union:
   472  		for i := 0; i < t.Len(); i++ {
   473  			if reaches(t.Term(i).Type(), c) {
   474  				return true
   475  			}
   476  		}
   477  	case *types.Interface:
   478  		for i := 0; i < t.NumEmbeddeds(); i++ {
   479  			if reaches(t.Embedded(i), c) {
   480  				return true
   481  			}
   482  		}
   483  		for i := 0; i < t.NumExplicitMethods(); i++ {
   484  			if reaches(t.ExplicitMethod(i).Type(), c) {
   485  				return true
   486  			}
   487  		}
   488  	case *types.Named, *aliases.Alias:
   489  		return reaches(t.Underlying(), c)
   490  	default:
   491  		panic("unreachable")
   492  	}
   493  	return false
   494  }