github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/courier/status_error/gen/status_error_generator.go (about)

     1  package gen
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"go/build"
     7  	"go/parser"
     8  	"go/types"
     9  	"path"
    10  	"path/filepath"
    11  	"reflect"
    12  	"sort"
    13  	"strconv"
    14  	"strings"
    15  
    16  	"golang.org/x/tools/go/loader"
    17  
    18  	"github.com/johnnyeven/libtools/codegen"
    19  	"github.com/johnnyeven/libtools/codegen/loaderx"
    20  	"github.com/johnnyeven/libtools/courier/status_error"
    21  )
    22  
    23  type StatusErrorGenerator struct {
    24  	pkgImportPath    string
    25  	program          *loader.Program
    26  	statusErrorCodes map[*types.Package]status_error.StatusErrorCodeMap
    27  }
    28  
    29  func (g *StatusErrorGenerator) Load(cwd string) {
    30  	ldr := loader.Config{
    31  		AllowErrors: true,
    32  		ParserMode:  parser.ParseComments,
    33  	}
    34  
    35  	pkgImportPath := codegen.GetPackageImportPath(cwd)
    36  	ldr.Import(pkgImportPath)
    37  
    38  	p, err := ldr.Load()
    39  	if err != nil {
    40  		panic(err)
    41  	}
    42  
    43  	g.program = p
    44  	g.pkgImportPath = pkgImportPath
    45  	g.statusErrorCodes = map[*types.Package]status_error.StatusErrorCodeMap{}
    46  }
    47  
    48  func (g *StatusErrorGenerator) Pick() {
    49  	statusErrorCodeType := reflect.TypeOf(status_error.StatusErrorCode(0))
    50  	statusErrorCodeTypeFullName := fmt.Sprintf("%s.%s", statusErrorCodeType.PkgPath(), statusErrorCodeType.Name())
    51  
    52  	for pkg, pkgInfo := range g.program.AllPackages {
    53  		if pkg.Path() != g.pkgImportPath {
    54  			continue
    55  		}
    56  		for ident, obj := range pkgInfo.Defs {
    57  			if constObj, ok := obj.(*types.Const); ok {
    58  				if strings.HasSuffix(constObj.Type().String(), statusErrorCodeTypeFullName) {
    59  					key := constObj.Name()
    60  					if key == "_" {
    61  						continue
    62  					}
    63  
    64  					doc := loaderx.CommentsOf(g.program.Fset, ident, pkgInfo.Files...)
    65  					code, _ := strconv.ParseInt(constObj.Val().String(), 10, 64)
    66  					msg, desc, canBeErrTalk := ParseStatusErrorDesc(doc)
    67  
    68  					if g.statusErrorCodes[pkg] == nil {
    69  						g.statusErrorCodes[pkg] = status_error.StatusErrorCodeMap{}
    70  					}
    71  					g.statusErrorCodes[pkg].Register(key, code, msg, desc, canBeErrTalk)
    72  				}
    73  			}
    74  		}
    75  	}
    76  }
    77  
    78  func (g *StatusErrorGenerator) Output(cwd string) codegen.Outputs {
    79  	statusErrorCodeType := reflect.TypeOf(status_error.StatusErrorCode(0))
    80  	outputs := codegen.Outputs{}
    81  	for pkg, statusErrorCodeMap := range g.statusErrorCodes {
    82  		p, _ := build.Import(pkg.Path(), "", build.FindOnly)
    83  		dir, _ := filepath.Rel(cwd, p.Dir)
    84  		content := fmt.Sprintf(`
    85  			package %s
    86  
    87  			import(
    88  				%s
    89  			)
    90  
    91  			%s `,
    92  			pkg.Name(),
    93  			strconv.Quote(statusErrorCodeType.PkgPath()),
    94  			g.toRegisterInit(statusErrorCodeMap),
    95  		)
    96  		outputs.Add(codegen.GeneratedSuffix(path.Join(dir, "status_err_codes.go")), content)
    97  	}
    98  	return outputs
    99  }
   100  
   101  func (g *StatusErrorGenerator) toRegisterInit(statusErrorCodeMap status_error.StatusErrorCodeMap) string {
   102  	buffer := bytes.Buffer{}
   103  	buffer.WriteString("func init () {")
   104  
   105  	registerMethod := "status_error.StatusErrorCodes"
   106  	pkgs := strings.Split(g.pkgImportPath, "/")
   107  	if strings.HasPrefix(registerMethod, pkgs[len(pkgs)-1]) {
   108  		registerMethod = "StatusErrorCodes"
   109  	}
   110  
   111  	statusErrorCodeList := []int{}
   112  	for _, statusErrorCode := range statusErrorCodeMap {
   113  		statusErrorCodeList = append(statusErrorCodeList, int(statusErrorCode.Code))
   114  	}
   115  	sort.Ints(statusErrorCodeList)
   116  
   117  	for _, code := range statusErrorCodeList {
   118  		statusErrorCode := statusErrorCodeMap[int64(code)]
   119  
   120  		buffer.WriteString(fmt.Sprintf(
   121  			"%s.Register(%s, %d, %s, %s, %s)\n",
   122  			registerMethod,
   123  			strconv.Quote(statusErrorCode.Key),
   124  			statusErrorCode.Code,
   125  			strconv.Quote(statusErrorCode.Msg),
   126  			strconv.Quote(statusErrorCode.Desc),
   127  			strconv.FormatBool(statusErrorCode.CanBeErrorTalk),
   128  		))
   129  	}
   130  
   131  	buffer.WriteString("}")
   132  	return buffer.String()
   133  }
   134  
   135  func ParseStatusErrorDesc(str string) (msg string, desc string, canBeErrTalk bool) {
   136  	lines := strings.Split(str, "\n")
   137  	firstLine := strings.Split(lines[0], "@errTalk")
   138  
   139  	if len(firstLine) > 1 {
   140  		canBeErrTalk = true
   141  		msg = strings.TrimSpace(firstLine[1])
   142  	} else {
   143  		canBeErrTalk = false
   144  		msg = strings.TrimSpace(firstLine[0])
   145  	}
   146  
   147  	if len(lines) > 1 {
   148  		desc = strings.TrimSpace(strings.Join(lines[1:], "\n"))
   149  	}
   150  	return
   151  }