github.com/asynkron/protoactor-go@v0.0.0-20240308120642-ef91a6abee75/protobuf/protoc-gen-go-grain/generate.go (about) 1 package main 2 3 import ( 4 "fmt" 5 "strings" 6 7 "github.com/asynkron/protoactor-go/protobuf/protoc-gen-go-grain/options" 8 "golang.org/x/text/cases" 9 "golang.org/x/text/language" 10 "google.golang.org/protobuf/compiler/protogen" 11 "google.golang.org/protobuf/proto" 12 "google.golang.org/protobuf/types/descriptorpb" 13 ) 14 15 const deprecationComment = "// Deprecated: Do not use." 16 17 const ( 18 timePackage = protogen.GoImportPath("time") 19 errorsPackage = protogen.GoImportPath("errors") 20 fmtPackage = protogen.GoImportPath("fmt") 21 slogPackage = protogen.GoImportPath("log/slog") 22 protoPackage = protogen.GoImportPath("google.golang.org/protobuf/proto") 23 actorPackage = protogen.GoImportPath("github.com/asynkron/protoactor-go/actor") 24 clusterPackage = protogen.GoImportPath("github.com/asynkron/protoactor-go/cluster") 25 ) 26 27 var ( 28 noLowerCaser = cases.Title(language.AmericanEnglish, cases.NoLower) 29 caser = cases.Title(language.AmericanEnglish) 30 ) 31 32 func generateFile(gen *protogen.Plugin, file *protogen.File) { 33 if len(file.Services) == 0 && len(file.Enums) == 0 { 34 return 35 } 36 37 filename := file.GeneratedFilenamePrefix + "_grain.pb.go" 38 g := gen.NewGeneratedFile(filename, file.GoImportPath) 39 40 generateHeader(gen, g, file) 41 generateContent(g, file) 42 } 43 44 func generateHeader(gen *protogen.Plugin, g *protogen.GeneratedFile, file *protogen.File) { 45 g.P("// Code generated by protoc-gen-grain. DO NOT EDIT.") 46 g.P("// versions:") 47 g.P("// protoc-gen-grain ", version) 48 protocVersion := "(unknown)" 49 if v := gen.Request.GetCompilerVersion(); v != nil { 50 protocVersion = fmt.Sprintf("v%v.%v.%v", v.GetMajor(), v.GetMinor(), v.GetPatch()) 51 if s := v.GetSuffix(); s != "" { 52 protocVersion += "-" + s 53 } 54 } 55 g.P("// protoc ", protocVersion) 56 if file.Proto.GetOptions().GetDeprecated() { 57 g.P("// ", file.Desc.Path(), " is a deprecated file.") 58 } else { 59 g.P("// source: ", file.Desc.Path()) 60 } 61 g.P() 62 } 63 64 func generateContent(g *protogen.GeneratedFile, file *protogen.File) { 65 g.P("package ", file.GoPackageName) 66 67 g.QualifiedGoIdent(clusterPackage.Ident("")) 68 g.QualifiedGoIdent(fmtPackage.Ident("")) 69 70 for _, enum := range file.Enums { 71 if enum.Desc.Name() == "ErrorReason" { 72 generateErrorReasons(g, enum) 73 } 74 } 75 76 if len(file.Services) == 0 { 77 return 78 } 79 80 g.QualifiedGoIdent(actorPackage.Ident("")) 81 g.QualifiedGoIdent(protoPackage.Ident("")) 82 g.QualifiedGoIdent(timePackage.Ident("")) 83 g.QualifiedGoIdent(slogPackage.Ident("")) 84 85 for _, service := range file.Services { 86 generateService(service, g) 87 g.P() 88 } 89 90 generateRespond(g) 91 } 92 93 func generateService(service *protogen.Service, g *protogen.GeneratedFile) { 94 if service.Desc.Options().(*descriptorpb.ServiceOptions).GetDeprecated() { 95 g.P("//") 96 g.P(deprecationComment) 97 } 98 99 sd := &serviceDesc{ 100 Name: service.GoName, 101 } 102 103 for i, method := range service.Methods { 104 if method.Desc.IsStreamingClient() || method.Desc.IsStreamingServer() { 105 continue 106 } 107 108 methodOptions, ok := proto.GetExtension(method.Desc.Options(), options.E_MethodOptions).(*options.MethodOptions) 109 if !ok { 110 continue 111 } 112 113 if methodOptions == nil { 114 methodOptions = &options.MethodOptions{} 115 } 116 117 md := &methodDesc{ 118 Name: method.GoName, 119 Input: g.QualifiedGoIdent(method.Input.GoIdent), 120 Output: g.QualifiedGoIdent(method.Output.GoIdent), 121 Index: i, 122 Options: methodOptions, 123 } 124 125 sd.Methods = append(sd.Methods, md) 126 } 127 128 if len(sd.Methods) != 0 { 129 g.P(sd.execute()) 130 } 131 } 132 133 func generateRespond(g *protogen.GeneratedFile) { 134 g.P("func respond[T proto.Message](ctx cluster.GrainContext) func (T) {") 135 g.P("return func (resp T) {") 136 g.P("ctx.Respond(resp)") 137 g.P("}") 138 g.P("}") 139 } 140 141 func generateErrorReasons(g *protogen.GeneratedFile, enum *protogen.Enum) { 142 var es errorsWrapper 143 for _, v := range enum.Values { 144 comment := v.Comments.Leading.String() 145 if comment == "" { 146 comment = v.Comments.Trailing.String() 147 } 148 149 err := &errorDesc{ 150 Name: string(enum.Desc.Name()), 151 Value: string(v.Desc.Name()), 152 CamelValue: toCamel(string(v.Desc.Name())), 153 Comment: comment, 154 HasComment: len(comment) > 0, 155 } 156 es.Errors = append(es.Errors, err) 157 } 158 if len(es.Errors) != 0 { 159 g.P(es.execute()) 160 } 161 } 162 163 func toCamel(s string) string { 164 if !strings.Contains(s, "_") { 165 if s == strings.ToUpper(s) { 166 s = strings.ToLower(s) 167 } 168 return noLowerCaser.String(s) 169 } 170 171 slice := strings.Split(s, "_") 172 for i := 0; i < len(slice); i++ { 173 slice[i] = caser.String(slice[i]) 174 } 175 return strings.Join(slice, "") 176 }