go.chromium.org/luci@v0.0.0-20240309015107-7cdc2e660f33/grpc/cmd/cproto/service.go (about)

     1  // Copyright 2016 The LUCI Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package main
    16  
    17  import (
    18  	"errors"
    19  	"fmt"
    20  	"strings"
    21  
    22  	"go/ast"
    23  	"go/token"
    24  )
    25  
    26  // service is the set of data extracted from a generated protobuf file (.pb.go)
    27  // for a single service.
    28  type service struct {
    29  	name string
    30  
    31  	protoPackageName string
    32  
    33  	clientIfaceDecl ast.Decl
    34  	clientIface     *ast.InterfaceType
    35  
    36  	registerServerFunc *ast.FuncType
    37  }
    38  
    39  func getServices(file *ast.File) ([]*service, error) {
    40  	svcs := map[string]*service{}
    41  	var serviceNames []string
    42  	get := func(name string) *service {
    43  		s := svcs[name]
    44  		if s == nil {
    45  			s = &service{name: name}
    46  			svcs[name] = s
    47  			serviceNames = append(serviceNames, name)
    48  		}
    49  		return s
    50  	}
    51  
    52  	for _, decl := range file.Decls {
    53  		switch dt := decl.(type) {
    54  		case *ast.FuncDecl:
    55  			// Identify server types by scanning for Register<NAME>Server functions.
    56  			name := trimPhrase(dt.Name.Name, "Register", "Server")
    57  			if name == "" {
    58  				break
    59  			}
    60  			s := get(name)
    61  			s.registerServerFunc = dt.Type
    62  
    63  		case *ast.GenDecl:
    64  			// Look for:
    65  			// 1) The client interface type: type ...Client
    66  			// 2) The service descriptor:
    67  			//    var ... = grpc.ServiceDesc
    68  			for _, spec := range dt.Specs {
    69  				switch st := spec.(type) {
    70  				case *ast.TypeSpec:
    71  					name := trimPhrase(st.Name.Name, "", "Client")
    72  					if name == "" {
    73  						break
    74  					}
    75  
    76  					iface, ok := st.Type.(*ast.InterfaceType)
    77  					if !ok {
    78  						continue
    79  					}
    80  
    81  					s := get(name)
    82  					s.clientIfaceDecl = decl
    83  					s.clientIface = iface
    84  
    85  				case *ast.ValueSpec:
    86  					if len(st.Values) != 1 {
    87  						continue
    88  					}
    89  					compLit, ok := st.Values[0].(*ast.CompositeLit)
    90  					if !ok {
    91  						continue
    92  					}
    93  
    94  					// Is the assigned type "grpc.ServiceDesc"?
    95  					tsel, ok := compLit.Type.(*ast.SelectorExpr)
    96  					if !ok || tsel.Sel.Name != "ServiceDesc" {
    97  						continue
    98  					}
    99  					pkg, ok := tsel.X.(*ast.Ident)
   100  					if !ok || pkg.Name != "grpc" {
   101  						continue
   102  					}
   103  
   104  					// Get the "ServiceName" struct field and parse it.
   105  					var serviceNameExpr *ast.KeyValueExpr
   106  					for _, e := range compLit.Elts {
   107  						kv, ok := e.(*ast.KeyValueExpr)
   108  						if !ok {
   109  							continue
   110  						}
   111  						kident, ok := kv.Key.(*ast.Ident)
   112  						if !ok {
   113  							continue
   114  						}
   115  						if kident.Name == "ServiceName" {
   116  							serviceNameExpr = kv
   117  							break
   118  						}
   119  					}
   120  					if serviceNameExpr == nil {
   121  						return nil, errors.New("could not find ServiceName member")
   122  					}
   123  					// Get string value.
   124  					serviceNameLit, ok := serviceNameExpr.Value.(*ast.BasicLit)
   125  					if !ok || serviceNameLit.Kind != token.STRING {
   126  						return nil, errors.New("ServiceDesc.ServiceName not a string literal")
   127  					}
   128  					value := trimPhrase(serviceNameLit.Value, `"`, `"`)
   129  					if value == "" {
   130  						return nil, errors.New("ServiceDesc.ServiceName is not properly quoted")
   131  					}
   132  					protoPackage, service, err := parseServiceName(value)
   133  					if err != nil {
   134  						return nil, err
   135  					}
   136  
   137  					s := get(service)
   138  					s.protoPackageName = protoPackage
   139  				}
   140  			}
   141  			break
   142  		}
   143  	}
   144  
   145  	// Export our services as a slice, ordered by when the service was first
   146  	// encountered in the source file.
   147  	// Verify each service is complete.
   148  	services := make([]*service, len(serviceNames))
   149  	for i, k := range serviceNames {
   150  		s := svcs[k]
   151  		if err := s.complete(); err != nil {
   152  			return nil, fmt.Errorf("incomplete service %q: %s", s.name, err)
   153  		}
   154  		services[i] = s
   155  	}
   156  	return services, nil
   157  }
   158  
   159  func (s *service) complete() error {
   160  	if s.protoPackageName == "" {
   161  		return errors.New("missing protobuf package name")
   162  	}
   163  	if s.clientIface == nil {
   164  		return errors.New("missing client iface")
   165  	}
   166  	if s.registerServerFunc == nil {
   167  		return errors.New("missing server registration function")
   168  	}
   169  	return nil
   170  }
   171  
   172  // trimPhrase removes the specified prefix and suffix strings from the supplied
   173  // v. If either prefix is missing, suffix is missing, or v consists entirely of
   174  // prefix and suffix, the empty string is returned.
   175  func trimPhrase(v, prefix, suffix string) string {
   176  	if !strings.HasPrefix(v, prefix) {
   177  		return ""
   178  	}
   179  	v = strings.TrimPrefix(v, prefix)
   180  
   181  	if !strings.HasSuffix(v, suffix) {
   182  		return ""
   183  	}
   184  	return strings.TrimSuffix(v, suffix)
   185  }
   186  
   187  func parseServiceName(v string) (string, string, error) {
   188  	idx := strings.LastIndex(v, ".")
   189  	if idx <= 0 {
   190  		return "", "", errors.New("malformed service name")
   191  	}
   192  	return v[:idx], v[idx+1:], nil
   193  }