github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/courier/swagger/gen/router_scanner.go (about)

     1  package gen
     2  
     3  import (
     4  	"go/ast"
     5  	"go/types"
     6  	"reflect"
     7  	"strings"
     8  
     9  	"github.com/julienschmidt/httprouter"
    10  	"golang.org/x/tools/go/loader"
    11  
    12  	"github.com/johnnyeven/libtools/codegen/loaderx"
    13  	"github.com/johnnyeven/libtools/courier"
    14  )
    15  
    16  var (
    17  	courierPkgImportPath = "github.com/johnnyeven/libtools/courier"
    18  	routerTypeString     = reflectTypeString(reflect.TypeOf(new(courier.Router)))
    19  )
    20  
    21  func isRouterType(tpe types.Type) bool {
    22  	return strings.HasSuffix(tpe.String(), routerTypeString)
    23  }
    24  
    25  func NewRouterScanner(program *loader.Program) *RouterScanner {
    26  	routerScanner := &RouterScanner{
    27  		program: program,
    28  		routers: map[*types.Var]*Router{},
    29  	}
    30  
    31  	routerScanner.init()
    32  
    33  	return routerScanner
    34  }
    35  
    36  type RouterScanner struct {
    37  	program *loader.Program
    38  	routers map[*types.Var]*Router
    39  }
    40  
    41  func (scanner *RouterScanner) Router(typeName *types.Var) *Router {
    42  	return scanner.routers[typeName]
    43  }
    44  
    45  type OperatorTypeNames []*types.TypeName
    46  
    47  func (ops OperatorTypeNames) Append(tpe types.Type) OperatorTypeNames {
    48  	switch tpe.(type) {
    49  	case *types.Named:
    50  		return append(ops, tpe.(*types.Named).Obj())
    51  	case *types.Pointer:
    52  		return append(ops, tpe.(*types.Pointer).Elem().(*types.Named).Obj())
    53  	}
    54  	return ops
    55  }
    56  
    57  func FromArgs(pkgInfo *loader.PackageInfo, args ...ast.Expr) OperatorTypeNames {
    58  	opTypeNames := OperatorTypeNames{}
    59  	for _, arg := range args {
    60  		opTypeNames = opTypeNames.Append(pkgInfo.TypeOf(arg))
    61  	}
    62  	return opTypeNames
    63  }
    64  
    65  func (scanner *RouterScanner) init() {
    66  	for _, pkgInfo := range scanner.program.AllPackages {
    67  		for ident, obj := range pkgInfo.Defs {
    68  			if typeVar, ok := obj.(*types.Var); ok {
    69  				if typeVar != nil && !strings.HasSuffix(typeVar.Pkg().Path(), courierPkgImportPath) {
    70  					if isRouterType(typeVar.Type()) {
    71  						router := NewRouter()
    72  
    73  						ast.Inspect(ident.Obj.Decl.(ast.Node), func(node ast.Node) bool {
    74  							switch node.(type) {
    75  							case *ast.CallExpr:
    76  								callExpr := node.(*ast.CallExpr)
    77  								router.AppendOperators(FromArgs(pkgInfo, callExpr.Args...)...)
    78  								return false
    79  							}
    80  							return true
    81  						})
    82  
    83  						scanner.routers[typeVar] = router
    84  					}
    85  				}
    86  			}
    87  		}
    88  	}
    89  
    90  	for _, pkgInfo := range scanner.program.AllPackages {
    91  		for selectExpr, selection := range pkgInfo.Selections {
    92  			if selection.Obj() != nil {
    93  				if typeFunc, ok := selection.Obj().(*types.Func); ok {
    94  					recv := typeFunc.Type().(*types.Signature).Recv()
    95  					if recv != nil && isRouterType(recv.Type()) {
    96  						for typeVar, router := range scanner.routers {
    97  							switch selectExpr.Sel.Name {
    98  							case "Register":
    99  								if typeVar == pkgInfo.ObjectOf(IdentOfCallExprSelectExpr(selectExpr)) {
   100  									file := loaderx.FileOf(selectExpr, pkgInfo.Files...)
   101  									ast.Inspect(file, func(node ast.Node) bool {
   102  										switch node.(type) {
   103  										case *ast.CallExpr:
   104  											callExpr := node.(*ast.CallExpr)
   105  											if callExpr.Fun == selectExpr {
   106  												routerIdent := callExpr.Args[0]
   107  												switch routerIdent.(type) {
   108  												case *ast.Ident:
   109  													argTypeVar := pkgInfo.ObjectOf(routerIdent.(*ast.Ident)).(*types.Var)
   110  													if r, ok := scanner.routers[argTypeVar]; ok {
   111  														router.Register(r)
   112  													}
   113  												case *ast.SelectorExpr:
   114  													argTypeVar := pkgInfo.ObjectOf(routerIdent.(*ast.SelectorExpr).Sel).(*types.Var)
   115  													if r, ok := scanner.routers[argTypeVar]; ok {
   116  														router.Register(r)
   117  													}
   118  												case *ast.CallExpr:
   119  													callExprForRegister := routerIdent.(*ast.CallExpr)
   120  													router.With(FromArgs(pkgInfo, callExprForRegister.Args...)...)
   121  												}
   122  												return false
   123  											}
   124  										}
   125  										return true
   126  									})
   127  								}
   128  							}
   129  						}
   130  					}
   131  				}
   132  			}
   133  		}
   134  	}
   135  }
   136  
   137  func IdentOfCallExprSelectExpr(selectExpr *ast.SelectorExpr) *ast.Ident {
   138  	switch selectExpr.X.(type) {
   139  	case *ast.Ident:
   140  		return selectExpr.X.(*ast.Ident)
   141  	case *ast.SelectorExpr:
   142  		return selectExpr.X.(*ast.SelectorExpr).Sel
   143  	}
   144  	return nil
   145  }
   146  
   147  func NewRouter(operators ...*types.TypeName) *Router {
   148  	return &Router{
   149  		operators: operators,
   150  	}
   151  }
   152  
   153  type Router struct {
   154  	parent    *Router
   155  	operators []*types.TypeName
   156  	children  map[*Router]bool
   157  }
   158  
   159  func (router *Router) AppendOperators(operators ...*types.TypeName) {
   160  	router.operators = append(router.operators, operators...)
   161  }
   162  
   163  func (router *Router) With(operators ...*types.TypeName) {
   164  	router.Register(NewRouter(operators...))
   165  }
   166  
   167  func (router *Router) Register(r *Router) {
   168  	if router.children == nil {
   169  		router.children = map[*Router]bool{}
   170  	}
   171  	r.parent = router
   172  	router.children[r] = true
   173  }
   174  
   175  func (router *Router) Route(program *loader.Program) *Route {
   176  	parent := router.parent
   177  	operators := router.operators
   178  
   179  	for parent != nil {
   180  		operators = append(parent.operators, operators...)
   181  		parent = parent.parent
   182  	}
   183  
   184  	route := Route{
   185  		last:      router.children == nil,
   186  		operators: operators,
   187  	}
   188  
   189  	route.SetMethod(program)
   190  	route.SetPath(program)
   191  
   192  	return &route
   193  }
   194  
   195  func (router *Router) Routes(program *loader.Program) (routes []*Route) {
   196  	for child := range router.children {
   197  		route := child.Route(program)
   198  		if route.last {
   199  			routes = append(routes, route)
   200  		}
   201  		if child.children != nil {
   202  			routes = append(routes, child.Routes(program)...)
   203  		}
   204  	}
   205  	return routes
   206  }
   207  
   208  type Route struct {
   209  	Method    string
   210  	Path      string
   211  	last      bool
   212  	operators []*types.TypeName
   213  }
   214  
   215  func (route *Route) SetPath(program *loader.Program) {
   216  	p := "/"
   217  	for _, operator := range route.operators {
   218  		typeFunc := loaderx.MethodOf(operator.Type().(*types.Named), "Path")
   219  
   220  		loaderx.ForEachFuncResult(program, typeFunc, func(resultTypeAndValues ...types.TypeAndValue) {
   221  			if resultTypeAndValues[0].IsValue() {
   222  				p += getConstVal(resultTypeAndValues[0].Value).(string)
   223  			}
   224  		})
   225  	}
   226  	route.Path = httprouter.CleanPath(p)
   227  }
   228  
   229  func (route *Route) SetMethod(program *loader.Program) {
   230  	if len(route.operators) > 0 {
   231  		operator := route.operators[len(route.operators)-1]
   232  		typeFunc := loaderx.MethodOf(operator.Type().(*types.Named), "Method")
   233  
   234  		loaderx.ForEachFuncResult(program, typeFunc, func(resultTypeAndValues ...types.TypeAndValue) {
   235  			if resultTypeAndValues[0].IsValue() {
   236  				route.Method = getConstVal(resultTypeAndValues[0].Value).(string)
   237  			}
   238  		})
   239  	}
   240  }