go-micro.dev/v5@v5.12.0/cmd/protoc-gen-micro/plugin/micro/micro.go (about)

     1  package micro
     2  
     3  import (
     4  	"fmt"
     5  	"path"
     6  	"strconv"
     7  	"strings"
     8  
     9  	"go-micro.dev/v5/cmd/protoc-gen-micro/generator"
    10  	options "google.golang.org/genproto/googleapis/api/annotations"
    11  	"google.golang.org/protobuf/proto"
    12  	pb "google.golang.org/protobuf/types/descriptorpb"
    13  )
    14  
    15  // Paths for packages used by code generated in this file,
    16  // relative to the import_prefix of the generator.Generator.
    17  const (
    18  	contextPkgPath = "context"
    19  	clientPkgPath  = "go-micro.dev/v5/client"
    20  	serverPkgPath  = "go-micro.dev/v5/server"
    21  )
    22  
    23  func init() {
    24  	generator.RegisterPlugin(new(micro))
    25  }
    26  
    27  // micro is an implementation of the Go protocol buffer compiler's
    28  // plugin architecture.  It generates bindings for go-micro support.
    29  type micro struct {
    30  	gen *generator.Generator
    31  }
    32  
    33  // Name returns the name of this plugin, "micro".
    34  func (g *micro) Name() string {
    35  	return "micro"
    36  }
    37  
    38  // The names for packages imported in the generated code.
    39  // They may vary from the final path component of the import path
    40  // if the name is used by other packages.
    41  var (
    42  	contextPkg string
    43  	clientPkg  string
    44  	serverPkg  string
    45  	pkgImports map[generator.GoPackageName]bool
    46  )
    47  
    48  // Init initializes the plugin.
    49  func (g *micro) Init(gen *generator.Generator) {
    50  	g.gen = gen
    51  	contextPkg = generator.RegisterUniquePackageName("context", nil)
    52  	clientPkg = generator.RegisterUniquePackageName("client", nil)
    53  	serverPkg = generator.RegisterUniquePackageName("server", nil)
    54  }
    55  
    56  // Given a type name defined in a .proto, return its object.
    57  // Also record that we're using it, to guarantee the associated import.
    58  func (g *micro) objectNamed(name string) generator.Object {
    59  	g.gen.RecordTypeUse(name)
    60  	return g.gen.ObjectNamed(name)
    61  }
    62  
    63  // Given a type name defined in a .proto, return its name as we will print it.
    64  func (g *micro) typeName(str string) string {
    65  	return g.gen.TypeName(g.objectNamed(str))
    66  }
    67  
    68  // P forwards to g.gen.P.
    69  func (g *micro) P(args ...interface{}) { g.gen.P(args...) }
    70  
    71  // Generate generates code for the services in the given file.
    72  func (g *micro) Generate(file *generator.FileDescriptor) {
    73  	if len(file.FileDescriptorProto.Service) == 0 {
    74  		return
    75  	}
    76  	g.P("// Reference imports to suppress errors if they are not otherwise used.")
    77  	g.P("var _ ", contextPkg, ".Context")
    78  	g.P("var _ ", clientPkg, ".Option")
    79  	g.P("var _ ", serverPkg, ".Option")
    80  	g.P()
    81  
    82  	for i, service := range file.FileDescriptorProto.Service {
    83  		g.generateService(file, service, i)
    84  	}
    85  }
    86  
    87  // GenerateImports generates the import declaration for this file.
    88  func (g *micro) GenerateImports(file *generator.FileDescriptor, imports map[generator.GoImportPath]generator.GoPackageName) {
    89  	if len(file.FileDescriptorProto.Service) == 0 {
    90  		return
    91  	}
    92  	g.P("import (")
    93  	g.P(contextPkg, " ", strconv.Quote(path.Join(g.gen.ImportPrefix, contextPkgPath)))
    94  	g.P(clientPkg, " ", strconv.Quote(path.Join(g.gen.ImportPrefix, clientPkgPath)))
    95  	g.P(serverPkg, " ", strconv.Quote(path.Join(g.gen.ImportPrefix, serverPkgPath)))
    96  	g.P(")")
    97  	g.P()
    98  
    99  	// We need to keep track of imported packages to make sure we don't produce
   100  	// a name collision when generating types.
   101  	pkgImports = make(map[generator.GoPackageName]bool)
   102  	for _, name := range imports {
   103  		pkgImports[name] = true
   104  	}
   105  }
   106  
   107  // reservedClientName records whether a client name is reserved on the client side.
   108  var reservedClientName = map[string]bool{
   109  	// TODO: do we need any in go-micro?
   110  }
   111  
   112  func unexport(s string) string {
   113  	if len(s) == 0 {
   114  		return ""
   115  	}
   116  	name := strings.ToLower(s[:1]) + s[1:]
   117  	if pkgImports[generator.GoPackageName(name)] {
   118  		return name + "_"
   119  	}
   120  	return name
   121  }
   122  
   123  // generateService generates all the code for the named service.
   124  func (g *micro) generateService(file *generator.FileDescriptor, service *pb.ServiceDescriptorProto, index int) {
   125  	path := fmt.Sprintf("6,%d", index) // 6 means service.
   126  
   127  	origServName := service.GetName()
   128  	serviceName := strings.ToLower(service.GetName())
   129  	pkg := file.GetPackage()
   130  	if pkg != "" {
   131  		serviceName = pkg
   132  	}
   133  	servName := generator.CamelCase(origServName)
   134  	servAlias := servName + "Service"
   135  
   136  	// strip suffix
   137  	if strings.HasSuffix(servAlias, "ServiceService") {
   138  		servAlias = strings.TrimSuffix(servAlias, "Service")
   139  	}
   140  
   141  	g.P()
   142  	g.P("// Client API for ", servName, " service")
   143  	g.P()
   144  
   145  	// Client interface.
   146  	g.P("type ", servAlias, " interface {")
   147  	for i, method := range service.Method {
   148  		g.gen.PrintComments(fmt.Sprintf("%s,2,%d", path, i)) // 2 means method in a service.
   149  		g.P(g.generateClientSignature(servName, method))
   150  	}
   151  	g.P("}")
   152  	g.P()
   153  
   154  	// Client structure.
   155  	g.P("type ", unexport(servAlias), " struct {")
   156  	g.P("c ", clientPkg, ".Client")
   157  	g.P("name string")
   158  	g.P("}")
   159  	g.P()
   160  
   161  	// NewClient factory.
   162  	g.P("func New", servAlias, " (name string, c ", clientPkg, ".Client) ", servAlias, " {")
   163  	/*
   164  		g.P("if c == nil {")
   165  		g.P("c = ", clientPkg, ".NewClient()")
   166  		g.P("}")
   167  		g.P("if len(name) == 0 {")
   168  		g.P(`name = "`, serviceName, `"`)
   169  		g.P("}")
   170  	*/
   171  	g.P("return &", unexport(servAlias), "{")
   172  	g.P("c: c,")
   173  	g.P("name: name,")
   174  	g.P("}")
   175  	g.P("}")
   176  	g.P()
   177  	var methodIndex, streamIndex int
   178  	serviceDescVar := "_" + servName + "_serviceDesc"
   179  	// Client method implementations.
   180  	for _, method := range service.Method {
   181  		var descExpr string
   182  		if !method.GetServerStreaming() {
   183  			// Unary RPC method
   184  			descExpr = fmt.Sprintf("&%s.Methods[%d]", serviceDescVar, methodIndex)
   185  			methodIndex++
   186  		} else {
   187  			// Streaming RPC method
   188  			descExpr = fmt.Sprintf("&%s.Streams[%d]", serviceDescVar, streamIndex)
   189  			streamIndex++
   190  		}
   191  		g.generateClientMethod(pkg, serviceName, servName, serviceDescVar, method, descExpr)
   192  	}
   193  
   194  	g.P("// Server API for ", servName, " service")
   195  	g.P()
   196  
   197  	// Server interface.
   198  	serverType := servName + "Handler"
   199  	g.P("type ", serverType, " interface {")
   200  	for i, method := range service.Method {
   201  		g.gen.PrintComments(fmt.Sprintf("%s,2,%d", path, i)) // 2 means method in a service.
   202  		g.P(g.generateServerSignature(servName, method))
   203  	}
   204  	g.P("}")
   205  	g.P()
   206  
   207  	// Server registration.
   208  	g.P("func Register", servName, "Handler(s ", serverPkg, ".Server, hdlr ", serverType, ", opts ...", serverPkg, ".HandlerOption) error {")
   209  	g.P("type ", unexport(servName), " interface {")
   210  
   211  	// generate interface methods
   212  	for _, method := range service.Method {
   213  		methName := generator.CamelCase(method.GetName())
   214  		inType := g.typeName(method.GetInputType())
   215  		outType := g.typeName(method.GetOutputType())
   216  
   217  		if !method.GetServerStreaming() && !method.GetClientStreaming() {
   218  			g.P(methName, "(ctx ", contextPkg, ".Context, in *", inType, ", out *", outType, ") error")
   219  			continue
   220  		}
   221  		g.P(methName, "(ctx ", contextPkg, ".Context, stream server.Stream) error")
   222  	}
   223  	g.P("}")
   224  	g.P("type ", servName, " struct {")
   225  	g.P(unexport(servName))
   226  	g.P("}")
   227  	g.P("h := &", unexport(servName), "Handler{hdlr}")
   228  	g.P("return s.Handle(s.NewHandler(&", servName, "{h}, opts...))")
   229  	g.P("}")
   230  	g.P()
   231  
   232  	g.P("type ", unexport(servName), "Handler struct {")
   233  	g.P(serverType)
   234  	g.P("}")
   235  
   236  	// Server handler implementations.
   237  	var handlerNames []string
   238  	for _, method := range service.Method {
   239  		hname := g.generateServerMethod(servName, method)
   240  		handlerNames = append(handlerNames, hname)
   241  	}
   242  }
   243  
   244  // generateEndpoint creates the api endpoint
   245  func (g *micro) generateEndpoint(servName string, method *pb.MethodDescriptorProto) {
   246  	if method.Options == nil || !proto.HasExtension(method.Options, options.E_Http) {
   247  		return
   248  	}
   249  	// http rules
   250  	r := proto.GetExtension(method.Options, options.E_Http)
   251  	rule := r.(*options.HttpRule)
   252  	var meth string
   253  	var path string
   254  	switch {
   255  	case len(rule.GetDelete()) > 0:
   256  		meth = "DELETE"
   257  		path = rule.GetDelete()
   258  	case len(rule.GetGet()) > 0:
   259  		meth = "GET"
   260  		path = rule.GetGet()
   261  	case len(rule.GetPatch()) > 0:
   262  		meth = "PATCH"
   263  		path = rule.GetPatch()
   264  	case len(rule.GetPost()) > 0:
   265  		meth = "POST"
   266  		path = rule.GetPost()
   267  	case len(rule.GetPut()) > 0:
   268  		meth = "PUT"
   269  		path = rule.GetPut()
   270  	}
   271  	if len(meth) == 0 || len(path) == 0 {
   272  		return
   273  	}
   274  	// TODO: process additional bindings
   275  	g.P("Name:", fmt.Sprintf(`"%s.%s",`, servName, method.GetName()))
   276  	g.P("Path:", fmt.Sprintf(`[]string{"%s"},`, path))
   277  	g.P("Method:", fmt.Sprintf(`[]string{"%s"},`, meth))
   278  	if method.GetServerStreaming() || method.GetClientStreaming() {
   279  		g.P("Stream: true,")
   280  	}
   281  	g.P(`Handler: "rpc",`)
   282  }
   283  
   284  // generateClientSignature returns the client-side signature for a method.
   285  func (g *micro) generateClientSignature(servName string, method *pb.MethodDescriptorProto) string {
   286  	origMethName := method.GetName()
   287  	methName := generator.CamelCase(origMethName)
   288  	if reservedClientName[methName] {
   289  		methName += "_"
   290  	}
   291  	reqArg := ", in *" + g.typeName(method.GetInputType())
   292  	if method.GetClientStreaming() {
   293  		reqArg = ""
   294  	}
   295  	respName := "*" + g.typeName(method.GetOutputType())
   296  	if method.GetServerStreaming() || method.GetClientStreaming() {
   297  		respName = servName + "_" + generator.CamelCase(origMethName) + "Service"
   298  	}
   299  
   300  	return fmt.Sprintf("%s(ctx %s.Context%s, opts ...%s.CallOption) (%s, error)", methName, contextPkg, reqArg, clientPkg, respName)
   301  }
   302  
   303  func (g *micro) generateClientMethod(pkg, reqServ, servName, serviceDescVar string, method *pb.MethodDescriptorProto, descExpr string) {
   304  	reqMethod := fmt.Sprintf("%s.%s", servName, method.GetName())
   305  	useGrpc := g.gen.Param["use_grpc"]
   306  	if useGrpc != "" {
   307  		reqMethod = fmt.Sprintf("/%s.%s/%s", pkg, servName, method.GetName())
   308  	}
   309  	methName := generator.CamelCase(method.GetName())
   310  	inType := g.typeName(method.GetInputType())
   311  	outType := g.typeName(method.GetOutputType())
   312  
   313  	servAlias := servName + "Service"
   314  
   315  	// strip suffix
   316  	if strings.HasSuffix(servAlias, "ServiceService") {
   317  		servAlias = strings.TrimSuffix(servAlias, "Service")
   318  	}
   319  
   320  	g.P("func (c *", unexport(servAlias), ") ", g.generateClientSignature(servName, method), "{")
   321  	if !method.GetServerStreaming() && !method.GetClientStreaming() {
   322  		g.P(`req := c.c.NewRequest(c.name, "`, reqMethod, `", in)`)
   323  		g.P("out := new(", outType, ")")
   324  		// TODO: Pass descExpr to Invoke.
   325  		g.P("err := ", `c.c.Call(ctx, req, out, opts...)`)
   326  		g.P("if err != nil { return nil, err }")
   327  		g.P("return out, nil")
   328  		g.P("}")
   329  		g.P()
   330  		return
   331  	}
   332  	streamType := unexport(servAlias) + methName
   333  	g.P(`req := c.c.NewRequest(c.name, "`, reqMethod, `", &`, inType, `{})`)
   334  	g.P("stream, err := c.c.Stream(ctx, req, opts...)")
   335  	g.P("if err != nil { return nil, err }")
   336  
   337  	if !method.GetClientStreaming() {
   338  		g.P("if err := stream.Send(in); err != nil { return nil, err }")
   339  		// TODO: currently only grpc support CloseSend
   340  		// g.P("if err := stream.CloseSend(); err != nil { return nil, err }")
   341  	}
   342  
   343  	g.P("return &", streamType, "{stream}, nil")
   344  	g.P("}")
   345  	g.P()
   346  
   347  	genSend := method.GetClientStreaming()
   348  	genRecv := method.GetServerStreaming()
   349  
   350  	// Stream auxiliary types and methods.
   351  	g.P("type ", servName, "_", methName, "Service interface {")
   352  	g.P("Context() context.Context")
   353  	g.P("SendMsg(interface{}) error")
   354  	g.P("RecvMsg(interface{}) error")
   355  	g.P("CloseSend() error")
   356  	g.P("Close() error")
   357  
   358  	if genSend {
   359  		g.P("Send(*", inType, ") error")
   360  	}
   361  	if genRecv {
   362  		g.P("Recv() (*", outType, ", error)")
   363  	}
   364  	g.P("}")
   365  	g.P()
   366  
   367  	g.P("type ", streamType, " struct {")
   368  	g.P("stream ", clientPkg, ".Stream")
   369  	g.P("}")
   370  	g.P()
   371  
   372  	g.P("func (x *", streamType, ") CloseSend() error {")
   373  	g.P("return x.stream.CloseSend()")
   374  	g.P("}")
   375  	g.P()
   376  
   377  	g.P("func (x *", streamType, ") Close() error {")
   378  	g.P("return x.stream.Close()")
   379  	g.P("}")
   380  	g.P()
   381  
   382  	g.P("func (x *", streamType, ") Context() context.Context {")
   383  	g.P("return x.stream.Context()")
   384  	g.P("}")
   385  	g.P()
   386  
   387  	g.P("func (x *", streamType, ") SendMsg(m interface{}) error {")
   388  	g.P("return x.stream.Send(m)")
   389  	g.P("}")
   390  	g.P()
   391  
   392  	g.P("func (x *", streamType, ") RecvMsg(m interface{}) error {")
   393  	g.P("return x.stream.Recv(m)")
   394  	g.P("}")
   395  	g.P()
   396  
   397  	if genSend {
   398  		g.P("func (x *", streamType, ") Send(m *", inType, ") error {")
   399  		g.P("return x.stream.Send(m)")
   400  		g.P("}")
   401  		g.P()
   402  
   403  	}
   404  
   405  	if genRecv {
   406  		g.P("func (x *", streamType, ") Recv() (*", outType, ", error) {")
   407  		g.P("m := new(", outType, ")")
   408  		g.P("err := x.stream.Recv(m)")
   409  		g.P("if err != nil {")
   410  		g.P("return nil, err")
   411  		g.P("}")
   412  		g.P("return m, nil")
   413  		g.P("}")
   414  		g.P()
   415  	}
   416  }
   417  
   418  // generateServerSignature returns the server-side signature for a method.
   419  func (g *micro) generateServerSignature(servName string, method *pb.MethodDescriptorProto) string {
   420  	origMethName := method.GetName()
   421  	methName := generator.CamelCase(origMethName)
   422  	if reservedClientName[methName] {
   423  		methName += "_"
   424  	}
   425  
   426  	var reqArgs []string
   427  	ret := "error"
   428  	reqArgs = append(reqArgs, contextPkg+".Context")
   429  
   430  	if !method.GetClientStreaming() {
   431  		reqArgs = append(reqArgs, "*"+g.typeName(method.GetInputType()))
   432  	}
   433  	if method.GetServerStreaming() || method.GetClientStreaming() {
   434  		reqArgs = append(reqArgs, servName+"_"+generator.CamelCase(origMethName)+"Stream")
   435  	}
   436  	if !method.GetClientStreaming() && !method.GetServerStreaming() {
   437  		reqArgs = append(reqArgs, "*"+g.typeName(method.GetOutputType()))
   438  	}
   439  	return methName + "(" + strings.Join(reqArgs, ", ") + ") " + ret
   440  }
   441  
   442  func (g *micro) generateServerMethod(servName string, method *pb.MethodDescriptorProto) string {
   443  	methName := generator.CamelCase(method.GetName())
   444  	hname := fmt.Sprintf("_%s_%s_Handler", servName, methName)
   445  	serveType := servName + "Handler"
   446  	inType := g.typeName(method.GetInputType())
   447  	outType := g.typeName(method.GetOutputType())
   448  
   449  	if !method.GetServerStreaming() && !method.GetClientStreaming() {
   450  		g.P("func (h *", unexport(servName), "Handler) ", methName, "(ctx ", contextPkg, ".Context, in *", inType, ", out *", outType, ") error {")
   451  		g.P("return h.", serveType, ".", methName, "(ctx, in, out)")
   452  		g.P("}")
   453  		g.P()
   454  		return hname
   455  	}
   456  	streamType := unexport(servName) + methName + "Stream"
   457  	g.P("func (h *", unexport(servName), "Handler) ", methName, "(ctx ", contextPkg, ".Context, stream server.Stream) error {")
   458  	if !method.GetClientStreaming() {
   459  		g.P("m := new(", inType, ")")
   460  		g.P("if err := stream.Recv(m); err != nil { return err }")
   461  		g.P("return h.", serveType, ".", methName, "(ctx, m, &", streamType, "{stream})")
   462  	} else {
   463  		g.P("return h.", serveType, ".", methName, "(ctx, &", streamType, "{stream})")
   464  	}
   465  	g.P("}")
   466  	g.P()
   467  
   468  	genSend := method.GetServerStreaming()
   469  	genRecv := method.GetClientStreaming()
   470  
   471  	// Stream auxiliary types and methods.
   472  	g.P("type ", servName, "_", methName, "Stream interface {")
   473  	g.P("Context() context.Context")
   474  	g.P("SendMsg(interface{}) error")
   475  	g.P("RecvMsg(interface{}) error")
   476  	g.P("Close() error")
   477  
   478  	if genSend {
   479  		g.P("Send(*", outType, ") error")
   480  	}
   481  
   482  	if genRecv {
   483  		g.P("Recv() (*", inType, ", error)")
   484  	}
   485  
   486  	g.P("}")
   487  	g.P()
   488  
   489  	g.P("type ", streamType, " struct {")
   490  	g.P("stream ", serverPkg, ".Stream")
   491  	g.P("}")
   492  	g.P()
   493  
   494  	g.P("func (x *", streamType, ") Close() error {")
   495  	g.P("return x.stream.Close()")
   496  	g.P("}")
   497  	g.P()
   498  
   499  	g.P("func (x *", streamType, ") Context() context.Context {")
   500  	g.P("return x.stream.Context()")
   501  	g.P("}")
   502  	g.P()
   503  
   504  	g.P("func (x *", streamType, ") SendMsg(m interface{}) error {")
   505  	g.P("return x.stream.Send(m)")
   506  	g.P("}")
   507  	g.P()
   508  
   509  	g.P("func (x *", streamType, ") RecvMsg(m interface{}) error {")
   510  	g.P("return x.stream.Recv(m)")
   511  	g.P("}")
   512  	g.P()
   513  
   514  	if genSend {
   515  		g.P("func (x *", streamType, ") Send(m *", outType, ") error {")
   516  		g.P("return x.stream.Send(m)")
   517  		g.P("}")
   518  		g.P()
   519  	}
   520  
   521  	if genRecv {
   522  		g.P("func (x *", streamType, ") Recv() (*", inType, ", error) {")
   523  		g.P("m := new(", inType, ")")
   524  		g.P("if err := x.stream.Recv(m); err != nil { return nil, err }")
   525  		g.P("return m, nil")
   526  		g.P("}")
   527  		g.P()
   528  	}
   529  
   530  	return hname
   531  }