golang.org/x/tools@v0.21.0/go/ssa/subst.go (about) 1 // Copyright 2022 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package ssa 6 7 import ( 8 "go/types" 9 10 "golang.org/x/tools/internal/aliases" 11 ) 12 13 // Type substituter for a fixed set of replacement types. 14 // 15 // A nil *subster is an valid, empty substitution map. It always acts as 16 // the identity function. This allows for treating parameterized and 17 // non-parameterized functions identically while compiling to ssa. 18 // 19 // Not concurrency-safe. 20 type subster struct { 21 replacements map[*types.TypeParam]types.Type // values should contain no type params 22 cache map[types.Type]types.Type // cache of subst results 23 ctxt *types.Context // cache for instantiation 24 scope *types.Scope // *types.Named declared within this scope can be substituted (optional) 25 debug bool // perform extra debugging checks 26 // TODO(taking): consider adding Pos 27 // TODO(zpavlinovic): replacements can contain type params 28 // when generating instances inside of a generic function body. 29 } 30 31 // Returns a subster that replaces tparams[i] with targs[i]. Uses ctxt as a cache. 32 // targs should not contain any types in tparams. 33 // scope is the (optional) lexical block of the generic function for which we are substituting. 34 func makeSubster(ctxt *types.Context, scope *types.Scope, tparams *types.TypeParamList, targs []types.Type, debug bool) *subster { 35 assert(tparams.Len() == len(targs), "makeSubster argument count must match") 36 37 subst := &subster{ 38 replacements: make(map[*types.TypeParam]types.Type, tparams.Len()), 39 cache: make(map[types.Type]types.Type), 40 ctxt: ctxt, 41 scope: scope, 42 debug: debug, 43 } 44 for i := 0; i < tparams.Len(); i++ { 45 subst.replacements[tparams.At(i)] = targs[i] 46 } 47 if subst.debug { 48 subst.wellFormed() 49 } 50 return subst 51 } 52 53 // wellFormed asserts that subst was properly initialized. 54 func (subst *subster) wellFormed() { 55 if subst == nil { 56 return 57 } 58 // Check that all of the type params do not appear in the arguments. 59 s := make(map[types.Type]bool, len(subst.replacements)) 60 for tparam := range subst.replacements { 61 s[tparam] = true 62 } 63 for _, r := range subst.replacements { 64 if reaches(r, s) { 65 panic(subst) 66 } 67 } 68 } 69 70 // typ returns the type of t with the type parameter tparams[i] substituted 71 // for the type targs[i] where subst was created using tparams and targs. 72 func (subst *subster) typ(t types.Type) (res types.Type) { 73 if subst == nil { 74 return t // A nil subst is type preserving. 75 } 76 if r, ok := subst.cache[t]; ok { 77 return r 78 } 79 defer func() { 80 subst.cache[t] = res 81 }() 82 83 switch t := t.(type) { 84 case *types.TypeParam: 85 r := subst.replacements[t] 86 assert(r != nil, "type param without replacement encountered") 87 return r 88 89 case *types.Basic: 90 return t 91 92 case *types.Array: 93 if r := subst.typ(t.Elem()); r != t.Elem() { 94 return types.NewArray(r, t.Len()) 95 } 96 return t 97 98 case *types.Slice: 99 if r := subst.typ(t.Elem()); r != t.Elem() { 100 return types.NewSlice(r) 101 } 102 return t 103 104 case *types.Pointer: 105 if r := subst.typ(t.Elem()); r != t.Elem() { 106 return types.NewPointer(r) 107 } 108 return t 109 110 case *types.Tuple: 111 return subst.tuple(t) 112 113 case *types.Struct: 114 return subst.struct_(t) 115 116 case *types.Map: 117 key := subst.typ(t.Key()) 118 elem := subst.typ(t.Elem()) 119 if key != t.Key() || elem != t.Elem() { 120 return types.NewMap(key, elem) 121 } 122 return t 123 124 case *types.Chan: 125 if elem := subst.typ(t.Elem()); elem != t.Elem() { 126 return types.NewChan(t.Dir(), elem) 127 } 128 return t 129 130 case *types.Signature: 131 return subst.signature(t) 132 133 case *types.Union: 134 return subst.union(t) 135 136 case *types.Interface: 137 return subst.interface_(t) 138 139 case *aliases.Alias: 140 return subst.alias(t) 141 142 case *types.Named: 143 return subst.named(t) 144 145 default: 146 panic("unreachable") 147 } 148 } 149 150 // types returns the result of {subst.typ(ts[i])}. 151 func (subst *subster) types(ts []types.Type) []types.Type { 152 res := make([]types.Type, len(ts)) 153 for i := range ts { 154 res[i] = subst.typ(ts[i]) 155 } 156 return res 157 } 158 159 func (subst *subster) tuple(t *types.Tuple) *types.Tuple { 160 if t != nil { 161 if vars := subst.varlist(t); vars != nil { 162 return types.NewTuple(vars...) 163 } 164 } 165 return t 166 } 167 168 type varlist interface { 169 At(i int) *types.Var 170 Len() int 171 } 172 173 // fieldlist is an adapter for structs for the varlist interface. 174 type fieldlist struct { 175 str *types.Struct 176 } 177 178 func (fl fieldlist) At(i int) *types.Var { return fl.str.Field(i) } 179 func (fl fieldlist) Len() int { return fl.str.NumFields() } 180 181 func (subst *subster) struct_(t *types.Struct) *types.Struct { 182 if t != nil { 183 if fields := subst.varlist(fieldlist{t}); fields != nil { 184 tags := make([]string, t.NumFields()) 185 for i, n := 0, t.NumFields(); i < n; i++ { 186 tags[i] = t.Tag(i) 187 } 188 return types.NewStruct(fields, tags) 189 } 190 } 191 return t 192 } 193 194 // varlist reutrns subst(in[i]) or return nils if subst(v[i]) == v[i] for all i. 195 func (subst *subster) varlist(in varlist) []*types.Var { 196 var out []*types.Var // nil => no updates 197 for i, n := 0, in.Len(); i < n; i++ { 198 v := in.At(i) 199 w := subst.var_(v) 200 if v != w && out == nil { 201 out = make([]*types.Var, n) 202 for j := 0; j < i; j++ { 203 out[j] = in.At(j) 204 } 205 } 206 if out != nil { 207 out[i] = w 208 } 209 } 210 return out 211 } 212 213 func (subst *subster) var_(v *types.Var) *types.Var { 214 if v != nil { 215 if typ := subst.typ(v.Type()); typ != v.Type() { 216 if v.IsField() { 217 return types.NewField(v.Pos(), v.Pkg(), v.Name(), typ, v.Embedded()) 218 } 219 return types.NewVar(v.Pos(), v.Pkg(), v.Name(), typ) 220 } 221 } 222 return v 223 } 224 225 func (subst *subster) union(u *types.Union) *types.Union { 226 var out []*types.Term // nil => no updates 227 228 for i, n := 0, u.Len(); i < n; i++ { 229 t := u.Term(i) 230 r := subst.typ(t.Type()) 231 if r != t.Type() && out == nil { 232 out = make([]*types.Term, n) 233 for j := 0; j < i; j++ { 234 out[j] = u.Term(j) 235 } 236 } 237 if out != nil { 238 out[i] = types.NewTerm(t.Tilde(), r) 239 } 240 } 241 242 if out != nil { 243 return types.NewUnion(out) 244 } 245 return u 246 } 247 248 func (subst *subster) interface_(iface *types.Interface) *types.Interface { 249 if iface == nil { 250 return nil 251 } 252 253 // methods for the interface. Initially nil if there is no known change needed. 254 // Signatures for the method where recv is nil. NewInterfaceType fills in the receivers. 255 var methods []*types.Func 256 initMethods := func(n int) { // copy first n explicit methods 257 methods = make([]*types.Func, iface.NumExplicitMethods()) 258 for i := 0; i < n; i++ { 259 f := iface.ExplicitMethod(i) 260 norecv := changeRecv(f.Type().(*types.Signature), nil) 261 methods[i] = types.NewFunc(f.Pos(), f.Pkg(), f.Name(), norecv) 262 } 263 } 264 for i := 0; i < iface.NumExplicitMethods(); i++ { 265 f := iface.ExplicitMethod(i) 266 // On interfaces, we need to cycle break on anonymous interface types 267 // being in a cycle with their signatures being in cycles with their receivers 268 // that do not go through a Named. 269 norecv := changeRecv(f.Type().(*types.Signature), nil) 270 sig := subst.typ(norecv) 271 if sig != norecv && methods == nil { 272 initMethods(i) 273 } 274 if methods != nil { 275 methods[i] = types.NewFunc(f.Pos(), f.Pkg(), f.Name(), sig.(*types.Signature)) 276 } 277 } 278 279 var embeds []types.Type 280 initEmbeds := func(n int) { // copy first n embedded types 281 embeds = make([]types.Type, iface.NumEmbeddeds()) 282 for i := 0; i < n; i++ { 283 embeds[i] = iface.EmbeddedType(i) 284 } 285 } 286 for i := 0; i < iface.NumEmbeddeds(); i++ { 287 e := iface.EmbeddedType(i) 288 r := subst.typ(e) 289 if e != r && embeds == nil { 290 initEmbeds(i) 291 } 292 if embeds != nil { 293 embeds[i] = r 294 } 295 } 296 297 if methods == nil && embeds == nil { 298 return iface 299 } 300 if methods == nil { 301 initMethods(iface.NumExplicitMethods()) 302 } 303 if embeds == nil { 304 initEmbeds(iface.NumEmbeddeds()) 305 } 306 return types.NewInterfaceType(methods, embeds).Complete() 307 } 308 309 func (subst *subster) alias(t *aliases.Alias) types.Type { 310 // TODO(go.dev/issues/46477): support TypeParameters once these are available from go/types. 311 u := aliases.Unalias(t) 312 if s := subst.typ(u); s != u { 313 // If there is any change, do not create a new alias. 314 return s 315 } 316 // If there is no change, t did not reach any type parameter. 317 // Keep the Alias. 318 return t 319 } 320 321 func (subst *subster) named(t *types.Named) types.Type { 322 // A named type may be: 323 // (1) ordinary named type (non-local scope, no type parameters, no type arguments), 324 // (2) locally scoped type, 325 // (3) generic (type parameters but no type arguments), or 326 // (4) instantiated (type parameters and type arguments). 327 tparams := t.TypeParams() 328 if tparams.Len() == 0 { 329 if subst.scope != nil && !subst.scope.Contains(t.Obj().Pos()) { 330 // Outside the current function scope? 331 return t // case (1) ordinary 332 } 333 334 // case (2) locally scoped type. 335 // Create a new named type to represent this instantiation. 336 // We assume that local types of distinct instantiations of a 337 // generic function are distinct, even if they don't refer to 338 // type parameters, but the spec is unclear; see golang/go#58573. 339 // 340 // Subtle: We short circuit substitution and use a newly created type in 341 // subst, i.e. cache[t]=n, to pre-emptively replace t with n in recursive 342 // types during traversal. This both breaks infinite cycles and allows for 343 // constructing types with the replacement applied in subst.typ(under). 344 // 345 // Example: 346 // func foo[T any]() { 347 // type linkedlist struct { 348 // next *linkedlist 349 // val T 350 // } 351 // } 352 // 353 // When the field `next *linkedlist` is visited during subst.typ(under), 354 // we want the substituted type for the field `next` to be `*n`. 355 n := types.NewNamed(t.Obj(), nil, nil) 356 subst.cache[t] = n 357 subst.cache[n] = n 358 n.SetUnderlying(subst.typ(t.Underlying())) 359 return n 360 } 361 targs := t.TypeArgs() 362 363 // insts are arguments to instantiate using. 364 insts := make([]types.Type, tparams.Len()) 365 366 // case (3) generic ==> targs.Len() == 0 367 // Instantiating a generic with no type arguments should be unreachable. 368 // Please report a bug if you encounter this. 369 assert(targs.Len() != 0, "substition into a generic Named type is currently unsupported") 370 371 // case (4) instantiated. 372 // Substitute into the type arguments and instantiate the replacements/ 373 // Example: 374 // type N[A any] func() A 375 // func Foo[T](g N[T]) {} 376 // To instantiate Foo[string], one goes through {T->string}. To get the type of g 377 // one subsitutes T with string in {N with typeargs == {T} and typeparams == {A} } 378 // to get {N with TypeArgs == {string} and typeparams == {A} }. 379 assert(targs.Len() == tparams.Len(), "typeargs.Len() must match typeparams.Len() if present") 380 for i, n := 0, targs.Len(); i < n; i++ { 381 inst := subst.typ(targs.At(i)) // TODO(generic): Check with rfindley for mutual recursion 382 insts[i] = inst 383 } 384 r, err := types.Instantiate(subst.ctxt, t.Origin(), insts, false) 385 assert(err == nil, "failed to Instantiate Named type") 386 return r 387 } 388 389 func (subst *subster) signature(t *types.Signature) types.Type { 390 tparams := t.TypeParams() 391 392 // We are choosing not to support tparams.Len() > 0 until a need has been observed in practice. 393 // 394 // There are some known usages for types.Types coming from types.{Eval,CheckExpr}. 395 // To support tparams.Len() > 0, we just need to do the following [psuedocode]: 396 // targs := {subst.replacements[tparams[i]]]}; Instantiate(ctxt, t, targs, false) 397 398 assert(tparams.Len() == 0, "Substituting types.Signatures with generic functions are currently unsupported.") 399 400 // Either: 401 // (1)non-generic function. 402 // no type params to substitute 403 // (2)generic method and recv needs to be substituted. 404 405 // Receivers can be either: 406 // named 407 // pointer to named 408 // interface 409 // nil 410 // interface is the problematic case. We need to cycle break there! 411 recv := subst.var_(t.Recv()) 412 params := subst.tuple(t.Params()) 413 results := subst.tuple(t.Results()) 414 if recv != t.Recv() || params != t.Params() || results != t.Results() { 415 return types.NewSignatureType(recv, nil, nil, params, results, t.Variadic()) 416 } 417 return t 418 } 419 420 // reaches returns true if a type t reaches any type t' s.t. c[t'] == true. 421 // It updates c to cache results. 422 // 423 // reaches is currently only part of the wellFormed debug logic, and 424 // in practice c is initially only type parameters. It is not currently 425 // relied on in production. 426 func reaches(t types.Type, c map[types.Type]bool) (res bool) { 427 if c, ok := c[t]; ok { 428 return c 429 } 430 431 // c is populated with temporary false entries as types are visited. 432 // This avoids repeat visits and break cycles. 433 c[t] = false 434 defer func() { 435 c[t] = res 436 }() 437 438 switch t := t.(type) { 439 case *types.TypeParam, *types.Basic: 440 return false 441 case *types.Array: 442 return reaches(t.Elem(), c) 443 case *types.Slice: 444 return reaches(t.Elem(), c) 445 case *types.Pointer: 446 return reaches(t.Elem(), c) 447 case *types.Tuple: 448 for i := 0; i < t.Len(); i++ { 449 if reaches(t.At(i).Type(), c) { 450 return true 451 } 452 } 453 case *types.Struct: 454 for i := 0; i < t.NumFields(); i++ { 455 if reaches(t.Field(i).Type(), c) { 456 return true 457 } 458 } 459 case *types.Map: 460 return reaches(t.Key(), c) || reaches(t.Elem(), c) 461 case *types.Chan: 462 return reaches(t.Elem(), c) 463 case *types.Signature: 464 if t.Recv() != nil && reaches(t.Recv().Type(), c) { 465 return true 466 } 467 return reaches(t.Params(), c) || reaches(t.Results(), c) 468 case *types.Union: 469 for i := 0; i < t.Len(); i++ { 470 if reaches(t.Term(i).Type(), c) { 471 return true 472 } 473 } 474 case *types.Interface: 475 for i := 0; i < t.NumEmbeddeds(); i++ { 476 if reaches(t.Embedded(i), c) { 477 return true 478 } 479 } 480 for i := 0; i < t.NumExplicitMethods(); i++ { 481 if reaches(t.ExplicitMethod(i).Type(), c) { 482 return true 483 } 484 } 485 case *types.Named, *aliases.Alias: 486 return reaches(t.Underlying(), c) 487 default: 488 panic("unreachable") 489 } 490 return false 491 }