github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/courier/status_error/gen_from_old/status_error_generator.go (about)

     1  package gen_from_old
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"go/ast"
     7  	"go/build"
     8  	"go/parser"
     9  	"go/types"
    10  	"net/http"
    11  	"path"
    12  	"path/filepath"
    13  	"reflect"
    14  	"sort"
    15  	"strconv"
    16  	"strings"
    17  
    18  	"golang.org/x/tools/go/loader"
    19  
    20  	"github.com/johnnyeven/libtools/codegen"
    21  	"github.com/johnnyeven/libtools/courier/status_error"
    22  )
    23  
    24  type StatusErrorGenerator struct {
    25  	DryRun        bool
    26  	pkgImportPath string
    27  	program       *loader.Program
    28  	statusErrors  map[int64]status_error.StatusError
    29  }
    30  
    31  func (g *StatusErrorGenerator) Load(cwd string) {
    32  	ldr := loader.Config{
    33  		AllowErrors: true,
    34  		ParserMode:  parser.ParseComments,
    35  	}
    36  
    37  	pkgImportPath := codegen.GetPackageImportPath(cwd)
    38  	ldr.Import(pkgImportPath)
    39  
    40  	p, err := ldr.Load()
    41  	if err != nil {
    42  		panic(err)
    43  	}
    44  
    45  	g.program = p
    46  	g.pkgImportPath = pkgImportPath
    47  	g.statusErrors = map[int64]status_error.StatusError{}
    48  }
    49  
    50  func (g *StatusErrorGenerator) Pick() {
    51  	statusErrorType := reflect.TypeOf(status_error.StatusError{})
    52  	statusErrorTypeFullName := fmt.Sprintf("%s.%s", statusErrorType.PkgPath(), statusErrorType.Name())
    53  
    54  	for pkg, pkgInfo := range g.program.AllPackages {
    55  		if pkg.Path() != g.pkgImportPath {
    56  			continue
    57  		}
    58  		for ident, obj := range pkgInfo.Defs {
    59  			if varObj, ok := obj.(*types.Var); ok {
    60  				if strings.HasSuffix(varObj.Type().String(), statusErrorTypeFullName) {
    61  					key := varObj.Name()
    62  					statusErr := status_error.StatusError{}
    63  					statusErr.Key = key
    64  
    65  					if valueSpec, ok := ident.Obj.Decl.(*ast.ValueSpec); ok {
    66  						ast.Inspect(valueSpec, func(node ast.Node) bool {
    67  							if keyValueExpr, ok := node.(*ast.KeyValueExpr); ok {
    68  								switch keyValueExpr.Key.(*ast.Ident).Name {
    69  								case "Code":
    70  									if basicLit, ok := keyValueExpr.Value.(*ast.BasicLit); ok {
    71  										statusErr.Code, _ = strconv.ParseInt(basicLit.Value, 10, 64)
    72  									}
    73  								case "Msg":
    74  									if basicLit, ok := keyValueExpr.Value.(*ast.BasicLit); ok {
    75  										statusErr.Msg, _ = strconv.Unquote(basicLit.Value)
    76  									}
    77  								case "Desc":
    78  									if basicLit, ok := keyValueExpr.Value.(*ast.BasicLit); ok {
    79  										statusErr.Desc, _ = strconv.Unquote(basicLit.Value)
    80  									}
    81  								case "CanBeErrorTalk":
    82  									if ident, ok := keyValueExpr.Value.(*ast.Ident); ok {
    83  										statusErr.CanBeErrorTalk = ident.Name == "true"
    84  									}
    85  								}
    86  							}
    87  							return true
    88  						})
    89  					}
    90  
    91  					if s, ok := g.statusErrors[statusErr.Code]; ok {
    92  						panic(fmt.Errorf("%d already used in %s", statusErr.Code, s.Error()))
    93  					}
    94  					g.statusErrors[statusErr.Code] = statusErr
    95  				}
    96  			}
    97  		}
    98  	}
    99  }
   100  
   101  func (g *StatusErrorGenerator) Output(cwd string) codegen.Outputs {
   102  	outputs := codegen.Outputs{}
   103  	codes := make([]int, 0)
   104  	for code := range g.statusErrors {
   105  		codes = append(codes, int(code))
   106  	}
   107  	sort.Ints(codes)
   108  
   109  	statusErrorGroups := make(map[int][]status_error.StatusError)
   110  
   111  	for _, code := range codes {
   112  		statueErr := g.statusErrors[int64(code)]
   113  		statusErrorGroups[statueErr.Status()] = append(statusErrorGroups[statueErr.Status()], statueErr)
   114  	}
   115  
   116  	p, _ := build.Import(g.pkgImportPath, "", build.ImportComment)
   117  	buf := bytes.NewBufferString(fmt.Sprintf(`package %s
   118  
   119  //go:generate tools gen error
   120  import (
   121  	"net/http"
   122  	"github.com/johnnyeven/libtools/courier/status_error"
   123  )
   124  `, p.Name))
   125  
   126  	for status, statusErrList := range statusErrorGroups {
   127  		buf.WriteString("const (\n")
   128  
   129  		index := 0
   130  		for i, statusErr := range statusErrList {
   131  			count := int(statusErr.Code) - status*1e3
   132  			firstLine := statusErr.Msg
   133  			if statusErr.CanBeErrorTalk {
   134  				firstLine = "@errTalk " + firstLine
   135  			}
   136  			comments := fmt.Sprintf(`
   137  // %s
   138  // %s`, firstLine, statusErr.Desc)
   139  
   140  			if i == 0 {
   141  				index = count - i
   142  				buf.WriteString(fmt.Sprintf(`%s
   143  %s status_error.StatusErrorCode = %s * 1e3 + iota + %d
   144  `, comments, statusErr.Key, httpCode(status), count),
   145  				)
   146  			} else {
   147  				index++
   148  				for count > index {
   149  					buf.WriteString(fmt.Sprintf("_%d\n", index))
   150  					index++
   151  				}
   152  				buf.WriteString(fmt.Sprintf(`%s
   153  %s
   154  `, comments, statusErr.Key))
   155  			}
   156  		}
   157  
   158  		buf.WriteString(")\n\n")
   159  	}
   160  
   161  	if g.DryRun {
   162  		fmt.Println(buf.String())
   163  	} else {
   164  		dir, _ := filepath.Rel(cwd, p.Dir)
   165  		outputs.Add(path.Join(dir, "status_err_codes.go"), buf.String())
   166  	}
   167  	return outputs
   168  }
   169  
   170  var httpCodes = map[int]string{
   171  	http.StatusBadRequest:          "http.StatusBadRequest",
   172  	http.StatusNotFound:            "http.StatusNotFound",
   173  	http.StatusForbidden:           "http.StatusForbidden",
   174  	http.StatusTooManyRequests:     "http.StatusTooManyRequests",
   175  	http.StatusConflict:            "http.StatusConflict",
   176  	http.StatusInternalServerError: "http.StatusInternalServerError",
   177  }
   178  
   179  func httpCode(code int) string {
   180  	if httpCode, ok := httpCodes[code]; ok {
   181  		return httpCode
   182  	}
   183  	return fmt.Sprintf("%d", code)
   184  }