github.com/powerman/golang-tools@v0.1.11-0.20220410185822-5ad214d8d803/go/ssa/subst.go (about)

     1  // Copyright 2022 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  package ssa
     5  
     6  import (
     7  	"fmt"
     8  	"go/types"
     9  
    10  	"github.com/powerman/golang-tools/internal/typeparams"
    11  )
    12  
    13  // Type substituter for a fixed set of replacement types.
    14  //
    15  // A nil *subster is an valid, empty substitution map. It always acts as
    16  // the identity function. This allows for treating parameterized and
    17  // non-parameterized functions identically while compiling to ssa.
    18  //
    19  // Not concurrency-safe.
    20  type subster struct {
    21  	replacements map[*typeparams.TypeParam]types.Type // values should contain no type params
    22  	cache        map[types.Type]types.Type            // cache of subst results
    23  	ctxt         *typeparams.Context
    24  	debug        bool // perform extra debugging checks
    25  	// TODO(taking): consider adding Pos
    26  }
    27  
    28  // Returns a subster that replaces tparams[i] with targs[i]. Uses ctxt as a cache.
    29  // targs should not contain any types in tparams.
    30  func makeSubster(ctxt *typeparams.Context, tparams []*typeparams.TypeParam, targs []types.Type, debug bool) *subster {
    31  	assert(len(tparams) == len(targs), "makeSubster argument count must match")
    32  
    33  	subst := &subster{
    34  		replacements: make(map[*typeparams.TypeParam]types.Type, len(tparams)),
    35  		cache:        make(map[types.Type]types.Type),
    36  		ctxt:         ctxt,
    37  		debug:        debug,
    38  	}
    39  	for i, tpar := range tparams {
    40  		subst.replacements[tpar] = targs[i]
    41  	}
    42  	if subst.debug {
    43  		if err := subst.wellFormed(); err != nil {
    44  			panic(err)
    45  		}
    46  	}
    47  	return subst
    48  }
    49  
    50  // wellFormed returns an error if subst was not properly initialized.
    51  func (subst *subster) wellFormed() error {
    52  	if subst == nil || len(subst.replacements) == 0 {
    53  		return nil
    54  	}
    55  	// Check that all of the type params do not appear in the arguments.
    56  	s := make(map[types.Type]bool, len(subst.replacements))
    57  	for tparam := range subst.replacements {
    58  		s[tparam] = true
    59  	}
    60  	for _, r := range subst.replacements {
    61  		if reaches(r, s) {
    62  			return fmt.Errorf("\n‰r %s s %v replacements %v\n", r, s, subst.replacements)
    63  		}
    64  	}
    65  	return nil
    66  }
    67  
    68  // typ returns the type of t with the type parameter tparams[i] substituted
    69  // for the type targs[i] where subst was created using tparams and targs.
    70  func (subst *subster) typ(t types.Type) (res types.Type) {
    71  	if subst == nil {
    72  		return t // A nil subst is type preserving.
    73  	}
    74  	if r, ok := subst.cache[t]; ok {
    75  		return r
    76  	}
    77  	defer func() {
    78  		subst.cache[t] = res
    79  	}()
    80  
    81  	// fall through if result r will be identical to t, types.Identical(r, t).
    82  	switch t := t.(type) {
    83  	case *typeparams.TypeParam:
    84  		r := subst.replacements[t]
    85  		assert(r != nil, "type param without replacement encountered")
    86  		return r
    87  
    88  	case *types.Basic:
    89  		return t
    90  
    91  	case *types.Array:
    92  		if r := subst.typ(t.Elem()); r != t.Elem() {
    93  			return types.NewArray(r, t.Len())
    94  		}
    95  		return t
    96  
    97  	case *types.Slice:
    98  		if r := subst.typ(t.Elem()); r != t.Elem() {
    99  			return types.NewSlice(r)
   100  		}
   101  		return t
   102  
   103  	case *types.Pointer:
   104  		if r := subst.typ(t.Elem()); r != t.Elem() {
   105  			return types.NewPointer(r)
   106  		}
   107  		return t
   108  
   109  	case *types.Tuple:
   110  		return subst.tuple(t)
   111  
   112  	case *types.Struct:
   113  		return subst.struct_(t)
   114  
   115  	case *types.Map:
   116  		key := subst.typ(t.Key())
   117  		elem := subst.typ(t.Elem())
   118  		if key != t.Key() || elem != t.Elem() {
   119  			return types.NewMap(key, elem)
   120  		}
   121  		return t
   122  
   123  	case *types.Chan:
   124  		if elem := subst.typ(t.Elem()); elem != t.Elem() {
   125  			return types.NewChan(t.Dir(), elem)
   126  		}
   127  		return t
   128  
   129  	case *types.Signature:
   130  		return subst.signature(t)
   131  
   132  	case *typeparams.Union:
   133  		return subst.union(t)
   134  
   135  	case *types.Interface:
   136  		return subst.interface_(t)
   137  
   138  	case *types.Named:
   139  		return subst.named(t)
   140  
   141  	default:
   142  		panic("unreachable")
   143  	}
   144  }
   145  
   146  func (subst *subster) tuple(t *types.Tuple) *types.Tuple {
   147  	if t != nil {
   148  		if vars := subst.varlist(t); vars != nil {
   149  			return types.NewTuple(vars...)
   150  		}
   151  	}
   152  	return t
   153  }
   154  
   155  type varlist interface {
   156  	At(i int) *types.Var
   157  	Len() int
   158  }
   159  
   160  // fieldlist is an adapter for structs for the varlist interface.
   161  type fieldlist struct {
   162  	str *types.Struct
   163  }
   164  
   165  func (fl fieldlist) At(i int) *types.Var { return fl.str.Field(i) }
   166  func (fl fieldlist) Len() int            { return fl.str.NumFields() }
   167  
   168  func (subst *subster) struct_(t *types.Struct) *types.Struct {
   169  	if t != nil {
   170  		if fields := subst.varlist(fieldlist{t}); fields != nil {
   171  			tags := make([]string, t.NumFields())
   172  			for i, n := 0, t.NumFields(); i < n; i++ {
   173  				tags[i] = t.Tag(i)
   174  			}
   175  			return types.NewStruct(fields, tags)
   176  		}
   177  	}
   178  	return t
   179  }
   180  
   181  // varlist reutrns subst(in[i]) or return nils if subst(v[i]) == v[i] for all i.
   182  func (subst *subster) varlist(in varlist) []*types.Var {
   183  	var out []*types.Var // nil => no updates
   184  	for i, n := 0, in.Len(); i < n; i++ {
   185  		v := in.At(i)
   186  		w := subst.var_(v)
   187  		if v != w && out == nil {
   188  			out = make([]*types.Var, n)
   189  			for j := 0; j < i; j++ {
   190  				out[j] = in.At(j)
   191  			}
   192  		}
   193  		if out != nil {
   194  			out[i] = w
   195  		}
   196  	}
   197  	return out
   198  }
   199  
   200  func (subst *subster) var_(v *types.Var) *types.Var {
   201  	if v != nil {
   202  		if typ := subst.typ(v.Type()); typ != v.Type() {
   203  			if v.IsField() {
   204  				return types.NewField(v.Pos(), v.Pkg(), v.Name(), typ, v.Embedded())
   205  			}
   206  			return types.NewVar(v.Pos(), v.Pkg(), v.Name(), typ)
   207  		}
   208  	}
   209  	return v
   210  }
   211  
   212  func (subst *subster) union(u *typeparams.Union) *typeparams.Union {
   213  	var out []*typeparams.Term // nil => no updates
   214  
   215  	for i, n := 0, u.Len(); i < n; i++ {
   216  		t := u.Term(i)
   217  		r := subst.typ(t.Type())
   218  		if r != t.Type() && out == nil {
   219  			out = make([]*typeparams.Term, n)
   220  			for j := 0; j < i; j++ {
   221  				out[j] = u.Term(j)
   222  			}
   223  		}
   224  		if out != nil {
   225  			out[i] = typeparams.NewTerm(t.Tilde(), r)
   226  		}
   227  	}
   228  
   229  	if out != nil {
   230  		return typeparams.NewUnion(out)
   231  	}
   232  	return u
   233  }
   234  
   235  func (subst *subster) interface_(iface *types.Interface) *types.Interface {
   236  	if iface == nil {
   237  		return nil
   238  	}
   239  
   240  	// methods for the interface. Initially nil if there is no known change needed.
   241  	// Signatures for the method where recv is nil. NewInterfaceType fills in the recievers.
   242  	var methods []*types.Func
   243  	initMethods := func(n int) { // copy first n explicit methods
   244  		methods = make([]*types.Func, iface.NumExplicitMethods())
   245  		for i := 0; i < n; i++ {
   246  			f := iface.ExplicitMethod(i)
   247  			norecv := changeRecv(f.Type().(*types.Signature), nil)
   248  			methods[i] = types.NewFunc(f.Pos(), f.Pkg(), f.Name(), norecv)
   249  		}
   250  	}
   251  	for i := 0; i < iface.NumExplicitMethods(); i++ {
   252  		f := iface.ExplicitMethod(i)
   253  		// On interfaces, we need to cycle break on anonymous interface types
   254  		// being in a cycle with their signatures being in cycles with their recievers
   255  		// that do not go through a Named.
   256  		norecv := changeRecv(f.Type().(*types.Signature), nil)
   257  		sig := subst.typ(norecv)
   258  		if sig != norecv && methods == nil {
   259  			initMethods(i)
   260  		}
   261  		if methods != nil {
   262  			methods[i] = types.NewFunc(f.Pos(), f.Pkg(), f.Name(), sig.(*types.Signature))
   263  		}
   264  	}
   265  
   266  	var embeds []types.Type
   267  	initEmbeds := func(n int) { // copy first n embedded types
   268  		embeds = make([]types.Type, iface.NumEmbeddeds())
   269  		for i := 0; i < n; i++ {
   270  			embeds[i] = iface.EmbeddedType(i)
   271  		}
   272  	}
   273  	for i := 0; i < iface.NumEmbeddeds(); i++ {
   274  		e := iface.EmbeddedType(i)
   275  		r := subst.typ(e)
   276  		if e != r && embeds == nil {
   277  			initEmbeds(i)
   278  		}
   279  		if embeds != nil {
   280  			embeds[i] = r
   281  		}
   282  	}
   283  
   284  	if methods == nil && embeds == nil {
   285  		return iface
   286  	}
   287  	if methods == nil {
   288  		initMethods(iface.NumExplicitMethods())
   289  	}
   290  	if embeds == nil {
   291  		initEmbeds(iface.NumEmbeddeds())
   292  	}
   293  	return types.NewInterfaceType(methods, embeds).Complete()
   294  }
   295  
   296  func (subst *subster) named(t *types.Named) types.Type {
   297  	// A name type may be:
   298  	// (1) ordinary (no type parameters, no type arguments),
   299  	// (2) generic (type parameters but no type arguments), or
   300  	// (3) instantiated (type parameters and type arguments).
   301  	tparams := typeparams.ForNamed(t)
   302  	if tparams.Len() == 0 {
   303  		// case (1) ordinary
   304  
   305  		// Note: If Go allows for local type declarations in generic
   306  		// functions we may need to descend into underlying as well.
   307  		return t
   308  	}
   309  	targs := typeparams.NamedTypeArgs(t)
   310  
   311  	// insts are arguments to instantiate using.
   312  	insts := make([]types.Type, tparams.Len())
   313  
   314  	// case (2) generic ==> targs.Len() == 0
   315  	// Instantiating a generic with no type arguments should be unreachable.
   316  	// Please report a bug if you encounter this.
   317  	assert(targs.Len() != 0, "substition into a generic Named type is currently unsupported")
   318  
   319  	// case (3) instantiated.
   320  	// Substitute into the type arguments and instantiate the replacements/
   321  	// Example:
   322  	//    type N[A any] func() A
   323  	//    func Foo[T](g N[T]) {}
   324  	//  To instantiate Foo[string], one goes through {T->string}. To get the type of g
   325  	//  one subsitutes T with string in {N with TypeArgs == {T} and TypeParams == {A} }
   326  	//  to get {N with TypeArgs == {string} and TypeParams == {A} }.
   327  	assert(targs.Len() == tparams.Len(), "TypeArgs().Len() must match TypeParams().Len() if present")
   328  	for i, n := 0, targs.Len(); i < n; i++ {
   329  		inst := subst.typ(targs.At(i)) // TODO(generic): Check with rfindley for mutual recursion
   330  		insts[i] = inst
   331  	}
   332  	r, err := typeparams.Instantiate(subst.ctxt, typeparams.NamedTypeOrigin(t), insts, false)
   333  	assert(err == nil, "failed to Instantiate Named type")
   334  	return r
   335  }
   336  
   337  func (subst *subster) signature(t *types.Signature) types.Type {
   338  	tparams := typeparams.ForSignature(t)
   339  
   340  	// We are choosing not to support tparams.Len() > 0 until a need has been observed in practice.
   341  	//
   342  	// There are some known usages for types.Types coming from types.{Eval,CheckExpr}.
   343  	// To support tparams.Len() > 0, we just need to do the following [psuedocode]:
   344  	//   targs := {subst.replacements[tparams[i]]]}; Instantiate(ctxt, t, targs, false)
   345  
   346  	assert(tparams.Len() == 0, "Substituting types.Signatures with generic functions are currently unsupported.")
   347  
   348  	// Either:
   349  	// (1)non-generic function.
   350  	//    no type params to substitute
   351  	// (2)generic method and recv needs to be substituted.
   352  
   353  	// Recievers can be either:
   354  	// named
   355  	// pointer to named
   356  	// interface
   357  	// nil
   358  	// interface is the problematic case. We need to cycle break there!
   359  	recv := subst.var_(t.Recv())
   360  	params := subst.tuple(t.Params())
   361  	results := subst.tuple(t.Results())
   362  	if recv != t.Recv() || params != t.Params() || results != t.Results() {
   363  		return types.NewSignature(recv, params, results, t.Variadic())
   364  	}
   365  	return t
   366  }
   367  
   368  // reaches returns true if a type t reaches any type t' s.t. c[t'] == true.
   369  // Updates c to cache results.
   370  func reaches(t types.Type, c map[types.Type]bool) (res bool) {
   371  	if c, ok := c[t]; ok {
   372  		return c
   373  	}
   374  	c[t] = false // prevent cycles
   375  	defer func() {
   376  		c[t] = res
   377  	}()
   378  
   379  	switch t := t.(type) {
   380  	case *typeparams.TypeParam, *types.Basic:
   381  		// no-op => c == false
   382  	case *types.Array:
   383  		return reaches(t.Elem(), c)
   384  	case *types.Slice:
   385  		return reaches(t.Elem(), c)
   386  	case *types.Pointer:
   387  		return reaches(t.Elem(), c)
   388  	case *types.Tuple:
   389  		for i := 0; i < t.Len(); i++ {
   390  			if reaches(t.At(i).Type(), c) {
   391  				return true
   392  			}
   393  		}
   394  	case *types.Struct:
   395  		for i := 0; i < t.NumFields(); i++ {
   396  			if reaches(t.Field(i).Type(), c) {
   397  				return true
   398  			}
   399  		}
   400  	case *types.Map:
   401  		return reaches(t.Key(), c) || reaches(t.Elem(), c)
   402  	case *types.Chan:
   403  		return reaches(t.Elem(), c)
   404  	case *types.Signature:
   405  		if t.Recv() != nil && reaches(t.Recv().Type(), c) {
   406  			return true
   407  		}
   408  		return reaches(t.Params(), c) || reaches(t.Results(), c)
   409  	case *typeparams.Union:
   410  		for i := 0; i < t.Len(); i++ {
   411  			if reaches(t.Term(i).Type(), c) {
   412  				return true
   413  			}
   414  		}
   415  	case *types.Interface:
   416  		for i := 0; i < t.NumEmbeddeds(); i++ {
   417  			if reaches(t.Embedded(i), c) {
   418  				return true
   419  			}
   420  		}
   421  		for i := 0; i < t.NumExplicitMethods(); i++ {
   422  			if reaches(t.ExplicitMethod(i).Type(), c) {
   423  				return true
   424  			}
   425  		}
   426  	case *types.Named:
   427  		return reaches(t.Underlying(), c)
   428  	default:
   429  		panic("unreachable")
   430  	}
   431  	return false
   432  }