github.com/asynkron/protoactor-go@v0.0.0-20240308120642-ef91a6abee75/protobuf/protoc-gen-go-grain/generate.go (about)

     1  package main
     2  
     3  import (
     4  	"fmt"
     5  	"strings"
     6  
     7  	"github.com/asynkron/protoactor-go/protobuf/protoc-gen-go-grain/options"
     8  	"golang.org/x/text/cases"
     9  	"golang.org/x/text/language"
    10  	"google.golang.org/protobuf/compiler/protogen"
    11  	"google.golang.org/protobuf/proto"
    12  	"google.golang.org/protobuf/types/descriptorpb"
    13  )
    14  
    15  const deprecationComment = "// Deprecated: Do not use."
    16  
    17  const (
    18  	timePackage    = protogen.GoImportPath("time")
    19  	errorsPackage  = protogen.GoImportPath("errors")
    20  	fmtPackage     = protogen.GoImportPath("fmt")
    21  	slogPackage    = protogen.GoImportPath("log/slog")
    22  	protoPackage   = protogen.GoImportPath("google.golang.org/protobuf/proto")
    23  	actorPackage   = protogen.GoImportPath("github.com/asynkron/protoactor-go/actor")
    24  	clusterPackage = protogen.GoImportPath("github.com/asynkron/protoactor-go/cluster")
    25  )
    26  
    27  var (
    28  	noLowerCaser = cases.Title(language.AmericanEnglish, cases.NoLower)
    29  	caser        = cases.Title(language.AmericanEnglish)
    30  )
    31  
    32  func generateFile(gen *protogen.Plugin, file *protogen.File) {
    33  	if len(file.Services) == 0 && len(file.Enums) == 0 {
    34  		return
    35  	}
    36  
    37  	filename := file.GeneratedFilenamePrefix + "_grain.pb.go"
    38  	g := gen.NewGeneratedFile(filename, file.GoImportPath)
    39  
    40  	generateHeader(gen, g, file)
    41  	generateContent(g, file)
    42  }
    43  
    44  func generateHeader(gen *protogen.Plugin, g *protogen.GeneratedFile, file *protogen.File) {
    45  	g.P("// Code generated by protoc-gen-grain. DO NOT EDIT.")
    46  	g.P("// versions:")
    47  	g.P("//  protoc-gen-grain ", version)
    48  	protocVersion := "(unknown)"
    49  	if v := gen.Request.GetCompilerVersion(); v != nil {
    50  		protocVersion = fmt.Sprintf("v%v.%v.%v", v.GetMajor(), v.GetMinor(), v.GetPatch())
    51  		if s := v.GetSuffix(); s != "" {
    52  			protocVersion += "-" + s
    53  		}
    54  	}
    55  	g.P("//  protoc           ", protocVersion)
    56  	if file.Proto.GetOptions().GetDeprecated() {
    57  		g.P("// ", file.Desc.Path(), " is a deprecated file.")
    58  	} else {
    59  		g.P("// source: ", file.Desc.Path())
    60  	}
    61  	g.P()
    62  }
    63  
    64  func generateContent(g *protogen.GeneratedFile, file *protogen.File) {
    65  	g.P("package ", file.GoPackageName)
    66  
    67  	g.QualifiedGoIdent(clusterPackage.Ident(""))
    68  	g.QualifiedGoIdent(fmtPackage.Ident(""))
    69  
    70  	for _, enum := range file.Enums {
    71  		if enum.Desc.Name() == "ErrorReason" {
    72  			generateErrorReasons(g, enum)
    73  		}
    74  	}
    75  
    76  	if len(file.Services) == 0 {
    77  		return
    78  	}
    79  
    80  	g.QualifiedGoIdent(actorPackage.Ident(""))
    81  	g.QualifiedGoIdent(protoPackage.Ident(""))
    82  	g.QualifiedGoIdent(timePackage.Ident(""))
    83  	g.QualifiedGoIdent(slogPackage.Ident(""))
    84  
    85  	for _, service := range file.Services {
    86  		generateService(service, g)
    87  		g.P()
    88  	}
    89  
    90  	generateRespond(g)
    91  }
    92  
    93  func generateService(service *protogen.Service, g *protogen.GeneratedFile) {
    94  	if service.Desc.Options().(*descriptorpb.ServiceOptions).GetDeprecated() {
    95  		g.P("//")
    96  		g.P(deprecationComment)
    97  	}
    98  
    99  	sd := &serviceDesc{
   100  		Name: service.GoName,
   101  	}
   102  
   103  	for i, method := range service.Methods {
   104  		if method.Desc.IsStreamingClient() || method.Desc.IsStreamingServer() {
   105  			continue
   106  		}
   107  
   108  		methodOptions, ok := proto.GetExtension(method.Desc.Options(), options.E_MethodOptions).(*options.MethodOptions)
   109  		if !ok {
   110  			continue
   111  		}
   112  
   113  		if methodOptions == nil {
   114  			methodOptions = &options.MethodOptions{}
   115  		}
   116  
   117  		md := &methodDesc{
   118  			Name:    method.GoName,
   119  			Input:   g.QualifiedGoIdent(method.Input.GoIdent),
   120  			Output:  g.QualifiedGoIdent(method.Output.GoIdent),
   121  			Index:   i,
   122  			Options: methodOptions,
   123  		}
   124  
   125  		sd.Methods = append(sd.Methods, md)
   126  	}
   127  
   128  	if len(sd.Methods) != 0 {
   129  		g.P(sd.execute())
   130  	}
   131  }
   132  
   133  func generateRespond(g *protogen.GeneratedFile) {
   134  	g.P("func respond[T proto.Message](ctx cluster.GrainContext) func (T) {")
   135  	g.P("return func (resp T) {")
   136  	g.P("ctx.Respond(resp)")
   137  	g.P("}")
   138  	g.P("}")
   139  }
   140  
   141  func generateErrorReasons(g *protogen.GeneratedFile, enum *protogen.Enum) {
   142  	var es errorsWrapper
   143  	for _, v := range enum.Values {
   144  		comment := v.Comments.Leading.String()
   145  		if comment == "" {
   146  			comment = v.Comments.Trailing.String()
   147  		}
   148  
   149  		err := &errorDesc{
   150  			Name:       string(enum.Desc.Name()),
   151  			Value:      string(v.Desc.Name()),
   152  			CamelValue: toCamel(string(v.Desc.Name())),
   153  			Comment:    comment,
   154  			HasComment: len(comment) > 0,
   155  		}
   156  		es.Errors = append(es.Errors, err)
   157  	}
   158  	if len(es.Errors) != 0 {
   159  		g.P(es.execute())
   160  	}
   161  }
   162  
   163  func toCamel(s string) string {
   164  	if !strings.Contains(s, "_") {
   165  		if s == strings.ToUpper(s) {
   166  			s = strings.ToLower(s)
   167  		}
   168  		return noLowerCaser.String(s)
   169  	}
   170  
   171  	slice := strings.Split(s, "_")
   172  	for i := 0; i < len(slice); i++ {
   173  		slice[i] = caser.String(slice[i])
   174  	}
   175  	return strings.Join(slice, "")
   176  }