github.com/zaquestion/lab@v0.25.1/cmd/util.go (about)

     1  // This file contains common functions that are shared in the lab package
     2  
     3  package cmd
     4  
     5  import (
     6  	"fmt"
     7  	"os"
     8  	"os/exec"
     9  	"strconv"
    10  	"strings"
    11  
    12  	"github.com/pkg/errors"
    13  	"github.com/spf13/cobra"
    14  	flag "github.com/spf13/pflag"
    15  	"github.com/spf13/viper"
    16  	gitconfig "github.com/tcnksm/go-gitconfig"
    17  	giturls "github.com/whilp/git-urls"
    18  	gitlab "github.com/xanzy/go-gitlab"
    19  	"github.com/zaquestion/lab/internal/config"
    20  	"github.com/zaquestion/lab/internal/git"
    21  	lab "github.com/zaquestion/lab/internal/gitlab"
    22  	"golang.org/x/crypto/ssh/terminal"
    23  )
    24  
    25  var (
    26  	commandPrefix string
    27  	// http vs ssh protocol control flag
    28  	useHTTP bool
    29  )
    30  
    31  // flagConfig compares command line flags and the flags set in the config
    32  // files.  The command line value will always override any value set in the
    33  // config files.
    34  func flagConfig(fs *flag.FlagSet) {
    35  	var cmdFlags string
    36  
    37  	fs.VisitAll(func(f *flag.Flag) {
    38  		var (
    39  			configValue  interface{}
    40  			configString string
    41  			flagChanged  bool
    42  		)
    43  
    44  		switch f.Value.Type() {
    45  		case "bool":
    46  			configValue = getMainConfig().GetBool(commandPrefix + f.Name)
    47  			configString = strconv.FormatBool(configValue.(bool))
    48  		case "string":
    49  			configValue = getMainConfig().GetString(commandPrefix + f.Name)
    50  			configString = configValue.(string)
    51  		case "stringSlice":
    52  			configValue = getMainConfig().GetStringSlice(commandPrefix + f.Name)
    53  			configString = strings.Join(configValue.([]string), " ")
    54  		case "int":
    55  			log.Fatal("ERROR: found int flag, use string instead: ", f.Value.Type(), f)
    56  		case "stringArray":
    57  			// viper does not have support for stringArray
    58  			configString = ""
    59  		default:
    60  			log.Fatal("ERROR: found unidentified flag: ", f.Value.Type(), f)
    61  		}
    62  
    63  		if f.Changed {
    64  			flagChanged = true
    65  		}
    66  
    67  		// o/w use the value in the configfile
    68  		if !flagChanged && configString != "" && configString != f.DefValue {
    69  			f.Value.Set(configString)
    70  			flagChanged = true
    71  		}
    72  
    73  		if flagChanged {
    74  			if f.Name != "debug" {
    75  				cmdFlags += fmt.Sprintf("  %s = %s\n", f.Name, f.Value.String())
    76  			}
    77  		}
    78  	})
    79  
    80  	if len(cmdFlags) != 0 {
    81  		log.Debugf("command flags enabled: \n%s", cmdFlags)
    82  	}
    83  }
    84  
    85  // getCurrentBranchMR returns the MR ID associated with the current branch.
    86  // If a MR ID cannot be found, the function returns 0.
    87  func getCurrentBranchMR(rn string) int {
    88  	currentBranch, err := git.CurrentBranch()
    89  	if err != nil {
    90  		return 0
    91  	}
    92  
    93  	return getBranchMR(rn, currentBranch)
    94  }
    95  
    96  func getBranchMR(rn, branch string) int {
    97  	var num int = 0
    98  
    99  	mrBranch, err := git.UpstreamBranch(branch)
   100  	if err != nil {
   101  		// Fall back to local branch
   102  		mrBranch = branch
   103  	}
   104  
   105  	branchRemote, err := determineSourceRemote(branch)
   106  	if err != nil {
   107  		log.Fatal(err)
   108  	}
   109  
   110  	branchProjectName, err := git.PathWithNamespace(branchRemote)
   111  	if err != nil {
   112  		log.Fatal(err)
   113  	}
   114  
   115  	branchProject, err := lab.FindProject(branchProjectName)
   116  	if err != nil {
   117  		log.Fatal(err)
   118  	}
   119  
   120  	mrs, err := lab.MRList(rn, gitlab.ListProjectMergeRequestsOptions{
   121  		Labels:       mrLabels,
   122  		State:        &mrState,
   123  		OrderBy:      gitlab.String("updated_at"),
   124  		SourceBranch: gitlab.String(mrBranch),
   125  	}, -1)
   126  	if err != nil {
   127  		log.Fatal(err)
   128  	}
   129  
   130  	for _, mr := range mrs {
   131  		if mr.SourceProjectID == branchProject.ID {
   132  			num = mr.IID
   133  			break
   134  		}
   135  	}
   136  	return num
   137  }
   138  
   139  // getMainConfig returns the merged config of ~/.config/lab/lab.toml and
   140  // .git/lab/lab.toml
   141  func getMainConfig() *viper.Viper {
   142  	return config.MainConfig
   143  }
   144  
   145  // parseArgsRemoteAndID is used by commands to parse command line arguments.
   146  // This function returns a remote name and number.
   147  func parseArgsRemoteAndID(args []string) (string, int64, error) {
   148  	if !git.InsideGitRepo() {
   149  		return "", 0, nil
   150  	}
   151  
   152  	remote, num, err := parseArgsStringAndID(args)
   153  	if err != nil {
   154  		return "", 0, err
   155  	}
   156  	ok, err := git.IsRemote(remote)
   157  	if err != nil {
   158  		return "", 0, err
   159  	} else if !ok && remote != "" {
   160  		switch len(args) {
   161  		case 1:
   162  			return "", 0, errors.Errorf("%s is not a valid remote or number", args[0])
   163  		default:
   164  			return "", 0, errors.Errorf("%s is not a valid remote", args[0])
   165  		}
   166  	}
   167  	if remote == "" {
   168  		remote = defaultRemote
   169  	}
   170  	rn, err := git.PathWithNamespace(remote)
   171  	if err != nil {
   172  		return "", 0, err
   173  	}
   174  	return rn, num, nil
   175  }
   176  
   177  // parseArgsRemoteAndProject is used by commands to parse command line
   178  // arguments.  This function returns a remote name and the project name.  If no
   179  // remote name is given, the function returns "" and the project name of the
   180  // default remote (ie 'origin').
   181  func parseArgsRemoteAndProject(args []string) (string, string, error) {
   182  	if !git.InsideGitRepo() {
   183  		return "", "", nil
   184  	}
   185  
   186  	remote, str, err := parseArgsRemoteAndString(args)
   187  	if err != nil {
   188  		return "", "", nil
   189  	}
   190  
   191  	remote, err = getRemoteName(remote)
   192  	if err != nil {
   193  		return "", "", err
   194  	}
   195  	return remote, str, nil
   196  }
   197  
   198  // parseArgsRemoteAndBranch is used by commands to parse command line
   199  // arguments.  This function returns a remote name and a branch name.
   200  // If no branch name is given, the function returns the upstream of
   201  // the current branch and the corresponding remote.
   202  func parseArgsRemoteAndBranch(args []string) (string, string, error) {
   203  	if !git.InsideGitRepo() {
   204  		return "", "", nil
   205  	}
   206  
   207  	remote, branch, err := parseArgsRemoteAndString(args)
   208  	if err != nil {
   209  		return "", "", err
   210  	} else if branch == "" {
   211  		branch, err = git.CurrentBranch()
   212  	}
   213  
   214  	remoteBranch, _ := git.UpstreamBranch(branch)
   215  	if remoteBranch != "" {
   216  		branch = remoteBranch
   217  	}
   218  
   219  	if remote == "" {
   220  		remote, err = determineSourceRemote(branch)
   221  		if err != nil {
   222  			return "", "", err
   223  		}
   224  	}
   225  	remote, err = getRemoteName(remote)
   226  	if err != nil {
   227  		return "", "", err
   228  	}
   229  
   230  	return remote, branch, nil
   231  }
   232  
   233  func getPipelineFromArgs(args []string, forMR bool) (string, int, error) {
   234  	if forMR {
   235  		rn, mrNum, err := parseArgsWithGitBranchMR(args)
   236  		if err != nil {
   237  			return "", 0, err
   238  		}
   239  
   240  		mr, err := lab.MRGet(rn, int(mrNum))
   241  		if err != nil {
   242  			return "", 0, err
   243  		}
   244  
   245  		// In this part, we only really care about the latest pipeline that
   246  		// ran, regardless its result.
   247  		if mr.HeadPipeline == nil {
   248  			return "", 0, errors.Errorf("No pipeline found for merge request %d", mrNum)
   249  		}
   250  
   251  		// MR pipelines may run on the source, target or another project
   252  		// (multi-project pipelines), and we don't have a proper way to
   253  		// know which it is. Here we handle the first two cases.
   254  		if strings.Contains(mr.HeadPipeline.WebURL, rn) {
   255  			return rn, mr.HeadPipeline.ID, nil
   256  		}
   257  
   258  		p, err := lab.GetProject(mr.SourceProjectID)
   259  		if err != nil {
   260  			return "", 0, err
   261  		}
   262  
   263  		return p.PathWithNamespace, mr.HeadPipeline.ID, nil
   264  	}
   265  	rn, refName, err := parseArgsRemoteAndBranch(args)
   266  	if err != nil {
   267  		return "", 0, err
   268  	}
   269  
   270  	commit, err := lab.GetCommit(rn, refName)
   271  	if err != nil {
   272  		return "", 0, err
   273  	}
   274  
   275  	if commit.LastPipeline == nil {
   276  		return "", 0, errors.Errorf("No pipeline found for %s", refName)
   277  	}
   278  
   279  	return rn, commit.LastPipeline.ID, nil
   280  }
   281  
   282  func getRemoteName(remote string) (string, error) {
   283  	if remote == "" {
   284  		remote = defaultRemote
   285  	}
   286  
   287  	ok, err := git.IsRemote(remote)
   288  	if err != nil {
   289  		return "", err
   290  	}
   291  	if !ok {
   292  		return "", errors.Errorf("%s is not a valid remote", remote)
   293  	}
   294  
   295  	remote, err = git.PathWithNamespace(remote)
   296  	if err != nil {
   297  		return "", err
   298  	}
   299  
   300  	return remote, nil
   301  }
   302  
   303  // parseArgsStringAndID is used by commands to parse command line arguments.
   304  // This function returns a string and number.
   305  func parseArgsStringAndID(args []string) (string, int64, error) {
   306  	if len(args) == 2 {
   307  		n, err := strconv.ParseInt(args[1], 0, 64)
   308  		if err != nil {
   309  			return args[0], 0, err
   310  		}
   311  		return args[0], n, nil
   312  	}
   313  	if len(args) == 1 {
   314  		n, err := strconv.ParseInt(args[0], 0, 64)
   315  		if err != nil {
   316  			return args[0], 0, nil
   317  		}
   318  		return "", n, nil
   319  	}
   320  	return "", 0, nil
   321  }
   322  
   323  func parseArgsRemoteAndString(args []string) (string, string, error) {
   324  	remote, str := "", ""
   325  
   326  	if len(args) == 1 {
   327  		ok, err := git.IsRemote(args[0])
   328  		if err != nil {
   329  			return "", "", err
   330  		}
   331  		if ok {
   332  			remote = args[0]
   333  		} else {
   334  			str = args[0]
   335  		}
   336  	} else if len(args) > 1 {
   337  		remote, str = args[0], args[1]
   338  	}
   339  
   340  	return remote, str, nil
   341  }
   342  
   343  // parseArgsWithGitBranchMR returns a remote name and a number if parsed.
   344  // If no number is specified, the MR id associated with the given branch
   345  // is returned, using the current branch as fallback.
   346  func parseArgsWithGitBranchMR(args []string) (string, int64, error) {
   347  	rn, id, err := parseArgsRemoteAndID(args)
   348  	if err == nil && id != 0 {
   349  		return rn, id, nil
   350  	}
   351  
   352  	rn, branch, err := parseArgsRemoteAndString(args)
   353  	if err != nil {
   354  		return "", 0, err
   355  	}
   356  
   357  	rn, err = getRemoteName(rn)
   358  	if err != nil {
   359  		return "", 0, err
   360  	}
   361  
   362  	if branch == "" {
   363  		id = int64(getCurrentBranchMR(rn))
   364  	} else {
   365  		id = int64(getBranchMR(rn, branch))
   366  	}
   367  
   368  	if id == 0 {
   369  		err = fmt.Errorf("cannot determine MR id")
   370  		return "", 0, err
   371  	}
   372  
   373  	return rn, id, nil
   374  }
   375  
   376  // filterCommentArg separate the case where a command can have both the
   377  // remote and "<mrID>:<commentID>" at the same time.
   378  func filterCommentArg(args []string) (int, []string, error) {
   379  	branchArgs := []string{}
   380  	idString := ""
   381  
   382  	if len(args) == 1 {
   383  		ok, err := git.IsRemote(args[0])
   384  		if err != nil {
   385  			return 0, branchArgs, err
   386  		}
   387  		if ok {
   388  			branchArgs = append(branchArgs, args[0])
   389  		} else {
   390  			idString = args[0]
   391  		}
   392  	} else if len(args) == 2 {
   393  		branchArgs = append(branchArgs, args[0])
   394  		idString = args[1]
   395  	}
   396  
   397  	if strings.Contains(idString, ":") {
   398  		ps := strings.Split(idString, ":")
   399  		branchArgs = append(branchArgs, ps[0])
   400  		idString = ps[1]
   401  	} else if idString != "" {
   402  		branchArgs = append(branchArgs, idString)
   403  		idString = ""
   404  	}
   405  
   406  	idNum, _ := strconv.Atoi(idString)
   407  	return idNum, branchArgs, nil
   408  }
   409  
   410  // setCommandPrefix returns a concatenated value of some of the commandline.
   411  // For example, 'lab mr show' would return 'mr_show.', and 'lab issue list'
   412  // would return 'issue_list.'
   413  func setCommandPrefix(scmd *cobra.Command) {
   414  	for _, command := range RootCmd.Commands() {
   415  		if commandPrefix != "" {
   416  			break
   417  		}
   418  		commandName := strings.Split(command.Use, " ")[0]
   419  		if scmd == command {
   420  			commandPrefix = commandName + "."
   421  			break
   422  		}
   423  		for _, subcommand := range command.Commands() {
   424  			subCommandName := commandName + "_" + strings.Split(subcommand.Use, " ")[0]
   425  			if scmd == subcommand {
   426  				commandPrefix = subCommandName + "."
   427  				break
   428  			}
   429  		}
   430  	}
   431  }
   432  
   433  // textToMarkdown converts text with markdown friendly line breaks
   434  // See https://gist.github.com/shaunlebron/746476e6e7a4d698b373 for more info.
   435  func textToMarkdown(text string) string {
   436  	text = strings.Replace(text, "\n", "  \n", -1)
   437  	return text
   438  }
   439  
   440  // isOutputTerminal checks if both stdout and stderr are indeed terminals
   441  // to avoid some markdown rendering garbage going to other outputs that
   442  // don't support some control chars.
   443  func isOutputTerminal() bool {
   444  	if !terminal.IsTerminal(sysStdout) ||
   445  		!terminal.IsTerminal(sysStderr) {
   446  		return false
   447  	}
   448  	return true
   449  }
   450  
   451  type pager struct {
   452  	proc   *os.Process
   453  	stdout int
   454  }
   455  
   456  // If standard output is a terminal, redirect output to an external
   457  // pager until the returned object's Close() method is called
   458  func newPager(fs *flag.FlagSet) *pager {
   459  	cmdLine, env := git.PagerCommand()
   460  	args := strings.Split(cmdLine, " ")
   461  
   462  	noPager, _ := fs.GetBool("no-pager")
   463  	if !isOutputTerminal() || noPager || args[0] == "cat" {
   464  		return &pager{}
   465  	}
   466  
   467  	pr, pw, _ := os.Pipe()
   468  	defer pw.Close()
   469  
   470  	name, _ := exec.LookPath(args[0])
   471  	proc, _ := os.StartProcess(name, args, &os.ProcAttr{
   472  		Env:   env,
   473  		Files: []*os.File{pr, os.Stdout, os.Stderr},
   474  	})
   475  
   476  	savedStdout, _ := dupFD(sysStdout)
   477  	_ = dupFD2(int(pw.Fd()), sysStdout)
   478  
   479  	return &pager{
   480  		proc:   proc,
   481  		stdout: savedStdout,
   482  	}
   483  }
   484  
   485  // Close closes the pager
   486  func (p *pager) Close() {
   487  	if p.stdout > 0 {
   488  		_ = dupFD2(p.stdout, sysStdout)
   489  		_ = closeFD(p.stdout)
   490  	}
   491  	if p.proc != nil {
   492  		p.proc.Wait()
   493  	}
   494  }
   495  
   496  func labPersistentPreRun(cmd *cobra.Command, args []string) {
   497  	flagConfig(cmd.Flags())
   498  }
   499  
   500  // labURLToRepo returns the string representing the URL to a certain repo based
   501  // on the protocol used
   502  func labURLToRepo(project *gitlab.Project) string {
   503  	urlToRepo := project.SSHURLToRepo
   504  	if useHTTP {
   505  		urlToRepo = project.HTTPURLToRepo
   506  	}
   507  	return urlToRepo
   508  }
   509  
   510  func determineSourceRemote(branch string) (string, error) {
   511  	// There is a precendence of options that should be considered here:
   512  	// branch.<name>.pushRemote > remote.pushDefault > branch.<name>.remote
   513  	// This rule is placed in git-config(1) manpage
   514  	r, err := gitconfig.Local("branch." + branch + ".pushRemote")
   515  	if err != nil {
   516  		r, err = gitconfig.Local("remote.pushDefault")
   517  		if err != nil {
   518  			r, err = gitconfig.Local("branch." + branch + ".remote")
   519  			if err != nil {
   520  				return forkRemote, nil
   521  			}
   522  		}
   523  	}
   524  
   525  	// Parse the remote name for possible URL.
   526  	u, err := giturls.Parse(r)
   527  	if err != nil {
   528  		return "", err
   529  	}
   530  
   531  	path := strings.TrimPrefix(u.Path, "/")
   532  	return path, nil
   533  }
   534  
   535  // Check of a case-insensitive prefix in a string
   536  func hasPrefix(str, prefix string) bool {
   537  	if len(str) < len(prefix) {
   538  		return false
   539  	}
   540  	return strings.EqualFold(str[0:len(prefix)], prefix)
   541  }
   542  
   543  // Match terms being searched with an existing list of terms, checking its
   544  // ambiguity at the same time
   545  func matchTerms(searchTerms, existentTerms []string) ([]string, error) {
   546  	var ambiguous bool
   547  	matches := make([]string, len(searchTerms))
   548  
   549  	for i, sTerm := range searchTerms {
   550  		ambiguous = false
   551  		lowerSTerm := strings.ToLower(sTerm)
   552  		for _, eTerm := range existentTerms {
   553  			lowerETerm := strings.ToLower(eTerm)
   554  
   555  			// no match
   556  			if !strings.Contains(lowerETerm, lowerSTerm) {
   557  				continue
   558  			}
   559  
   560  			// check for ambiguity on substring level
   561  			if matches[i] != "" && lowerSTerm != lowerETerm {
   562  				ambiguous = true
   563  				continue
   564  			}
   565  
   566  			matches[i] = eTerm
   567  
   568  			// exact match
   569  			// may happen after multiple substring matches
   570  			if lowerETerm == lowerSTerm {
   571  				ambiguous = false
   572  				break
   573  			}
   574  		}
   575  
   576  		if matches[i] == "" {
   577  			return nil, errors.Errorf("'%s' not found", sTerm)
   578  		}
   579  
   580  		// Ambiguous matches should not be returned to avoid
   581  		// manipulating the wrong item.
   582  		if ambiguous {
   583  			return nil, errors.Errorf("'%s' has no exact match and is ambiguous", sTerm)
   584  		}
   585  	}
   586  
   587  	return matches, nil
   588  }
   589  
   590  // union returns all the unique elements in a and b
   591  func union(a, b []string) []string {
   592  	mb := map[string]bool{}
   593  	ab := []string{}
   594  	for _, x := range b {
   595  		mb[x] = true
   596  		// add all of b's elements to ab
   597  		ab = append(ab, x)
   598  	}
   599  	for _, x := range a {
   600  		if _, ok := mb[x]; !ok {
   601  			// if a's elements aren't in b, add them to ab
   602  			// if they are, we don't need to add them
   603  			ab = append(ab, x)
   604  		}
   605  	}
   606  	return ab
   607  }
   608  
   609  // difference returns the elements in a that aren't in b
   610  func difference(a, b []string) []string {
   611  	mb := map[string]bool{}
   612  	for _, x := range b {
   613  		mb[x] = true
   614  	}
   615  	ab := []string{}
   616  	for _, x := range a {
   617  		if _, ok := mb[x]; !ok {
   618  			ab = append(ab, x)
   619  		}
   620  	}
   621  	return ab
   622  }
   623  
   624  // same returns true if a and b contain the same strings (regardless of order)
   625  func same(a, b []string) bool {
   626  	if len(a) != len(b) {
   627  		return false
   628  	}
   629  
   630  	mb := map[string]bool{}
   631  	for _, x := range b {
   632  		mb[x] = true
   633  	}
   634  
   635  	for _, x := range a {
   636  		if _, ok := mb[x]; !ok {
   637  			return false
   638  		}
   639  	}
   640  	return true
   641  }
   642  
   643  // getUser returns the userID for use with other GitLab API calls.
   644  func getUserID(user string) *int {
   645  	var (
   646  		err    error
   647  		userID int
   648  	)
   649  
   650  	if user == "" {
   651  		return nil
   652  	}
   653  
   654  	if user[0] == '@' {
   655  		user = user[1:]
   656  	}
   657  
   658  	if strings.Contains(user, "@") {
   659  		userID, err = lab.UserIDFromEmail(user)
   660  	} else {
   661  		userID, err = lab.UserIDFromUsername(user)
   662  	}
   663  	if err != nil {
   664  		return nil
   665  	}
   666  	if userID == -1 {
   667  		return nil
   668  	}
   669  
   670  	return gitlab.Int(userID)
   671  }
   672  
   673  // getUsers returns the userIDs for use with other GitLab API calls.
   674  func getUserIDs(users []string) []int {
   675  	var ids []int
   676  	for _, user := range users {
   677  		userID := getUserID(user)
   678  		if userID != nil {
   679  			ids = append(ids, *userID)
   680  		} else {
   681  			fmt.Printf("Warning: %s is not a valid username\n", user)
   682  		}
   683  	}
   684  	return ids
   685  }