github.com/artisanhe/tools@v1.0.1-0.20210607022958-19a8fef2eb04/courier/swagger/gen/router_scanner.go (about)

     1  package gen
     2  
     3  import (
     4  	"go/ast"
     5  	"go/types"
     6  	"reflect"
     7  	"strconv"
     8  	"strings"
     9  
    10  	"github.com/julienschmidt/httprouter"
    11  	"golang.org/x/tools/go/loader"
    12  
    13  	"github.com/artisanhe/tools/codegen/loaderx"
    14  	"github.com/artisanhe/tools/courier"
    15  )
    16  
    17  var (
    18  	courierPkgImportPath = "github.com/artisanhe/tools/courier"
    19  	routerTypeString     = reflectTypeString(reflect.TypeOf(new(courier.Router)))
    20  )
    21  
    22  func isRouterType(tpe types.Type) bool {
    23  	return strings.HasSuffix(tpe.String(), courierPkgImportPath+".Router")
    24  }
    25  
    26  func NewRouterScanner(program *loader.Program) *RouterScanner {
    27  	routerScanner := &RouterScanner{
    28  		program: program,
    29  		routers: map[*types.Var]*Router{},
    30  	}
    31  
    32  	routerScanner.init()
    33  
    34  	return routerScanner
    35  }
    36  
    37  type RouterScanner struct {
    38  	program *loader.Program
    39  	routers map[*types.Var]*Router
    40  }
    41  
    42  func (scanner *RouterScanner) Router(typeName *types.Var) *Router {
    43  	return scanner.routers[typeName]
    44  }
    45  
    46  type OperatorTypeName struct {
    47  	Path string
    48  	*types.TypeName
    49  }
    50  
    51  type OperatorTypeNames []*OperatorTypeName
    52  
    53  func OperatorTypeNameFromType(tpe types.Type) *OperatorTypeName {
    54  	switch tpe.(type) {
    55  	case *types.Named:
    56  		return &OperatorTypeName{
    57  			TypeName: tpe.(*types.Named).Obj(),
    58  		}
    59  	case *types.Pointer:
    60  		return &OperatorTypeName{
    61  			TypeName: tpe.(*types.Pointer).Elem().(*types.Named).Obj(),
    62  		}
    63  	}
    64  	return nil
    65  }
    66  
    67  func FromArgs(pkgInfo *loader.PackageInfo, args ...ast.Expr) OperatorTypeNames {
    68  	opTypeNames := OperatorTypeNames{}
    69  	for _, arg := range args {
    70  		opTypeName := OperatorTypeNameFromType(pkgInfo.TypeOf(arg))
    71  		if opTypeName == nil {
    72  			continue
    73  		}
    74  		if callExpr, ok := arg.(*ast.CallExpr); ok {
    75  			if selectorExpr, ok := callExpr.Fun.(*ast.SelectorExpr); ok {
    76  				if selectorExpr.Sel.Name == "Group" {
    77  					if strings.Contains(pkgInfo.ObjectOf(selectorExpr.Sel).Pkg().String(), courierPkgImportPath) {
    78  						switch v := callExpr.Args[0].(type) {
    79  						case *ast.BasicLit:
    80  							opTypeName.Path, _ = strconv.Unquote(v.Value)
    81  						}
    82  					}
    83  				}
    84  			}
    85  		}
    86  		opTypeNames = append(opTypeNames, opTypeName)
    87  
    88  	}
    89  	return opTypeNames
    90  }
    91  
    92  func (scanner *RouterScanner) init() {
    93  	for _, pkgInfo := range scanner.program.AllPackages {
    94  		for ident, obj := range pkgInfo.Defs {
    95  			if typeVar, ok := obj.(*types.Var); ok {
    96  				if typeVar != nil && !strings.HasSuffix(typeVar.Pkg().Path(), courierPkgImportPath) {
    97  					if isRouterType(typeVar.Type()) {
    98  						router := NewRouter()
    99  
   100  						ast.Inspect(ident.Obj.Decl.(ast.Node), func(node ast.Node) bool {
   101  							switch node.(type) {
   102  							case *ast.CallExpr:
   103  								callExpr := node.(*ast.CallExpr)
   104  								router.AppendOperators(FromArgs(pkgInfo, callExpr.Args...)...)
   105  								return false
   106  							}
   107  							return true
   108  						})
   109  
   110  						scanner.routers[typeVar] = router
   111  					}
   112  				}
   113  			}
   114  		}
   115  	}
   116  
   117  	for _, pkgInfo := range scanner.program.AllPackages {
   118  		for selectExpr, selection := range pkgInfo.Selections {
   119  			if selection.Obj() != nil {
   120  				if typeFunc, ok := selection.Obj().(*types.Func); ok {
   121  					recv := typeFunc.Type().(*types.Signature).Recv()
   122  					if recv != nil && isRouterType(recv.Type()) {
   123  						for typeVar, router := range scanner.routers {
   124  							switch selectExpr.Sel.Name {
   125  							case "Register":
   126  								if typeVar == pkgInfo.ObjectOf(IdentOfCallExprSelectExpr(selectExpr)) {
   127  									file := loaderx.FileOf(selectExpr, pkgInfo.Files...)
   128  									ast.Inspect(file, func(node ast.Node) bool {
   129  										switch node.(type) {
   130  										case *ast.CallExpr:
   131  											callExpr := node.(*ast.CallExpr)
   132  											if callExpr.Fun == selectExpr {
   133  												routerIdent := callExpr.Args[0]
   134  												switch routerIdent.(type) {
   135  												case *ast.Ident:
   136  													argTypeVar := pkgInfo.ObjectOf(routerIdent.(*ast.Ident)).(*types.Var)
   137  													if r, ok := scanner.routers[argTypeVar]; ok {
   138  														router.Register(r)
   139  													}
   140  												case *ast.SelectorExpr:
   141  													argTypeVar := pkgInfo.ObjectOf(routerIdent.(*ast.SelectorExpr).Sel).(*types.Var)
   142  													if r, ok := scanner.routers[argTypeVar]; ok {
   143  														router.Register(r)
   144  													}
   145  												case *ast.CallExpr:
   146  													callExprForRegister := routerIdent.(*ast.CallExpr)
   147  													router.With(FromArgs(pkgInfo, callExprForRegister.Args...)...)
   148  												}
   149  												return false
   150  											}
   151  										}
   152  										return true
   153  									})
   154  								}
   155  							}
   156  						}
   157  					}
   158  				}
   159  			}
   160  		}
   161  	}
   162  }
   163  
   164  func IdentOfCallExprSelectExpr(selectExpr *ast.SelectorExpr) *ast.Ident {
   165  	switch selectExpr.X.(type) {
   166  	case *ast.Ident:
   167  		return selectExpr.X.(*ast.Ident)
   168  	case *ast.SelectorExpr:
   169  		return selectExpr.X.(*ast.SelectorExpr).Sel
   170  	}
   171  	return nil
   172  }
   173  
   174  func NewRouter(operators ...*OperatorTypeName) *Router {
   175  	return &Router{
   176  		operators: operators,
   177  	}
   178  }
   179  
   180  type Router struct {
   181  	parent    *Router
   182  	operators []*OperatorTypeName
   183  	children  map[*Router]bool
   184  }
   185  
   186  func (router *Router) AppendOperators(operators ...*OperatorTypeName) {
   187  	router.operators = append(router.operators, operators...)
   188  }
   189  
   190  func (router *Router) With(operators ...*OperatorTypeName) {
   191  	router.Register(NewRouter(operators...))
   192  }
   193  
   194  func (router *Router) Register(r *Router) {
   195  	if router.children == nil {
   196  		router.children = map[*Router]bool{}
   197  	}
   198  	r.parent = router
   199  	router.children[r] = true
   200  }
   201  
   202  func (router *Router) Route(program *loader.Program) *Route {
   203  	parent := router.parent
   204  	operators := router.operators
   205  
   206  	for parent != nil {
   207  		operators = append(parent.operators, operators...)
   208  		parent = parent.parent
   209  	}
   210  
   211  	route := Route{
   212  		last:      router.children == nil,
   213  		operators: operators,
   214  	}
   215  
   216  	route.SetMethod(program)
   217  	route.SetPath(program)
   218  
   219  	return &route
   220  }
   221  
   222  func (router *Router) Routes(program *loader.Program) (routes []*Route) {
   223  	for child := range router.children {
   224  		route := child.Route(program)
   225  		if route.last {
   226  			routes = append(routes, route)
   227  		}
   228  		if child.children != nil {
   229  			routes = append(routes, child.Routes(program)...)
   230  		}
   231  	}
   232  	return routes
   233  }
   234  
   235  type Route struct {
   236  	Method    string
   237  	Path      string
   238  	last      bool
   239  	operators []*OperatorTypeName
   240  }
   241  
   242  func (route *Route) SetPath(program *loader.Program) {
   243  	p := "/"
   244  	for _, operator := range route.operators {
   245  		if operator.Path != "" {
   246  			p += operator.Path
   247  			continue
   248  		}
   249  
   250  		typeFunc := loaderx.MethodOf(operator.Type().(*types.Named), "Path")
   251  
   252  		loaderx.ForEachFuncResult(program, typeFunc, func(resultTypeAndValues ...types.TypeAndValue) {
   253  			if resultTypeAndValues[0].IsValue() {
   254  				p += getConstVal(resultTypeAndValues[0].Value).(string)
   255  			}
   256  		})
   257  	}
   258  	route.Path = httprouter.CleanPath(p)
   259  }
   260  
   261  func (route *Route) SetMethod(program *loader.Program) {
   262  	if len(route.operators) > 0 {
   263  		operator := route.operators[len(route.operators)-1]
   264  		typeFunc := loaderx.MethodOf(operator.Type().(*types.Named), "Method")
   265  
   266  		loaderx.ForEachFuncResult(program, typeFunc, func(resultTypeAndValues ...types.TypeAndValue) {
   267  			if resultTypeAndValues[0].IsValue() {
   268  				route.Method = getConstVal(resultTypeAndValues[0].Value).(string)
   269  			}
   270  		})
   271  	}
   272  }