github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/codegen/loaderx/types.go (about)

     1  package loaderx
     2  
     3  import (
     4  	"go/ast"
     5  	"go/types"
     6  
     7  	"golang.org/x/tools/go/loader"
     8  )
     9  
    10  func FuncDeclOfTypeFunc(pkgInfo *loader.PackageInfo, typeFunc *types.Func) *ast.FuncDecl {
    11  	for ident, def := range pkgInfo.Defs {
    12  		if typeFuncDef, ok := def.(*types.Func); ok {
    13  			if typeFuncDef == typeFunc {
    14  				return FuncDeclOf(ident, FileOf(ident, pkgInfo.Files...))
    15  			}
    16  		}
    17  	}
    18  	return nil
    19  }
    20  
    21  func FuncDeclOf(ident *ast.Ident, file *ast.File) (funcDecl *ast.FuncDecl) {
    22  	ast.Inspect(file, func(node ast.Node) bool {
    23  		if decl, ok := node.(*ast.FuncDecl); ok {
    24  			if decl.Name == ident {
    25  				funcDecl = decl
    26  				return false
    27  			}
    28  		}
    29  		return true
    30  	})
    31  	return
    32  }
    33  
    34  func GetIdentChainCallOfCallFun(expr ast.Expr) (list []*ast.Ident) {
    35  	switch expr.(type) {
    36  	case *ast.SelectorExpr:
    37  		selectorExpr := expr.(*ast.SelectorExpr)
    38  		list = append(list, GetIdentChainCallOfCallFun(selectorExpr.X)...)
    39  		list = append(list, selectorExpr.Sel)
    40  	case *ast.Ident:
    41  		list = append(list, expr.(*ast.Ident))
    42  	}
    43  	return
    44  }
    45  
    46  func ForEachFuncResult(program *loader.Program, typeFunc *types.Func, walker func(resultTypeAndValues ...types.TypeAndValue)) {
    47  	if typeFunc == nil {
    48  		return
    49  	}
    50  
    51  	pkgInfo := program.Package(typeFunc.Pkg().Path())
    52  	funcDecl := FuncDeclOfTypeFunc(pkgInfo, typeFunc)
    53  
    54  	if funcDecl == nil {
    55  		// todo find way to location interface
    56  		return
    57  	}
    58  
    59  	signature := typeFunc.Type().(*types.Signature)
    60  	results := signature.Results()
    61  	resultLength := results.Len()
    62  
    63  	returnStmtList := make([]*ast.ReturnStmt, 0)
    64  
    65  	// collect all return stmt
    66  	ast.Inspect(funcDecl, func(node ast.Node) bool {
    67  		switch node.(type) {
    68  		case *ast.FuncLit:
    69  			// skip func inline declaration
    70  			return false
    71  		case *ast.ReturnStmt:
    72  			returnStmtList = append(returnStmtList, node.(*ast.ReturnStmt))
    73  		}
    74  		return true
    75  	})
    76  
    77  	for _, returnStmt := range returnStmtList {
    78  		namedResults := make([]ast.Expr, resultLength)
    79  		typeAndValues := make([]types.TypeAndValue, resultLength)
    80  		skip := false
    81  
    82  		ast.Inspect(funcDecl, func(node ast.Node) bool {
    83  			switch node.(type) {
    84  			// skip for `switch-case` if return no in
    85  			case *ast.CaseClause:
    86  				caseBody := node.(*ast.CaseClause)
    87  				return !hasReturn(node) || caseBody.Pos() <= returnStmt.Pos() && returnStmt.Pos() < caseBody.End()
    88  			case *ast.IfStmt:
    89  				ifBody := node.(*ast.IfStmt).Body
    90  				return !hasReturn(node) || ifBody.Pos() <= returnStmt.Pos() && returnStmt.Pos() < ifBody.End()
    91  			case *ast.ReturnStmt:
    92  				currentReturnStmt := node.(*ast.ReturnStmt)
    93  				if !skip && currentReturnStmt == returnStmt {
    94  					if currentReturnStmt.Results == nil {
    95  						for i := 0; i < resultLength; i++ {
    96  							typeAndValues[i] = MustEvalExpr(program.Fset, pkgInfo.Pkg, namedResults[i])
    97  							typeAndValues[i] = patchTypeAndValue(results.At(i).Type(), typeAndValues[i])
    98  						}
    99  						walker(typeAndValues...)
   100  					} else {
   101  						if len(currentReturnStmt.Results) < resultLength {
   102  							if callExpr, ok := currentReturnStmt.Results[0].(*ast.CallExpr); ok {
   103  								identList := GetIdentChainCallOfCallFun(callExpr.Fun)
   104  								ForEachFuncResult(program, pkgInfo.ObjectOf(identList[len(identList)-1]).(*types.Func), walker)
   105  							}
   106  						} else {
   107  							for i := 0; i < resultLength; i++ {
   108  								typeAndValues[i] = MustEvalExpr(program.Fset, pkgInfo.Pkg, currentReturnStmt.Results[i])
   109  								typeAndValues[i] = patchTypeAndValue(results.At(i).Type(), typeAndValues[i])
   110  							}
   111  							walker(typeAndValues...)
   112  						}
   113  					}
   114  				}
   115  			case *ast.AssignStmt:
   116  				// only scan before return
   117  				if node != nil && node.Pos() >= returnStmt.Pos() {
   118  					return false
   119  				}
   120  
   121  				assignStmt := node.(*ast.AssignStmt)
   122  
   123  				if len(assignStmt.Rhs) == 1 {
   124  					if _, ok := assignStmt.Rhs[0].(*ast.FuncLit); ok {
   125  						return false
   126  					}
   127  				}
   128  
   129  				if len(assignStmt.Lhs) == len(assignStmt.Rhs) {
   130  					for i, expr := range assignStmt.Lhs {
   131  						if ident, ok := expr.(*ast.Ident); ok {
   132  							for resultIndex := 0; resultIndex < resultLength; resultIndex++ {
   133  								if pkgInfo.ObjectOf(ident) == results.At(resultIndex) {
   134  									namedResults[resultIndex] = assignStmt.Rhs[i]
   135  								}
   136  							}
   137  						}
   138  					}
   139  				} else {
   140  					if callExpr, ok := assignStmt.Rhs[0].(*ast.CallExpr); ok {
   141  						allNamedResult := false
   142  						for resultIndex := 0; resultIndex < resultLength; resultIndex++ {
   143  							switch lhs := assignStmt.Lhs[0].(type) {
   144  							case *ast.Ident:
   145  								if pkgInfo.ObjectOf(lhs) == results.At(resultIndex) {
   146  									allNamedResult = true
   147  								}
   148  							case *ast.SelectorExpr:
   149  								if pkgInfo.ObjectOf(lhs.Sel) == results.At(resultIndex) {
   150  									allNamedResult = true
   151  								}
   152  							}
   153  						}
   154  						if allNamedResult {
   155  							identList := GetIdentChainCallOfCallFun(callExpr.Fun)
   156  							skip = true
   157  							ForEachFuncResult(program, pkgInfo.ObjectOf(identList[len(identList)-1]).(*types.Func), walker)
   158  						}
   159  					}
   160  				}
   161  			}
   162  			return true
   163  		})
   164  	}
   165  }
   166  
   167  func hasReturn(node ast.Node) (ok bool) {
   168  	ast.Inspect(node, func(n ast.Node) bool {
   169  		switch n.(type) {
   170  		case *ast.ReturnStmt:
   171  			ok = true
   172  		}
   173  		return true
   174  	})
   175  	return
   176  }
   177  
   178  func patchTypeAndValue(tpe types.Type, typeAndValue types.TypeAndValue) types.TypeAndValue {
   179  	if typeAndValue.IsValue() && typeAndValue.Value == nil {
   180  		return types.TypeAndValue{
   181  			Type: typeAndValue.Type,
   182  		}
   183  	}
   184  	_, isInterface := tpe.(*types.Interface)
   185  	if !isInterface && typeAndValue.Type == types.Typ[types.UntypedNil] {
   186  		return types.TypeAndValue{
   187  			Type: tpe,
   188  		}
   189  	}
   190  	return typeAndValue
   191  }
   192  
   193  func MethodOf(named *types.Named, funcName string) (typeFunc *types.Func) {
   194  	for i := 0; i < named.NumMethods(); i++ {
   195  		method := named.Method(i)
   196  		if method.Name() == funcName {
   197  			return method
   198  		}
   199  	}
   200  
   201  	if structType, ok := named.Underlying().(*types.Struct); ok {
   202  		for i := 0; i < structType.NumFields(); i++ {
   203  			field := structType.Field(i)
   204  			if field.Anonymous() {
   205  				typeFuncAnonymous := MethodOf(IndirectType(field.Type()).(*types.Named), funcName)
   206  				if typeFunc == nil {
   207  					typeFunc = typeFuncAnonymous
   208  				} else if typeFuncAnonymous != nil {
   209  					typeFunc = nil
   210  				}
   211  			}
   212  		}
   213  	}
   214  	return
   215  }
   216  
   217  func IndirectType(tpe types.Type) types.Type {
   218  	switch tpe.(type) {
   219  	case *types.Pointer:
   220  		return IndirectType(tpe.(*types.Pointer).Elem())
   221  	default:
   222  		return tpe
   223  	}
   224  }