go.chromium.org/luci@v0.0.0-20240309015107-7cdc2e660f33/grpc/cmd/cproto/transform.go (about)

     1  // Copyright 2016 The LUCI Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // This file implements .go code transformation.
    16  
    17  package main
    18  
    19  import (
    20  	"bytes"
    21  	"fmt"
    22  	"os"
    23  	"strings"
    24  	"text/template"
    25  	"unicode/utf8"
    26  
    27  	"go/ast"
    28  	"go/format"
    29  	"go/parser"
    30  	"go/printer"
    31  	"go/token"
    32  )
    33  
    34  const (
    35  	prpcPackagePath = `go.chromium.org/luci/grpc/prpc`
    36  )
    37  
    38  var (
    39  	serverPrpcPkg = ast.NewIdent("prpc")
    40  )
    41  
    42  type transformer struct {
    43  	fset          *token.FileSet
    44  	inPRPCPackage bool
    45  	services      []*service
    46  	PackageName   string
    47  }
    48  
    49  // transformGoFile rewrites a .go file to work with prpc.
    50  func (t *transformer) transformGoFile(filename string) error {
    51  	t.fset = token.NewFileSet()
    52  	file, err := parser.ParseFile(t.fset, filename, nil, parser.ParseComments)
    53  	if err != nil {
    54  		return err
    55  	}
    56  
    57  	t.PackageName = file.Name.Name
    58  	t.services, err = getServices(file)
    59  	if err != nil {
    60  		return err
    61  	}
    62  	if len(t.services) == 0 {
    63  		return nil
    64  	}
    65  
    66  	t.inPRPCPackage, err = isInPackage(filename, prpcPackagePath)
    67  	if err != nil {
    68  		return err
    69  	}
    70  
    71  	if err := t.transformFile(file); err != nil {
    72  		return err
    73  	}
    74  
    75  	var buf bytes.Buffer
    76  	if err := printer.Fprint(&buf, t.fset, file); err != nil {
    77  		return err
    78  	}
    79  	formatted, err := gofmt(buf.Bytes())
    80  	if err != nil {
    81  		return err
    82  	}
    83  
    84  	return os.WriteFile(filename, formatted, 0666)
    85  }
    86  
    87  func (t *transformer) transformFile(file *ast.File) error {
    88  	var includePrpc bool
    89  	for _, s := range t.services {
    90  		t.transformRegisterServerFuncs(s)
    91  		if !t.inPRPCPackage {
    92  			includePrpc = true
    93  		}
    94  
    95  		if err := t.generateClients(file, s); err != nil {
    96  			return err
    97  		}
    98  	}
    99  	if includePrpc {
   100  		t.insertImport(file, serverPrpcPkg, prpcPackagePath)
   101  	}
   102  	return nil
   103  }
   104  
   105  // transformRegisterServerFuncs finds RegisterXXXServer functions and
   106  // checks its first parameter type to prpc.Registrar.
   107  // Returns true if modified ast.
   108  func (t *transformer) transformRegisterServerFuncs(s *service) {
   109  	registrarName := ast.NewIdent("Registrar")
   110  	var registrarType ast.Expr = registrarName
   111  	if !t.inPRPCPackage {
   112  		registrarType = &ast.SelectorExpr{X: serverPrpcPkg, Sel: registrarName}
   113  	}
   114  	s.registerServerFunc.Params.List[0].Type = registrarType
   115  }
   116  
   117  // generateClients finds client interface declarations
   118  // and inserts pRPC implementations after them.
   119  func (t *transformer) generateClients(file *ast.File, s *service) error {
   120  	switch newDecls, err := t.generateClient(s.protoPackageName, s.name, s.clientIface); {
   121  	case err != nil:
   122  		return err
   123  	case len(newDecls) > 0:
   124  		insertAST(file, s.clientIfaceDecl, newDecls)
   125  		return nil
   126  	default:
   127  		return nil
   128  	}
   129  }
   130  
   131  func insertAST(file *ast.File, after ast.Decl, newDecls []ast.Decl) {
   132  	for i, d := range file.Decls {
   133  		if d == after {
   134  			file.Decls = append(file.Decls[:i+1], append(newDecls, file.Decls[i+1:]...)...)
   135  			return
   136  		}
   137  	}
   138  	panic("unable to find after node")
   139  }
   140  
   141  var clientCodeTemplate = template.Must(template.New("").Parse(`
   142  package template
   143  
   144  type {{$.StructName}} struct {
   145  	client *{{.PRPCSymbolPrefix}}Client
   146  }
   147  
   148  func New{{.Service}}PRPCClient(client *{{.PRPCSymbolPrefix}}Client) {{.Service}}Client {
   149  	return &{{$.StructName}}{client}
   150  }
   151  
   152  {{range .Methods}}
   153  func (c *{{$.StructName}}) {{.Name}}(ctx context.Context, in *{{.InputMessage}}, opts ...grpc.CallOption) (*{{.OutputMessage}}, error) {
   154  	out := new({{.OutputMessage}})
   155  	err := c.client.Call(ctx, "{{$.ProtoPkg}}.{{$.Service}}", "{{.Name}}", in, out, opts...)
   156  	if err != nil {
   157  		return nil, err
   158  	}
   159  	return out, nil
   160  }
   161  {{end}}
   162  `))
   163  
   164  // generateClient generates pRPC implementation of a client interface.
   165  func (t *transformer) generateClient(protoPackage, serviceName string, iface *ast.InterfaceType) ([]ast.Decl, error) {
   166  	// This function used to construct an AST. It was a lot of code.
   167  	// Now it generates code via a template and parses back to AST.
   168  	// Slower, but saner and easier to make changes.
   169  
   170  	type Method struct {
   171  		Name          string
   172  		InputMessage  string
   173  		OutputMessage string
   174  	}
   175  	methods := make([]Method, 0, len(iface.Methods.List))
   176  
   177  	var buf bytes.Buffer
   178  	toGoCode := func(n ast.Node) (string, error) {
   179  		defer buf.Reset()
   180  		err := format.Node(&buf, t.fset, n)
   181  		if err != nil {
   182  			return "", err
   183  		}
   184  		return buf.String(), nil
   185  	}
   186  
   187  	for _, m := range iface.Methods.List {
   188  		signature, ok := m.Type.(*ast.FuncType)
   189  		if !ok {
   190  			return nil, fmt.Errorf("unexpected embedded interface in %sClient", serviceName)
   191  		}
   192  
   193  		inStructPtr := signature.Params.List[1].Type.(*ast.StarExpr)
   194  		inStruct, err := toGoCode(inStructPtr.X)
   195  		if err != nil {
   196  			return nil, err
   197  		}
   198  
   199  		outStructPtr := signature.Results.List[0].Type.(*ast.StarExpr)
   200  		outStruct, err := toGoCode(outStructPtr.X)
   201  		if err != nil {
   202  			return nil, err
   203  		}
   204  
   205  		methods = append(methods, Method{
   206  			Name:          m.Names[0].Name,
   207  			InputMessage:  inStruct,
   208  			OutputMessage: outStruct,
   209  		})
   210  	}
   211  
   212  	prpcSymbolPrefix := "prpc."
   213  	if t.inPRPCPackage {
   214  		prpcSymbolPrefix = ""
   215  	}
   216  	err := clientCodeTemplate.Execute(&buf, map[string]any{
   217  		"Service":          serviceName,
   218  		"ProtoPkg":         protoPackage,
   219  		"StructName":       firstLower(serviceName) + "PRPCClient",
   220  		"Methods":          methods,
   221  		"PRPCSymbolPrefix": prpcSymbolPrefix,
   222  	})
   223  	if err != nil {
   224  		return nil, fmt.Errorf("client template execution: %s", err)
   225  	}
   226  
   227  	f, err := parser.ParseFile(t.fset, "", buf.String(), 0)
   228  	if err != nil {
   229  		return nil, fmt.Errorf("client template result parsing: %s. Code: %#v", err, buf.String())
   230  	}
   231  	return f.Decls, nil
   232  }
   233  
   234  func (t *transformer) insertImport(file *ast.File, name *ast.Ident, path string) {
   235  	spec := &ast.ImportSpec{
   236  		Name: name,
   237  		Path: &ast.BasicLit{
   238  			Kind:  token.STRING,
   239  			Value: `"` + path + `"`,
   240  		},
   241  	}
   242  	importDecl := &ast.GenDecl{
   243  		Tok:   token.IMPORT,
   244  		Specs: []ast.Spec{spec},
   245  	}
   246  	file.Decls = append([]ast.Decl{importDecl}, file.Decls...)
   247  }
   248  
   249  func firstLower(s string) string {
   250  	_, w := utf8.DecodeRuneInString(s)
   251  	return strings.ToLower(s[:w]) + s[w:]
   252  }