golang.org/x/tools@v0.21.1-0.20240520172518-788d39e776b1/go/ssa/subst.go (about) 1 // Copyright 2022 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package ssa 6 7 import ( 8 "go/types" 9 10 "golang.org/x/tools/internal/aliases" 11 ) 12 13 // Type substituter for a fixed set of replacement types. 14 // 15 // A nil *subster is an valid, empty substitution map. It always acts as 16 // the identity function. This allows for treating parameterized and 17 // non-parameterized functions identically while compiling to ssa. 18 // 19 // Not concurrency-safe. 20 type subster struct { 21 replacements map[*types.TypeParam]types.Type // values should contain no type params 22 cache map[types.Type]types.Type // cache of subst results 23 ctxt *types.Context // cache for instantiation 24 scope *types.Scope // *types.Named declared within this scope can be substituted (optional) 25 debug bool // perform extra debugging checks 26 // TODO(taking): consider adding Pos 27 // TODO(zpavlinovic): replacements can contain type params 28 // when generating instances inside of a generic function body. 29 } 30 31 // Returns a subster that replaces tparams[i] with targs[i]. Uses ctxt as a cache. 32 // targs should not contain any types in tparams. 33 // scope is the (optional) lexical block of the generic function for which we are substituting. 34 func makeSubster(ctxt *types.Context, scope *types.Scope, tparams *types.TypeParamList, targs []types.Type, debug bool) *subster { 35 assert(tparams.Len() == len(targs), "makeSubster argument count must match") 36 37 subst := &subster{ 38 replacements: make(map[*types.TypeParam]types.Type, tparams.Len()), 39 cache: make(map[types.Type]types.Type), 40 ctxt: ctxt, 41 scope: scope, 42 debug: debug, 43 } 44 for i := 0; i < tparams.Len(); i++ { 45 subst.replacements[tparams.At(i)] = targs[i] 46 } 47 if subst.debug { 48 subst.wellFormed() 49 } 50 return subst 51 } 52 53 // wellFormed asserts that subst was properly initialized. 54 func (subst *subster) wellFormed() { 55 if subst == nil { 56 return 57 } 58 // Check that all of the type params do not appear in the arguments. 59 s := make(map[types.Type]bool, len(subst.replacements)) 60 for tparam := range subst.replacements { 61 s[tparam] = true 62 } 63 for _, r := range subst.replacements { 64 if reaches(r, s) { 65 panic(subst) 66 } 67 } 68 } 69 70 // typ returns the type of t with the type parameter tparams[i] substituted 71 // for the type targs[i] where subst was created using tparams and targs. 72 func (subst *subster) typ(t types.Type) (res types.Type) { 73 if subst == nil { 74 return t // A nil subst is type preserving. 75 } 76 if r, ok := subst.cache[t]; ok { 77 return r 78 } 79 defer func() { 80 subst.cache[t] = res 81 }() 82 83 switch t := t.(type) { 84 case *types.TypeParam: 85 r := subst.replacements[t] 86 assert(r != nil, "type param without replacement encountered") 87 return r 88 89 case *types.Basic: 90 return t 91 92 case *types.Array: 93 if r := subst.typ(t.Elem()); r != t.Elem() { 94 return types.NewArray(r, t.Len()) 95 } 96 return t 97 98 case *types.Slice: 99 if r := subst.typ(t.Elem()); r != t.Elem() { 100 return types.NewSlice(r) 101 } 102 return t 103 104 case *types.Pointer: 105 if r := subst.typ(t.Elem()); r != t.Elem() { 106 return types.NewPointer(r) 107 } 108 return t 109 110 case *types.Tuple: 111 return subst.tuple(t) 112 113 case *types.Struct: 114 return subst.struct_(t) 115 116 case *types.Map: 117 key := subst.typ(t.Key()) 118 elem := subst.typ(t.Elem()) 119 if key != t.Key() || elem != t.Elem() { 120 return types.NewMap(key, elem) 121 } 122 return t 123 124 case *types.Chan: 125 if elem := subst.typ(t.Elem()); elem != t.Elem() { 126 return types.NewChan(t.Dir(), elem) 127 } 128 return t 129 130 case *types.Signature: 131 return subst.signature(t) 132 133 case *types.Union: 134 return subst.union(t) 135 136 case *types.Interface: 137 return subst.interface_(t) 138 139 case *aliases.Alias: 140 return subst.alias(t) 141 142 case *types.Named: 143 return subst.named(t) 144 145 case *opaqueType: 146 return t // opaque types are never substituted 147 148 default: 149 panic("unreachable") 150 } 151 } 152 153 // types returns the result of {subst.typ(ts[i])}. 154 func (subst *subster) types(ts []types.Type) []types.Type { 155 res := make([]types.Type, len(ts)) 156 for i := range ts { 157 res[i] = subst.typ(ts[i]) 158 } 159 return res 160 } 161 162 func (subst *subster) tuple(t *types.Tuple) *types.Tuple { 163 if t != nil { 164 if vars := subst.varlist(t); vars != nil { 165 return types.NewTuple(vars...) 166 } 167 } 168 return t 169 } 170 171 type varlist interface { 172 At(i int) *types.Var 173 Len() int 174 } 175 176 // fieldlist is an adapter for structs for the varlist interface. 177 type fieldlist struct { 178 str *types.Struct 179 } 180 181 func (fl fieldlist) At(i int) *types.Var { return fl.str.Field(i) } 182 func (fl fieldlist) Len() int { return fl.str.NumFields() } 183 184 func (subst *subster) struct_(t *types.Struct) *types.Struct { 185 if t != nil { 186 if fields := subst.varlist(fieldlist{t}); fields != nil { 187 tags := make([]string, t.NumFields()) 188 for i, n := 0, t.NumFields(); i < n; i++ { 189 tags[i] = t.Tag(i) 190 } 191 return types.NewStruct(fields, tags) 192 } 193 } 194 return t 195 } 196 197 // varlist reutrns subst(in[i]) or return nils if subst(v[i]) == v[i] for all i. 198 func (subst *subster) varlist(in varlist) []*types.Var { 199 var out []*types.Var // nil => no updates 200 for i, n := 0, in.Len(); i < n; i++ { 201 v := in.At(i) 202 w := subst.var_(v) 203 if v != w && out == nil { 204 out = make([]*types.Var, n) 205 for j := 0; j < i; j++ { 206 out[j] = in.At(j) 207 } 208 } 209 if out != nil { 210 out[i] = w 211 } 212 } 213 return out 214 } 215 216 func (subst *subster) var_(v *types.Var) *types.Var { 217 if v != nil { 218 if typ := subst.typ(v.Type()); typ != v.Type() { 219 if v.IsField() { 220 return types.NewField(v.Pos(), v.Pkg(), v.Name(), typ, v.Embedded()) 221 } 222 return types.NewVar(v.Pos(), v.Pkg(), v.Name(), typ) 223 } 224 } 225 return v 226 } 227 228 func (subst *subster) union(u *types.Union) *types.Union { 229 var out []*types.Term // nil => no updates 230 231 for i, n := 0, u.Len(); i < n; i++ { 232 t := u.Term(i) 233 r := subst.typ(t.Type()) 234 if r != t.Type() && out == nil { 235 out = make([]*types.Term, n) 236 for j := 0; j < i; j++ { 237 out[j] = u.Term(j) 238 } 239 } 240 if out != nil { 241 out[i] = types.NewTerm(t.Tilde(), r) 242 } 243 } 244 245 if out != nil { 246 return types.NewUnion(out) 247 } 248 return u 249 } 250 251 func (subst *subster) interface_(iface *types.Interface) *types.Interface { 252 if iface == nil { 253 return nil 254 } 255 256 // methods for the interface. Initially nil if there is no known change needed. 257 // Signatures for the method where recv is nil. NewInterfaceType fills in the receivers. 258 var methods []*types.Func 259 initMethods := func(n int) { // copy first n explicit methods 260 methods = make([]*types.Func, iface.NumExplicitMethods()) 261 for i := 0; i < n; i++ { 262 f := iface.ExplicitMethod(i) 263 norecv := changeRecv(f.Type().(*types.Signature), nil) 264 methods[i] = types.NewFunc(f.Pos(), f.Pkg(), f.Name(), norecv) 265 } 266 } 267 for i := 0; i < iface.NumExplicitMethods(); i++ { 268 f := iface.ExplicitMethod(i) 269 // On interfaces, we need to cycle break on anonymous interface types 270 // being in a cycle with their signatures being in cycles with their receivers 271 // that do not go through a Named. 272 norecv := changeRecv(f.Type().(*types.Signature), nil) 273 sig := subst.typ(norecv) 274 if sig != norecv && methods == nil { 275 initMethods(i) 276 } 277 if methods != nil { 278 methods[i] = types.NewFunc(f.Pos(), f.Pkg(), f.Name(), sig.(*types.Signature)) 279 } 280 } 281 282 var embeds []types.Type 283 initEmbeds := func(n int) { // copy first n embedded types 284 embeds = make([]types.Type, iface.NumEmbeddeds()) 285 for i := 0; i < n; i++ { 286 embeds[i] = iface.EmbeddedType(i) 287 } 288 } 289 for i := 0; i < iface.NumEmbeddeds(); i++ { 290 e := iface.EmbeddedType(i) 291 r := subst.typ(e) 292 if e != r && embeds == nil { 293 initEmbeds(i) 294 } 295 if embeds != nil { 296 embeds[i] = r 297 } 298 } 299 300 if methods == nil && embeds == nil { 301 return iface 302 } 303 if methods == nil { 304 initMethods(iface.NumExplicitMethods()) 305 } 306 if embeds == nil { 307 initEmbeds(iface.NumEmbeddeds()) 308 } 309 return types.NewInterfaceType(methods, embeds).Complete() 310 } 311 312 func (subst *subster) alias(t *aliases.Alias) types.Type { 313 // TODO(go.dev/issues/46477): support TypeParameters once these are available from go/types. 314 u := aliases.Unalias(t) 315 if s := subst.typ(u); s != u { 316 // If there is any change, do not create a new alias. 317 return s 318 } 319 // If there is no change, t did not reach any type parameter. 320 // Keep the Alias. 321 return t 322 } 323 324 func (subst *subster) named(t *types.Named) types.Type { 325 // A named type may be: 326 // (1) ordinary named type (non-local scope, no type parameters, no type arguments), 327 // (2) locally scoped type, 328 // (3) generic (type parameters but no type arguments), or 329 // (4) instantiated (type parameters and type arguments). 330 tparams := t.TypeParams() 331 if tparams.Len() == 0 { 332 if subst.scope != nil && !subst.scope.Contains(t.Obj().Pos()) { 333 // Outside the current function scope? 334 return t // case (1) ordinary 335 } 336 337 // case (2) locally scoped type. 338 // Create a new named type to represent this instantiation. 339 // We assume that local types of distinct instantiations of a 340 // generic function are distinct, even if they don't refer to 341 // type parameters, but the spec is unclear; see golang/go#58573. 342 // 343 // Subtle: We short circuit substitution and use a newly created type in 344 // subst, i.e. cache[t]=n, to pre-emptively replace t with n in recursive 345 // types during traversal. This both breaks infinite cycles and allows for 346 // constructing types with the replacement applied in subst.typ(under). 347 // 348 // Example: 349 // func foo[T any]() { 350 // type linkedlist struct { 351 // next *linkedlist 352 // val T 353 // } 354 // } 355 // 356 // When the field `next *linkedlist` is visited during subst.typ(under), 357 // we want the substituted type for the field `next` to be `*n`. 358 n := types.NewNamed(t.Obj(), nil, nil) 359 subst.cache[t] = n 360 subst.cache[n] = n 361 n.SetUnderlying(subst.typ(t.Underlying())) 362 return n 363 } 364 targs := t.TypeArgs() 365 366 // insts are arguments to instantiate using. 367 insts := make([]types.Type, tparams.Len()) 368 369 // case (3) generic ==> targs.Len() == 0 370 // Instantiating a generic with no type arguments should be unreachable. 371 // Please report a bug if you encounter this. 372 assert(targs.Len() != 0, "substition into a generic Named type is currently unsupported") 373 374 // case (4) instantiated. 375 // Substitute into the type arguments and instantiate the replacements/ 376 // Example: 377 // type N[A any] func() A 378 // func Foo[T](g N[T]) {} 379 // To instantiate Foo[string], one goes through {T->string}. To get the type of g 380 // one subsitutes T with string in {N with typeargs == {T} and typeparams == {A} } 381 // to get {N with TypeArgs == {string} and typeparams == {A} }. 382 assert(targs.Len() == tparams.Len(), "typeargs.Len() must match typeparams.Len() if present") 383 for i, n := 0, targs.Len(); i < n; i++ { 384 inst := subst.typ(targs.At(i)) // TODO(generic): Check with rfindley for mutual recursion 385 insts[i] = inst 386 } 387 r, err := types.Instantiate(subst.ctxt, t.Origin(), insts, false) 388 assert(err == nil, "failed to Instantiate Named type") 389 return r 390 } 391 392 func (subst *subster) signature(t *types.Signature) types.Type { 393 tparams := t.TypeParams() 394 395 // We are choosing not to support tparams.Len() > 0 until a need has been observed in practice. 396 // 397 // There are some known usages for types.Types coming from types.{Eval,CheckExpr}. 398 // To support tparams.Len() > 0, we just need to do the following [psuedocode]: 399 // targs := {subst.replacements[tparams[i]]]}; Instantiate(ctxt, t, targs, false) 400 401 assert(tparams.Len() == 0, "Substituting types.Signatures with generic functions are currently unsupported.") 402 403 // Either: 404 // (1)non-generic function. 405 // no type params to substitute 406 // (2)generic method and recv needs to be substituted. 407 408 // Receivers can be either: 409 // named 410 // pointer to named 411 // interface 412 // nil 413 // interface is the problematic case. We need to cycle break there! 414 recv := subst.var_(t.Recv()) 415 params := subst.tuple(t.Params()) 416 results := subst.tuple(t.Results()) 417 if recv != t.Recv() || params != t.Params() || results != t.Results() { 418 return types.NewSignatureType(recv, nil, nil, params, results, t.Variadic()) 419 } 420 return t 421 } 422 423 // reaches returns true if a type t reaches any type t' s.t. c[t'] == true. 424 // It updates c to cache results. 425 // 426 // reaches is currently only part of the wellFormed debug logic, and 427 // in practice c is initially only type parameters. It is not currently 428 // relied on in production. 429 func reaches(t types.Type, c map[types.Type]bool) (res bool) { 430 if c, ok := c[t]; ok { 431 return c 432 } 433 434 // c is populated with temporary false entries as types are visited. 435 // This avoids repeat visits and break cycles. 436 c[t] = false 437 defer func() { 438 c[t] = res 439 }() 440 441 switch t := t.(type) { 442 case *types.TypeParam, *types.Basic: 443 return false 444 case *types.Array: 445 return reaches(t.Elem(), c) 446 case *types.Slice: 447 return reaches(t.Elem(), c) 448 case *types.Pointer: 449 return reaches(t.Elem(), c) 450 case *types.Tuple: 451 for i := 0; i < t.Len(); i++ { 452 if reaches(t.At(i).Type(), c) { 453 return true 454 } 455 } 456 case *types.Struct: 457 for i := 0; i < t.NumFields(); i++ { 458 if reaches(t.Field(i).Type(), c) { 459 return true 460 } 461 } 462 case *types.Map: 463 return reaches(t.Key(), c) || reaches(t.Elem(), c) 464 case *types.Chan: 465 return reaches(t.Elem(), c) 466 case *types.Signature: 467 if t.Recv() != nil && reaches(t.Recv().Type(), c) { 468 return true 469 } 470 return reaches(t.Params(), c) || reaches(t.Results(), c) 471 case *types.Union: 472 for i := 0; i < t.Len(); i++ { 473 if reaches(t.Term(i).Type(), c) { 474 return true 475 } 476 } 477 case *types.Interface: 478 for i := 0; i < t.NumEmbeddeds(); i++ { 479 if reaches(t.Embedded(i), c) { 480 return true 481 } 482 } 483 for i := 0; i < t.NumExplicitMethods(); i++ { 484 if reaches(t.ExplicitMethod(i).Type(), c) { 485 return true 486 } 487 } 488 case *types.Named, *aliases.Alias: 489 return reaches(t.Underlying(), c) 490 default: 491 panic("unreachable") 492 } 493 return false 494 }