github.com/powerman/golang-tools@v0.1.11-0.20220410185822-5ad214d8d803/go/ssa/subst.go (about) 1 // Copyright 2022 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 package ssa 5 6 import ( 7 "fmt" 8 "go/types" 9 10 "github.com/powerman/golang-tools/internal/typeparams" 11 ) 12 13 // Type substituter for a fixed set of replacement types. 14 // 15 // A nil *subster is an valid, empty substitution map. It always acts as 16 // the identity function. This allows for treating parameterized and 17 // non-parameterized functions identically while compiling to ssa. 18 // 19 // Not concurrency-safe. 20 type subster struct { 21 replacements map[*typeparams.TypeParam]types.Type // values should contain no type params 22 cache map[types.Type]types.Type // cache of subst results 23 ctxt *typeparams.Context 24 debug bool // perform extra debugging checks 25 // TODO(taking): consider adding Pos 26 } 27 28 // Returns a subster that replaces tparams[i] with targs[i]. Uses ctxt as a cache. 29 // targs should not contain any types in tparams. 30 func makeSubster(ctxt *typeparams.Context, tparams []*typeparams.TypeParam, targs []types.Type, debug bool) *subster { 31 assert(len(tparams) == len(targs), "makeSubster argument count must match") 32 33 subst := &subster{ 34 replacements: make(map[*typeparams.TypeParam]types.Type, len(tparams)), 35 cache: make(map[types.Type]types.Type), 36 ctxt: ctxt, 37 debug: debug, 38 } 39 for i, tpar := range tparams { 40 subst.replacements[tpar] = targs[i] 41 } 42 if subst.debug { 43 if err := subst.wellFormed(); err != nil { 44 panic(err) 45 } 46 } 47 return subst 48 } 49 50 // wellFormed returns an error if subst was not properly initialized. 51 func (subst *subster) wellFormed() error { 52 if subst == nil || len(subst.replacements) == 0 { 53 return nil 54 } 55 // Check that all of the type params do not appear in the arguments. 56 s := make(map[types.Type]bool, len(subst.replacements)) 57 for tparam := range subst.replacements { 58 s[tparam] = true 59 } 60 for _, r := range subst.replacements { 61 if reaches(r, s) { 62 return fmt.Errorf("\n‰r %s s %v replacements %v\n", r, s, subst.replacements) 63 } 64 } 65 return nil 66 } 67 68 // typ returns the type of t with the type parameter tparams[i] substituted 69 // for the type targs[i] where subst was created using tparams and targs. 70 func (subst *subster) typ(t types.Type) (res types.Type) { 71 if subst == nil { 72 return t // A nil subst is type preserving. 73 } 74 if r, ok := subst.cache[t]; ok { 75 return r 76 } 77 defer func() { 78 subst.cache[t] = res 79 }() 80 81 // fall through if result r will be identical to t, types.Identical(r, t). 82 switch t := t.(type) { 83 case *typeparams.TypeParam: 84 r := subst.replacements[t] 85 assert(r != nil, "type param without replacement encountered") 86 return r 87 88 case *types.Basic: 89 return t 90 91 case *types.Array: 92 if r := subst.typ(t.Elem()); r != t.Elem() { 93 return types.NewArray(r, t.Len()) 94 } 95 return t 96 97 case *types.Slice: 98 if r := subst.typ(t.Elem()); r != t.Elem() { 99 return types.NewSlice(r) 100 } 101 return t 102 103 case *types.Pointer: 104 if r := subst.typ(t.Elem()); r != t.Elem() { 105 return types.NewPointer(r) 106 } 107 return t 108 109 case *types.Tuple: 110 return subst.tuple(t) 111 112 case *types.Struct: 113 return subst.struct_(t) 114 115 case *types.Map: 116 key := subst.typ(t.Key()) 117 elem := subst.typ(t.Elem()) 118 if key != t.Key() || elem != t.Elem() { 119 return types.NewMap(key, elem) 120 } 121 return t 122 123 case *types.Chan: 124 if elem := subst.typ(t.Elem()); elem != t.Elem() { 125 return types.NewChan(t.Dir(), elem) 126 } 127 return t 128 129 case *types.Signature: 130 return subst.signature(t) 131 132 case *typeparams.Union: 133 return subst.union(t) 134 135 case *types.Interface: 136 return subst.interface_(t) 137 138 case *types.Named: 139 return subst.named(t) 140 141 default: 142 panic("unreachable") 143 } 144 } 145 146 func (subst *subster) tuple(t *types.Tuple) *types.Tuple { 147 if t != nil { 148 if vars := subst.varlist(t); vars != nil { 149 return types.NewTuple(vars...) 150 } 151 } 152 return t 153 } 154 155 type varlist interface { 156 At(i int) *types.Var 157 Len() int 158 } 159 160 // fieldlist is an adapter for structs for the varlist interface. 161 type fieldlist struct { 162 str *types.Struct 163 } 164 165 func (fl fieldlist) At(i int) *types.Var { return fl.str.Field(i) } 166 func (fl fieldlist) Len() int { return fl.str.NumFields() } 167 168 func (subst *subster) struct_(t *types.Struct) *types.Struct { 169 if t != nil { 170 if fields := subst.varlist(fieldlist{t}); fields != nil { 171 tags := make([]string, t.NumFields()) 172 for i, n := 0, t.NumFields(); i < n; i++ { 173 tags[i] = t.Tag(i) 174 } 175 return types.NewStruct(fields, tags) 176 } 177 } 178 return t 179 } 180 181 // varlist reutrns subst(in[i]) or return nils if subst(v[i]) == v[i] for all i. 182 func (subst *subster) varlist(in varlist) []*types.Var { 183 var out []*types.Var // nil => no updates 184 for i, n := 0, in.Len(); i < n; i++ { 185 v := in.At(i) 186 w := subst.var_(v) 187 if v != w && out == nil { 188 out = make([]*types.Var, n) 189 for j := 0; j < i; j++ { 190 out[j] = in.At(j) 191 } 192 } 193 if out != nil { 194 out[i] = w 195 } 196 } 197 return out 198 } 199 200 func (subst *subster) var_(v *types.Var) *types.Var { 201 if v != nil { 202 if typ := subst.typ(v.Type()); typ != v.Type() { 203 if v.IsField() { 204 return types.NewField(v.Pos(), v.Pkg(), v.Name(), typ, v.Embedded()) 205 } 206 return types.NewVar(v.Pos(), v.Pkg(), v.Name(), typ) 207 } 208 } 209 return v 210 } 211 212 func (subst *subster) union(u *typeparams.Union) *typeparams.Union { 213 var out []*typeparams.Term // nil => no updates 214 215 for i, n := 0, u.Len(); i < n; i++ { 216 t := u.Term(i) 217 r := subst.typ(t.Type()) 218 if r != t.Type() && out == nil { 219 out = make([]*typeparams.Term, n) 220 for j := 0; j < i; j++ { 221 out[j] = u.Term(j) 222 } 223 } 224 if out != nil { 225 out[i] = typeparams.NewTerm(t.Tilde(), r) 226 } 227 } 228 229 if out != nil { 230 return typeparams.NewUnion(out) 231 } 232 return u 233 } 234 235 func (subst *subster) interface_(iface *types.Interface) *types.Interface { 236 if iface == nil { 237 return nil 238 } 239 240 // methods for the interface. Initially nil if there is no known change needed. 241 // Signatures for the method where recv is nil. NewInterfaceType fills in the recievers. 242 var methods []*types.Func 243 initMethods := func(n int) { // copy first n explicit methods 244 methods = make([]*types.Func, iface.NumExplicitMethods()) 245 for i := 0; i < n; i++ { 246 f := iface.ExplicitMethod(i) 247 norecv := changeRecv(f.Type().(*types.Signature), nil) 248 methods[i] = types.NewFunc(f.Pos(), f.Pkg(), f.Name(), norecv) 249 } 250 } 251 for i := 0; i < iface.NumExplicitMethods(); i++ { 252 f := iface.ExplicitMethod(i) 253 // On interfaces, we need to cycle break on anonymous interface types 254 // being in a cycle with their signatures being in cycles with their recievers 255 // that do not go through a Named. 256 norecv := changeRecv(f.Type().(*types.Signature), nil) 257 sig := subst.typ(norecv) 258 if sig != norecv && methods == nil { 259 initMethods(i) 260 } 261 if methods != nil { 262 methods[i] = types.NewFunc(f.Pos(), f.Pkg(), f.Name(), sig.(*types.Signature)) 263 } 264 } 265 266 var embeds []types.Type 267 initEmbeds := func(n int) { // copy first n embedded types 268 embeds = make([]types.Type, iface.NumEmbeddeds()) 269 for i := 0; i < n; i++ { 270 embeds[i] = iface.EmbeddedType(i) 271 } 272 } 273 for i := 0; i < iface.NumEmbeddeds(); i++ { 274 e := iface.EmbeddedType(i) 275 r := subst.typ(e) 276 if e != r && embeds == nil { 277 initEmbeds(i) 278 } 279 if embeds != nil { 280 embeds[i] = r 281 } 282 } 283 284 if methods == nil && embeds == nil { 285 return iface 286 } 287 if methods == nil { 288 initMethods(iface.NumExplicitMethods()) 289 } 290 if embeds == nil { 291 initEmbeds(iface.NumEmbeddeds()) 292 } 293 return types.NewInterfaceType(methods, embeds).Complete() 294 } 295 296 func (subst *subster) named(t *types.Named) types.Type { 297 // A name type may be: 298 // (1) ordinary (no type parameters, no type arguments), 299 // (2) generic (type parameters but no type arguments), or 300 // (3) instantiated (type parameters and type arguments). 301 tparams := typeparams.ForNamed(t) 302 if tparams.Len() == 0 { 303 // case (1) ordinary 304 305 // Note: If Go allows for local type declarations in generic 306 // functions we may need to descend into underlying as well. 307 return t 308 } 309 targs := typeparams.NamedTypeArgs(t) 310 311 // insts are arguments to instantiate using. 312 insts := make([]types.Type, tparams.Len()) 313 314 // case (2) generic ==> targs.Len() == 0 315 // Instantiating a generic with no type arguments should be unreachable. 316 // Please report a bug if you encounter this. 317 assert(targs.Len() != 0, "substition into a generic Named type is currently unsupported") 318 319 // case (3) instantiated. 320 // Substitute into the type arguments and instantiate the replacements/ 321 // Example: 322 // type N[A any] func() A 323 // func Foo[T](g N[T]) {} 324 // To instantiate Foo[string], one goes through {T->string}. To get the type of g 325 // one subsitutes T with string in {N with TypeArgs == {T} and TypeParams == {A} } 326 // to get {N with TypeArgs == {string} and TypeParams == {A} }. 327 assert(targs.Len() == tparams.Len(), "TypeArgs().Len() must match TypeParams().Len() if present") 328 for i, n := 0, targs.Len(); i < n; i++ { 329 inst := subst.typ(targs.At(i)) // TODO(generic): Check with rfindley for mutual recursion 330 insts[i] = inst 331 } 332 r, err := typeparams.Instantiate(subst.ctxt, typeparams.NamedTypeOrigin(t), insts, false) 333 assert(err == nil, "failed to Instantiate Named type") 334 return r 335 } 336 337 func (subst *subster) signature(t *types.Signature) types.Type { 338 tparams := typeparams.ForSignature(t) 339 340 // We are choosing not to support tparams.Len() > 0 until a need has been observed in practice. 341 // 342 // There are some known usages for types.Types coming from types.{Eval,CheckExpr}. 343 // To support tparams.Len() > 0, we just need to do the following [psuedocode]: 344 // targs := {subst.replacements[tparams[i]]]}; Instantiate(ctxt, t, targs, false) 345 346 assert(tparams.Len() == 0, "Substituting types.Signatures with generic functions are currently unsupported.") 347 348 // Either: 349 // (1)non-generic function. 350 // no type params to substitute 351 // (2)generic method and recv needs to be substituted. 352 353 // Recievers can be either: 354 // named 355 // pointer to named 356 // interface 357 // nil 358 // interface is the problematic case. We need to cycle break there! 359 recv := subst.var_(t.Recv()) 360 params := subst.tuple(t.Params()) 361 results := subst.tuple(t.Results()) 362 if recv != t.Recv() || params != t.Params() || results != t.Results() { 363 return types.NewSignature(recv, params, results, t.Variadic()) 364 } 365 return t 366 } 367 368 // reaches returns true if a type t reaches any type t' s.t. c[t'] == true. 369 // Updates c to cache results. 370 func reaches(t types.Type, c map[types.Type]bool) (res bool) { 371 if c, ok := c[t]; ok { 372 return c 373 } 374 c[t] = false // prevent cycles 375 defer func() { 376 c[t] = res 377 }() 378 379 switch t := t.(type) { 380 case *typeparams.TypeParam, *types.Basic: 381 // no-op => c == false 382 case *types.Array: 383 return reaches(t.Elem(), c) 384 case *types.Slice: 385 return reaches(t.Elem(), c) 386 case *types.Pointer: 387 return reaches(t.Elem(), c) 388 case *types.Tuple: 389 for i := 0; i < t.Len(); i++ { 390 if reaches(t.At(i).Type(), c) { 391 return true 392 } 393 } 394 case *types.Struct: 395 for i := 0; i < t.NumFields(); i++ { 396 if reaches(t.Field(i).Type(), c) { 397 return true 398 } 399 } 400 case *types.Map: 401 return reaches(t.Key(), c) || reaches(t.Elem(), c) 402 case *types.Chan: 403 return reaches(t.Elem(), c) 404 case *types.Signature: 405 if t.Recv() != nil && reaches(t.Recv().Type(), c) { 406 return true 407 } 408 return reaches(t.Params(), c) || reaches(t.Results(), c) 409 case *typeparams.Union: 410 for i := 0; i < t.Len(); i++ { 411 if reaches(t.Term(i).Type(), c) { 412 return true 413 } 414 } 415 case *types.Interface: 416 for i := 0; i < t.NumEmbeddeds(); i++ { 417 if reaches(t.Embedded(i), c) { 418 return true 419 } 420 } 421 for i := 0; i < t.NumExplicitMethods(); i++ { 422 if reaches(t.ExplicitMethod(i).Type(), c) { 423 return true 424 } 425 } 426 case *types.Named: 427 return reaches(t.Underlying(), c) 428 default: 429 panic("unreachable") 430 } 431 return false 432 }