github.com/qiuhoude/go-web@v0.0.0-20220223060959-ab545e78f20d/prepare/13_generate/myenumstr/myenumstr.go (about)

     1  //+build ignore
     2  
     3  package main
     4  
     5  import (
     6  	"bytes"
     7  	"flag"
     8  	"fmt"
     9  	"go/ast"
    10  	"go/build"
    11  	"go/format"
    12  	"go/parser"
    13  	"go/token"
    14  	"io/ioutil"
    15  	"log"
    16  	"os"
    17  	"path/filepath"
    18  	"strings"
    19  	"text/template"
    20  )
    21  
    22  var (
    23  	pkgInfo *build.Package
    24  )
    25  var (
    26  	typeNames = flag.String("type", "", "必填,逗号连接的多个Type名")
    27  )
    28  
    29  func main() {
    30  	flag.Parse()
    31  	if len(*typeNames) == 0 {
    32  		log.Fatal("-type 必填")
    33  	}
    34  	consts := getConsts()    // 获取数据
    35  	src := genString(consts) // 转成[]byte
    36  	// 保存到文件
    37  	outputName := ""
    38  	if outputName == "" {
    39  		types := strings.Split(*typeNames, ",")
    40  		baseName := fmt.Sprintf("%s_string.go", types[0])
    41  		outputName = filepath.Join(".", strings.ToLower(baseName))
    42  	}
    43  	err := ioutil.WriteFile(outputName, src, 0644)
    44  	if err != nil {
    45  		log.Fatalf("writing output: %s", err)
    46  	}
    47  }
    48  func getConsts() map[string][]string {
    49  	//获得待处理的Type
    50  	types := strings.Split(*typeNames, ",")
    51  	typesMap := make(map[string][]string, len(types))
    52  	for _, v := range types {
    53  		typesMap[strings.TrimSpace(v)] = []string{}
    54  	}
    55  	//解析当前目录下包信息
    56  	var err error
    57  	pkgInfo, err = build.ImportDir(".", 0)
    58  	if err != nil {
    59  		log.Fatal(err)
    60  	}
    61  	//解析Go文件语法树,提取Status相关信息
    62  	// 我们约定所定义的枚举信息实际应该全部是Const。需从语法树中 提取出所有的Const,并判断类型是否符合条件。
    63  	fset := token.NewFileSet()
    64  	//解析go文件
    65  	for _, file := range pkgInfo.GoFiles {
    66  		// Go的 语法树库go/ast(abstract syntax tree)和解析库go/parser 语法树是按语句块()形成树结构
    67  		f, err := parser.ParseFile(fset, file, nil, 0)
    68  		if err != nil {
    69  			log.Fatal(err)
    70  		}
    71  		typ := ""
    72  		//遍历每个树节点
    73  		ast.Inspect(f, func(n ast.Node) bool {
    74  			decl, ok := n.(*ast.GenDecl)
    75  			// 只需要const
    76  			if !ok || decl.Tok != token.CONST {
    77  				return true
    78  			}
    79  			for _, spec := range decl.Specs {
    80  				vspec := spec.(*ast.ValueSpec)
    81  				if vspec.Type == nil && len(vspec.Values) > 0 {
    82  					// 排除 v = 1 这种结构
    83  					typ = ""
    84  					continue
    85  				}
    86  				//如果Type不为空,则确认typ
    87  				if vspec.Type != nil {
    88  					ident, ok := vspec.Type.(*ast.Ident)
    89  					if !ok {
    90  						continue
    91  					}
    92  					typ = ident.Name
    93  				}
    94  				//typ是否是需处理的类型
    95  				consts, ok := typesMap[typ]
    96  				if !ok {
    97  					continue
    98  				}
    99  				//将所有const变量名保存
   100  				for _, n := range vspec.Names {
   101  					consts = append(consts, n.Name)
   102  				}
   103  				typesMap[typ] = consts
   104  			}
   105  			return true
   106  		})
   107  	}
   108  	return typesMap
   109  }
   110  func genString(types map[string][]string) []byte {
   111  	const strTmp = `
   112  	package {{.pkg}}
   113  	import "fmt"
   114  	
   115  	{{range $typ,$consts :=.types}}
   116  	func (c {{$typ}}) String() string{
   117  		switch c { {{range $consts}}
   118  			case {{.}}:return "{{.}}"{{end}}
   119  		}
   120  		return fmt.Sprintf("Status(%d)", c)	
   121  	}
   122  	{{end}}
   123  	`
   124  	pkgName := os.Getenv("GOPACKAGE")
   125  	if pkgName == "" {
   126  		pkgName = pkgInfo.Name
   127  	}
   128  
   129  	data := map[string]interface{}{
   130  		"pkg":   pkgName,
   131  		"types": types,
   132  	}
   133  	//利用模板库,生成代码文件
   134  	t, err := template.New("").Parse(strTmp)
   135  	if err != nil {
   136  		log.Fatal(err)
   137  	}
   138  	buff := bytes.NewBufferString("")
   139  	err = t.Execute(buff, data)
   140  	if err != nil {
   141  		log.Fatal(err)
   142  	}
   143  	//格式化
   144  	src, err := format.Source(buff.Bytes())
   145  	if err != nil {
   146  		log.Fatal(err)
   147  	}
   148  	return src
   149  }