github.com/songzhibin97/gkit@v1.2.13/parser/parse_go/model.go (about)

     1  package parse_go
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"go/ast"
     7  	"io/ioutil"
     8  	"strings"
     9  	"text/template"
    10  
    11  	"github.com/songzhibin97/gkit/cache/buffer"
    12  	"github.com/songzhibin97/gkit/options"
    13  )
    14  
    15  type (
    16  	ParseStruct func(file *File)
    17  	ParseFunc   func(server *Server)
    18  	CheckFunc   func(g *GoParsePB) error
    19  )
    20  
    21  // GoParsePB .go 文件转成 pb文件
    22  type GoParsePB struct {
    23  	PkgName      string            // PkgName: 包名
    24  	FilePath     string            // FilePath: 文件的路径
    25  	Server       []*Server         // Server: 解析出来function的信息
    26  	Message      []*Message        // Message: 解析出struct的信息
    27  	Note         []*Note           // Note: 其他注释
    28  	Metas        map[string]string // Metas: 其他元信息
    29  	ParseStructs []ParseStruct
    30  	ParseFuncS   []ParseFunc
    31  	CheckFuncS   []CheckFunc
    32  }
    33  
    34  type Note struct {
    35  	IsUse bool // 判断作用域, 如果是 struct中 或者 func中代表已经使用
    36  	*ast.Comment
    37  }
    38  
    39  // Server Server对应Go func
    40  type Server struct {
    41  	Pos             int            // Pos: 函数的起始字节位置
    42  	End             int            // End: 函数的结束字节为止
    43  	Name            string         // Name: 函数名
    44  	ServerName      string         // ServerName: server name 通过 parseFunc 绑定
    45  	Method          string         // Method: method 通过 parseFunc 绑定
    46  	Router          string         // Router: router 通过 parseFunc 绑定
    47  	InputParameter  string         // InputParameter: 入参
    48  	OutputParameter string         // OutputParameter: 出参
    49  	Doc             []string       // Doc: 函数注释信息,可以通过自定义的 parseFunc 去进行解析
    50  	Notes           []*ast.Comment // Notes: 函数中的注释信息,用于埋点打桩
    51  
    52  }
    53  
    54  // CreateServer 创建Server
    55  func CreateServer(name string, pos, end int, doc []string, inputParameter string, outputParameter string) *Server {
    56  	return &Server{
    57  		Pos:             pos,
    58  		End:             end,
    59  		Name:            name,
    60  		Doc:             doc,
    61  		InputParameter:  inputParameter,
    62  		OutputParameter: outputParameter,
    63  	}
    64  }
    65  
    66  // Message Message对应struct
    67  type Message struct {
    68  	Pos   int            // Pos: struct的起始字节位置
    69  	End   int            // End: struct的结束字节为止
    70  	Name  string         // Name: struct name
    71  	Files []*File        // Files: 字段信息
    72  	Notes []*ast.Comment // Notes: struct的注释信息,用于埋点打桩
    73  }
    74  
    75  // AddFiles 添加字段信息
    76  func (m *Message) AddFiles(files ...*File) {
    77  	m.Files = append(m.Files, files...)
    78  }
    79  
    80  // CreateMessage 创建Message
    81  func CreateMessage(name string, pos, end int) *Message {
    82  	return &Message{
    83  		Name: name,
    84  		Pos:  pos,
    85  		End:  end,
    86  	}
    87  }
    88  
    89  // File 字段信息
    90  type File struct {
    91  	Tag    string // Tag: 字段的tag标记
    92  	Name   string // Name: 字段名
    93  	TypeGo string // TypeGo: 字段的原始类型
    94  	TypePB string // TypePB: 字段在proto中的类型
    95  }
    96  
    97  // CreateFile 创建字段信息
    98  func CreateFile(tag string, name string, tGo string, tPb string) *File {
    99  	return &File{
   100  		Tag:    tag,
   101  		Name:   name,
   102  		TypeGo: tGo,
   103  		TypePB: tPb,
   104  	}
   105  }
   106  
   107  // CreateGoParsePB 创建 GoParsePB Metas
   108  func CreateGoParsePB(pkgName string, filepath string, notes []*Note) *GoParsePB {
   109  	return &GoParsePB{
   110  		PkgName:  pkgName,
   111  		FilePath: filepath,
   112  		Metas:    make(map[string]string),
   113  		Note:     notes,
   114  	}
   115  }
   116  
   117  // addParseStruct 添加自定义解析struct内容
   118  func (g *GoParsePB) addParseStruct(parseTag ...ParseStruct) {
   119  	g.ParseStructs = append(g.ParseStructs, parseTag...)
   120  }
   121  
   122  // addParseFunc 添加自定义解析Func
   123  func (g *GoParsePB) addParseFunc(parseDocs ...ParseFunc) {
   124  	g.ParseFuncS = append(g.ParseFuncS, parseDocs...)
   125  }
   126  
   127  // addCheck 添加后续校验信息
   128  func (g *GoParsePB) addCheck(checkFunc ...CheckFunc) {
   129  	g.CheckFuncS = append(g.CheckFuncS, checkFunc...)
   130  }
   131  
   132  // parseStruct 解析struct信息
   133  func (g *GoParsePB) parseStruct(st *ast.GenDecl) {
   134  	for _, spec := range st.Specs {
   135  		if v, ok := spec.(*ast.TypeSpec); ok {
   136  			ret := CreateMessage(v.Name.Name, int(v.Pos()), int(v.End()))
   137  			if sType, ok := v.Type.(*ast.StructType); ok {
   138  
   139  				for _, field := range sType.Fields.List {
   140  					var tag, name string
   141  					if field.Tag != nil {
   142  						tag = field.Tag.Value
   143  					}
   144  					if field.Names != nil {
   145  						name = field.Names[0].Name
   146  					}
   147  
   148  					if field.Type != nil {
   149  						switch tType := field.Type.(type) {
   150  						case *ast.InterfaceType:
   151  							// 去除接口类型
   152  							continue
   153  
   154  						case *ast.Ident:
   155  							if tType.Obj != nil {
   156  								// 去除接口类型
   157  								continue
   158  							}
   159  							tGo := fmt.Sprintf(`%s`, tType.Name)
   160  							tPb := fmt.Sprintf("%s", GoTypeToPB(tType.Name))
   161  							if name == "" {
   162  								ret.AddFiles(CreateFile(tag, tPb, tGo, tPb))
   163  							} else {
   164  								ret.AddFiles(CreateFile(tag, name, tGo, tPb))
   165  							}
   166  
   167  						case *ast.ArrayType:
   168  							if aType, ok := tType.Elt.(*ast.Ident); ok {
   169  								tGo := fmt.Sprintf(`[]%s`, aType.Name)
   170  								if aType.Name == "byte" {
   171  									tPb := `bytes`
   172  									ret.AddFiles(CreateFile(tag, name, tGo, tPb))
   173  								} else {
   174  									tPb := fmt.Sprintf(`repeated %s`, GoTypeToPB(aType.Name))
   175  									ret.AddFiles(CreateFile(tag, name, tGo, tPb))
   176  								}
   177  
   178  							}
   179  						case *ast.MapType:
   180  							// 判断是否是 Ident
   181  							mKey, ok := tType.Key.(*ast.Ident)
   182  							if !ok {
   183  								continue
   184  							}
   185  							mValue, ok := tType.Key.(*ast.Ident)
   186  							if !ok {
   187  								continue
   188  							}
   189  							if IsMappingKey(GoTypeToPB(mKey.Name)) {
   190  								mk, mv := GoTypeToPB(mKey.Name), GoTypeToPB(mValue.Name)
   191  								tGo := fmt.Sprintf(`map[%s]%s`, mk, mv)
   192  								tPb := fmt.Sprintf(`map<%s,%s>`, mk, mv)
   193  								ret.AddFiles(CreateFile(tag, name, tGo, tPb))
   194  							}
   195  						}
   196  					}
   197  				}
   198  				// 执行tag解析
   199  				for _, f := range g.ParseStructs {
   200  					for _, file := range ret.Files {
   201  						f(file)
   202  					}
   203  				}
   204  
   205  				g.AddMessages(ret)
   206  			}
   207  		}
   208  	}
   209  }
   210  
   211  // parseFunc 解析函数信息
   212  func (g *GoParsePB) parseFunc(fn *ast.FuncDecl) {
   213  	var (
   214  		tags            []string
   215  		name            string
   216  		inputParameter  string
   217  		outputParameter string
   218  	)
   219  	if fn.Doc != nil {
   220  		tags = make([]string, len(fn.Doc.List))
   221  		for i, v := range fn.Doc.List {
   222  			tags[i] = v.Text
   223  		}
   224  	}
   225  	name = fn.Name.Name
   226  
   227  	if fn.Type != nil {
   228  		t := fn.Type
   229  		if t.Params != nil && t.Params.List != nil {
   230  			switch parameter := t.Params.List[0].Type.(type) {
   231  			case *ast.Ident:
   232  				inputParameter = parameter.Name
   233  			}
   234  		}
   235  		if t.Results != nil && t.Results.List != nil {
   236  			switch parameter := t.Results.List[0].Type.(type) {
   237  			case *ast.Ident:
   238  				outputParameter = parameter.Name
   239  			}
   240  		}
   241  	}
   242  	ret := CreateServer(name, int(fn.Pos()), int(fn.End()), tags, inputParameter, outputParameter)
   243  	for _, f := range g.ParseFuncS {
   244  		f(ret)
   245  	}
   246  	g.AddServers(ret)
   247  }
   248  
   249  // checkFormat 简单处理meta信息,将对应func、server中的注释移入
   250  func (g *GoParsePB) checkFormat() error {
   251  	// 之前已经调用过了,就直接返回了
   252  	if _, ok := g.Metas["ServerName"]; ok {
   253  		return nil
   254  	}
   255  	msgHashSet := make(map[string]struct{})
   256  
   257  	for _, message := range g.Message {
   258  		if _, ok := msgHashSet[message.Name]; ok {
   259  			return errors.New("message repeat")
   260  		}
   261  		for _, note := range g.Note {
   262  			if !note.IsUse && int(note.Pos()) > message.Pos && int(note.End()) <= message.End {
   263  				message.Notes = append(message.Notes, note.Comment)
   264  				note.IsUse = true
   265  			}
   266  		}
   267  		msgHashSet[message.Name] = struct{}{}
   268  	}
   269  	// serverHashSet := make(map[string]struct{})
   270  	for _, serve := range g.Server {
   271  		if serve.ServerName != "" {
   272  			g.Metas["ServerName"] = serve.ServerName
   273  		}
   274  		//if serve.Router == "" || serve.Method == "" {
   275  		//	return errors.New("server router or method is empty")
   276  		//}
   277  		//if _, ok := serverHashSet[serve.Router+serve.Method]; ok {
   278  		//	return errors.New("server router method repeat")
   279  		//}
   280  		//if _, ok := msgHashSet[serve.InputParameter]; !ok {
   281  		//	return errors.New("server input Parameters is empty")
   282  		//}
   283  		//if _, ok := msgHashSet[serve.OutputParameter]; !ok {
   284  		//	return errors.New("server output Parameters is empty")
   285  		//}
   286  		for _, note := range g.Note {
   287  			if !note.IsUse && int(note.Pos()) > serve.Pos && int(note.End()) <= serve.End {
   288  				serve.Notes = append(serve.Notes, note.Comment)
   289  				note.IsUse = true
   290  			}
   291  		}
   292  	}
   293  
   294  	for _, checkFunc := range g.CheckFuncS {
   295  		if err := checkFunc(g); err != nil {
   296  			return err
   297  		}
   298  	}
   299  
   300  	return nil
   301  }
   302  
   303  // Servers 返回解析后的所有Server对象
   304  func (g *GoParsePB) Servers() []*Server {
   305  	return g.Server
   306  }
   307  
   308  // Messages 返回解析后的所有Message对象
   309  func (g *GoParsePB) Messages() []*Message {
   310  	return g.Message
   311  }
   312  
   313  // AddServers 添加server信息
   314  func (g *GoParsePB) AddServers(servers ...*Server) {
   315  	g.Server = append(g.Server, servers...)
   316  }
   317  
   318  // AddMessages 添加message信息
   319  func (g *GoParsePB) AddMessages(messages ...*Message) {
   320  	g.Message = append(g.Message, messages...)
   321  }
   322  
   323  // Notes 获取注释消息
   324  func (g *GoParsePB) Notes() []*Note {
   325  	return g.Note
   326  }
   327  
   328  func (g *GoParsePB) AddNotes(notes ...*Note) {
   329  	g.Note = append(g.Note, notes...)
   330  }
   331  
   332  // PackageName 返回包名
   333  func (g *GoParsePB) PackageName() string {
   334  	return g.PkgName
   335  }
   336  
   337  // Generate 生成pb文件
   338  func (g *GoParsePB) Generate() string {
   339  	temp := `syntax = "proto3";
   340  package {{.PackageName}};
   341  
   342  // message{{range .Message}}
   343  message {{.Name}}{
   344  {{range  $index, $Message :=.Files}}   {{$Message.TypePB}} {{$Message.Name}} = {{addOne $index}};
   345  {{end}}}
   346  {{end}}
   347  
   348  // server
   349  service {{.Metas.ServerName}}{
   350  {{range .Server}}  rpc {{.Name }} ({{.InputParameter}}) returns ({{.OutputParameter}}) {
   351      option (google.api.http) = {
   352        {{.Method}} : "{{.Router}}"
   353      };
   354    }
   355  {{end}}}
   356  `
   357  	tmpl, err := template.New("GeneratePB").Funcs(template.FuncMap{"addOne": addOne}).Parse(temp)
   358  	if err != nil {
   359  		return ""
   360  	}
   361  	b := buffer.NewIoBuffer(1024)
   362  	err = tmpl.Execute(b, g)
   363  	if err != nil {
   364  		return ""
   365  	}
   366  	return b.String()
   367  }
   368  
   369  // PileDriving 源文件打桩
   370  // functionName: 指定函数内打桩,选传
   371  // startNotes,endNotes: 可以传两个打桩点,startNotes,endNotes中必填一个
   372  // insertCode: 插入代码段
   373  func (g *GoParsePB) PileDriving(functionName string, startNotes, endNotes string, insertCode string) error {
   374  	// srcData: 源文件内容
   375  	srcData, err := ioutil.ReadFile(g.FilePath)
   376  	if err != nil {
   377  		return err
   378  	}
   379  	startNotesPos, endNotesPos, err := g.pileFind(srcData, functionName, startNotes, endNotes)
   380  	srcData, err = g.pileDriving(srcData, startNotesPos, endNotesPos, insertCode)
   381  	if err != nil {
   382  		return err
   383  	}
   384  	return ioutil.WriteFile(g.FilePath, srcData, 0o600)
   385  }
   386  
   387  func (g *GoParsePB) PileDismantle(clearCode string) error {
   388  	// srcData: 源文件内容
   389  	srcData, err := ioutil.ReadFile(g.FilePath)
   390  	if err != nil {
   391  		return err
   392  	}
   393  
   394  	srcData, err = g.pileDismantle(srcData, clearCode)
   395  	if err != nil {
   396  		return err
   397  	}
   398  	return ioutil.WriteFile(g.FilePath, srcData, 0o600)
   399  }
   400  
   401  // pileFind: 找到打桩点,返回 startNotesPos、endNotesPos
   402  func (g *GoParsePB) pileFind(srcData []byte, functionName string, startNotes, endNotes string) (int, int, error) {
   403  	var (
   404  		startNotesPos = -1
   405  		endNotesPos   = len(srcData) + 1
   406  	)
   407  	// 判断是否指定functionName
   408  	if len(functionName) > 0 {
   409  		// 从函数中找桩
   410  		for _, server := range g.Server {
   411  			if server.Name != functionName {
   412  				continue
   413  			}
   414  			// 遍历notes看是否匹配
   415  			for _, note := range server.Notes {
   416  				if startNotesPos != -1 && endNotesPos != len(srcData)+1 {
   417  					break
   418  				}
   419  				if startNotesPos == -1 && strings.Contains(note.Text, startNotes) {
   420  					startNotesPos = int(note.Pos())
   421  				}
   422  				if endNotesPos == len(srcData)+1 && strings.Contains(note.Text, endNotes) {
   423  					endNotesPos = int(note.Pos())
   424  				}
   425  			}
   426  		}
   427  	} else {
   428  		// 从全局注释里面找
   429  		for _, note := range g.Note {
   430  			if startNotesPos != -1 && endNotesPos != len(srcData)+1 {
   431  				break
   432  			}
   433  			if startNotesPos == -1 && strings.Contains(note.Text, startNotes) {
   434  				startNotesPos = int(note.Pos())
   435  			}
   436  			if endNotesPos == len(srcData)+1 && strings.Contains(note.Text, endNotes) {
   437  				endNotesPos = int(note.Pos())
   438  			}
   439  		}
   440  	}
   441  	// 判断是否找到桩点
   442  	if startNotesPos == -1 && endNotesPos == len(srcData)+1 {
   443  		return 0, 0, errors.New("startNotes and endNotes is not find")
   444  	}
   445  	// 判断是否两个都找到
   446  	if startNotesPos != -1 && endNotesPos != len(srcData)+1 {
   447  		// 如果是同一行,需要处理
   448  		if startNotesPos == endNotesPos {
   449  			endNotesPos = startNotesPos + strings.Index(string(srcData[startNotesPos:]), startNotes)
   450  			for srcData[endNotesPos] != '/' {
   451  				endNotesPos--
   452  			}
   453  		}
   454  	}
   455  
   456  	return startNotesPos, endNotesPos, nil
   457  }
   458  
   459  // pileDriving: 打桩,返回已经打好的 srcData数据
   460  func (g *GoParsePB) pileDriving(srcData []byte, startNotesPos, endNotesPos int, insertCode string) ([]byte, error) {
   461  	var (
   462  		sym     []byte
   463  		oldTail []byte
   464  	)
   465  	if endNotesPos == len(srcData)+1 {
   466  		endNotesPos = startNotesPos
   467  		if checkRepeat(insertCode, string(srcData[endNotesPos:])) {
   468  			return nil, errors.New("重复添加")
   469  		}
   470  
   471  		// 收集标记符
   472  		endNotesPos--
   473  		symStart := endNotesPos
   474  		for symStart > 0 && srcData[symStart] != '\n' {
   475  			symStart--
   476  		}
   477  		sym = make([]byte, endNotesPos-symStart)
   478  		copy(sym, srcData[symStart:])
   479  
   480  		for endNotesPos < len(srcData)-1 && srcData[endNotesPos] != '\n' {
   481  			endNotesPos++
   482  		}
   483  
   484  		oldTail = make([]byte, len(srcData)-endNotesPos)
   485  		copy(oldTail, srcData[endNotesPos:])
   486  
   487  		srcData = srcData[:endNotesPos]
   488  	} else {
   489  		if checkRepeat(insertCode, string(srcData[:endNotesPos])) {
   490  			return nil, errors.New("重复添加")
   491  		}
   492  		endNotesPos--
   493  		symStart := endNotesPos
   494  		for symStart > 0 && srcData[symStart] != '\n' {
   495  			symStart--
   496  		}
   497  		sym = make([]byte, endNotesPos-symStart)
   498  		copy(sym, srcData[symStart:])
   499  
   500  		oldTail = make([]byte, len(srcData)-symStart)
   501  		copy(oldTail, srcData[symStart:])
   502  
   503  		srcData = srcData[:symStart]
   504  	}
   505  	srcData = append(srcData, sym...)
   506  	srcData = append(srcData, []byte(insertCode)...)
   507  	srcData = append(srcData, oldTail...)
   508  	return srcData, nil
   509  }
   510  
   511  func (g *GoParsePB) pileDismantle(srcData []byte, clearCode string) ([]byte, error) {
   512  	return cleanCode(clearCode, string(srcData))
   513  }
   514  
   515  // cleanCode: 清除桩内内容
   516  func cleanCode(clearCode string, srcData string) ([]byte, error) {
   517  	bf := make([]rune, 0, 1024)
   518  	for i, v := range srcData {
   519  		if v == '\n' {
   520  			if strings.TrimSpace(string(bf)) == clearCode {
   521  				return append([]byte(srcData[:i-len(bf)]), []byte(srcData[i+1:])...), nil
   522  			}
   523  			bf = (bf)[:0]
   524  			continue
   525  		}
   526  		bf = append(bf, v)
   527  	}
   528  	return []byte(srcData), errors.New("未找到内容")
   529  }
   530  
   531  // checkRepeat 检查是否重复
   532  func checkRepeat(code string, context string) bool {
   533  	bf := make([]rune, 0, 1024)
   534  	for _, v := range context {
   535  		if v == '\n' {
   536  			if strings.TrimSpace(string(bf)) == code {
   537  				return true
   538  			}
   539  			bf = (bf)[:0]
   540  			continue
   541  		}
   542  		bf = append(bf, v)
   543  	}
   544  	return false
   545  }
   546  
   547  // AddParseStruct 添加自定义解析struct内容
   548  func AddParseStruct(parseTag ...ParseStruct) options.Option {
   549  	return func(o interface{}) {
   550  		o.(*GoParsePB).addParseStruct(parseTag...)
   551  	}
   552  }
   553  
   554  // AddParseFunc 添加自定义解析Func
   555  func AddParseFunc(parseDocs ...ParseFunc) options.Option {
   556  	return func(o interface{}) {
   557  		o.(*GoParsePB).addParseFunc(parseDocs...)
   558  	}
   559  }
   560  
   561  // AddCheck 添加后续校验信息
   562  func AddCheck(checkFuncs ...CheckFunc) options.Option {
   563  	return func(o interface{}) {
   564  		o.(*GoParsePB).addCheck(checkFuncs...)
   565  	}
   566  }