gitee.com/sasukebo/go-micro/v4@v4.7.1/cmd/protoc-gen-micro/plugin/micro/micro.go (about)

     1  package micro
     2  
     3  import (
     4  	"fmt"
     5  	"path"
     6  	"strconv"
     7  	"strings"
     8  
     9  	"gitee.com/sasukebo/go-micro/v4/cmd/protoc-gen-micro/generator"
    10  	options "google.golang.org/genproto/googleapis/api/annotations"
    11  	"google.golang.org/protobuf/proto"
    12  	pb "google.golang.org/protobuf/types/descriptorpb"
    13  )
    14  
    15  // Paths for packages used by code generated in this file,
    16  // relative to the import_prefix of the generator.Generator.
    17  const (
    18  	apiPkgPath     = "gitee.com/sasukebo/go-micro/v4/api"
    19  	contextPkgPath = "context"
    20  	clientPkgPath  = "gitee.com/sasukebo/go-micro/v4/client"
    21  	serverPkgPath  = "gitee.com/sasukebo/go-micro/v4/server"
    22  )
    23  
    24  func init() {
    25  	generator.RegisterPlugin(new(micro))
    26  }
    27  
    28  // micro is an implementation of the Go protocol buffer compiler's
    29  // plugin architecture.  It generates bindings for go-micro support.
    30  type micro struct {
    31  	gen *generator.Generator
    32  }
    33  
    34  // Name returns the name of this plugin, "micro".
    35  func (g *micro) Name() string {
    36  	return "micro"
    37  }
    38  
    39  // The names for packages imported in the generated code.
    40  // They may vary from the final path component of the import path
    41  // if the name is used by other packages.
    42  var (
    43  	apiPkg     string
    44  	contextPkg string
    45  	clientPkg  string
    46  	serverPkg  string
    47  	pkgImports map[generator.GoPackageName]bool
    48  )
    49  
    50  // Init initializes the plugin.
    51  func (g *micro) Init(gen *generator.Generator) {
    52  	g.gen = gen
    53  	apiPkg = generator.RegisterUniquePackageName("api", nil)
    54  	contextPkg = generator.RegisterUniquePackageName("context", nil)
    55  	clientPkg = generator.RegisterUniquePackageName("client", nil)
    56  	serverPkg = generator.RegisterUniquePackageName("server", nil)
    57  }
    58  
    59  // Given a type name defined in a .proto, return its object.
    60  // Also record that we're using it, to guarantee the associated import.
    61  func (g *micro) objectNamed(name string) generator.Object {
    62  	g.gen.RecordTypeUse(name)
    63  	return g.gen.ObjectNamed(name)
    64  }
    65  
    66  // Given a type name defined in a .proto, return its name as we will print it.
    67  func (g *micro) typeName(str string) string {
    68  	return g.gen.TypeName(g.objectNamed(str))
    69  }
    70  
    71  // P forwards to g.gen.P.
    72  func (g *micro) P(args ...interface{}) { g.gen.P(args...) }
    73  
    74  // Generate generates code for the services in the given file.
    75  func (g *micro) Generate(file *generator.FileDescriptor) {
    76  	if len(file.FileDescriptorProto.Service) == 0 {
    77  		return
    78  	}
    79  	g.P("// Reference imports to suppress errors if they are not otherwise used.")
    80  	g.P("var _ ", apiPkg, ".Endpoint")
    81  	g.P("var _ ", contextPkg, ".Context")
    82  	g.P("var _ ", clientPkg, ".Option")
    83  	g.P("var _ ", serverPkg, ".Option")
    84  	g.P()
    85  
    86  	for i, service := range file.FileDescriptorProto.Service {
    87  		g.generateService(file, service, i)
    88  	}
    89  }
    90  
    91  // GenerateImports generates the import declaration for this file.
    92  func (g *micro) GenerateImports(file *generator.FileDescriptor, imports map[generator.GoImportPath]generator.GoPackageName) {
    93  	if len(file.FileDescriptorProto.Service) == 0 {
    94  		return
    95  	}
    96  	g.P("import (")
    97  	g.P(apiPkg, " ", strconv.Quote(path.Join(g.gen.ImportPrefix, apiPkgPath)))
    98  	g.P(contextPkg, " ", strconv.Quote(path.Join(g.gen.ImportPrefix, contextPkgPath)))
    99  	g.P(clientPkg, " ", strconv.Quote(path.Join(g.gen.ImportPrefix, clientPkgPath)))
   100  	g.P(serverPkg, " ", strconv.Quote(path.Join(g.gen.ImportPrefix, serverPkgPath)))
   101  	g.P(")")
   102  	g.P()
   103  
   104  	// We need to keep track of imported packages to make sure we don't produce
   105  	// a name collision when generating types.
   106  	pkgImports = make(map[generator.GoPackageName]bool)
   107  	for _, name := range imports {
   108  		pkgImports[name] = true
   109  	}
   110  }
   111  
   112  // reservedClientName records whether a client name is reserved on the client side.
   113  var reservedClientName = map[string]bool{
   114  	// TODO: do we need any in go-micro?
   115  }
   116  
   117  func unexport(s string) string {
   118  	if len(s) == 0 {
   119  		return ""
   120  	}
   121  	name := strings.ToLower(s[:1]) + s[1:]
   122  	if pkgImports[generator.GoPackageName(name)] {
   123  		return name + "_"
   124  	}
   125  	return name
   126  }
   127  
   128  // generateService generates all the code for the named service.
   129  func (g *micro) generateService(file *generator.FileDescriptor, service *pb.ServiceDescriptorProto, index int) {
   130  	path := fmt.Sprintf("6,%d", index) // 6 means service.
   131  
   132  	origServName := service.GetName()
   133  	serviceName := strings.ToLower(service.GetName())
   134  	pkg := file.GetPackage()
   135  	if pkg != "" {
   136  		serviceName = pkg
   137  	}
   138  	servName := generator.CamelCase(origServName)
   139  	servAlias := servName + "Service"
   140  
   141  	// strip suffix
   142  	if strings.HasSuffix(servAlias, "ServiceService") {
   143  		servAlias = strings.TrimSuffix(servAlias, "Service")
   144  	}
   145  
   146  	g.P()
   147  	g.P("// Api Endpoints for ", servName, " service")
   148  	g.P()
   149  
   150  	g.P("func New", servName, "Endpoints () []*", apiPkg, ".Endpoint {")
   151  	g.P("return []*", apiPkg, ".Endpoint{")
   152  	for _, method := range service.Method {
   153  		if method.Options != nil && proto.HasExtension(method.Options, options.E_Http) {
   154  			g.P("{")
   155  			g.generateEndpoint(servName, method)
   156  			g.P("},")
   157  		}
   158  	}
   159  	g.P("}")
   160  	g.P("}")
   161  	g.P()
   162  
   163  	g.P()
   164  	g.P("// Client API for ", servName, " service")
   165  	g.P()
   166  
   167  	// Client interface.
   168  	g.P("type ", servAlias, " interface {")
   169  	for i, method := range service.Method {
   170  		g.gen.PrintComments(fmt.Sprintf("%s,2,%d", path, i)) // 2 means method in a service.
   171  		g.P(g.generateClientSignature(servName, method))
   172  	}
   173  	g.P("}")
   174  	g.P()
   175  
   176  	// Client structure.
   177  	g.P("type ", unexport(servAlias), " struct {")
   178  	g.P("c ", clientPkg, ".Client")
   179  	g.P("name string")
   180  	g.P("}")
   181  	g.P()
   182  
   183  	// NewClient factory.
   184  	g.P("func New", servAlias, " (name string, c ", clientPkg, ".Client) ", servAlias, " {")
   185  	/*
   186  		g.P("if c == nil {")
   187  		g.P("c = ", clientPkg, ".NewClient()")
   188  		g.P("}")
   189  		g.P("if len(name) == 0 {")
   190  		g.P(`name = "`, serviceName, `"`)
   191  		g.P("}")
   192  	*/
   193  	g.P("return &", unexport(servAlias), "{")
   194  	g.P("c: c,")
   195  	g.P("name: name,")
   196  	g.P("}")
   197  	g.P("}")
   198  	g.P()
   199  	var methodIndex, streamIndex int
   200  	serviceDescVar := "_" + servName + "_serviceDesc"
   201  	// Client method implementations.
   202  	for _, method := range service.Method {
   203  		var descExpr string
   204  		if !method.GetServerStreaming() {
   205  			// Unary RPC method
   206  			descExpr = fmt.Sprintf("&%s.Methods[%d]", serviceDescVar, methodIndex)
   207  			methodIndex++
   208  		} else {
   209  			// Streaming RPC method
   210  			descExpr = fmt.Sprintf("&%s.Streams[%d]", serviceDescVar, streamIndex)
   211  			streamIndex++
   212  		}
   213  		g.generateClientMethod(pkg, serviceName, servName, serviceDescVar, method, descExpr)
   214  	}
   215  
   216  	g.P("// Server API for ", servName, " service")
   217  	g.P()
   218  
   219  	// Server interface.
   220  	serverType := servName + "Handler"
   221  	g.P("type ", serverType, " interface {")
   222  	for i, method := range service.Method {
   223  		g.gen.PrintComments(fmt.Sprintf("%s,2,%d", path, i)) // 2 means method in a service.
   224  		g.P(g.generateServerSignature(servName, method))
   225  	}
   226  	g.P("}")
   227  	g.P()
   228  
   229  	// Server registration.
   230  	g.P("func Register", servName, "Handler(s ", serverPkg, ".Server, hdlr ", serverType, ", opts ...", serverPkg, ".HandlerOption) error {")
   231  	g.P("type ", unexport(servName), " interface {")
   232  
   233  	// generate interface methods
   234  	for _, method := range service.Method {
   235  		methName := generator.CamelCase(method.GetName())
   236  		inType := g.typeName(method.GetInputType())
   237  		outType := g.typeName(method.GetOutputType())
   238  
   239  		if !method.GetServerStreaming() && !method.GetClientStreaming() {
   240  			g.P(methName, "(ctx ", contextPkg, ".Context, in *", inType, ", out *", outType, ") error")
   241  			continue
   242  		}
   243  		g.P(methName, "(ctx ", contextPkg, ".Context, stream server.Stream) error")
   244  	}
   245  	g.P("}")
   246  	g.P("type ", servName, " struct {")
   247  	g.P(unexport(servName))
   248  	g.P("}")
   249  	g.P("h := &", unexport(servName), "Handler{hdlr}")
   250  	for _, method := range service.Method {
   251  		if method.Options != nil && proto.HasExtension(method.Options, options.E_Http) {
   252  			g.P("opts = append(opts, ", apiPkg, ".WithEndpoint(&", apiPkg, ".Endpoint{")
   253  			g.generateEndpoint(servName, method)
   254  			g.P("}))")
   255  		}
   256  	}
   257  	g.P("return s.Handle(s.NewHandler(&", servName, "{h}, opts...))")
   258  	g.P("}")
   259  	g.P()
   260  
   261  	g.P("type ", unexport(servName), "Handler struct {")
   262  	g.P(serverType)
   263  	g.P("}")
   264  
   265  	// Server handler implementations.
   266  	var handlerNames []string
   267  	for _, method := range service.Method {
   268  		hname := g.generateServerMethod(servName, method)
   269  		handlerNames = append(handlerNames, hname)
   270  	}
   271  }
   272  
   273  // generateEndpoint creates the api endpoint
   274  func (g *micro) generateEndpoint(servName string, method *pb.MethodDescriptorProto) {
   275  	if method.Options == nil || !proto.HasExtension(method.Options, options.E_Http) {
   276  		return
   277  	}
   278  	// http rules
   279  	r := proto.GetExtension(method.Options, options.E_Http)
   280  	rule := r.(*options.HttpRule)
   281  	var meth string
   282  	var path string
   283  	switch {
   284  	case len(rule.GetDelete()) > 0:
   285  		meth = "DELETE"
   286  		path = rule.GetDelete()
   287  	case len(rule.GetGet()) > 0:
   288  		meth = "GET"
   289  		path = rule.GetGet()
   290  	case len(rule.GetPatch()) > 0:
   291  		meth = "PATCH"
   292  		path = rule.GetPatch()
   293  	case len(rule.GetPost()) > 0:
   294  		meth = "POST"
   295  		path = rule.GetPost()
   296  	case len(rule.GetPut()) > 0:
   297  		meth = "PUT"
   298  		path = rule.GetPut()
   299  	}
   300  	if len(meth) == 0 || len(path) == 0 {
   301  		return
   302  	}
   303  	// TODO: process additional bindings
   304  	g.P("Name:", fmt.Sprintf(`"%s.%s",`, servName, method.GetName()))
   305  	g.P("Path:", fmt.Sprintf(`[]string{"%s"},`, path))
   306  	g.P("Method:", fmt.Sprintf(`[]string{"%s"},`, meth))
   307  	if len(rule.GetGet()) == 0 {
   308  		g.P("Body:", fmt.Sprintf(`"%s",`, rule.GetBody()))
   309  	}
   310  	if method.GetServerStreaming() || method.GetClientStreaming() {
   311  		g.P("Stream: true,")
   312  	}
   313  	g.P(`Handler: "rpc",`)
   314  }
   315  
   316  // generateClientSignature returns the client-side signature for a method.
   317  func (g *micro) generateClientSignature(servName string, method *pb.MethodDescriptorProto) string {
   318  	origMethName := method.GetName()
   319  	methName := generator.CamelCase(origMethName)
   320  	if reservedClientName[methName] {
   321  		methName += "_"
   322  	}
   323  	reqArg := ", in *" + g.typeName(method.GetInputType())
   324  	if method.GetClientStreaming() {
   325  		reqArg = ""
   326  	}
   327  	respName := "*" + g.typeName(method.GetOutputType())
   328  	if method.GetServerStreaming() || method.GetClientStreaming() {
   329  		respName = servName + "_" + generator.CamelCase(origMethName) + "Service"
   330  	}
   331  
   332  	return fmt.Sprintf("%s(ctx %s.Context%s, opts ...%s.CallOption) (%s, error)", methName, contextPkg, reqArg, clientPkg, respName)
   333  }
   334  
   335  func (g *micro) generateClientMethod(pkg, reqServ, servName, serviceDescVar string, method *pb.MethodDescriptorProto, descExpr string) {
   336  	reqMethod := fmt.Sprintf("%s.%s", servName, method.GetName())
   337  	useGrpc := g.gen.Param["use_grpc"]
   338  	if useGrpc != "" {
   339  		reqMethod = fmt.Sprintf("/%s.%s/%s", pkg, servName, method.GetName())
   340  	}
   341  	methName := generator.CamelCase(method.GetName())
   342  	inType := g.typeName(method.GetInputType())
   343  	outType := g.typeName(method.GetOutputType())
   344  
   345  	servAlias := servName + "Service"
   346  
   347  	// strip suffix
   348  	if strings.HasSuffix(servAlias, "ServiceService") {
   349  		servAlias = strings.TrimSuffix(servAlias, "Service")
   350  	}
   351  
   352  	g.P("func (c *", unexport(servAlias), ") ", g.generateClientSignature(servName, method), "{")
   353  	if !method.GetServerStreaming() && !method.GetClientStreaming() {
   354  		g.P(`req := c.c.NewRequest(c.name, "`, reqMethod, `", in)`)
   355  		g.P("out := new(", outType, ")")
   356  		// TODO: Pass descExpr to Invoke.
   357  		g.P("err := ", `c.c.Call(ctx, req, out, opts...)`)
   358  		g.P("if err != nil { return nil, err }")
   359  		g.P("return out, nil")
   360  		g.P("}")
   361  		g.P()
   362  		return
   363  	}
   364  	streamType := unexport(servAlias) + methName
   365  	g.P(`req := c.c.NewRequest(c.name, "`, reqMethod, `", &`, inType, `{})`)
   366  	g.P("stream, err := c.c.Stream(ctx, req, opts...)")
   367  	g.P("if err != nil { return nil, err }")
   368  
   369  	if !method.GetClientStreaming() {
   370  		g.P("if err := stream.Send(in); err != nil { return nil, err }")
   371  		// TODO: currently only grpc support CloseSend
   372  		// g.P("if err := stream.CloseSend(); err != nil { return nil, err }")
   373  	}
   374  
   375  	g.P("return &", streamType, "{stream}, nil")
   376  	g.P("}")
   377  	g.P()
   378  
   379  	genSend := method.GetClientStreaming()
   380  	genRecv := method.GetServerStreaming()
   381  
   382  	// Stream auxiliary types and methods.
   383  	g.P("type ", servName, "_", methName, "Service interface {")
   384  	g.P("Context() context.Context")
   385  	g.P("SendMsg(interface{}) error")
   386  	g.P("RecvMsg(interface{}) error")
   387  	g.P("CloseSend() error")
   388  	g.P("Close() error")
   389  
   390  	if genSend {
   391  		g.P("Send(*", inType, ") error")
   392  	}
   393  	if genRecv {
   394  		g.P("Recv() (*", outType, ", error)")
   395  	}
   396  	g.P("}")
   397  	g.P()
   398  
   399  	g.P("type ", streamType, " struct {")
   400  	g.P("stream ", clientPkg, ".Stream")
   401  	g.P("}")
   402  	g.P()
   403  
   404  	g.P("func (x *", streamType, ") CloseSend() error {")
   405  	g.P("return x.stream.CloseSend()")
   406  	g.P("}")
   407  	g.P()
   408  
   409  	g.P("func (x *", streamType, ") Close() error {")
   410  	g.P("return x.stream.Close()")
   411  	g.P("}")
   412  	g.P()
   413  
   414  	g.P("func (x *", streamType, ") Context() context.Context {")
   415  	g.P("return x.stream.Context()")
   416  	g.P("}")
   417  	g.P()
   418  
   419  	g.P("func (x *", streamType, ") SendMsg(m interface{}) error {")
   420  	g.P("return x.stream.Send(m)")
   421  	g.P("}")
   422  	g.P()
   423  
   424  	g.P("func (x *", streamType, ") RecvMsg(m interface{}) error {")
   425  	g.P("return x.stream.Recv(m)")
   426  	g.P("}")
   427  	g.P()
   428  
   429  	if genSend {
   430  		g.P("func (x *", streamType, ") Send(m *", inType, ") error {")
   431  		g.P("return x.stream.Send(m)")
   432  		g.P("}")
   433  		g.P()
   434  
   435  	}
   436  
   437  	if genRecv {
   438  		g.P("func (x *", streamType, ") Recv() (*", outType, ", error) {")
   439  		g.P("m := new(", outType, ")")
   440  		g.P("err := x.stream.Recv(m)")
   441  		g.P("if err != nil {")
   442  		g.P("return nil, err")
   443  		g.P("}")
   444  		g.P("return m, nil")
   445  		g.P("}")
   446  		g.P()
   447  	}
   448  }
   449  
   450  // generateServerSignature returns the server-side signature for a method.
   451  func (g *micro) generateServerSignature(servName string, method *pb.MethodDescriptorProto) string {
   452  	origMethName := method.GetName()
   453  	methName := generator.CamelCase(origMethName)
   454  	if reservedClientName[methName] {
   455  		methName += "_"
   456  	}
   457  
   458  	var reqArgs []string
   459  	ret := "error"
   460  	reqArgs = append(reqArgs, contextPkg+".Context")
   461  
   462  	if !method.GetClientStreaming() {
   463  		reqArgs = append(reqArgs, "*"+g.typeName(method.GetInputType()))
   464  	}
   465  	if method.GetServerStreaming() || method.GetClientStreaming() {
   466  		reqArgs = append(reqArgs, servName+"_"+generator.CamelCase(origMethName)+"Stream")
   467  	}
   468  	if !method.GetClientStreaming() && !method.GetServerStreaming() {
   469  		reqArgs = append(reqArgs, "*"+g.typeName(method.GetOutputType()))
   470  	}
   471  	return methName + "(" + strings.Join(reqArgs, ", ") + ") " + ret
   472  }
   473  
   474  func (g *micro) generateServerMethod(servName string, method *pb.MethodDescriptorProto) string {
   475  	methName := generator.CamelCase(method.GetName())
   476  	hname := fmt.Sprintf("_%s_%s_Handler", servName, methName)
   477  	serveType := servName + "Handler"
   478  	inType := g.typeName(method.GetInputType())
   479  	outType := g.typeName(method.GetOutputType())
   480  
   481  	if !method.GetServerStreaming() && !method.GetClientStreaming() {
   482  		g.P("func (h *", unexport(servName), "Handler) ", methName, "(ctx ", contextPkg, ".Context, in *", inType, ", out *", outType, ") error {")
   483  		g.P("return h.", serveType, ".", methName, "(ctx, in, out)")
   484  		g.P("}")
   485  		g.P()
   486  		return hname
   487  	}
   488  	streamType := unexport(servName) + methName + "Stream"
   489  	g.P("func (h *", unexport(servName), "Handler) ", methName, "(ctx ", contextPkg, ".Context, stream server.Stream) error {")
   490  	if !method.GetClientStreaming() {
   491  		g.P("m := new(", inType, ")")
   492  		g.P("if err := stream.Recv(m); err != nil { return err }")
   493  		g.P("return h.", serveType, ".", methName, "(ctx, m, &", streamType, "{stream})")
   494  	} else {
   495  		g.P("return h.", serveType, ".", methName, "(ctx, &", streamType, "{stream})")
   496  	}
   497  	g.P("}")
   498  	g.P()
   499  
   500  	genSend := method.GetServerStreaming()
   501  	genRecv := method.GetClientStreaming()
   502  
   503  	// Stream auxiliary types and methods.
   504  	g.P("type ", servName, "_", methName, "Stream interface {")
   505  	g.P("Context() context.Context")
   506  	g.P("SendMsg(interface{}) error")
   507  	g.P("RecvMsg(interface{}) error")
   508  	g.P("Close() error")
   509  
   510  	if genSend {
   511  		g.P("Send(*", outType, ") error")
   512  	}
   513  
   514  	if genRecv {
   515  		g.P("Recv() (*", inType, ", error)")
   516  	}
   517  
   518  	g.P("}")
   519  	g.P()
   520  
   521  	g.P("type ", streamType, " struct {")
   522  	g.P("stream ", serverPkg, ".Stream")
   523  	g.P("}")
   524  	g.P()
   525  
   526  	g.P("func (x *", streamType, ") Close() error {")
   527  	g.P("return x.stream.Close()")
   528  	g.P("}")
   529  	g.P()
   530  
   531  	g.P("func (x *", streamType, ") Context() context.Context {")
   532  	g.P("return x.stream.Context()")
   533  	g.P("}")
   534  	g.P()
   535  
   536  	g.P("func (x *", streamType, ") SendMsg(m interface{}) error {")
   537  	g.P("return x.stream.Send(m)")
   538  	g.P("}")
   539  	g.P()
   540  
   541  	g.P("func (x *", streamType, ") RecvMsg(m interface{}) error {")
   542  	g.P("return x.stream.Recv(m)")
   543  	g.P("}")
   544  	g.P()
   545  
   546  	if genSend {
   547  		g.P("func (x *", streamType, ") Send(m *", outType, ") error {")
   548  		g.P("return x.stream.Send(m)")
   549  		g.P("}")
   550  		g.P()
   551  	}
   552  
   553  	if genRecv {
   554  		g.P("func (x *", streamType, ") Recv() (*", inType, ", error) {")
   555  		g.P("m := new(", inType, ")")
   556  		g.P("if err := x.stream.Recv(m); err != nil { return nil, err }")
   557  		g.P("return m, nil")
   558  		g.P("}")
   559  		g.P()
   560  	}
   561  
   562  	return hname
   563  }