github.com/mloves0824/enron/cmd/enron@v0.0.0-20230830012320-113bbf6be3c8/internal/proto/server/server.go (about)

     1  package server
     2  
     3  import (
     4  	"fmt"
     5  	"log"
     6  	"os"
     7  	"path/filepath"
     8  	"strings"
     9  
    10  	"github.com/emicklei/proto"
    11  	"github.com/spf13/cobra"
    12  	"golang.org/x/text/cases"
    13  	"golang.org/x/text/language"
    14  )
    15  
    16  // CmdServer the service command.
    17  var CmdServer = &cobra.Command{
    18  	Use:   "server",
    19  	Short: "Generate the proto server implementations",
    20  	Long:  "Generate the proto server implementations. Example: enron proto server api/xxx.proto --target-dir=internal/service",
    21  	Run:   run,
    22  }
    23  var targetDir string
    24  
    25  func init() {
    26  	CmdServer.Flags().StringVarP(&targetDir, "target-dir", "t", "internal/service", "generate target directory")
    27  }
    28  
    29  func run(_ *cobra.Command, args []string) {
    30  	if len(args) == 0 {
    31  		fmt.Fprintln(os.Stderr, "Please specify the proto file. Example: enron proto server api/xxx.proto")
    32  		return
    33  	}
    34  	reader, err := os.Open(args[0])
    35  	if err != nil {
    36  		log.Fatal(err)
    37  	}
    38  	defer reader.Close()
    39  
    40  	parser := proto.NewParser(reader)
    41  	definition, err := parser.Parse()
    42  	if err != nil {
    43  		log.Fatal(err)
    44  	}
    45  
    46  	var (
    47  		pkg string
    48  		res []*Service
    49  	)
    50  	proto.Walk(definition,
    51  		proto.WithOption(func(o *proto.Option) {
    52  			if o.Name == "go_package" {
    53  				pkg = strings.Split(o.Constant.Source, ";")[0]
    54  			}
    55  		}),
    56  		proto.WithService(func(s *proto.Service) {
    57  			cs := &Service{
    58  				Package: pkg,
    59  				Service: serviceName(s.Name),
    60  			}
    61  			for _, e := range s.Elements {
    62  				r, ok := e.(*proto.RPC)
    63  				if !ok {
    64  					continue
    65  				}
    66  				cs.Methods = append(cs.Methods, &Method{
    67  					Service: serviceName(s.Name), Name: serviceName(r.Name), Request: parametersName(r.RequestType),
    68  					Reply: parametersName(r.ReturnsType), Type: getMethodType(r.StreamsRequest, r.StreamsReturns),
    69  				})
    70  			}
    71  			res = append(res, cs)
    72  		}),
    73  	)
    74  	if _, err := os.Stat(targetDir); os.IsNotExist(err) {
    75  		fmt.Printf("Target directory: %s does not exsit\n", targetDir)
    76  		return
    77  	}
    78  	for _, s := range res {
    79  		to := filepath.Join(targetDir, strings.ToLower(s.Service)+".go")
    80  		if _, err := os.Stat(to); !os.IsNotExist(err) {
    81  			fmt.Fprintf(os.Stderr, "%s already exists: %s\n", s.Service, to)
    82  			continue
    83  		}
    84  		b, err := s.execute()
    85  		if err != nil {
    86  			log.Fatal(err)
    87  		}
    88  		if err := os.WriteFile(to, b, 0o644); err != nil {
    89  			log.Fatal(err)
    90  		}
    91  		fmt.Println(to)
    92  	}
    93  }
    94  
    95  func getMethodType(streamsRequest, streamsReturns bool) MethodType {
    96  	if !streamsRequest && !streamsReturns {
    97  		return unaryType
    98  	} else if streamsRequest && streamsReturns {
    99  		return twoWayStreamsType
   100  	} else if streamsRequest {
   101  		return requestStreamsType
   102  	} else if streamsReturns {
   103  		return returnsStreamsType
   104  	}
   105  	return unaryType
   106  }
   107  
   108  func parametersName(name string) string {
   109  	return strings.ReplaceAll(name, ".", "_")
   110  }
   111  
   112  func serviceName(name string) string {
   113  	return toUpperCamelCase(strings.Split(name, ".")[0])
   114  }
   115  
   116  func toUpperCamelCase(s string) string {
   117  	s = strings.ReplaceAll(s, "_", " ")
   118  	s = cases.Title(language.Und, cases.NoLower).String(s)
   119  	return strings.ReplaceAll(s, " ", "")
   120  }