github.com/traefik/yaegi@v0.15.1/interp/generic.go (about) 1 package interp 2 3 import ( 4 "strings" 5 "sync/atomic" 6 ) 7 8 // adot produces an AST dot(1) directed acyclic graph for the given node. For debugging only. 9 // func (n *node) adot() { n.astDot(dotWriter(n.interp.dotCmd), n.ident) } 10 11 // genAST returns a new AST where generic types are replaced by instantiated types. 12 func genAST(sc *scope, root *node, types []*itype) (*node, bool, error) { 13 typeParam := map[string]*node{} 14 pindex := 0 15 tname := "" 16 rtname := "" 17 recvrPtr := false 18 fixNodes := []*node{} 19 var gtree func(*node, *node) (*node, error) 20 sname := root.child[0].ident + "[" 21 if root.kind == funcDecl { 22 sname = root.child[1].ident + "[" 23 } 24 25 // Input type parameters must be resolved prior AST generation, as compilation 26 // of generated AST may occur in a different scope. 27 for _, t := range types { 28 sname += t.id() + "," 29 } 30 sname = strings.TrimSuffix(sname, ",") + "]" 31 32 gtree = func(n, anc *node) (*node, error) { 33 nod := copyNode(n, anc, false) 34 switch n.kind { 35 case funcDecl, funcType: 36 nod.val = nod 37 38 case identExpr: 39 // Replace generic type by instantiated one. 40 nt, ok := typeParam[n.ident] 41 if !ok { 42 break 43 } 44 nod = copyNode(nt, anc, true) 45 nod.typ = nt.typ 46 47 case indexExpr: 48 // Catch a possible recursive generic type definition 49 if root.kind != typeSpec { 50 break 51 } 52 if root.child[0].ident != n.child[0].ident { 53 break 54 } 55 nod := copyNode(n.child[0], anc, false) 56 fixNodes = append(fixNodes, nod) 57 return nod, nil 58 59 case fieldList: 60 // Node is the type parameters list of a generic function. 61 if root.kind == funcDecl && n.anc == root.child[2] && childPos(n) == 0 { 62 // Fill the types lookup table used for type substitution. 63 for _, c := range n.child { 64 l := len(c.child) - 1 65 for _, cc := range c.child[:l] { 66 if pindex >= len(types) { 67 return nil, cc.cfgErrorf("undefined type for %s", cc.ident) 68 } 69 t, err := nodeType(c.interp, sc, c.child[l]) 70 if err != nil { 71 return nil, err 72 } 73 if err := checkConstraint(types[pindex], t); err != nil { 74 return nil, err 75 } 76 typeParam[cc.ident] = copyNode(cc, cc.anc, false) 77 typeParam[cc.ident].ident = types[pindex].id() 78 typeParam[cc.ident].typ = types[pindex] 79 pindex++ 80 } 81 } 82 // Skip type parameters specification, so generated func doesn't look generic. 83 return nod, nil 84 } 85 86 // Node is the receiver of a generic method. 87 if root.kind == funcDecl && n.anc == root && childPos(n) == 0 && len(n.child) > 0 { 88 rtn := n.child[0].child[1] 89 // Method receiver is a generic type if it takes some type parameters. 90 if rtn.kind == indexExpr || rtn.kind == indexListExpr || (rtn.kind == starExpr && (rtn.child[0].kind == indexExpr || rtn.child[0].kind == indexListExpr)) { 91 if rtn.kind == starExpr { 92 // Method receiver is a pointer on a generic type. 93 rtn = rtn.child[0] 94 recvrPtr = true 95 } 96 rtname = rtn.child[0].ident + "[" 97 for _, cc := range rtn.child[1:] { 98 if pindex >= len(types) { 99 return nil, cc.cfgErrorf("undefined type for %s", cc.ident) 100 } 101 it := types[pindex] 102 typeParam[cc.ident] = copyNode(cc, cc.anc, false) 103 typeParam[cc.ident].ident = it.id() 104 typeParam[cc.ident].typ = it 105 rtname += it.id() + "," 106 pindex++ 107 } 108 rtname = strings.TrimSuffix(rtname, ",") + "]" 109 } 110 } 111 112 // Node is the type parameters list of a generic type. 113 if root.kind == typeSpec && n.anc == root && childPos(n) == 1 { 114 // Fill the types lookup table used for type substitution. 115 tname = n.anc.child[0].ident + "[" 116 for _, c := range n.child { 117 l := len(c.child) - 1 118 for _, cc := range c.child[:l] { 119 if pindex >= len(types) { 120 return nil, cc.cfgErrorf("undefined type for %s", cc.ident) 121 } 122 it := types[pindex] 123 t, err := nodeType(c.interp, sc, c.child[l]) 124 if err != nil { 125 return nil, err 126 } 127 if err := checkConstraint(types[pindex], t); err != nil { 128 return nil, err 129 } 130 typeParam[cc.ident] = copyNode(cc, cc.anc, false) 131 typeParam[cc.ident].ident = it.id() 132 typeParam[cc.ident].typ = it 133 tname += it.id() + "," 134 pindex++ 135 } 136 } 137 tname = strings.TrimSuffix(tname, ",") + "]" 138 return nod, nil 139 } 140 } 141 142 for _, c := range n.child { 143 gn, err := gtree(c, nod) 144 if err != nil { 145 return nil, err 146 } 147 nod.child = append(nod.child, gn) 148 } 149 return nod, nil 150 } 151 152 if nod, found := root.interp.generic[sname]; found { 153 return nod, true, nil 154 } 155 156 r, err := gtree(root, root.anc) 157 if err != nil { 158 return nil, false, err 159 } 160 root.interp.generic[sname] = r 161 r.param = append(r.param, types...) 162 if tname != "" { 163 for _, nod := range fixNodes { 164 nod.ident = tname 165 } 166 r.child[0].ident = tname 167 } 168 if rtname != "" { 169 // Replace method receiver type by synthetized ident. 170 nod := r.child[0].child[0].child[1] 171 if recvrPtr { 172 nod = nod.child[0] 173 } 174 nod.kind = identExpr 175 nod.ident = rtname 176 nod.child = nil 177 } 178 // r.adot() // Used for debugging only. 179 return r, false, nil 180 } 181 182 func copyNode(n, anc *node, recursive bool) *node { 183 var i interface{} 184 nindex := atomic.AddInt64(&n.interp.nindex, 1) 185 nod := &node{ 186 debug: n.debug, 187 anc: anc, 188 interp: n.interp, 189 index: nindex, 190 level: n.level, 191 nleft: n.nleft, 192 nright: n.nright, 193 kind: n.kind, 194 pos: n.pos, 195 action: n.action, 196 gen: n.gen, 197 val: &i, 198 rval: n.rval, 199 ident: n.ident, 200 meta: n.meta, 201 } 202 nod.start = nod 203 if recursive { 204 for _, c := range n.child { 205 nod.child = append(nod.child, copyNode(c, nod, true)) 206 } 207 } 208 return nod 209 } 210 211 func inferTypesFromCall(sc *scope, fun *node, args []*node) ([]*itype, error) { 212 ftn := fun.typ.node 213 // Fill the map of parameter types, indexed by type param ident. 214 paramTypes := map[string]*itype{} 215 for _, c := range ftn.child[0].child { 216 typ, err := nodeType(fun.interp, sc, c.lastChild()) 217 if err != nil { 218 return nil, err 219 } 220 for _, cc := range c.child[:len(c.child)-1] { 221 paramTypes[cc.ident] = typ 222 } 223 } 224 225 var inferTypes func(*itype, *itype) ([]*itype, error) 226 inferTypes = func(param, input *itype) ([]*itype, error) { 227 switch param.cat { 228 case chanT, ptrT, sliceT: 229 return inferTypes(param.val, input.val) 230 231 case mapT: 232 k, err := inferTypes(param.key, input.key) 233 if err != nil { 234 return nil, err 235 } 236 v, err := inferTypes(param.val, input.val) 237 if err != nil { 238 return nil, err 239 } 240 return append(k, v...), nil 241 242 case structT: 243 lt := []*itype{} 244 for i, f := range param.field { 245 nl, err := inferTypes(f.typ, input.field[i].typ) 246 if err != nil { 247 return nil, err 248 } 249 lt = append(lt, nl...) 250 } 251 return lt, nil 252 253 case funcT: 254 lt := []*itype{} 255 for i, t := range param.arg { 256 if i >= len(input.arg) { 257 break 258 } 259 nl, err := inferTypes(t, input.arg[i]) 260 if err != nil { 261 return nil, err 262 } 263 lt = append(lt, nl...) 264 } 265 for i, t := range param.ret { 266 if i >= len(input.ret) { 267 break 268 } 269 nl, err := inferTypes(t, input.ret[i]) 270 if err != nil { 271 return nil, err 272 } 273 lt = append(lt, nl...) 274 } 275 return lt, nil 276 277 case nilT: 278 if paramTypes[param.name] != nil { 279 return []*itype{input}, nil 280 } 281 282 case genericT: 283 return []*itype{input}, nil 284 } 285 return nil, nil 286 } 287 288 types := []*itype{} 289 for i, c := range ftn.child[1].child { 290 typ, err := nodeType(fun.interp, sc, c.lastChild()) 291 if err != nil { 292 return nil, err 293 } 294 lt, err := inferTypes(typ, args[i].typ) 295 if err != nil { 296 return nil, err 297 } 298 types = append(types, lt...) 299 } 300 301 return types, nil 302 } 303 304 func checkConstraint(it, ct *itype) error { 305 if len(ct.constraint) == 0 && len(ct.ulconstraint) == 0 { 306 return nil 307 } 308 for _, c := range ct.constraint { 309 if it.equals(c) { 310 return nil 311 } 312 } 313 for _, c := range ct.ulconstraint { 314 if it.underlying().equals(c) { 315 return nil 316 } 317 } 318 return it.node.cfgErrorf("%s does not implement %s", it.id(), ct.id()) 319 }