github.com/gopherjs/gopherjs@v1.19.0-beta1.0.20240506212314-27071a8796e4/internal/govendor/subst/subst.go (about) 1 // Copyright 2022 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package subst 6 7 import ( 8 "go/types" 9 ) 10 11 // Type substituter for a fixed set of replacement types. 12 // 13 // A nil *subster is an valid, empty substitution map. It always acts as 14 // the identity function. This allows for treating parameterized and 15 // non-parameterized functions identically while compiling to ssa. 16 // 17 // Not concurrency-safe. 18 type subster struct { 19 replacements map[*types.TypeParam]types.Type // values should contain no type params 20 cache map[types.Type]types.Type // cache of subst results 21 ctxt *types.Context // cache for instantiation 22 scope *types.Scope // *types.Named declared within this scope can be substituted (optional) 23 debug bool // perform extra debugging checks 24 // TODO(taking): consider adding Pos 25 // TODO(zpavlinovic): replacements can contain type params 26 // when generating instances inside of a generic function body. 27 } 28 29 // Returns a subster that replaces tparams[i] with targs[i]. Uses ctxt as a cache. 30 // targs should not contain any types in tparams. 31 // scope is the (optional) lexical block of the generic function for which we are substituting. 32 func makeSubster(ctxt *types.Context, scope *types.Scope, tparams *types.TypeParamList, targs []types.Type, debug bool) *subster { 33 assert(tparams.Len() == len(targs), "makeSubster argument count must match") 34 35 subst := &subster{ 36 replacements: make(map[*types.TypeParam]types.Type, tparams.Len()), 37 cache: make(map[types.Type]types.Type), 38 ctxt: ctxt, 39 scope: scope, 40 debug: debug, 41 } 42 for i := 0; i < tparams.Len(); i++ { 43 subst.replacements[tparams.At(i)] = targs[i] 44 } 45 if subst.debug { 46 subst.wellFormed() 47 } 48 return subst 49 } 50 51 // wellFormed asserts that subst was properly initialized. 52 func (subst *subster) wellFormed() { 53 if subst == nil { 54 return 55 } 56 // Check that all of the type params do not appear in the arguments. 57 s := make(map[types.Type]bool, len(subst.replacements)) 58 for tparam := range subst.replacements { 59 s[tparam] = true 60 } 61 for _, r := range subst.replacements { 62 if reaches(r, s) { 63 panic(subst) 64 } 65 } 66 } 67 68 // typ returns the type of t with the type parameter tparams[i] substituted 69 // for the type targs[i] where subst was created using tparams and targs. 70 func (subst *subster) typ(t types.Type) (res types.Type) { 71 if subst == nil { 72 return t // A nil subst is type preserving. 73 } 74 if r, ok := subst.cache[t]; ok { 75 return r 76 } 77 defer func() { 78 subst.cache[t] = res 79 }() 80 81 // fall through if result r will be identical to t, types.Identical(r, t). 82 switch t := t.(type) { 83 case *types.TypeParam: 84 r := subst.replacements[t] 85 assert(r != nil, "type param without replacement encountered") 86 return r 87 88 case *types.Basic: 89 return t 90 91 case *types.Array: 92 if r := subst.typ(t.Elem()); r != t.Elem() { 93 return types.NewArray(r, t.Len()) 94 } 95 return t 96 97 case *types.Slice: 98 if r := subst.typ(t.Elem()); r != t.Elem() { 99 return types.NewSlice(r) 100 } 101 return t 102 103 case *types.Pointer: 104 if r := subst.typ(t.Elem()); r != t.Elem() { 105 return types.NewPointer(r) 106 } 107 return t 108 109 case *types.Tuple: 110 return subst.tuple(t) 111 112 case *types.Struct: 113 return subst.struct_(t) 114 115 case *types.Map: 116 key := subst.typ(t.Key()) 117 elem := subst.typ(t.Elem()) 118 if key != t.Key() || elem != t.Elem() { 119 return types.NewMap(key, elem) 120 } 121 return t 122 123 case *types.Chan: 124 if elem := subst.typ(t.Elem()); elem != t.Elem() { 125 return types.NewChan(t.Dir(), elem) 126 } 127 return t 128 129 case *types.Signature: 130 return subst.signature(t) 131 132 case *types.Union: 133 return subst.union(t) 134 135 case *types.Interface: 136 return subst.interface_(t) 137 138 case *types.Named: 139 return subst.named(t) 140 141 default: 142 panic("unreachable") 143 } 144 } 145 146 // types returns the result of {subst.typ(ts[i])}. 147 func (subst *subster) types(ts []types.Type) []types.Type { 148 res := make([]types.Type, len(ts)) 149 for i := range ts { 150 res[i] = subst.typ(ts[i]) 151 } 152 return res 153 } 154 155 func (subst *subster) tuple(t *types.Tuple) *types.Tuple { 156 if t != nil { 157 if vars := subst.varlist(t); vars != nil { 158 return types.NewTuple(vars...) 159 } 160 } 161 return t 162 } 163 164 type varlist interface { 165 At(i int) *types.Var 166 Len() int 167 } 168 169 // fieldlist is an adapter for structs for the varlist interface. 170 type fieldlist struct { 171 str *types.Struct 172 } 173 174 func (fl fieldlist) At(i int) *types.Var { return fl.str.Field(i) } 175 func (fl fieldlist) Len() int { return fl.str.NumFields() } 176 177 func (subst *subster) struct_(t *types.Struct) *types.Struct { 178 if t != nil { 179 if fields := subst.varlist(fieldlist{t}); fields != nil { 180 tags := make([]string, t.NumFields()) 181 for i, n := 0, t.NumFields(); i < n; i++ { 182 tags[i] = t.Tag(i) 183 } 184 return types.NewStruct(fields, tags) 185 } 186 } 187 return t 188 } 189 190 // varlist reutrns subst(in[i]) or return nils if subst(v[i]) == v[i] for all i. 191 func (subst *subster) varlist(in varlist) []*types.Var { 192 var out []*types.Var // nil => no updates 193 for i, n := 0, in.Len(); i < n; i++ { 194 v := in.At(i) 195 w := subst.var_(v) 196 if v != w && out == nil { 197 out = make([]*types.Var, n) 198 for j := 0; j < i; j++ { 199 out[j] = in.At(j) 200 } 201 } 202 if out != nil { 203 out[i] = w 204 } 205 } 206 return out 207 } 208 209 func (subst *subster) var_(v *types.Var) *types.Var { 210 if v != nil { 211 if typ := subst.typ(v.Type()); typ != v.Type() { 212 if v.IsField() { 213 return types.NewField(v.Pos(), v.Pkg(), v.Name(), typ, v.Embedded()) 214 } 215 return types.NewVar(v.Pos(), v.Pkg(), v.Name(), typ) 216 } 217 } 218 return v 219 } 220 221 func (subst *subster) union(u *types.Union) *types.Union { 222 var out []*types.Term // nil => no updates 223 224 for i, n := 0, u.Len(); i < n; i++ { 225 t := u.Term(i) 226 r := subst.typ(t.Type()) 227 if r != t.Type() && out == nil { 228 out = make([]*types.Term, n) 229 for j := 0; j < i; j++ { 230 out[j] = u.Term(j) 231 } 232 } 233 if out != nil { 234 out[i] = types.NewTerm(t.Tilde(), r) 235 } 236 } 237 238 if out != nil { 239 return types.NewUnion(out) 240 } 241 return u 242 } 243 244 func (subst *subster) interface_(iface *types.Interface) *types.Interface { 245 if iface == nil { 246 return nil 247 } 248 249 // methods for the interface. Initially nil if there is no known change needed. 250 // Signatures for the method where recv is nil. NewInterfaceType fills in the receivers. 251 var methods []*types.Func 252 initMethods := func(n int) { // copy first n explicit methods 253 methods = make([]*types.Func, iface.NumExplicitMethods()) 254 for i := 0; i < n; i++ { 255 f := iface.ExplicitMethod(i) 256 norecv := changeRecv(f.Type().(*types.Signature), nil) 257 methods[i] = types.NewFunc(f.Pos(), f.Pkg(), f.Name(), norecv) 258 } 259 } 260 for i := 0; i < iface.NumExplicitMethods(); i++ { 261 f := iface.ExplicitMethod(i) 262 // On interfaces, we need to cycle break on anonymous interface types 263 // being in a cycle with their signatures being in cycles with their receivers 264 // that do not go through a Named. 265 norecv := changeRecv(f.Type().(*types.Signature), nil) 266 sig := subst.typ(norecv) 267 if sig != norecv && methods == nil { 268 initMethods(i) 269 } 270 if methods != nil { 271 methods[i] = types.NewFunc(f.Pos(), f.Pkg(), f.Name(), sig.(*types.Signature)) 272 } 273 } 274 275 var embeds []types.Type 276 initEmbeds := func(n int) { // copy first n embedded types 277 embeds = make([]types.Type, iface.NumEmbeddeds()) 278 for i := 0; i < n; i++ { 279 embeds[i] = iface.EmbeddedType(i) 280 } 281 } 282 for i := 0; i < iface.NumEmbeddeds(); i++ { 283 e := iface.EmbeddedType(i) 284 r := subst.typ(e) 285 if e != r && embeds == nil { 286 initEmbeds(i) 287 } 288 if embeds != nil { 289 embeds[i] = r 290 } 291 } 292 293 if methods == nil && embeds == nil { 294 return iface 295 } 296 if methods == nil { 297 initMethods(iface.NumExplicitMethods()) 298 } 299 if embeds == nil { 300 initEmbeds(iface.NumEmbeddeds()) 301 } 302 return types.NewInterfaceType(methods, embeds).Complete() 303 } 304 305 func (subst *subster) named(t *types.Named) types.Type { 306 // A named type may be: 307 // (1) ordinary named type (non-local scope, no type parameters, no type arguments), 308 // (2) locally scoped type, 309 // (3) generic (type parameters but no type arguments), or 310 // (4) instantiated (type parameters and type arguments). 311 tparams := t.TypeParams() 312 if tparams.Len() == 0 { 313 if subst.scope != nil && !subst.scope.Contains(t.Obj().Pos()) { 314 // Outside the current function scope? 315 return t // case (1) ordinary 316 } 317 318 // case (2) locally scoped type. 319 // Create a new named type to represent this instantiation. 320 // We assume that local types of distinct instantiations of a 321 // generic function are distinct, even if they don't refer to 322 // type parameters, but the spec is unclear; see golang/go#58573. 323 // 324 // Subtle: We short circuit substitution and use a newly created type in 325 // subst, i.e. cache[t]=n, to pre-emptively replace t with n in recursive 326 // types during traversal. This both breaks infinite cycles and allows for 327 // constructing types with the replacement applied in subst.typ(under). 328 // 329 // Example: 330 // func foo[T any]() { 331 // type linkedlist struct { 332 // next *linkedlist 333 // val T 334 // } 335 // } 336 // 337 // When the field `next *linkedlist` is visited during subst.typ(under), 338 // we want the substituted type for the field `next` to be `*n`. 339 n := types.NewNamed(t.Obj(), nil, nil) 340 subst.cache[t] = n 341 subst.cache[n] = n 342 n.SetUnderlying(subst.typ(t.Underlying())) 343 return n 344 } 345 targs := t.TypeArgs() 346 347 // insts are arguments to instantiate using. 348 insts := make([]types.Type, tparams.Len()) 349 350 // case (3) generic ==> targs.Len() == 0 351 // Instantiating a generic with no type arguments should be unreachable. 352 // Please report a bug if you encounter this. 353 assert(targs.Len() != 0, "substition into a generic Named type is currently unsupported") 354 355 // case (4) instantiated. 356 // Substitute into the type arguments and instantiate the replacements/ 357 // Example: 358 // type N[A any] func() A 359 // func Foo[T](g N[T]) {} 360 // To instantiate Foo[string], one goes through {T->string}. To get the type of g 361 // one subsitutes T with string in {N with typeargs == {T} and typeparams == {A} } 362 // to get {N with TypeArgs == {string} and typeparams == {A} }. 363 assert(targs.Len() == tparams.Len(), "typeargs.Len() must match typeparams.Len() if present") 364 for i, n := 0, targs.Len(); i < n; i++ { 365 inst := subst.typ(targs.At(i)) // TODO(generic): Check with rfindley for mutual recursion 366 insts[i] = inst 367 } 368 r, err := types.Instantiate(subst.ctxt, t.Origin(), insts, false) 369 assert(err == nil, "failed to Instantiate Named type") 370 return r 371 } 372 373 func (subst *subster) signature(t *types.Signature) types.Type { 374 tparams := t.TypeParams() 375 376 // We are choosing not to support tparams.Len() > 0 until a need has been observed in practice. 377 // 378 // There are some known usages for types.Types coming from types.{Eval,CheckExpr}. 379 // To support tparams.Len() > 0, we just need to do the following [psuedocode]: 380 // targs := {subst.replacements[tparams[i]]]}; Instantiate(ctxt, t, targs, false) 381 382 assert(tparams.Len() == 0, "Substituting types.Signatures with generic functions are currently unsupported.") 383 384 // Either: 385 // (1)non-generic function. 386 // no type params to substitute 387 // (2)generic method and recv needs to be substituted. 388 389 // Receivers can be either: 390 // named 391 // pointer to named 392 // interface 393 // nil 394 // interface is the problematic case. We need to cycle break there! 395 recv := subst.var_(t.Recv()) 396 params := subst.tuple(t.Params()) 397 results := subst.tuple(t.Results()) 398 if recv != t.Recv() || params != t.Params() || results != t.Results() { 399 return types.NewSignatureType(recv, nil, nil, params, results, t.Variadic()) 400 } 401 return t 402 } 403 404 // reaches returns true if a type t reaches any type t' s.t. c[t'] == true. 405 // It updates c to cache results. 406 // 407 // reaches is currently only part of the wellFormed debug logic, and 408 // in practice c is initially only type parameters. It is not currently 409 // relied on in production. 410 func reaches(t types.Type, c map[types.Type]bool) (res bool) { 411 if c, ok := c[t]; ok { 412 return c 413 } 414 415 // c is populated with temporary false entries as types are visited. 416 // This avoids repeat visits and break cycles. 417 c[t] = false 418 defer func() { 419 c[t] = res 420 }() 421 422 switch t := t.(type) { 423 case *types.TypeParam, *types.Basic: 424 return false 425 case *types.Array: 426 return reaches(t.Elem(), c) 427 case *types.Slice: 428 return reaches(t.Elem(), c) 429 case *types.Pointer: 430 return reaches(t.Elem(), c) 431 case *types.Tuple: 432 for i := 0; i < t.Len(); i++ { 433 if reaches(t.At(i).Type(), c) { 434 return true 435 } 436 } 437 case *types.Struct: 438 for i := 0; i < t.NumFields(); i++ { 439 if reaches(t.Field(i).Type(), c) { 440 return true 441 } 442 } 443 case *types.Map: 444 return reaches(t.Key(), c) || reaches(t.Elem(), c) 445 case *types.Chan: 446 return reaches(t.Elem(), c) 447 case *types.Signature: 448 if t.Recv() != nil && reaches(t.Recv().Type(), c) { 449 return true 450 } 451 return reaches(t.Params(), c) || reaches(t.Results(), c) 452 case *types.Union: 453 for i := 0; i < t.Len(); i++ { 454 if reaches(t.Term(i).Type(), c) { 455 return true 456 } 457 } 458 case *types.Interface: 459 for i := 0; i < t.NumEmbeddeds(); i++ { 460 if reaches(t.Embedded(i), c) { 461 return true 462 } 463 } 464 for i := 0; i < t.NumExplicitMethods(); i++ { 465 if reaches(t.ExplicitMethod(i).Type(), c) { 466 return true 467 } 468 } 469 case *types.Named: 470 return reaches(t.Underlying(), c) 471 default: 472 panic("unreachable") 473 } 474 return false 475 }