go.chromium.org/luci@v0.0.0-20240309015107-7cdc2e660f33/grpc/cmd/cproto/transform.go (about) 1 // Copyright 2016 The LUCI Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 // This file implements .go code transformation. 16 17 package main 18 19 import ( 20 "bytes" 21 "fmt" 22 "os" 23 "strings" 24 "text/template" 25 "unicode/utf8" 26 27 "go/ast" 28 "go/format" 29 "go/parser" 30 "go/printer" 31 "go/token" 32 ) 33 34 const ( 35 prpcPackagePath = `go.chromium.org/luci/grpc/prpc` 36 ) 37 38 var ( 39 serverPrpcPkg = ast.NewIdent("prpc") 40 ) 41 42 type transformer struct { 43 fset *token.FileSet 44 inPRPCPackage bool 45 services []*service 46 PackageName string 47 } 48 49 // transformGoFile rewrites a .go file to work with prpc. 50 func (t *transformer) transformGoFile(filename string) error { 51 t.fset = token.NewFileSet() 52 file, err := parser.ParseFile(t.fset, filename, nil, parser.ParseComments) 53 if err != nil { 54 return err 55 } 56 57 t.PackageName = file.Name.Name 58 t.services, err = getServices(file) 59 if err != nil { 60 return err 61 } 62 if len(t.services) == 0 { 63 return nil 64 } 65 66 t.inPRPCPackage, err = isInPackage(filename, prpcPackagePath) 67 if err != nil { 68 return err 69 } 70 71 if err := t.transformFile(file); err != nil { 72 return err 73 } 74 75 var buf bytes.Buffer 76 if err := printer.Fprint(&buf, t.fset, file); err != nil { 77 return err 78 } 79 formatted, err := gofmt(buf.Bytes()) 80 if err != nil { 81 return err 82 } 83 84 return os.WriteFile(filename, formatted, 0666) 85 } 86 87 func (t *transformer) transformFile(file *ast.File) error { 88 var includePrpc bool 89 for _, s := range t.services { 90 t.transformRegisterServerFuncs(s) 91 if !t.inPRPCPackage { 92 includePrpc = true 93 } 94 95 if err := t.generateClients(file, s); err != nil { 96 return err 97 } 98 } 99 if includePrpc { 100 t.insertImport(file, serverPrpcPkg, prpcPackagePath) 101 } 102 return nil 103 } 104 105 // transformRegisterServerFuncs finds RegisterXXXServer functions and 106 // checks its first parameter type to prpc.Registrar. 107 // Returns true if modified ast. 108 func (t *transformer) transformRegisterServerFuncs(s *service) { 109 registrarName := ast.NewIdent("Registrar") 110 var registrarType ast.Expr = registrarName 111 if !t.inPRPCPackage { 112 registrarType = &ast.SelectorExpr{X: serverPrpcPkg, Sel: registrarName} 113 } 114 s.registerServerFunc.Params.List[0].Type = registrarType 115 } 116 117 // generateClients finds client interface declarations 118 // and inserts pRPC implementations after them. 119 func (t *transformer) generateClients(file *ast.File, s *service) error { 120 switch newDecls, err := t.generateClient(s.protoPackageName, s.name, s.clientIface); { 121 case err != nil: 122 return err 123 case len(newDecls) > 0: 124 insertAST(file, s.clientIfaceDecl, newDecls) 125 return nil 126 default: 127 return nil 128 } 129 } 130 131 func insertAST(file *ast.File, after ast.Decl, newDecls []ast.Decl) { 132 for i, d := range file.Decls { 133 if d == after { 134 file.Decls = append(file.Decls[:i+1], append(newDecls, file.Decls[i+1:]...)...) 135 return 136 } 137 } 138 panic("unable to find after node") 139 } 140 141 var clientCodeTemplate = template.Must(template.New("").Parse(` 142 package template 143 144 type {{$.StructName}} struct { 145 client *{{.PRPCSymbolPrefix}}Client 146 } 147 148 func New{{.Service}}PRPCClient(client *{{.PRPCSymbolPrefix}}Client) {{.Service}}Client { 149 return &{{$.StructName}}{client} 150 } 151 152 {{range .Methods}} 153 func (c *{{$.StructName}}) {{.Name}}(ctx context.Context, in *{{.InputMessage}}, opts ...grpc.CallOption) (*{{.OutputMessage}}, error) { 154 out := new({{.OutputMessage}}) 155 err := c.client.Call(ctx, "{{$.ProtoPkg}}.{{$.Service}}", "{{.Name}}", in, out, opts...) 156 if err != nil { 157 return nil, err 158 } 159 return out, nil 160 } 161 {{end}} 162 `)) 163 164 // generateClient generates pRPC implementation of a client interface. 165 func (t *transformer) generateClient(protoPackage, serviceName string, iface *ast.InterfaceType) ([]ast.Decl, error) { 166 // This function used to construct an AST. It was a lot of code. 167 // Now it generates code via a template and parses back to AST. 168 // Slower, but saner and easier to make changes. 169 170 type Method struct { 171 Name string 172 InputMessage string 173 OutputMessage string 174 } 175 methods := make([]Method, 0, len(iface.Methods.List)) 176 177 var buf bytes.Buffer 178 toGoCode := func(n ast.Node) (string, error) { 179 defer buf.Reset() 180 err := format.Node(&buf, t.fset, n) 181 if err != nil { 182 return "", err 183 } 184 return buf.String(), nil 185 } 186 187 for _, m := range iface.Methods.List { 188 signature, ok := m.Type.(*ast.FuncType) 189 if !ok { 190 return nil, fmt.Errorf("unexpected embedded interface in %sClient", serviceName) 191 } 192 193 inStructPtr := signature.Params.List[1].Type.(*ast.StarExpr) 194 inStruct, err := toGoCode(inStructPtr.X) 195 if err != nil { 196 return nil, err 197 } 198 199 outStructPtr := signature.Results.List[0].Type.(*ast.StarExpr) 200 outStruct, err := toGoCode(outStructPtr.X) 201 if err != nil { 202 return nil, err 203 } 204 205 methods = append(methods, Method{ 206 Name: m.Names[0].Name, 207 InputMessage: inStruct, 208 OutputMessage: outStruct, 209 }) 210 } 211 212 prpcSymbolPrefix := "prpc." 213 if t.inPRPCPackage { 214 prpcSymbolPrefix = "" 215 } 216 err := clientCodeTemplate.Execute(&buf, map[string]any{ 217 "Service": serviceName, 218 "ProtoPkg": protoPackage, 219 "StructName": firstLower(serviceName) + "PRPCClient", 220 "Methods": methods, 221 "PRPCSymbolPrefix": prpcSymbolPrefix, 222 }) 223 if err != nil { 224 return nil, fmt.Errorf("client template execution: %s", err) 225 } 226 227 f, err := parser.ParseFile(t.fset, "", buf.String(), 0) 228 if err != nil { 229 return nil, fmt.Errorf("client template result parsing: %s. Code: %#v", err, buf.String()) 230 } 231 return f.Decls, nil 232 } 233 234 func (t *transformer) insertImport(file *ast.File, name *ast.Ident, path string) { 235 spec := &ast.ImportSpec{ 236 Name: name, 237 Path: &ast.BasicLit{ 238 Kind: token.STRING, 239 Value: `"` + path + `"`, 240 }, 241 } 242 importDecl := &ast.GenDecl{ 243 Tok: token.IMPORT, 244 Specs: []ast.Spec{spec}, 245 } 246 file.Decls = append([]ast.Decl{importDecl}, file.Decls...) 247 } 248 249 func firstLower(s string) string { 250 _, w := utf8.DecodeRuneInString(s) 251 return strings.ToLower(s[:w]) + s[w:] 252 }