github.com/cli/cli@v1.14.1-0.20210902173923-1af6a669e342/pkg/cmdutil/repo_override.go (about)

     1  package cmdutil
     2  
     3  import (
     4  	"os"
     5  	"sort"
     6  	"strings"
     7  
     8  	"github.com/cli/cli/internal/ghrepo"
     9  	"github.com/spf13/cobra"
    10  )
    11  
    12  func executeParentHooks(cmd *cobra.Command, args []string) error {
    13  	for cmd.HasParent() {
    14  		cmd = cmd.Parent()
    15  		if cmd.PersistentPreRunE != nil {
    16  			return cmd.PersistentPreRunE(cmd, args)
    17  		}
    18  	}
    19  	return nil
    20  }
    21  
    22  func EnableRepoOverride(cmd *cobra.Command, f *Factory) {
    23  	cmd.PersistentFlags().StringP("repo", "R", "", "Select another repository using the `[HOST/]OWNER/REPO` format")
    24  	_ = cmd.RegisterFlagCompletionFunc("repo", func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
    25  		remotes, err := f.Remotes()
    26  		if err != nil {
    27  			return nil, cobra.ShellCompDirectiveError
    28  		}
    29  
    30  		config, err := f.Config()
    31  		if err != nil {
    32  			return nil, cobra.ShellCompDirectiveError
    33  		}
    34  		defaultHost, err := config.DefaultHost()
    35  		if err != nil {
    36  			return nil, cobra.ShellCompDirectiveError
    37  		}
    38  
    39  		var results []string
    40  		for _, remote := range remotes {
    41  			repo := remote.RepoOwner() + "/" + remote.RepoName()
    42  			if !strings.EqualFold(remote.RepoHost(), defaultHost) {
    43  				repo = remote.RepoHost() + "/" + repo
    44  			}
    45  			if strings.HasPrefix(repo, toComplete) {
    46  				results = append(results, repo)
    47  			}
    48  		}
    49  		sort.Strings(results)
    50  		return results, cobra.ShellCompDirectiveNoFileComp
    51  	})
    52  
    53  	cmd.PersistentPreRunE = func(cmd *cobra.Command, args []string) error {
    54  		if err := executeParentHooks(cmd, args); err != nil {
    55  			return err
    56  		}
    57  		repoOverride, _ := cmd.Flags().GetString("repo")
    58  		f.BaseRepo = OverrideBaseRepoFunc(f, repoOverride)
    59  		return nil
    60  	}
    61  }
    62  
    63  func OverrideBaseRepoFunc(f *Factory, override string) func() (ghrepo.Interface, error) {
    64  	if override == "" {
    65  		override = os.Getenv("GH_REPO")
    66  	}
    67  	if override != "" {
    68  		return func() (ghrepo.Interface, error) {
    69  			return ghrepo.FromFullName(override)
    70  		}
    71  	}
    72  	return f.BaseRepo
    73  }