github.com/sandwich-go/boost@v1.3.29/misc/goformat/func.go (about) 1 package goformat 2 3 import ( 4 "go/ast" 5 "go/token" 6 ) 7 8 // hasSingleCallReturnVal call 函数是否有单个返回值 9 func hasSingleCallReturnVal(ce *ast.CallExpr) bool { 10 if id, ok0 := ce.Fun.(*ast.Ident); ok0 && id.Obj != nil { 11 if fn, ok1 := id.Obj.Decl.(*ast.FuncDecl); ok1 { 12 return len(fn.Type.Results.List) == 1 13 } 14 } 15 return false 16 } 17 18 type visitor struct { 19 enclosing *ast.FuncType // innermost enclosing func 20 returns map[*ast.ReturnStmt]*ast.FuncType // potentially incomplete returns 21 } 22 23 func (v visitor) Visit(node ast.Node) ast.Visitor { 24 if node == nil { 25 return v 26 } 27 switch n := node.(type) { 28 case *ast.FuncDecl: 29 return visitor{enclosing: n.Type, returns: v.returns} 30 case *ast.FuncLit: 31 return visitor{enclosing: n.Type, returns: v.returns} 32 case *ast.ReturnStmt: 33 v.returns[n] = v.enclosing 34 } 35 return v 36 } 37 38 // fillReturnValues 补充call 函数返回值 39 func fillReturnValues(f *ast.File) error { 40 incReturns := map[*ast.ReturnStmt]*ast.FuncType{} 41 ast.Walk(visitor{returns: incReturns}, f) 42 43 returnsLoop: 44 for ret, ftyp := range incReturns { 45 if ftyp.Results == nil { 46 continue 47 } 48 numRVs := len(ret.Results) 49 if numRVs == len(ftyp.Results.List) { 50 continue 51 } 52 if numRVs == 0 { 53 continue 54 } 55 if numRVs > len(ftyp.Results.List) { 56 continue 57 } 58 if e, ok := ret.Results[0].(*ast.CallExpr); ok { 59 if !hasSingleCallReturnVal(e) { 60 continue 61 } 62 } 63 zvs := make([]ast.Expr, len(ftyp.Results.List)-numRVs) 64 for i, rt := range ftyp.Results.List[:len(zvs)] { 65 zv := newZeroValueNode(rt.Type) 66 if zv == nil { 67 continue returnsLoop 68 } 69 zvs[i] = zv 70 } 71 ret.Results = append(zvs, ret.Results...) 72 } 73 return nil 74 } 75 76 // newZeroValueNode 新建零值节点 77 func newZeroValueNode(typ ast.Expr) ast.Expr { 78 switch v := typ.(type) { 79 case *ast.Ident: 80 switch v.Name { 81 case "uint8", "uint16", "uint32", "uint64", "int8", "int16", "int32", "int64", "byte", "rune", "uint", "int", "uintptr": 82 return &ast.BasicLit{Kind: token.INT, Value: "0"} 83 case "float32", "float64": 84 return &ast.BasicLit{Kind: token.FLOAT, Value: "0"} 85 case "complex64", "complex128": 86 return &ast.BasicLit{Kind: token.IMAG, Value: "0"} 87 case "bool": 88 return &ast.Ident{Name: "false"} 89 case "string": 90 return &ast.BasicLit{Kind: token.STRING, Value: `""`} 91 case "error": 92 return &ast.Ident{Name: "nil"} 93 } 94 case *ast.ArrayType: 95 if v.Len == nil { 96 // slice 97 return &ast.Ident{Name: "nil"} 98 } 99 return &ast.CompositeLit{Type: v} 100 case *ast.StarExpr: 101 return &ast.Ident{Name: "nil"} 102 } 103 return nil 104 } 105 106 func removeBareReturns(f *ast.File) error { 107 returns := map[*ast.ReturnStmt]*ast.FuncType{} 108 ast.Walk(visitor{returns: returns}, f) 109 110 returnsLoop: 111 for ret, ftyp := range returns { 112 if ftyp.Results == nil { 113 continue 114 } 115 numRVs := len(ret.Results) 116 if numRVs == len(ftyp.Results.List) { 117 continue 118 } 119 120 if numRVs == 0 && len(ftyp.Results.List) > 0 { 121 zvs := make([]ast.Expr, len(ftyp.Results.List)) 122 for i, rt := range ftyp.Results.List { 123 if len(rt.Names) == 0 { 124 continue returnsLoop 125 } 126 zv := &ast.Ident{Name: rt.Names[0].Name} 127 zvs[i] = zv 128 } 129 ret.Results = append(zvs, ret.Results...) 130 } 131 } 132 return nil 133 } 134 135 func containsMainFunc(file *ast.File) bool { 136 for _, decl := range file.Decls { 137 if f, ok := decl.(*ast.FuncDecl); ok { 138 if f.Name.Name != "main" { 139 continue 140 } 141 if len(f.Type.Params.List) != 0 { 142 continue 143 } 144 if f.Type.Results != nil && len(f.Type.Results.List) != 0 { 145 continue 146 } 147 return true 148 } 149 } 150 return false 151 }