trpc.group/trpc-go/trpc-cmdline@v1.0.9/cmd/create/options.go (about)

     1  // Tencent is pleased to support the open source community by making tRPC available.
     2  //
     3  // Copyright (C) 2023 THL A29 Limited, a Tencent company.
     4  // All rights reserved.
     5  //
     6  // If you have downloaded a copy of the tRPC source code from Tencent,
     7  // please note that tRPC source code is licensed under the  Apache 2.0 License,
     8  // A copy of the Apache 2.0 License is included in this file.
     9  
    10  package create
    11  
    12  import (
    13  	"encoding/json"
    14  	"fmt"
    15  	"os"
    16  	"path/filepath"
    17  	"strings"
    18  
    19  	"github.com/pkg/errors"
    20  	"github.com/spf13/pflag"
    21  
    22  	"trpc.group/trpc-go/trpc-cmdline/config"
    23  	"trpc.group/trpc-go/trpc-cmdline/plugin"
    24  	"trpc.group/trpc-go/trpc-cmdline/util/fs"
    25  	"trpc.group/trpc-go/trpc-cmdline/util/lang"
    26  	"trpc.group/trpc-go/trpc-cmdline/util/paths"
    27  	"trpc.group/trpc-go/trpc-cmdline/util/pb"
    28  )
    29  
    30  // loadOptions loads options from flags.
    31  func (c *Create) loadOptions(flags *pflag.FlagSet) error {
    32  	if err := c.parse(flags); err != nil {
    33  		return fmt.Errorf("load options err: %w", err)
    34  	}
    35  	if err := c.fixOptions(); err != nil {
    36  		return fmt.Errorf("fix create options inside prerun err: %w", err)
    37  	}
    38  	return nil
    39  }
    40  
    41  // parse verifies and parses the flags.
    42  func (c *Create) parse(flags *pflag.FlagSet) error {
    43  	if err := c.parseIDLOptions(flags); err != nil {
    44  		return fmt.Errorf("parse idl options err: %w", err)
    45  	}
    46  
    47  	if err := c.parseGeneralOptions(flags); err != nil {
    48  		return fmt.Errorf("parse general options err: %w", err)
    49  	}
    50  
    51  	if err := c.parseSwaggerOptions(flags); err != nil {
    52  		return fmt.Errorf("parse swagger options err: %w", err)
    53  	}
    54  
    55  	if err := c.parseProtocolOptions(flags); err != nil {
    56  		return fmt.Errorf("parse protocol options err: %w", err)
    57  	}
    58  
    59  	if err := c.parseAuxiliaryOptions(flags); err != nil {
    60  		return fmt.Errorf("parse auxiliary options err: %w", err)
    61  	}
    62  
    63  	if err := c.parseSyncGitOptions(flags); err != nil {
    64  		return fmt.Errorf("parse sync git options err: %w", err)
    65  	}
    66  	return nil
    67  }
    68  
    69  // fixOptions fixes the options.
    70  func (c *Create) fixOptions() error {
    71  	c.fixGoMod()
    72  	if err := c.fixIDL(); err != nil {
    73  		return fmt.Errorf("fix idl err: %w", err)
    74  	}
    75  	return nil
    76  }
    77  
    78  // fixGoMod fixes the module name.
    79  func (c *Create) fixGoMod() {
    80  	//  1. Use module name specified by -mod.
    81  	//  2. Use local go.mod if -mod is not specified (for backward compatibility).
    82  	//  3. Use package name defined in pb (implemented by template).
    83  	if c.options.GoMod != "" {
    84  		return
    85  	}
    86  	mod, err := lang.LoadGoMod()
    87  	if err != nil {
    88  		return
    89  	}
    90  	if mod == "" {
    91  		return
    92  	}
    93  	c.options.GoModEx = mod
    94  	c.options.GoMod = mod
    95  	return
    96  }
    97  
    98  // fixIDL fixes the IDL.
    99  func (c *Create) fixIDL() error {
   100  	if c.options.OtherType != "" {
   101  		return c.fixOtherType() // Non-IDL type, such as kafka, HTTP.
   102  	}
   103  	if c.options.Protofile == "" {
   104  		return errors.New("protobuf/flatbuffers file both empty")
   105  	}
   106  	if err := c.fixProtoDirs(); err != nil {
   107  		return fmt.Errorf("fix proto dirs err: %w", err)
   108  	}
   109  	return c.fixProtocolType()
   110  }
   111  
   112  // fixOtherType updates the options related to "OtherType".
   113  func (c *Create) fixOtherType() error {
   114  	installPath, err := config.CurrentTemplatePath()
   115  	if err != nil {
   116  		return fmt.Errorf("failed to get current template path for other type err: %w", err)
   117  	}
   118  
   119  	if c.options.Assetdir == "" { // Do not override options provided by the user.
   120  		c.options.Assetdir = filepath.Join(installPath, "without_idl", c.options.Language, c.options.OtherType)
   121  	}
   122  
   123  	c.options.Protocol = c.options.OtherType
   124  	// May consider using c.options.OutputDir here.
   125  
   126  	c.options.Protofile = c.options.OtherType
   127  
   128  	// Plugins for code generation.
   129  	plugin.Plugins = []plugin.Plugin{
   130  		&plugin.GoImports{}, // goimports, runs before mockgen, to eliminate `package import but unused` errors
   131  		&plugin.Formatter{}, // gofmt
   132  	}
   133  
   134  	plugin.PluginsExt[c.options.Language] = nil
   135  
   136  	return nil
   137  }
   138  
   139  // fixProtoDirs updates the options related to proto directories.
   140  // If DescriptorSetIn is provided, it locates the file and updates the file path.
   141  // Otherwise, it locates the proto file and updates the options accordingly.
   142  func (c *Create) fixProtoDirs() error {
   143  	if c.options.DescriptorSetIn != "" {
   144  		var err error
   145  		filePath, err := fs.LocateFile(c.options.DescriptorSetIn, nil)
   146  		if err != nil {
   147  			return fmt.Errorf("fs locate file %s err: %w", c.options.DescriptorSetIn, err)
   148  		}
   149  		c.options.DescriptorSetIn = filePath
   150  		return nil
   151  	}
   152  
   153  	p, err := paths.Locate(pb.ProtoTRPC)
   154  	if err != nil {
   155  		return fmt.Errorf("paths locate %s failed err: %w", pb.ProtoTRPC, err)
   156  	}
   157  
   158  	c.options.Protodirs = fs.UniqFilePath(append(append(c.options.Protodirs, p),
   159  		paths.ExpandSearch(p)...,
   160  	))
   161  
   162  	target, err := fs.LocateFile(c.options.Protofile, c.options.Protodirs)
   163  	if err != nil {
   164  		return fmt.Errorf("locate file in proto dirs failed err: %w", err)
   165  	}
   166  
   167  	if c.options.UseBaseName {
   168  		c.options.Protofile = filepath.Base(target)
   169  	} else if filepath.IsAbs(c.options.Protofile) {
   170  		c.options.Protofile = strings.TrimPrefix(c.options.Protofile, "/")
   171  	} else {
   172  		c.options.Protofile = strings.TrimPrefix(c.options.Protofile, "./")
   173  	}
   174  
   175  	c.options.ProtofileAbs = target
   176  	c.options.Protodirs = append(c.options.Protodirs, filepath.Dir(target))
   177  
   178  	return nil
   179  }
   180  
   181  // fixProtocolType updates the options related to the protocol type.
   182  // It loads configurations from trpc.yaml and updates the options accordingly.
   183  func (c *Create) fixProtocolType() error {
   184  	// Load configurations from trpc.yaml.
   185  	cfg, err := config.GetTemplate(c.options.IDLType, c.options.Language)
   186  	if err != nil {
   187  		return fmt.Errorf("config get template failed err: %w", err)
   188  	}
   189  	if c.options.Assetdir == "" {
   190  		c.options.Assetdir = cfg.AssetDir
   191  	}
   192  	if c.options.Domain == "" {
   193  		c.options.Domain = config.GlobalConfig().Domain
   194  	}
   195  	if c.options.VersionSuffix != "" {
   196  		c.options.VersionSuffix = "/" + c.options.VersionSuffix
   197  	}
   198  	return nil
   199  }
   200  
   201  // parseIDLOptions parses the IDL-related options from the command line flags.
   202  // It parses the "usebasename" flag and delegates to other functions to parse protobuf/flatbuffers options.
   203  func (c *Create) parseIDLOptions(flags *pflag.FlagSet) error {
   204  	var err error
   205  	c.options.UseBaseName, err = flags.GetBool("usebasename")
   206  	if err != nil {
   207  		return fmt.Errorf("flags parse usebasename %w", err)
   208  	}
   209  	// Parse protobuf/flatbuffers options.
   210  	if err := c.parsePBIDLOptions(flags); err != nil {
   211  		return fmt.Errorf("flags parse pb idl options err: %w", err)
   212  	}
   213  	// If protofile field is empty, try parse flatbuffers related flags.
   214  	if c.options.Protofile == "" {
   215  		if err := c.parseFBIDLOptions(flags); err != nil {
   216  			return fmt.Errorf("flags parse fb idl options, err: %w", err)
   217  		}
   218  	}
   219  	return nil
   220  }
   221  
   222  // parseGeneralOptions parses the general options from the command line flags.
   223  // It parses the "verbose" flag and delegates to other functions to parse input/output options.
   224  func (c *Create) parseGeneralOptions(flags *pflag.FlagSet) error {
   225  	var err error
   226  	c.options.Verbose, err = flags.GetBool("verbose")
   227  	if err != nil {
   228  		return fmt.Errorf("flags parse verbose string err: %w", err)
   229  	}
   230  	if err := c.parseInputOptions(flags); err != nil {
   231  		return err
   232  	}
   233  	if err := c.parseOutputOptions(flags); err != nil {
   234  		return err
   235  	}
   236  	return nil
   237  }
   238  
   239  // parseAuxiliaryOptions parses the auxiliary options from the command line flags.
   240  // It parses various boolean and string flags related to auxiliary options.
   241  func (c *Create) parseAuxiliaryOptions(flags *pflag.FlagSet) error {
   242  	var err error
   243  	c.options.MultiVersion, err = flags.GetBool("multi-version")
   244  	if err != nil {
   245  		return fmt.Errorf("flags parse multi-version bool err: %w", err)
   246  	}
   247  	c.options.NoServiceSuffix, err = flags.GetBool("noservicesuffix")
   248  	if err != nil {
   249  		return fmt.Errorf("flags parse noservicesuffix bool err: %w", err)
   250  	}
   251  	return nil
   252  }
   253  
   254  // parseInputOptions parses the input options from the command line flags.
   255  // It parses various string and boolean flags related to input options.
   256  func (c *Create) parseInputOptions(flags *pflag.FlagSet) error {
   257  	var err error
   258  	c.options.Assetdir, err = flags.GetString("assetdir")
   259  	if err != nil {
   260  		return fmt.Errorf("flags parse assetdir string err: %w", err)
   261  	}
   262  	c.options.Language, err = flags.GetString("lang")
   263  	if err != nil {
   264  		return fmt.Errorf("flags parse lang string err: %w", err)
   265  	}
   266  	c.options.AliasOn, err = flags.GetBool("alias")
   267  	if err != nil {
   268  		return fmt.Errorf("flags parse alias bool err: %w", err)
   269  	}
   270  	c.options.AliasAsClientRPCName, err = flags.GetBool("alias-as-client-rpcname")
   271  	if err != nil {
   272  		return fmt.Errorf("flags parse alias-as-client-rpcname bool err: %w", err)
   273  	}
   274  	c.options.GoMod, err = flags.GetString("mod")
   275  	if err != nil {
   276  		return fmt.Errorf("flags parse mod string err: %w", err)
   277  	}
   278  	c.options.GoVersion, err = flags.GetString("goversion")
   279  	if err != nil {
   280  		return fmt.Errorf("flags parse goversion string err: %w", err)
   281  	}
   282  	c.options.TRPCGoVersion, err = flags.GetString("trpcgoversion")
   283  	if err != nil {
   284  		return fmt.Errorf("flags parse trpcgoversion string err: %w", err)
   285  	}
   286  	c.options.CustomAPPName, err = flags.GetString("app")
   287  	if err != nil {
   288  		return fmt.Errorf("flags parse app string err: %w", err)
   289  	}
   290  	c.options.CustomServerName, err = flags.GetString("server")
   291  	if err != nil {
   292  		return fmt.Errorf("flags parse server string err: %w", err)
   293  	}
   294  	c.options.DescriptorSetIn, err = flags.GetString("descriptor_set_in")
   295  	if err != nil {
   296  		return fmt.Errorf("flags parse descriptor_set_in string err: %w", err)
   297  	}
   298  	return nil
   299  }
   300  
   301  // parseOutputOptions parses the output options from the command line flags.
   302  // It parses various string and boolean flags related to output options.
   303  func (c *Create) parseOutputOptions(flags *pflag.FlagSet) error {
   304  	var err error
   305  	c.options.OutputDir, err = flags.GetString("output")
   306  	if err != nil {
   307  		return fmt.Errorf("flags parse output string err: %w", err)
   308  	}
   309  	c.options.RPCOnly, err = flags.GetBool("rpconly")
   310  	if err != nil {
   311  		return fmt.Errorf("flags parse rpconly bool err: %w", err)
   312  	}
   313  	c.options.DependencyStub, err = flags.GetBool("dependencystub")
   314  	if err != nil {
   315  		return fmt.Errorf("flags parse dependencystub %w", err)
   316  	}
   317  	c.options.NoGoMod, err = flags.GetBool("nogomod")
   318  	if err != nil {
   319  		return fmt.Errorf("flags parse nogomod bool err: %w", err)
   320  	}
   321  	c.options.KeepOrigRPCName = true // Always true.
   322  	c.options.SecvEnabled, err = flags.GetBool("secvenabled")
   323  	if err != nil {
   324  		return fmt.Errorf("flags parse secvenabled bool err: %w", err)
   325  	}
   326  	c.options.ValidateEnabled, err = flags.GetBool("validate")
   327  	if err != nil {
   328  		return fmt.Errorf("flags parse validate bool err: %w", err)
   329  	}
   330  	kvFile, err := flags.GetString("kvfile")
   331  	if err != nil {
   332  		return fmt.Errorf("flags parse kvfile string err: %w", err)
   333  	}
   334  	if kvFile != "" {
   335  		bs, err := os.ReadFile(kvFile)
   336  		if err != nil {
   337  			return fmt.Errorf("read kv file %s err: %w", kvFile, err)
   338  		}
   339  		if err := json.Unmarshal(bs, &c.options.KVs); err != nil {
   340  			return fmt.Errorf("json unmarshal kv file %s into %T err: %w", kvFile, c.options.KVs, err)
   341  		}
   342  	}
   343  	kvRawJSON, err := flags.GetString("kvrawjson")
   344  	if err != nil {
   345  		return fmt.Errorf("flags parse kvrawjson string err: %w", err)
   346  	}
   347  	if kvRawJSON != "" {
   348  		if err := json.Unmarshal([]byte(kvRawJSON), &c.options.KVs); err != nil {
   349  			return fmt.Errorf("json unmarshal kv raw json %s into %T err: %w", kvRawJSON, c.options.KVs, err)
   350  		}
   351  	}
   352  	c.options.Force, err = flags.GetBool("force")
   353  	if err != nil {
   354  		return fmt.Errorf("flags parse force bool err: %w", err)
   355  	}
   356  	c.options.Mockgen, err = flags.GetBool("mock")
   357  	if err != nil {
   358  		return fmt.Errorf("flags parse mock bool err: %w", err)
   359  	}
   360  	c.options.PerMethod, err = flags.GetBool("split-by-method")
   361  	if err != nil {
   362  		return fmt.Errorf("flags parse split-by-method bool err: %w", err)
   363  	}
   364  	c.options.Domain, err = flags.GetString("domain")
   365  	if err != nil {
   366  		return fmt.Errorf("flags parse domain string err: %w", err)
   367  	}
   368  	c.options.GroupName, err = flags.GetString("groupname")
   369  	if err != nil {
   370  		return fmt.Errorf("flags parse groupname string err: %w", err)
   371  	}
   372  	c.options.VersionSuffix, err = flags.GetString("versionsuffix")
   373  	if err != nil {
   374  		return fmt.Errorf("flags parse versionsuffix string err: %w", err)
   375  	}
   376  	return nil
   377  }
   378  
   379  // parseSwaggerOptions parses the swagger options from the command line flags.
   380  // It parses various string and boolean flags related to swagger options.
   381  func (c *Create) parseSwaggerOptions(flags *pflag.FlagSet) error {
   382  	var err error
   383  	c.options.SwaggerOn, err = flags.GetBool("swagger")
   384  	if err != nil {
   385  		return fmt.Errorf("flags parse swagger bool err: %w", err)
   386  	}
   387  	c.options.SwaggerOut, err = flags.GetString("swagger-out")
   388  	if err != nil {
   389  		return fmt.Errorf("flags parse swagger-out string err: %w", err)
   390  	}
   391  	c.options.SwaggerOptJSONParam, err = flags.GetBool("swagger-json-param")
   392  	if err != nil {
   393  		return fmt.Errorf("flags parse swagger-json-param bool err: %w", err)
   394  	}
   395  	return nil
   396  }
   397  
   398  // parseProtocolOptions parses the protocol options from the command line flags.
   399  // It parses various string flags related to protocol options.
   400  func (c *Create) parseProtocolOptions(flags *pflag.FlagSet) error {
   401  	var err error
   402  	c.options.Protocol, err = flags.GetString("protocol")
   403  	if err != nil {
   404  		return fmt.Errorf("flags parse protocol string err: %w", err)
   405  	}
   406  	c.options.OtherType, err = flags.GetString("non-protocol-type")
   407  	if err != nil {
   408  		return fmt.Errorf("flags parse non-protocol-type string err: %w", err)
   409  	}
   410  	return nil
   411  }
   412  
   413  // parsePBIDLOptions parses the protobuf IDL options from the command line flags.
   414  // It parses various string flags and an array of strings related to protobuf options.
   415  func (c *Create) parsePBIDLOptions(flags *pflag.FlagSet) error {
   416  	dirs, err := flags.GetStringArray("protodir")
   417  	if err != nil {
   418  		return fmt.Errorf("flags get protodir string array failed err: %w", err)
   419  	}
   420  	// Always append the current working directory and root directory.
   421  	c.options.Protodirs = fs.UniqFilePath(append(dirs, ".", "/"))
   422  	c.options.Protofile, err = flags.GetString("protofile")
   423  	if err != nil {
   424  		return fmt.Errorf("flags get protofile string failed err: %w", err)
   425  	}
   426  	c.options.Gotag, err = flags.GetBool("gotag")
   427  	if err != nil {
   428  		return fmt.Errorf("flags get gotag bool failed err: %w", err)
   429  	}
   430  	c.options.IDLType = config.IDLTypeProtobuf
   431  	return nil
   432  }
   433  
   434  // parseFBIDLOptions parses the FlatBuffers IDL options from the command line flags.
   435  // It parses various string flags and an array of strings related to FlatBuffers options.
   436  func (c *Create) parseFBIDLOptions(flags *pflag.FlagSet) error {
   437  	dirs, err := flags.GetStringArray("fbsdir")
   438  	if err != nil {
   439  		return fmt.Errorf("flags get fbsdir string array failed err: %w", err)
   440  	}
   441  	// Always append the current working directory.
   442  	c.options.Protodirs = fs.UniqFilePath(append(dirs, "."))
   443  	c.options.Protofile, err = flags.GetString("fbs")
   444  	if err != nil {
   445  		return fmt.Errorf("flags get fbs string failed err: %w", err)
   446  	}
   447  	c.options.IDLType = config.IDLTypeFlatBuffers
   448  	return nil
   449  }
   450  
   451  // parseSyncGitOptions parses the synchronization and git options from the command line flags.
   452  // It parses various string and boolean flags related to git synchronization options.
   453  func (c *Create) parseSyncGitOptions(flags *pflag.FlagSet) error {
   454  	sync, err := flags.GetBool("sync")
   455  	if err != nil {
   456  		return fmt.Errorf("flags get git sync bool failed err: %w", err)
   457  	}
   458  	c.options.Sync = sync
   459  	remote, err := flags.GetString("remote")
   460  	if err != nil {
   461  		return fmt.Errorf("flags get git remote address url failed err: %w", err)
   462  	}
   463  	c.options.Remote = remote
   464  	newTag, err := flags.GetBool("newtag")
   465  	if err != nil {
   466  		return fmt.Errorf("flags get git new tag bool failed err: %w", err)
   467  	}
   468  	c.options.NewTag = newTag
   469  	tag, err := flags.GetString("tag")
   470  	if err != nil {
   471  		return fmt.Errorf("flags get git tag failed err: %w", err)
   472  	}
   473  	c.options.Tag = tag
   474  	return nil
   475  }