github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/codegen/loaderx/types.go (about) 1 package loaderx 2 3 import ( 4 "go/ast" 5 "go/types" 6 7 "golang.org/x/tools/go/loader" 8 ) 9 10 func FuncDeclOfTypeFunc(pkgInfo *loader.PackageInfo, typeFunc *types.Func) *ast.FuncDecl { 11 for ident, def := range pkgInfo.Defs { 12 if typeFuncDef, ok := def.(*types.Func); ok { 13 if typeFuncDef == typeFunc { 14 return FuncDeclOf(ident, FileOf(ident, pkgInfo.Files...)) 15 } 16 } 17 } 18 return nil 19 } 20 21 func FuncDeclOf(ident *ast.Ident, file *ast.File) (funcDecl *ast.FuncDecl) { 22 ast.Inspect(file, func(node ast.Node) bool { 23 if decl, ok := node.(*ast.FuncDecl); ok { 24 if decl.Name == ident { 25 funcDecl = decl 26 return false 27 } 28 } 29 return true 30 }) 31 return 32 } 33 34 func GetIdentChainCallOfCallFun(expr ast.Expr) (list []*ast.Ident) { 35 switch expr.(type) { 36 case *ast.SelectorExpr: 37 selectorExpr := expr.(*ast.SelectorExpr) 38 list = append(list, GetIdentChainCallOfCallFun(selectorExpr.X)...) 39 list = append(list, selectorExpr.Sel) 40 case *ast.Ident: 41 list = append(list, expr.(*ast.Ident)) 42 } 43 return 44 } 45 46 func ForEachFuncResult(program *loader.Program, typeFunc *types.Func, walker func(resultTypeAndValues ...types.TypeAndValue)) { 47 if typeFunc == nil { 48 return 49 } 50 51 pkgInfo := program.Package(typeFunc.Pkg().Path()) 52 funcDecl := FuncDeclOfTypeFunc(pkgInfo, typeFunc) 53 54 if funcDecl == nil { 55 // todo find way to location interface 56 return 57 } 58 59 signature := typeFunc.Type().(*types.Signature) 60 results := signature.Results() 61 resultLength := results.Len() 62 63 returnStmtList := make([]*ast.ReturnStmt, 0) 64 65 // collect all return stmt 66 ast.Inspect(funcDecl, func(node ast.Node) bool { 67 switch node.(type) { 68 case *ast.FuncLit: 69 // skip func inline declaration 70 return false 71 case *ast.ReturnStmt: 72 returnStmtList = append(returnStmtList, node.(*ast.ReturnStmt)) 73 } 74 return true 75 }) 76 77 for _, returnStmt := range returnStmtList { 78 namedResults := make([]ast.Expr, resultLength) 79 typeAndValues := make([]types.TypeAndValue, resultLength) 80 skip := false 81 82 ast.Inspect(funcDecl, func(node ast.Node) bool { 83 switch node.(type) { 84 // skip for `switch-case` if return no in 85 case *ast.CaseClause: 86 caseBody := node.(*ast.CaseClause) 87 return !hasReturn(node) || caseBody.Pos() <= returnStmt.Pos() && returnStmt.Pos() < caseBody.End() 88 case *ast.IfStmt: 89 ifBody := node.(*ast.IfStmt).Body 90 return !hasReturn(node) || ifBody.Pos() <= returnStmt.Pos() && returnStmt.Pos() < ifBody.End() 91 case *ast.ReturnStmt: 92 currentReturnStmt := node.(*ast.ReturnStmt) 93 if !skip && currentReturnStmt == returnStmt { 94 if currentReturnStmt.Results == nil { 95 for i := 0; i < resultLength; i++ { 96 typeAndValues[i] = MustEvalExpr(program.Fset, pkgInfo.Pkg, namedResults[i]) 97 typeAndValues[i] = patchTypeAndValue(results.At(i).Type(), typeAndValues[i]) 98 } 99 walker(typeAndValues...) 100 } else { 101 if len(currentReturnStmt.Results) < resultLength { 102 if callExpr, ok := currentReturnStmt.Results[0].(*ast.CallExpr); ok { 103 identList := GetIdentChainCallOfCallFun(callExpr.Fun) 104 ForEachFuncResult(program, pkgInfo.ObjectOf(identList[len(identList)-1]).(*types.Func), walker) 105 } 106 } else { 107 for i := 0; i < resultLength; i++ { 108 typeAndValues[i] = MustEvalExpr(program.Fset, pkgInfo.Pkg, currentReturnStmt.Results[i]) 109 typeAndValues[i] = patchTypeAndValue(results.At(i).Type(), typeAndValues[i]) 110 } 111 walker(typeAndValues...) 112 } 113 } 114 } 115 case *ast.AssignStmt: 116 // only scan before return 117 if node != nil && node.Pos() >= returnStmt.Pos() { 118 return false 119 } 120 121 assignStmt := node.(*ast.AssignStmt) 122 123 if len(assignStmt.Rhs) == 1 { 124 if _, ok := assignStmt.Rhs[0].(*ast.FuncLit); ok { 125 return false 126 } 127 } 128 129 if len(assignStmt.Lhs) == len(assignStmt.Rhs) { 130 for i, expr := range assignStmt.Lhs { 131 if ident, ok := expr.(*ast.Ident); ok { 132 for resultIndex := 0; resultIndex < resultLength; resultIndex++ { 133 if pkgInfo.ObjectOf(ident) == results.At(resultIndex) { 134 namedResults[resultIndex] = assignStmt.Rhs[i] 135 } 136 } 137 } 138 } 139 } else { 140 if callExpr, ok := assignStmt.Rhs[0].(*ast.CallExpr); ok { 141 allNamedResult := false 142 for resultIndex := 0; resultIndex < resultLength; resultIndex++ { 143 switch lhs := assignStmt.Lhs[0].(type) { 144 case *ast.Ident: 145 if pkgInfo.ObjectOf(lhs) == results.At(resultIndex) { 146 allNamedResult = true 147 } 148 case *ast.SelectorExpr: 149 if pkgInfo.ObjectOf(lhs.Sel) == results.At(resultIndex) { 150 allNamedResult = true 151 } 152 } 153 } 154 if allNamedResult { 155 identList := GetIdentChainCallOfCallFun(callExpr.Fun) 156 skip = true 157 ForEachFuncResult(program, pkgInfo.ObjectOf(identList[len(identList)-1]).(*types.Func), walker) 158 } 159 } 160 } 161 } 162 return true 163 }) 164 } 165 } 166 167 func hasReturn(node ast.Node) (ok bool) { 168 ast.Inspect(node, func(n ast.Node) bool { 169 switch n.(type) { 170 case *ast.ReturnStmt: 171 ok = true 172 } 173 return true 174 }) 175 return 176 } 177 178 func patchTypeAndValue(tpe types.Type, typeAndValue types.TypeAndValue) types.TypeAndValue { 179 if typeAndValue.IsValue() && typeAndValue.Value == nil { 180 return types.TypeAndValue{ 181 Type: typeAndValue.Type, 182 } 183 } 184 _, isInterface := tpe.(*types.Interface) 185 if !isInterface && typeAndValue.Type == types.Typ[types.UntypedNil] { 186 return types.TypeAndValue{ 187 Type: tpe, 188 } 189 } 190 return typeAndValue 191 } 192 193 func MethodOf(named *types.Named, funcName string) (typeFunc *types.Func) { 194 for i := 0; i < named.NumMethods(); i++ { 195 method := named.Method(i) 196 if method.Name() == funcName { 197 return method 198 } 199 } 200 201 if structType, ok := named.Underlying().(*types.Struct); ok { 202 for i := 0; i < structType.NumFields(); i++ { 203 field := structType.Field(i) 204 if field.Anonymous() { 205 typeFuncAnonymous := MethodOf(IndirectType(field.Type()).(*types.Named), funcName) 206 if typeFunc == nil { 207 typeFunc = typeFuncAnonymous 208 } else if typeFuncAnonymous != nil { 209 typeFunc = nil 210 } 211 } 212 } 213 } 214 return 215 } 216 217 func IndirectType(tpe types.Type) types.Type { 218 switch tpe.(type) { 219 case *types.Pointer: 220 return IndirectType(tpe.(*types.Pointer).Elem()) 221 default: 222 return tpe 223 } 224 }