github.com/sandwich-go/boost@v1.3.29/misc/goformat/func.go (about)

     1  package goformat
     2  
     3  import (
     4  	"go/ast"
     5  	"go/token"
     6  )
     7  
     8  // hasSingleCallReturnVal call 函数是否有单个返回值
     9  func hasSingleCallReturnVal(ce *ast.CallExpr) bool {
    10  	if id, ok0 := ce.Fun.(*ast.Ident); ok0 && id.Obj != nil {
    11  		if fn, ok1 := id.Obj.Decl.(*ast.FuncDecl); ok1 {
    12  			return len(fn.Type.Results.List) == 1
    13  		}
    14  	}
    15  	return false
    16  }
    17  
    18  type visitor struct {
    19  	enclosing *ast.FuncType                     // innermost enclosing func
    20  	returns   map[*ast.ReturnStmt]*ast.FuncType // potentially incomplete returns
    21  }
    22  
    23  func (v visitor) Visit(node ast.Node) ast.Visitor {
    24  	if node == nil {
    25  		return v
    26  	}
    27  	switch n := node.(type) {
    28  	case *ast.FuncDecl:
    29  		return visitor{enclosing: n.Type, returns: v.returns}
    30  	case *ast.FuncLit:
    31  		return visitor{enclosing: n.Type, returns: v.returns}
    32  	case *ast.ReturnStmt:
    33  		v.returns[n] = v.enclosing
    34  	}
    35  	return v
    36  }
    37  
    38  // fillReturnValues 补充call 函数返回值
    39  func fillReturnValues(f *ast.File) error {
    40  	incReturns := map[*ast.ReturnStmt]*ast.FuncType{}
    41  	ast.Walk(visitor{returns: incReturns}, f)
    42  
    43  returnsLoop:
    44  	for ret, ftyp := range incReturns {
    45  		if ftyp.Results == nil {
    46  			continue
    47  		}
    48  		numRVs := len(ret.Results)
    49  		if numRVs == len(ftyp.Results.List) {
    50  			continue
    51  		}
    52  		if numRVs == 0 {
    53  			continue
    54  		}
    55  		if numRVs > len(ftyp.Results.List) {
    56  			continue
    57  		}
    58  		if e, ok := ret.Results[0].(*ast.CallExpr); ok {
    59  			if !hasSingleCallReturnVal(e) {
    60  				continue
    61  			}
    62  		}
    63  		zvs := make([]ast.Expr, len(ftyp.Results.List)-numRVs)
    64  		for i, rt := range ftyp.Results.List[:len(zvs)] {
    65  			zv := newZeroValueNode(rt.Type)
    66  			if zv == nil {
    67  				continue returnsLoop
    68  			}
    69  			zvs[i] = zv
    70  		}
    71  		ret.Results = append(zvs, ret.Results...)
    72  	}
    73  	return nil
    74  }
    75  
    76  // newZeroValueNode 新建零值节点
    77  func newZeroValueNode(typ ast.Expr) ast.Expr {
    78  	switch v := typ.(type) {
    79  	case *ast.Ident:
    80  		switch v.Name {
    81  		case "uint8", "uint16", "uint32", "uint64", "int8", "int16", "int32", "int64", "byte", "rune", "uint", "int", "uintptr":
    82  			return &ast.BasicLit{Kind: token.INT, Value: "0"}
    83  		case "float32", "float64":
    84  			return &ast.BasicLit{Kind: token.FLOAT, Value: "0"}
    85  		case "complex64", "complex128":
    86  			return &ast.BasicLit{Kind: token.IMAG, Value: "0"}
    87  		case "bool":
    88  			return &ast.Ident{Name: "false"}
    89  		case "string":
    90  			return &ast.BasicLit{Kind: token.STRING, Value: `""`}
    91  		case "error":
    92  			return &ast.Ident{Name: "nil"}
    93  		}
    94  	case *ast.ArrayType:
    95  		if v.Len == nil {
    96  			// slice
    97  			return &ast.Ident{Name: "nil"}
    98  		}
    99  		return &ast.CompositeLit{Type: v}
   100  	case *ast.StarExpr:
   101  		return &ast.Ident{Name: "nil"}
   102  	}
   103  	return nil
   104  }
   105  
   106  func removeBareReturns(f *ast.File) error {
   107  	returns := map[*ast.ReturnStmt]*ast.FuncType{}
   108  	ast.Walk(visitor{returns: returns}, f)
   109  
   110  returnsLoop:
   111  	for ret, ftyp := range returns {
   112  		if ftyp.Results == nil {
   113  			continue
   114  		}
   115  		numRVs := len(ret.Results)
   116  		if numRVs == len(ftyp.Results.List) {
   117  			continue
   118  		}
   119  
   120  		if numRVs == 0 && len(ftyp.Results.List) > 0 {
   121  			zvs := make([]ast.Expr, len(ftyp.Results.List))
   122  			for i, rt := range ftyp.Results.List {
   123  				if len(rt.Names) == 0 {
   124  					continue returnsLoop
   125  				}
   126  				zv := &ast.Ident{Name: rt.Names[0].Name}
   127  				zvs[i] = zv
   128  			}
   129  			ret.Results = append(zvs, ret.Results...)
   130  		}
   131  	}
   132  	return nil
   133  }
   134  
   135  func containsMainFunc(file *ast.File) bool {
   136  	for _, decl := range file.Decls {
   137  		if f, ok := decl.(*ast.FuncDecl); ok {
   138  			if f.Name.Name != "main" {
   139  				continue
   140  			}
   141  			if len(f.Type.Params.List) != 0 {
   142  				continue
   143  			}
   144  			if f.Type.Results != nil && len(f.Type.Results.List) != 0 {
   145  				continue
   146  			}
   147  			return true
   148  		}
   149  	}
   150  	return false
   151  }