github.com/cockroachdb/tools@v0.0.0-20230222021103-a6d27438930d/go/ssa/subst.go (about) 1 // Copyright 2022 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package ssa 6 7 import ( 8 "go/types" 9 10 "golang.org/x/tools/internal/typeparams" 11 ) 12 13 // Type substituter for a fixed set of replacement types. 14 // 15 // A nil *subster is an valid, empty substitution map. It always acts as 16 // the identity function. This allows for treating parameterized and 17 // non-parameterized functions identically while compiling to ssa. 18 // 19 // Not concurrency-safe. 20 type subster struct { 21 replacements map[*typeparams.TypeParam]types.Type // values should contain no type params 22 cache map[types.Type]types.Type // cache of subst results 23 ctxt *typeparams.Context // cache for instantiation 24 scope *types.Scope // *types.Named declared within this scope can be substituted (optional) 25 debug bool // perform extra debugging checks 26 // TODO(taking): consider adding Pos 27 // TODO(zpavlinovic): replacements can contain type params 28 // when generating instances inside of a generic function body. 29 } 30 31 // Returns a subster that replaces tparams[i] with targs[i]. Uses ctxt as a cache. 32 // targs should not contain any types in tparams. 33 // scope is the (optional) lexical block of the generic function for which we are substituting. 34 func makeSubster(ctxt *typeparams.Context, scope *types.Scope, tparams *typeparams.TypeParamList, targs []types.Type, debug bool) *subster { 35 assert(tparams.Len() == len(targs), "makeSubster argument count must match") 36 37 subst := &subster{ 38 replacements: make(map[*typeparams.TypeParam]types.Type, tparams.Len()), 39 cache: make(map[types.Type]types.Type), 40 ctxt: ctxt, 41 scope: scope, 42 debug: debug, 43 } 44 for i := 0; i < tparams.Len(); i++ { 45 subst.replacements[tparams.At(i)] = targs[i] 46 } 47 if subst.debug { 48 subst.wellFormed() 49 } 50 return subst 51 } 52 53 // wellFormed asserts that subst was properly initialized. 54 func (subst *subster) wellFormed() { 55 if subst == nil { 56 return 57 } 58 // Check that all of the type params do not appear in the arguments. 59 s := make(map[types.Type]bool, len(subst.replacements)) 60 for tparam := range subst.replacements { 61 s[tparam] = true 62 } 63 for _, r := range subst.replacements { 64 if reaches(r, s) { 65 panic(subst) 66 } 67 } 68 } 69 70 // typ returns the type of t with the type parameter tparams[i] substituted 71 // for the type targs[i] where subst was created using tparams and targs. 72 func (subst *subster) typ(t types.Type) (res types.Type) { 73 if subst == nil { 74 return t // A nil subst is type preserving. 75 } 76 if r, ok := subst.cache[t]; ok { 77 return r 78 } 79 defer func() { 80 subst.cache[t] = res 81 }() 82 83 // fall through if result r will be identical to t, types.Identical(r, t). 84 switch t := t.(type) { 85 case *typeparams.TypeParam: 86 r := subst.replacements[t] 87 assert(r != nil, "type param without replacement encountered") 88 return r 89 90 case *types.Basic: 91 return t 92 93 case *types.Array: 94 if r := subst.typ(t.Elem()); r != t.Elem() { 95 return types.NewArray(r, t.Len()) 96 } 97 return t 98 99 case *types.Slice: 100 if r := subst.typ(t.Elem()); r != t.Elem() { 101 return types.NewSlice(r) 102 } 103 return t 104 105 case *types.Pointer: 106 if r := subst.typ(t.Elem()); r != t.Elem() { 107 return types.NewPointer(r) 108 } 109 return t 110 111 case *types.Tuple: 112 return subst.tuple(t) 113 114 case *types.Struct: 115 return subst.struct_(t) 116 117 case *types.Map: 118 key := subst.typ(t.Key()) 119 elem := subst.typ(t.Elem()) 120 if key != t.Key() || elem != t.Elem() { 121 return types.NewMap(key, elem) 122 } 123 return t 124 125 case *types.Chan: 126 if elem := subst.typ(t.Elem()); elem != t.Elem() { 127 return types.NewChan(t.Dir(), elem) 128 } 129 return t 130 131 case *types.Signature: 132 return subst.signature(t) 133 134 case *typeparams.Union: 135 return subst.union(t) 136 137 case *types.Interface: 138 return subst.interface_(t) 139 140 case *types.Named: 141 return subst.named(t) 142 143 default: 144 panic("unreachable") 145 } 146 } 147 148 // types returns the result of {subst.typ(ts[i])}. 149 func (subst *subster) types(ts []types.Type) []types.Type { 150 res := make([]types.Type, len(ts)) 151 for i := range ts { 152 res[i] = subst.typ(ts[i]) 153 } 154 return res 155 } 156 157 func (subst *subster) tuple(t *types.Tuple) *types.Tuple { 158 if t != nil { 159 if vars := subst.varlist(t); vars != nil { 160 return types.NewTuple(vars...) 161 } 162 } 163 return t 164 } 165 166 type varlist interface { 167 At(i int) *types.Var 168 Len() int 169 } 170 171 // fieldlist is an adapter for structs for the varlist interface. 172 type fieldlist struct { 173 str *types.Struct 174 } 175 176 func (fl fieldlist) At(i int) *types.Var { return fl.str.Field(i) } 177 func (fl fieldlist) Len() int { return fl.str.NumFields() } 178 179 func (subst *subster) struct_(t *types.Struct) *types.Struct { 180 if t != nil { 181 if fields := subst.varlist(fieldlist{t}); fields != nil { 182 tags := make([]string, t.NumFields()) 183 for i, n := 0, t.NumFields(); i < n; i++ { 184 tags[i] = t.Tag(i) 185 } 186 return types.NewStruct(fields, tags) 187 } 188 } 189 return t 190 } 191 192 // varlist reutrns subst(in[i]) or return nils if subst(v[i]) == v[i] for all i. 193 func (subst *subster) varlist(in varlist) []*types.Var { 194 var out []*types.Var // nil => no updates 195 for i, n := 0, in.Len(); i < n; i++ { 196 v := in.At(i) 197 w := subst.var_(v) 198 if v != w && out == nil { 199 out = make([]*types.Var, n) 200 for j := 0; j < i; j++ { 201 out[j] = in.At(j) 202 } 203 } 204 if out != nil { 205 out[i] = w 206 } 207 } 208 return out 209 } 210 211 func (subst *subster) var_(v *types.Var) *types.Var { 212 if v != nil { 213 if typ := subst.typ(v.Type()); typ != v.Type() { 214 if v.IsField() { 215 return types.NewField(v.Pos(), v.Pkg(), v.Name(), typ, v.Embedded()) 216 } 217 return types.NewVar(v.Pos(), v.Pkg(), v.Name(), typ) 218 } 219 } 220 return v 221 } 222 223 func (subst *subster) union(u *typeparams.Union) *typeparams.Union { 224 var out []*typeparams.Term // nil => no updates 225 226 for i, n := 0, u.Len(); i < n; i++ { 227 t := u.Term(i) 228 r := subst.typ(t.Type()) 229 if r != t.Type() && out == nil { 230 out = make([]*typeparams.Term, n) 231 for j := 0; j < i; j++ { 232 out[j] = u.Term(j) 233 } 234 } 235 if out != nil { 236 out[i] = typeparams.NewTerm(t.Tilde(), r) 237 } 238 } 239 240 if out != nil { 241 return typeparams.NewUnion(out) 242 } 243 return u 244 } 245 246 func (subst *subster) interface_(iface *types.Interface) *types.Interface { 247 if iface == nil { 248 return nil 249 } 250 251 // methods for the interface. Initially nil if there is no known change needed. 252 // Signatures for the method where recv is nil. NewInterfaceType fills in the recievers. 253 var methods []*types.Func 254 initMethods := func(n int) { // copy first n explicit methods 255 methods = make([]*types.Func, iface.NumExplicitMethods()) 256 for i := 0; i < n; i++ { 257 f := iface.ExplicitMethod(i) 258 norecv := changeRecv(f.Type().(*types.Signature), nil) 259 methods[i] = types.NewFunc(f.Pos(), f.Pkg(), f.Name(), norecv) 260 } 261 } 262 for i := 0; i < iface.NumExplicitMethods(); i++ { 263 f := iface.ExplicitMethod(i) 264 // On interfaces, we need to cycle break on anonymous interface types 265 // being in a cycle with their signatures being in cycles with their recievers 266 // that do not go through a Named. 267 norecv := changeRecv(f.Type().(*types.Signature), nil) 268 sig := subst.typ(norecv) 269 if sig != norecv && methods == nil { 270 initMethods(i) 271 } 272 if methods != nil { 273 methods[i] = types.NewFunc(f.Pos(), f.Pkg(), f.Name(), sig.(*types.Signature)) 274 } 275 } 276 277 var embeds []types.Type 278 initEmbeds := func(n int) { // copy first n embedded types 279 embeds = make([]types.Type, iface.NumEmbeddeds()) 280 for i := 0; i < n; i++ { 281 embeds[i] = iface.EmbeddedType(i) 282 } 283 } 284 for i := 0; i < iface.NumEmbeddeds(); i++ { 285 e := iface.EmbeddedType(i) 286 r := subst.typ(e) 287 if e != r && embeds == nil { 288 initEmbeds(i) 289 } 290 if embeds != nil { 291 embeds[i] = r 292 } 293 } 294 295 if methods == nil && embeds == nil { 296 return iface 297 } 298 if methods == nil { 299 initMethods(iface.NumExplicitMethods()) 300 } 301 if embeds == nil { 302 initEmbeds(iface.NumEmbeddeds()) 303 } 304 return types.NewInterfaceType(methods, embeds).Complete() 305 } 306 307 func (subst *subster) named(t *types.Named) types.Type { 308 // A named type may be: 309 // (1) ordinary named type (non-local scope, no type parameters, no type arguments), 310 // (2) locally scoped type, 311 // (3) generic (type parameters but no type arguments), or 312 // (4) instantiated (type parameters and type arguments). 313 tparams := typeparams.ForNamed(t) 314 if tparams.Len() == 0 { 315 if subst.scope != nil && !subst.scope.Contains(t.Obj().Pos()) { 316 // Outside the current function scope? 317 return t // case (1) ordinary 318 } 319 320 // case (2) locally scoped type. 321 // Create a new named type to represent this instantiation. 322 // We assume that local types of distinct instantiations of a 323 // generic function are distinct, even if they don't refer to 324 // type parameters, but the spec is unclear; see golang/go#58573. 325 // 326 // Subtle: We short circuit substitution and use a newly created type in 327 // subst, i.e. cache[t]=n, to pre-emptively replace t with n in recursive 328 // types during traversal. This both breaks infinite cycles and allows for 329 // constructing types with the replacement applied in subst.typ(under). 330 // 331 // Example: 332 // func foo[T any]() { 333 // type linkedlist struct { 334 // next *linkedlist 335 // val T 336 // } 337 // } 338 // 339 // When the field `next *linkedlist` is visited during subst.typ(under), 340 // we want the substituted type for the field `next` to be `*n`. 341 n := types.NewNamed(t.Obj(), nil, nil) 342 subst.cache[t] = n 343 subst.cache[n] = n 344 n.SetUnderlying(subst.typ(t.Underlying())) 345 return n 346 } 347 targs := typeparams.NamedTypeArgs(t) 348 349 // insts are arguments to instantiate using. 350 insts := make([]types.Type, tparams.Len()) 351 352 // case (3) generic ==> targs.Len() == 0 353 // Instantiating a generic with no type arguments should be unreachable. 354 // Please report a bug if you encounter this. 355 assert(targs.Len() != 0, "substition into a generic Named type is currently unsupported") 356 357 // case (4) instantiated. 358 // Substitute into the type arguments and instantiate the replacements/ 359 // Example: 360 // type N[A any] func() A 361 // func Foo[T](g N[T]) {} 362 // To instantiate Foo[string], one goes through {T->string}. To get the type of g 363 // one subsitutes T with string in {N with typeargs == {T} and typeparams == {A} } 364 // to get {N with TypeArgs == {string} and typeparams == {A} }. 365 assert(targs.Len() == tparams.Len(), "typeargs.Len() must match typeparams.Len() if present") 366 for i, n := 0, targs.Len(); i < n; i++ { 367 inst := subst.typ(targs.At(i)) // TODO(generic): Check with rfindley for mutual recursion 368 insts[i] = inst 369 } 370 r, err := typeparams.Instantiate(subst.ctxt, typeparams.NamedTypeOrigin(t), insts, false) 371 assert(err == nil, "failed to Instantiate Named type") 372 return r 373 } 374 375 func (subst *subster) signature(t *types.Signature) types.Type { 376 tparams := typeparams.ForSignature(t) 377 378 // We are choosing not to support tparams.Len() > 0 until a need has been observed in practice. 379 // 380 // There are some known usages for types.Types coming from types.{Eval,CheckExpr}. 381 // To support tparams.Len() > 0, we just need to do the following [psuedocode]: 382 // targs := {subst.replacements[tparams[i]]]}; Instantiate(ctxt, t, targs, false) 383 384 assert(tparams.Len() == 0, "Substituting types.Signatures with generic functions are currently unsupported.") 385 386 // Either: 387 // (1)non-generic function. 388 // no type params to substitute 389 // (2)generic method and recv needs to be substituted. 390 391 // Recievers can be either: 392 // named 393 // pointer to named 394 // interface 395 // nil 396 // interface is the problematic case. We need to cycle break there! 397 recv := subst.var_(t.Recv()) 398 params := subst.tuple(t.Params()) 399 results := subst.tuple(t.Results()) 400 if recv != t.Recv() || params != t.Params() || results != t.Results() { 401 return typeparams.NewSignatureType(recv, nil, nil, params, results, t.Variadic()) 402 } 403 return t 404 } 405 406 // reaches returns true if a type t reaches any type t' s.t. c[t'] == true. 407 // It updates c to cache results. 408 // 409 // reaches is currently only part of the wellFormed debug logic, and 410 // in practice c is initially only type parameters. It is not currently 411 // relied on in production. 412 func reaches(t types.Type, c map[types.Type]bool) (res bool) { 413 if c, ok := c[t]; ok { 414 return c 415 } 416 417 // c is populated with temporary false entries as types are visited. 418 // This avoids repeat visits and break cycles. 419 c[t] = false 420 defer func() { 421 c[t] = res 422 }() 423 424 switch t := t.(type) { 425 case *typeparams.TypeParam, *types.Basic: 426 return false 427 case *types.Array: 428 return reaches(t.Elem(), c) 429 case *types.Slice: 430 return reaches(t.Elem(), c) 431 case *types.Pointer: 432 return reaches(t.Elem(), c) 433 case *types.Tuple: 434 for i := 0; i < t.Len(); i++ { 435 if reaches(t.At(i).Type(), c) { 436 return true 437 } 438 } 439 case *types.Struct: 440 for i := 0; i < t.NumFields(); i++ { 441 if reaches(t.Field(i).Type(), c) { 442 return true 443 } 444 } 445 case *types.Map: 446 return reaches(t.Key(), c) || reaches(t.Elem(), c) 447 case *types.Chan: 448 return reaches(t.Elem(), c) 449 case *types.Signature: 450 if t.Recv() != nil && reaches(t.Recv().Type(), c) { 451 return true 452 } 453 return reaches(t.Params(), c) || reaches(t.Results(), c) 454 case *typeparams.Union: 455 for i := 0; i < t.Len(); i++ { 456 if reaches(t.Term(i).Type(), c) { 457 return true 458 } 459 } 460 case *types.Interface: 461 for i := 0; i < t.NumEmbeddeds(); i++ { 462 if reaches(t.Embedded(i), c) { 463 return true 464 } 465 } 466 for i := 0; i < t.NumExplicitMethods(); i++ { 467 if reaches(t.ExplicitMethod(i).Type(), c) { 468 return true 469 } 470 } 471 case *types.Named: 472 return reaches(t.Underlying(), c) 473 default: 474 panic("unreachable") 475 } 476 return false 477 }