github.com/traefik/yaegi@v0.15.1/interp/generic.go (about)

     1  package interp
     2  
     3  import (
     4  	"strings"
     5  	"sync/atomic"
     6  )
     7  
     8  // adot produces an AST dot(1) directed acyclic graph for the given node. For debugging only.
     9  // func (n *node) adot() { n.astDot(dotWriter(n.interp.dotCmd), n.ident) }
    10  
    11  // genAST returns a new AST where generic types are replaced by instantiated types.
    12  func genAST(sc *scope, root *node, types []*itype) (*node, bool, error) {
    13  	typeParam := map[string]*node{}
    14  	pindex := 0
    15  	tname := ""
    16  	rtname := ""
    17  	recvrPtr := false
    18  	fixNodes := []*node{}
    19  	var gtree func(*node, *node) (*node, error)
    20  	sname := root.child[0].ident + "["
    21  	if root.kind == funcDecl {
    22  		sname = root.child[1].ident + "["
    23  	}
    24  
    25  	// Input type parameters must be resolved prior AST generation, as compilation
    26  	// of generated AST may occur in a different scope.
    27  	for _, t := range types {
    28  		sname += t.id() + ","
    29  	}
    30  	sname = strings.TrimSuffix(sname, ",") + "]"
    31  
    32  	gtree = func(n, anc *node) (*node, error) {
    33  		nod := copyNode(n, anc, false)
    34  		switch n.kind {
    35  		case funcDecl, funcType:
    36  			nod.val = nod
    37  
    38  		case identExpr:
    39  			// Replace generic type by instantiated one.
    40  			nt, ok := typeParam[n.ident]
    41  			if !ok {
    42  				break
    43  			}
    44  			nod = copyNode(nt, anc, true)
    45  			nod.typ = nt.typ
    46  
    47  		case indexExpr:
    48  			// Catch a possible recursive generic type definition
    49  			if root.kind != typeSpec {
    50  				break
    51  			}
    52  			if root.child[0].ident != n.child[0].ident {
    53  				break
    54  			}
    55  			nod := copyNode(n.child[0], anc, false)
    56  			fixNodes = append(fixNodes, nod)
    57  			return nod, nil
    58  
    59  		case fieldList:
    60  			//  Node is the type parameters list of a generic function.
    61  			if root.kind == funcDecl && n.anc == root.child[2] && childPos(n) == 0 {
    62  				// Fill the types lookup table used for type substitution.
    63  				for _, c := range n.child {
    64  					l := len(c.child) - 1
    65  					for _, cc := range c.child[:l] {
    66  						if pindex >= len(types) {
    67  							return nil, cc.cfgErrorf("undefined type for %s", cc.ident)
    68  						}
    69  						t, err := nodeType(c.interp, sc, c.child[l])
    70  						if err != nil {
    71  							return nil, err
    72  						}
    73  						if err := checkConstraint(types[pindex], t); err != nil {
    74  							return nil, err
    75  						}
    76  						typeParam[cc.ident] = copyNode(cc, cc.anc, false)
    77  						typeParam[cc.ident].ident = types[pindex].id()
    78  						typeParam[cc.ident].typ = types[pindex]
    79  						pindex++
    80  					}
    81  				}
    82  				// Skip type parameters specification, so generated func doesn't look generic.
    83  				return nod, nil
    84  			}
    85  
    86  			// Node is the receiver of a generic method.
    87  			if root.kind == funcDecl && n.anc == root && childPos(n) == 0 && len(n.child) > 0 {
    88  				rtn := n.child[0].child[1]
    89  				// Method receiver is a generic type if it takes some type parameters.
    90  				if rtn.kind == indexExpr || rtn.kind == indexListExpr || (rtn.kind == starExpr && (rtn.child[0].kind == indexExpr || rtn.child[0].kind == indexListExpr)) {
    91  					if rtn.kind == starExpr {
    92  						// Method receiver is a pointer on a generic type.
    93  						rtn = rtn.child[0]
    94  						recvrPtr = true
    95  					}
    96  					rtname = rtn.child[0].ident + "["
    97  					for _, cc := range rtn.child[1:] {
    98  						if pindex >= len(types) {
    99  							return nil, cc.cfgErrorf("undefined type for %s", cc.ident)
   100  						}
   101  						it := types[pindex]
   102  						typeParam[cc.ident] = copyNode(cc, cc.anc, false)
   103  						typeParam[cc.ident].ident = it.id()
   104  						typeParam[cc.ident].typ = it
   105  						rtname += it.id() + ","
   106  						pindex++
   107  					}
   108  					rtname = strings.TrimSuffix(rtname, ",") + "]"
   109  				}
   110  			}
   111  
   112  			// Node is the type parameters list of a generic type.
   113  			if root.kind == typeSpec && n.anc == root && childPos(n) == 1 {
   114  				// Fill the types lookup table used for type substitution.
   115  				tname = n.anc.child[0].ident + "["
   116  				for _, c := range n.child {
   117  					l := len(c.child) - 1
   118  					for _, cc := range c.child[:l] {
   119  						if pindex >= len(types) {
   120  							return nil, cc.cfgErrorf("undefined type for %s", cc.ident)
   121  						}
   122  						it := types[pindex]
   123  						t, err := nodeType(c.interp, sc, c.child[l])
   124  						if err != nil {
   125  							return nil, err
   126  						}
   127  						if err := checkConstraint(types[pindex], t); err != nil {
   128  							return nil, err
   129  						}
   130  						typeParam[cc.ident] = copyNode(cc, cc.anc, false)
   131  						typeParam[cc.ident].ident = it.id()
   132  						typeParam[cc.ident].typ = it
   133  						tname += it.id() + ","
   134  						pindex++
   135  					}
   136  				}
   137  				tname = strings.TrimSuffix(tname, ",") + "]"
   138  				return nod, nil
   139  			}
   140  		}
   141  
   142  		for _, c := range n.child {
   143  			gn, err := gtree(c, nod)
   144  			if err != nil {
   145  				return nil, err
   146  			}
   147  			nod.child = append(nod.child, gn)
   148  		}
   149  		return nod, nil
   150  	}
   151  
   152  	if nod, found := root.interp.generic[sname]; found {
   153  		return nod, true, nil
   154  	}
   155  
   156  	r, err := gtree(root, root.anc)
   157  	if err != nil {
   158  		return nil, false, err
   159  	}
   160  	root.interp.generic[sname] = r
   161  	r.param = append(r.param, types...)
   162  	if tname != "" {
   163  		for _, nod := range fixNodes {
   164  			nod.ident = tname
   165  		}
   166  		r.child[0].ident = tname
   167  	}
   168  	if rtname != "" {
   169  		// Replace method receiver type by synthetized ident.
   170  		nod := r.child[0].child[0].child[1]
   171  		if recvrPtr {
   172  			nod = nod.child[0]
   173  		}
   174  		nod.kind = identExpr
   175  		nod.ident = rtname
   176  		nod.child = nil
   177  	}
   178  	// r.adot() // Used for debugging only.
   179  	return r, false, nil
   180  }
   181  
   182  func copyNode(n, anc *node, recursive bool) *node {
   183  	var i interface{}
   184  	nindex := atomic.AddInt64(&n.interp.nindex, 1)
   185  	nod := &node{
   186  		debug:  n.debug,
   187  		anc:    anc,
   188  		interp: n.interp,
   189  		index:  nindex,
   190  		level:  n.level,
   191  		nleft:  n.nleft,
   192  		nright: n.nright,
   193  		kind:   n.kind,
   194  		pos:    n.pos,
   195  		action: n.action,
   196  		gen:    n.gen,
   197  		val:    &i,
   198  		rval:   n.rval,
   199  		ident:  n.ident,
   200  		meta:   n.meta,
   201  	}
   202  	nod.start = nod
   203  	if recursive {
   204  		for _, c := range n.child {
   205  			nod.child = append(nod.child, copyNode(c, nod, true))
   206  		}
   207  	}
   208  	return nod
   209  }
   210  
   211  func inferTypesFromCall(sc *scope, fun *node, args []*node) ([]*itype, error) {
   212  	ftn := fun.typ.node
   213  	// Fill the map of parameter types, indexed by type param ident.
   214  	paramTypes := map[string]*itype{}
   215  	for _, c := range ftn.child[0].child {
   216  		typ, err := nodeType(fun.interp, sc, c.lastChild())
   217  		if err != nil {
   218  			return nil, err
   219  		}
   220  		for _, cc := range c.child[:len(c.child)-1] {
   221  			paramTypes[cc.ident] = typ
   222  		}
   223  	}
   224  
   225  	var inferTypes func(*itype, *itype) ([]*itype, error)
   226  	inferTypes = func(param, input *itype) ([]*itype, error) {
   227  		switch param.cat {
   228  		case chanT, ptrT, sliceT:
   229  			return inferTypes(param.val, input.val)
   230  
   231  		case mapT:
   232  			k, err := inferTypes(param.key, input.key)
   233  			if err != nil {
   234  				return nil, err
   235  			}
   236  			v, err := inferTypes(param.val, input.val)
   237  			if err != nil {
   238  				return nil, err
   239  			}
   240  			return append(k, v...), nil
   241  
   242  		case structT:
   243  			lt := []*itype{}
   244  			for i, f := range param.field {
   245  				nl, err := inferTypes(f.typ, input.field[i].typ)
   246  				if err != nil {
   247  					return nil, err
   248  				}
   249  				lt = append(lt, nl...)
   250  			}
   251  			return lt, nil
   252  
   253  		case funcT:
   254  			lt := []*itype{}
   255  			for i, t := range param.arg {
   256  				if i >= len(input.arg) {
   257  					break
   258  				}
   259  				nl, err := inferTypes(t, input.arg[i])
   260  				if err != nil {
   261  					return nil, err
   262  				}
   263  				lt = append(lt, nl...)
   264  			}
   265  			for i, t := range param.ret {
   266  				if i >= len(input.ret) {
   267  					break
   268  				}
   269  				nl, err := inferTypes(t, input.ret[i])
   270  				if err != nil {
   271  					return nil, err
   272  				}
   273  				lt = append(lt, nl...)
   274  			}
   275  			return lt, nil
   276  
   277  		case nilT:
   278  			if paramTypes[param.name] != nil {
   279  				return []*itype{input}, nil
   280  			}
   281  
   282  		case genericT:
   283  			return []*itype{input}, nil
   284  		}
   285  		return nil, nil
   286  	}
   287  
   288  	types := []*itype{}
   289  	for i, c := range ftn.child[1].child {
   290  		typ, err := nodeType(fun.interp, sc, c.lastChild())
   291  		if err != nil {
   292  			return nil, err
   293  		}
   294  		lt, err := inferTypes(typ, args[i].typ)
   295  		if err != nil {
   296  			return nil, err
   297  		}
   298  		types = append(types, lt...)
   299  	}
   300  
   301  	return types, nil
   302  }
   303  
   304  func checkConstraint(it, ct *itype) error {
   305  	if len(ct.constraint) == 0 && len(ct.ulconstraint) == 0 {
   306  		return nil
   307  	}
   308  	for _, c := range ct.constraint {
   309  		if it.equals(c) {
   310  			return nil
   311  		}
   312  	}
   313  	for _, c := range ct.ulconstraint {
   314  		if it.underlying().equals(c) {
   315  			return nil
   316  		}
   317  	}
   318  	return it.node.cfgErrorf("%s does not implement %s", it.id(), ct.id())
   319  }