github.com/gopherjs/gopherjs@v1.19.0-beta1.0.20240506212314-27071a8796e4/internal/govendor/subst/subst.go (about)

     1  // Copyright 2022 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package subst
     6  
     7  import (
     8  	"go/types"
     9  )
    10  
    11  // Type substituter for a fixed set of replacement types.
    12  //
    13  // A nil *subster is an valid, empty substitution map. It always acts as
    14  // the identity function. This allows for treating parameterized and
    15  // non-parameterized functions identically while compiling to ssa.
    16  //
    17  // Not concurrency-safe.
    18  type subster struct {
    19  	replacements map[*types.TypeParam]types.Type // values should contain no type params
    20  	cache        map[types.Type]types.Type       // cache of subst results
    21  	ctxt         *types.Context                  // cache for instantiation
    22  	scope        *types.Scope                    // *types.Named declared within this scope can be substituted (optional)
    23  	debug        bool                            // perform extra debugging checks
    24  	// TODO(taking): consider adding Pos
    25  	// TODO(zpavlinovic): replacements can contain type params
    26  	// when generating instances inside of a generic function body.
    27  }
    28  
    29  // Returns a subster that replaces tparams[i] with targs[i]. Uses ctxt as a cache.
    30  // targs should not contain any types in tparams.
    31  // scope is the (optional) lexical block of the generic function for which we are substituting.
    32  func makeSubster(ctxt *types.Context, scope *types.Scope, tparams *types.TypeParamList, targs []types.Type, debug bool) *subster {
    33  	assert(tparams.Len() == len(targs), "makeSubster argument count must match")
    34  
    35  	subst := &subster{
    36  		replacements: make(map[*types.TypeParam]types.Type, tparams.Len()),
    37  		cache:        make(map[types.Type]types.Type),
    38  		ctxt:         ctxt,
    39  		scope:        scope,
    40  		debug:        debug,
    41  	}
    42  	for i := 0; i < tparams.Len(); i++ {
    43  		subst.replacements[tparams.At(i)] = targs[i]
    44  	}
    45  	if subst.debug {
    46  		subst.wellFormed()
    47  	}
    48  	return subst
    49  }
    50  
    51  // wellFormed asserts that subst was properly initialized.
    52  func (subst *subster) wellFormed() {
    53  	if subst == nil {
    54  		return
    55  	}
    56  	// Check that all of the type params do not appear in the arguments.
    57  	s := make(map[types.Type]bool, len(subst.replacements))
    58  	for tparam := range subst.replacements {
    59  		s[tparam] = true
    60  	}
    61  	for _, r := range subst.replacements {
    62  		if reaches(r, s) {
    63  			panic(subst)
    64  		}
    65  	}
    66  }
    67  
    68  // typ returns the type of t with the type parameter tparams[i] substituted
    69  // for the type targs[i] where subst was created using tparams and targs.
    70  func (subst *subster) typ(t types.Type) (res types.Type) {
    71  	if subst == nil {
    72  		return t // A nil subst is type preserving.
    73  	}
    74  	if r, ok := subst.cache[t]; ok {
    75  		return r
    76  	}
    77  	defer func() {
    78  		subst.cache[t] = res
    79  	}()
    80  
    81  	// fall through if result r will be identical to t, types.Identical(r, t).
    82  	switch t := t.(type) {
    83  	case *types.TypeParam:
    84  		r := subst.replacements[t]
    85  		assert(r != nil, "type param without replacement encountered")
    86  		return r
    87  
    88  	case *types.Basic:
    89  		return t
    90  
    91  	case *types.Array:
    92  		if r := subst.typ(t.Elem()); r != t.Elem() {
    93  			return types.NewArray(r, t.Len())
    94  		}
    95  		return t
    96  
    97  	case *types.Slice:
    98  		if r := subst.typ(t.Elem()); r != t.Elem() {
    99  			return types.NewSlice(r)
   100  		}
   101  		return t
   102  
   103  	case *types.Pointer:
   104  		if r := subst.typ(t.Elem()); r != t.Elem() {
   105  			return types.NewPointer(r)
   106  		}
   107  		return t
   108  
   109  	case *types.Tuple:
   110  		return subst.tuple(t)
   111  
   112  	case *types.Struct:
   113  		return subst.struct_(t)
   114  
   115  	case *types.Map:
   116  		key := subst.typ(t.Key())
   117  		elem := subst.typ(t.Elem())
   118  		if key != t.Key() || elem != t.Elem() {
   119  			return types.NewMap(key, elem)
   120  		}
   121  		return t
   122  
   123  	case *types.Chan:
   124  		if elem := subst.typ(t.Elem()); elem != t.Elem() {
   125  			return types.NewChan(t.Dir(), elem)
   126  		}
   127  		return t
   128  
   129  	case *types.Signature:
   130  		return subst.signature(t)
   131  
   132  	case *types.Union:
   133  		return subst.union(t)
   134  
   135  	case *types.Interface:
   136  		return subst.interface_(t)
   137  
   138  	case *types.Named:
   139  		return subst.named(t)
   140  
   141  	default:
   142  		panic("unreachable")
   143  	}
   144  }
   145  
   146  // types returns the result of {subst.typ(ts[i])}.
   147  func (subst *subster) types(ts []types.Type) []types.Type {
   148  	res := make([]types.Type, len(ts))
   149  	for i := range ts {
   150  		res[i] = subst.typ(ts[i])
   151  	}
   152  	return res
   153  }
   154  
   155  func (subst *subster) tuple(t *types.Tuple) *types.Tuple {
   156  	if t != nil {
   157  		if vars := subst.varlist(t); vars != nil {
   158  			return types.NewTuple(vars...)
   159  		}
   160  	}
   161  	return t
   162  }
   163  
   164  type varlist interface {
   165  	At(i int) *types.Var
   166  	Len() int
   167  }
   168  
   169  // fieldlist is an adapter for structs for the varlist interface.
   170  type fieldlist struct {
   171  	str *types.Struct
   172  }
   173  
   174  func (fl fieldlist) At(i int) *types.Var { return fl.str.Field(i) }
   175  func (fl fieldlist) Len() int            { return fl.str.NumFields() }
   176  
   177  func (subst *subster) struct_(t *types.Struct) *types.Struct {
   178  	if t != nil {
   179  		if fields := subst.varlist(fieldlist{t}); fields != nil {
   180  			tags := make([]string, t.NumFields())
   181  			for i, n := 0, t.NumFields(); i < n; i++ {
   182  				tags[i] = t.Tag(i)
   183  			}
   184  			return types.NewStruct(fields, tags)
   185  		}
   186  	}
   187  	return t
   188  }
   189  
   190  // varlist reutrns subst(in[i]) or return nils if subst(v[i]) == v[i] for all i.
   191  func (subst *subster) varlist(in varlist) []*types.Var {
   192  	var out []*types.Var // nil => no updates
   193  	for i, n := 0, in.Len(); i < n; i++ {
   194  		v := in.At(i)
   195  		w := subst.var_(v)
   196  		if v != w && out == nil {
   197  			out = make([]*types.Var, n)
   198  			for j := 0; j < i; j++ {
   199  				out[j] = in.At(j)
   200  			}
   201  		}
   202  		if out != nil {
   203  			out[i] = w
   204  		}
   205  	}
   206  	return out
   207  }
   208  
   209  func (subst *subster) var_(v *types.Var) *types.Var {
   210  	if v != nil {
   211  		if typ := subst.typ(v.Type()); typ != v.Type() {
   212  			if v.IsField() {
   213  				return types.NewField(v.Pos(), v.Pkg(), v.Name(), typ, v.Embedded())
   214  			}
   215  			return types.NewVar(v.Pos(), v.Pkg(), v.Name(), typ)
   216  		}
   217  	}
   218  	return v
   219  }
   220  
   221  func (subst *subster) union(u *types.Union) *types.Union {
   222  	var out []*types.Term // nil => no updates
   223  
   224  	for i, n := 0, u.Len(); i < n; i++ {
   225  		t := u.Term(i)
   226  		r := subst.typ(t.Type())
   227  		if r != t.Type() && out == nil {
   228  			out = make([]*types.Term, n)
   229  			for j := 0; j < i; j++ {
   230  				out[j] = u.Term(j)
   231  			}
   232  		}
   233  		if out != nil {
   234  			out[i] = types.NewTerm(t.Tilde(), r)
   235  		}
   236  	}
   237  
   238  	if out != nil {
   239  		return types.NewUnion(out)
   240  	}
   241  	return u
   242  }
   243  
   244  func (subst *subster) interface_(iface *types.Interface) *types.Interface {
   245  	if iface == nil {
   246  		return nil
   247  	}
   248  
   249  	// methods for the interface. Initially nil if there is no known change needed.
   250  	// Signatures for the method where recv is nil. NewInterfaceType fills in the receivers.
   251  	var methods []*types.Func
   252  	initMethods := func(n int) { // copy first n explicit methods
   253  		methods = make([]*types.Func, iface.NumExplicitMethods())
   254  		for i := 0; i < n; i++ {
   255  			f := iface.ExplicitMethod(i)
   256  			norecv := changeRecv(f.Type().(*types.Signature), nil)
   257  			methods[i] = types.NewFunc(f.Pos(), f.Pkg(), f.Name(), norecv)
   258  		}
   259  	}
   260  	for i := 0; i < iface.NumExplicitMethods(); i++ {
   261  		f := iface.ExplicitMethod(i)
   262  		// On interfaces, we need to cycle break on anonymous interface types
   263  		// being in a cycle with their signatures being in cycles with their receivers
   264  		// that do not go through a Named.
   265  		norecv := changeRecv(f.Type().(*types.Signature), nil)
   266  		sig := subst.typ(norecv)
   267  		if sig != norecv && methods == nil {
   268  			initMethods(i)
   269  		}
   270  		if methods != nil {
   271  			methods[i] = types.NewFunc(f.Pos(), f.Pkg(), f.Name(), sig.(*types.Signature))
   272  		}
   273  	}
   274  
   275  	var embeds []types.Type
   276  	initEmbeds := func(n int) { // copy first n embedded types
   277  		embeds = make([]types.Type, iface.NumEmbeddeds())
   278  		for i := 0; i < n; i++ {
   279  			embeds[i] = iface.EmbeddedType(i)
   280  		}
   281  	}
   282  	for i := 0; i < iface.NumEmbeddeds(); i++ {
   283  		e := iface.EmbeddedType(i)
   284  		r := subst.typ(e)
   285  		if e != r && embeds == nil {
   286  			initEmbeds(i)
   287  		}
   288  		if embeds != nil {
   289  			embeds[i] = r
   290  		}
   291  	}
   292  
   293  	if methods == nil && embeds == nil {
   294  		return iface
   295  	}
   296  	if methods == nil {
   297  		initMethods(iface.NumExplicitMethods())
   298  	}
   299  	if embeds == nil {
   300  		initEmbeds(iface.NumEmbeddeds())
   301  	}
   302  	return types.NewInterfaceType(methods, embeds).Complete()
   303  }
   304  
   305  func (subst *subster) named(t *types.Named) types.Type {
   306  	// A named type may be:
   307  	// (1) ordinary named type (non-local scope, no type parameters, no type arguments),
   308  	// (2) locally scoped type,
   309  	// (3) generic (type parameters but no type arguments), or
   310  	// (4) instantiated (type parameters and type arguments).
   311  	tparams := t.TypeParams()
   312  	if tparams.Len() == 0 {
   313  		if subst.scope != nil && !subst.scope.Contains(t.Obj().Pos()) {
   314  			// Outside the current function scope?
   315  			return t // case (1) ordinary
   316  		}
   317  
   318  		// case (2) locally scoped type.
   319  		// Create a new named type to represent this instantiation.
   320  		// We assume that local types of distinct instantiations of a
   321  		// generic function are distinct, even if they don't refer to
   322  		// type parameters, but the spec is unclear; see golang/go#58573.
   323  		//
   324  		// Subtle: We short circuit substitution and use a newly created type in
   325  		// subst, i.e. cache[t]=n, to pre-emptively replace t with n in recursive
   326  		// types during traversal. This both breaks infinite cycles and allows for
   327  		// constructing types with the replacement applied in subst.typ(under).
   328  		//
   329  		// Example:
   330  		// func foo[T any]() {
   331  		//   type linkedlist struct {
   332  		//     next *linkedlist
   333  		//     val T
   334  		//   }
   335  		// }
   336  		//
   337  		// When the field `next *linkedlist` is visited during subst.typ(under),
   338  		// we want the substituted type for the field `next` to be `*n`.
   339  		n := types.NewNamed(t.Obj(), nil, nil)
   340  		subst.cache[t] = n
   341  		subst.cache[n] = n
   342  		n.SetUnderlying(subst.typ(t.Underlying()))
   343  		return n
   344  	}
   345  	targs := t.TypeArgs()
   346  
   347  	// insts are arguments to instantiate using.
   348  	insts := make([]types.Type, tparams.Len())
   349  
   350  	// case (3) generic ==> targs.Len() == 0
   351  	// Instantiating a generic with no type arguments should be unreachable.
   352  	// Please report a bug if you encounter this.
   353  	assert(targs.Len() != 0, "substition into a generic Named type is currently unsupported")
   354  
   355  	// case (4) instantiated.
   356  	// Substitute into the type arguments and instantiate the replacements/
   357  	// Example:
   358  	//    type N[A any] func() A
   359  	//    func Foo[T](g N[T]) {}
   360  	//  To instantiate Foo[string], one goes through {T->string}. To get the type of g
   361  	//  one subsitutes T with string in {N with typeargs == {T} and typeparams == {A} }
   362  	//  to get {N with TypeArgs == {string} and typeparams == {A} }.
   363  	assert(targs.Len() == tparams.Len(), "typeargs.Len() must match typeparams.Len() if present")
   364  	for i, n := 0, targs.Len(); i < n; i++ {
   365  		inst := subst.typ(targs.At(i)) // TODO(generic): Check with rfindley for mutual recursion
   366  		insts[i] = inst
   367  	}
   368  	r, err := types.Instantiate(subst.ctxt, t.Origin(), insts, false)
   369  	assert(err == nil, "failed to Instantiate Named type")
   370  	return r
   371  }
   372  
   373  func (subst *subster) signature(t *types.Signature) types.Type {
   374  	tparams := t.TypeParams()
   375  
   376  	// We are choosing not to support tparams.Len() > 0 until a need has been observed in practice.
   377  	//
   378  	// There are some known usages for types.Types coming from types.{Eval,CheckExpr}.
   379  	// To support tparams.Len() > 0, we just need to do the following [psuedocode]:
   380  	//   targs := {subst.replacements[tparams[i]]]}; Instantiate(ctxt, t, targs, false)
   381  
   382  	assert(tparams.Len() == 0, "Substituting types.Signatures with generic functions are currently unsupported.")
   383  
   384  	// Either:
   385  	// (1)non-generic function.
   386  	//    no type params to substitute
   387  	// (2)generic method and recv needs to be substituted.
   388  
   389  	// Receivers can be either:
   390  	// named
   391  	// pointer to named
   392  	// interface
   393  	// nil
   394  	// interface is the problematic case. We need to cycle break there!
   395  	recv := subst.var_(t.Recv())
   396  	params := subst.tuple(t.Params())
   397  	results := subst.tuple(t.Results())
   398  	if recv != t.Recv() || params != t.Params() || results != t.Results() {
   399  		return types.NewSignatureType(recv, nil, nil, params, results, t.Variadic())
   400  	}
   401  	return t
   402  }
   403  
   404  // reaches returns true if a type t reaches any type t' s.t. c[t'] == true.
   405  // It updates c to cache results.
   406  //
   407  // reaches is currently only part of the wellFormed debug logic, and
   408  // in practice c is initially only type parameters. It is not currently
   409  // relied on in production.
   410  func reaches(t types.Type, c map[types.Type]bool) (res bool) {
   411  	if c, ok := c[t]; ok {
   412  		return c
   413  	}
   414  
   415  	// c is populated with temporary false entries as types are visited.
   416  	// This avoids repeat visits and break cycles.
   417  	c[t] = false
   418  	defer func() {
   419  		c[t] = res
   420  	}()
   421  
   422  	switch t := t.(type) {
   423  	case *types.TypeParam, *types.Basic:
   424  		return false
   425  	case *types.Array:
   426  		return reaches(t.Elem(), c)
   427  	case *types.Slice:
   428  		return reaches(t.Elem(), c)
   429  	case *types.Pointer:
   430  		return reaches(t.Elem(), c)
   431  	case *types.Tuple:
   432  		for i := 0; i < t.Len(); i++ {
   433  			if reaches(t.At(i).Type(), c) {
   434  				return true
   435  			}
   436  		}
   437  	case *types.Struct:
   438  		for i := 0; i < t.NumFields(); i++ {
   439  			if reaches(t.Field(i).Type(), c) {
   440  				return true
   441  			}
   442  		}
   443  	case *types.Map:
   444  		return reaches(t.Key(), c) || reaches(t.Elem(), c)
   445  	case *types.Chan:
   446  		return reaches(t.Elem(), c)
   447  	case *types.Signature:
   448  		if t.Recv() != nil && reaches(t.Recv().Type(), c) {
   449  			return true
   450  		}
   451  		return reaches(t.Params(), c) || reaches(t.Results(), c)
   452  	case *types.Union:
   453  		for i := 0; i < t.Len(); i++ {
   454  			if reaches(t.Term(i).Type(), c) {
   455  				return true
   456  			}
   457  		}
   458  	case *types.Interface:
   459  		for i := 0; i < t.NumEmbeddeds(); i++ {
   460  			if reaches(t.Embedded(i), c) {
   461  				return true
   462  			}
   463  		}
   464  		for i := 0; i < t.NumExplicitMethods(); i++ {
   465  			if reaches(t.ExplicitMethod(i).Type(), c) {
   466  				return true
   467  			}
   468  		}
   469  	case *types.Named:
   470  		return reaches(t.Underlying(), c)
   471  	default:
   472  		panic("unreachable")
   473  	}
   474  	return false
   475  }