github.com/astaxie/beego@v1.12.3/parser.go (about)

     1  // Copyright 2014 beego Author. All Rights Reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package beego
    16  
    17  import (
    18  	"encoding/json"
    19  	"errors"
    20  	"fmt"
    21  	"go/ast"
    22  	"go/parser"
    23  	"go/token"
    24  	"io/ioutil"
    25  	"os"
    26  	"path/filepath"
    27  	"regexp"
    28  	"sort"
    29  	"strconv"
    30  	"strings"
    31  	"unicode"
    32  
    33  	"github.com/astaxie/beego/context/param"
    34  	"github.com/astaxie/beego/logs"
    35  	"github.com/astaxie/beego/utils"
    36  )
    37  
    38  var globalRouterTemplate = `package {{.routersDir}}
    39  
    40  import (
    41  	"github.com/astaxie/beego"
    42  	"github.com/astaxie/beego/context/param"{{.globalimport}}
    43  )
    44  
    45  func init() {
    46  {{.globalinfo}}
    47  }
    48  `
    49  
    50  var (
    51  	lastupdateFilename = "lastupdate.tmp"
    52  	commentFilename    string
    53  	pkgLastupdate      map[string]int64
    54  	genInfoList        map[string][]ControllerComments
    55  
    56  	routerHooks = map[string]int{
    57  		"beego.BeforeStatic": BeforeStatic,
    58  		"beego.BeforeRouter": BeforeRouter,
    59  		"beego.BeforeExec":   BeforeExec,
    60  		"beego.AfterExec":    AfterExec,
    61  		"beego.FinishRouter": FinishRouter,
    62  	}
    63  
    64  	routerHooksMapping = map[int]string{
    65  		BeforeStatic: "beego.BeforeStatic",
    66  		BeforeRouter: "beego.BeforeRouter",
    67  		BeforeExec:   "beego.BeforeExec",
    68  		AfterExec:    "beego.AfterExec",
    69  		FinishRouter: "beego.FinishRouter",
    70  	}
    71  )
    72  
    73  const commentPrefix = "commentsRouter_"
    74  
    75  func init() {
    76  	pkgLastupdate = make(map[string]int64)
    77  }
    78  
    79  func parserPkg(pkgRealpath, pkgpath string) error {
    80  	rep := strings.NewReplacer("\\", "_", "/", "_", ".", "_")
    81  	commentFilename, _ = filepath.Rel(AppPath, pkgRealpath)
    82  	commentFilename = commentPrefix + rep.Replace(commentFilename) + ".go"
    83  	if !compareFile(pkgRealpath) {
    84  		logs.Info(pkgRealpath + " no changed")
    85  		return nil
    86  	}
    87  	genInfoList = make(map[string][]ControllerComments)
    88  	fileSet := token.NewFileSet()
    89  	astPkgs, err := parser.ParseDir(fileSet, pkgRealpath, func(info os.FileInfo) bool {
    90  		name := info.Name()
    91  		return !info.IsDir() && !strings.HasPrefix(name, ".") && strings.HasSuffix(name, ".go")
    92  	}, parser.ParseComments)
    93  
    94  	if err != nil {
    95  		return err
    96  	}
    97  	for _, pkg := range astPkgs {
    98  		for _, fl := range pkg.Files {
    99  			for _, d := range fl.Decls {
   100  				switch specDecl := d.(type) {
   101  				case *ast.FuncDecl:
   102  					if specDecl.Recv != nil {
   103  						exp, ok := specDecl.Recv.List[0].Type.(*ast.StarExpr) // Check that the type is correct first beforing throwing to parser
   104  						if ok {
   105  							parserComments(specDecl, fmt.Sprint(exp.X), pkgpath)
   106  						}
   107  					}
   108  				}
   109  			}
   110  		}
   111  	}
   112  	genRouterCode(pkgRealpath)
   113  	savetoFile(pkgRealpath)
   114  	return nil
   115  }
   116  
   117  type parsedComment struct {
   118  	routerPath string
   119  	methods    []string
   120  	params     map[string]parsedParam
   121  	filters    []parsedFilter
   122  	imports    []parsedImport
   123  }
   124  
   125  type parsedImport struct {
   126  	importPath  string
   127  	importAlias string
   128  }
   129  
   130  type parsedFilter struct {
   131  	pattern string
   132  	pos     int
   133  	filter  string
   134  	params  []bool
   135  }
   136  
   137  type parsedParam struct {
   138  	name     string
   139  	datatype string
   140  	location string
   141  	defValue string
   142  	required bool
   143  }
   144  
   145  func parserComments(f *ast.FuncDecl, controllerName, pkgpath string) error {
   146  	if f.Doc != nil {
   147  		parsedComments, err := parseComment(f.Doc.List)
   148  		if err != nil {
   149  			return err
   150  		}
   151  		for _, parsedComment := range parsedComments {
   152  			if parsedComment.routerPath != "" {
   153  				key := pkgpath + ":" + controllerName
   154  				cc := ControllerComments{}
   155  				cc.Method = f.Name.String()
   156  				cc.Router = parsedComment.routerPath
   157  				cc.AllowHTTPMethods = parsedComment.methods
   158  				cc.MethodParams = buildMethodParams(f.Type.Params.List, parsedComment)
   159  				cc.FilterComments = buildFilters(parsedComment.filters)
   160  				cc.ImportComments = buildImports(parsedComment.imports)
   161  				genInfoList[key] = append(genInfoList[key], cc)
   162  			}
   163  		}
   164  	}
   165  	return nil
   166  }
   167  
   168  func buildImports(pis []parsedImport) []*ControllerImportComments {
   169  	var importComments []*ControllerImportComments
   170  
   171  	for _, pi := range pis {
   172  		importComments = append(importComments, &ControllerImportComments{
   173  			ImportPath:  pi.importPath,
   174  			ImportAlias: pi.importAlias,
   175  		})
   176  	}
   177  
   178  	return importComments
   179  }
   180  
   181  func buildFilters(pfs []parsedFilter) []*ControllerFilterComments {
   182  	var filterComments []*ControllerFilterComments
   183  
   184  	for _, pf := range pfs {
   185  		var (
   186  			returnOnOutput bool
   187  			resetParams    bool
   188  		)
   189  
   190  		if len(pf.params) >= 1 {
   191  			returnOnOutput = pf.params[0]
   192  		}
   193  
   194  		if len(pf.params) >= 2 {
   195  			resetParams = pf.params[1]
   196  		}
   197  
   198  		filterComments = append(filterComments, &ControllerFilterComments{
   199  			Filter:         pf.filter,
   200  			Pattern:        pf.pattern,
   201  			Pos:            pf.pos,
   202  			ReturnOnOutput: returnOnOutput,
   203  			ResetParams:    resetParams,
   204  		})
   205  	}
   206  
   207  	return filterComments
   208  }
   209  
   210  func buildMethodParams(funcParams []*ast.Field, pc *parsedComment) []*param.MethodParam {
   211  	result := make([]*param.MethodParam, 0, len(funcParams))
   212  	for _, fparam := range funcParams {
   213  		for _, pName := range fparam.Names {
   214  			methodParam := buildMethodParam(fparam, pName.Name, pc)
   215  			result = append(result, methodParam)
   216  		}
   217  	}
   218  	return result
   219  }
   220  
   221  func buildMethodParam(fparam *ast.Field, name string, pc *parsedComment) *param.MethodParam {
   222  	options := []param.MethodParamOption{}
   223  	if cparam, ok := pc.params[name]; ok {
   224  		//Build param from comment info
   225  		name = cparam.name
   226  		if cparam.required {
   227  			options = append(options, param.IsRequired)
   228  		}
   229  		switch cparam.location {
   230  		case "body":
   231  			options = append(options, param.InBody)
   232  		case "header":
   233  			options = append(options, param.InHeader)
   234  		case "path":
   235  			options = append(options, param.InPath)
   236  		}
   237  		if cparam.defValue != "" {
   238  			options = append(options, param.Default(cparam.defValue))
   239  		}
   240  	} else {
   241  		if paramInPath(name, pc.routerPath) {
   242  			options = append(options, param.InPath)
   243  		}
   244  	}
   245  	return param.New(name, options...)
   246  }
   247  
   248  func paramInPath(name, route string) bool {
   249  	return strings.HasSuffix(route, ":"+name) ||
   250  		strings.Contains(route, ":"+name+"/")
   251  }
   252  
   253  var routeRegex = regexp.MustCompile(`@router\s+(\S+)(?:\s+\[(\S+)\])?`)
   254  
   255  func parseComment(lines []*ast.Comment) (pcs []*parsedComment, err error) {
   256  	pcs = []*parsedComment{}
   257  	params := map[string]parsedParam{}
   258  	filters := []parsedFilter{}
   259  	imports := []parsedImport{}
   260  
   261  	for _, c := range lines {
   262  		t := strings.TrimSpace(strings.TrimLeft(c.Text, "//"))
   263  		if strings.HasPrefix(t, "@Param") {
   264  			pv := getparams(strings.TrimSpace(strings.TrimLeft(t, "@Param")))
   265  			if len(pv) < 4 {
   266  				logs.Error("Invalid @Param format. Needs at least 4 parameters")
   267  			}
   268  			p := parsedParam{}
   269  			names := strings.SplitN(pv[0], "=>", 2)
   270  			p.name = names[0]
   271  			funcParamName := p.name
   272  			if len(names) > 1 {
   273  				funcParamName = names[1]
   274  			}
   275  			p.location = pv[1]
   276  			p.datatype = pv[2]
   277  			switch len(pv) {
   278  			case 5:
   279  				p.required, _ = strconv.ParseBool(pv[3])
   280  			case 6:
   281  				p.defValue = pv[3]
   282  				p.required, _ = strconv.ParseBool(pv[4])
   283  			}
   284  			params[funcParamName] = p
   285  		}
   286  	}
   287  
   288  	for _, c := range lines {
   289  		t := strings.TrimSpace(strings.TrimLeft(c.Text, "//"))
   290  		if strings.HasPrefix(t, "@Import") {
   291  			iv := getparams(strings.TrimSpace(strings.TrimLeft(t, "@Import")))
   292  			if len(iv) == 0 || len(iv) > 2 {
   293  				logs.Error("Invalid @Import format. Only accepts 1 or 2 parameters")
   294  				continue
   295  			}
   296  
   297  			p := parsedImport{}
   298  			p.importPath = iv[0]
   299  
   300  			if len(iv) == 2 {
   301  				p.importAlias = iv[1]
   302  			}
   303  
   304  			imports = append(imports, p)
   305  		}
   306  	}
   307  
   308  filterLoop:
   309  	for _, c := range lines {
   310  		t := strings.TrimSpace(strings.TrimLeft(c.Text, "//"))
   311  		if strings.HasPrefix(t, "@Filter") {
   312  			fv := getparams(strings.TrimSpace(strings.TrimLeft(t, "@Filter")))
   313  			if len(fv) < 3 {
   314  				logs.Error("Invalid @Filter format. Needs at least 3 parameters")
   315  				continue filterLoop
   316  			}
   317  
   318  			p := parsedFilter{}
   319  			p.pattern = fv[0]
   320  			posName := fv[1]
   321  			if pos, exists := routerHooks[posName]; exists {
   322  				p.pos = pos
   323  			} else {
   324  				logs.Error("Invalid @Filter pos: ", posName)
   325  				continue filterLoop
   326  			}
   327  
   328  			p.filter = fv[2]
   329  			fvParams := fv[3:]
   330  			for _, fvParam := range fvParams {
   331  				switch fvParam {
   332  				case "true":
   333  					p.params = append(p.params, true)
   334  				case "false":
   335  					p.params = append(p.params, false)
   336  				default:
   337  					logs.Error("Invalid @Filter param: ", fvParam)
   338  					continue filterLoop
   339  				}
   340  			}
   341  
   342  			filters = append(filters, p)
   343  		}
   344  	}
   345  
   346  	for _, c := range lines {
   347  		var pc = &parsedComment{}
   348  		pc.params = params
   349  		pc.filters = filters
   350  		pc.imports = imports
   351  
   352  		t := strings.TrimSpace(strings.TrimLeft(c.Text, "//"))
   353  		if strings.HasPrefix(t, "@router") {
   354  			t := strings.TrimSpace(strings.TrimLeft(c.Text, "//"))
   355  			matches := routeRegex.FindStringSubmatch(t)
   356  			if len(matches) == 3 {
   357  				pc.routerPath = matches[1]
   358  				methods := matches[2]
   359  				if methods == "" {
   360  					pc.methods = []string{"get"}
   361  					//pc.hasGet = true
   362  				} else {
   363  					pc.methods = strings.Split(methods, ",")
   364  					//pc.hasGet = strings.Contains(methods, "get")
   365  				}
   366  				pcs = append(pcs, pc)
   367  			} else {
   368  				return nil, errors.New("Router information is missing")
   369  			}
   370  		}
   371  	}
   372  	return
   373  }
   374  
   375  // direct copy from bee\g_docs.go
   376  // analysis params return []string
   377  // @Param	query		form	 string	true		"The email for login"
   378  // [query form string true "The email for login"]
   379  func getparams(str string) []string {
   380  	var s []rune
   381  	var j int
   382  	var start bool
   383  	var r []string
   384  	var quoted int8
   385  	for _, c := range str {
   386  		if unicode.IsSpace(c) && quoted == 0 {
   387  			if !start {
   388  				continue
   389  			} else {
   390  				start = false
   391  				j++
   392  				r = append(r, string(s))
   393  				s = make([]rune, 0)
   394  				continue
   395  			}
   396  		}
   397  
   398  		start = true
   399  		if c == '"' {
   400  			quoted ^= 1
   401  			continue
   402  		}
   403  		s = append(s, c)
   404  	}
   405  	if len(s) > 0 {
   406  		r = append(r, string(s))
   407  	}
   408  	return r
   409  }
   410  
   411  func genRouterCode(pkgRealpath string) {
   412  	os.Mkdir(getRouterDir(pkgRealpath), 0755)
   413  	logs.Info("generate router from comments")
   414  	var (
   415  		globalinfo   string
   416  		globalimport string
   417  		sortKey      []string
   418  	)
   419  	for k := range genInfoList {
   420  		sortKey = append(sortKey, k)
   421  	}
   422  	sort.Strings(sortKey)
   423  	for _, k := range sortKey {
   424  		cList := genInfoList[k]
   425  		sort.Sort(ControllerCommentsSlice(cList))
   426  		for _, c := range cList {
   427  			allmethod := "nil"
   428  			if len(c.AllowHTTPMethods) > 0 {
   429  				allmethod = "[]string{"
   430  				for _, m := range c.AllowHTTPMethods {
   431  					allmethod += `"` + m + `",`
   432  				}
   433  				allmethod = strings.TrimRight(allmethod, ",") + "}"
   434  			}
   435  
   436  			params := "nil"
   437  			if len(c.Params) > 0 {
   438  				params = "[]map[string]string{"
   439  				for _, p := range c.Params {
   440  					for k, v := range p {
   441  						params = params + `map[string]string{` + k + `:"` + v + `"},`
   442  					}
   443  				}
   444  				params = strings.TrimRight(params, ",") + "}"
   445  			}
   446  
   447  			methodParams := "param.Make("
   448  			if len(c.MethodParams) > 0 {
   449  				lines := make([]string, 0, len(c.MethodParams))
   450  				for _, m := range c.MethodParams {
   451  					lines = append(lines, fmt.Sprint(m))
   452  				}
   453  				methodParams += "\n				" +
   454  					strings.Join(lines, ",\n				") +
   455  					",\n			"
   456  			}
   457  			methodParams += ")"
   458  
   459  			imports := ""
   460  			if len(c.ImportComments) > 0 {
   461  				for _, i := range c.ImportComments {
   462  					var s string
   463  					if i.ImportAlias != "" {
   464  						s = fmt.Sprintf(`
   465  	%s "%s"`, i.ImportAlias, i.ImportPath)
   466  					} else {
   467  						s = fmt.Sprintf(`
   468  	"%s"`, i.ImportPath)
   469  					}
   470  					if !strings.Contains(globalimport, s) {
   471  						imports += s
   472  					}
   473  				}
   474  			}
   475  
   476  			filters := ""
   477  			if len(c.FilterComments) > 0 {
   478  				for _, f := range c.FilterComments {
   479  					filters += fmt.Sprintf(`                &beego.ControllerFilter{
   480                      Pattern: "%s",
   481                      Pos: %s,
   482                      Filter: %s,
   483                      ReturnOnOutput: %v,
   484                      ResetParams: %v,
   485                  },`, f.Pattern, routerHooksMapping[f.Pos], f.Filter, f.ReturnOnOutput, f.ResetParams)
   486  				}
   487  			}
   488  
   489  			if filters == "" {
   490  				filters = "nil"
   491  			} else {
   492  				filters = fmt.Sprintf(`[]*beego.ControllerFilter{
   493  %s
   494              }`, filters)
   495  			}
   496  
   497  			globalimport += imports
   498  
   499  			globalinfo = globalinfo + `
   500      beego.GlobalControllerRouter["` + k + `"] = append(beego.GlobalControllerRouter["` + k + `"],
   501          beego.ControllerComments{
   502              Method: "` + strings.TrimSpace(c.Method) + `",
   503              ` + `Router: "` + c.Router + `"` + `,
   504              AllowHTTPMethods: ` + allmethod + `,
   505              MethodParams: ` + methodParams + `,
   506              Filters: ` + filters + `,
   507              Params: ` + params + `})
   508  `
   509  		}
   510  	}
   511  
   512  	if globalinfo != "" {
   513  		f, err := os.Create(filepath.Join(getRouterDir(pkgRealpath), commentFilename))
   514  		if err != nil {
   515  			panic(err)
   516  		}
   517  		defer f.Close()
   518  
   519  		routersDir := AppConfig.DefaultString("routersdir", "routers")
   520  		content := strings.Replace(globalRouterTemplate, "{{.globalinfo}}", globalinfo, -1)
   521  		content = strings.Replace(content, "{{.routersDir}}", routersDir, -1)
   522  		content = strings.Replace(content, "{{.globalimport}}", globalimport, -1)
   523  		f.WriteString(content)
   524  	}
   525  }
   526  
   527  func compareFile(pkgRealpath string) bool {
   528  	if !utils.FileExists(filepath.Join(getRouterDir(pkgRealpath), commentFilename)) {
   529  		return true
   530  	}
   531  	if utils.FileExists(lastupdateFilename) {
   532  		content, err := ioutil.ReadFile(lastupdateFilename)
   533  		if err != nil {
   534  			return true
   535  		}
   536  		json.Unmarshal(content, &pkgLastupdate)
   537  		lastupdate, err := getpathTime(pkgRealpath)
   538  		if err != nil {
   539  			return true
   540  		}
   541  		if v, ok := pkgLastupdate[pkgRealpath]; ok {
   542  			if lastupdate <= v {
   543  				return false
   544  			}
   545  		}
   546  	}
   547  	return true
   548  }
   549  
   550  func savetoFile(pkgRealpath string) {
   551  	lastupdate, err := getpathTime(pkgRealpath)
   552  	if err != nil {
   553  		return
   554  	}
   555  	pkgLastupdate[pkgRealpath] = lastupdate
   556  	d, err := json.Marshal(pkgLastupdate)
   557  	if err != nil {
   558  		return
   559  	}
   560  	ioutil.WriteFile(lastupdateFilename, d, os.ModePerm)
   561  }
   562  
   563  func getpathTime(pkgRealpath string) (lastupdate int64, err error) {
   564  	fl, err := ioutil.ReadDir(pkgRealpath)
   565  	if err != nil {
   566  		return lastupdate, err
   567  	}
   568  	for _, f := range fl {
   569  		if lastupdate < f.ModTime().UnixNano() {
   570  			lastupdate = f.ModTime().UnixNano()
   571  		}
   572  	}
   573  	return lastupdate, nil
   574  }
   575  
   576  func getRouterDir(pkgRealpath string) string {
   577  	dir := filepath.Dir(pkgRealpath)
   578  	for {
   579  		routersDir := AppConfig.DefaultString("routersdir", "routers")
   580  		d := filepath.Join(dir, routersDir)
   581  		if utils.FileExists(d) {
   582  			return d
   583  		}
   584  
   585  		if r, _ := filepath.Rel(dir, AppPath); r == "." {
   586  			return d
   587  		}
   588  		// Parent dir.
   589  		dir = filepath.Dir(dir)
   590  	}
   591  }