github.com/abhinav/git-fu@v0.6.1-0.20171029234004-54218d68c11b/git/gateway.go (about)

     1  package git
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"os"
     7  	"os/exec"
     8  	"path/filepath"
     9  	"strings"
    10  	"sync"
    11  
    12  	"github.com/abhinav/git-pr/gateway"
    13  
    14  	"go.uber.org/multierr"
    15  )
    16  
    17  // Gateway is a git gateway.
    18  type Gateway struct {
    19  	mu sync.RWMutex
    20  
    21  	dir string
    22  }
    23  
    24  var _ gateway.Git = (*Gateway)(nil)
    25  
    26  // NewGateway builds a new Git gateway.
    27  func NewGateway(startDir string) (*Gateway, error) {
    28  	if startDir == "" {
    29  		dir, err := os.Getwd()
    30  		if err != nil {
    31  			return nil, fmt.Errorf(
    32  				"failed to determine current working directory: %v", err)
    33  		}
    34  		startDir = dir
    35  	} else {
    36  		dir, err := filepath.Abs(startDir)
    37  		if err != nil {
    38  			return nil, fmt.Errorf(
    39  				"failed to determine absolute path of %v: %v", startDir, err)
    40  		}
    41  		startDir = dir
    42  	}
    43  
    44  	dir := startDir
    45  	for {
    46  		_, err := os.Stat(filepath.Join(dir, ".git"))
    47  		if err == nil {
    48  			break
    49  		}
    50  		newDir := filepath.Dir(dir)
    51  		if dir == newDir {
    52  			return nil, fmt.Errorf(
    53  				"could not find git repository at %v", startDir)
    54  		}
    55  		dir = newDir
    56  	}
    57  
    58  	return &Gateway{dir: dir}, nil
    59  }
    60  
    61  // CurrentBranch determines the current branch name.
    62  func (g *Gateway) CurrentBranch() (string, error) {
    63  	g.mu.RLock()
    64  	defer g.mu.RUnlock()
    65  
    66  	out, err := g.output("rev-parse", "--abbrev-ref", "HEAD")
    67  	if err != nil {
    68  		return "", fmt.Errorf("could not determine current branch: %v", err)
    69  	}
    70  	return strings.TrimSpace(out), nil
    71  }
    72  
    73  // DoesBranchExist checks if this branch exists locally.
    74  func (g *Gateway) DoesBranchExist(name string) bool {
    75  	g.mu.RLock()
    76  	defer g.mu.RUnlock()
    77  
    78  	err := g.cmd("show-ref", "--verify", "--quiet", "refs/heads/"+name).Run()
    79  	return err == nil
    80  }
    81  
    82  // CreateBranchAndCheckout creates a branch with the given name and head and
    83  // switches to it.
    84  func (g *Gateway) CreateBranchAndCheckout(name, head string) error {
    85  	g.mu.Lock()
    86  	defer g.mu.Unlock()
    87  
    88  	if err := g.cmd("checkout", "-b", name, head).Run(); err != nil {
    89  		return fmt.Errorf(
    90  			"failed to create and checkout branch %q at ref %q: %v", name, head, err)
    91  	}
    92  	return nil
    93  }
    94  
    95  // CreateBranch creates a branch with the given name and head but does not
    96  // check it out.
    97  func (g *Gateway) CreateBranch(name, head string) error {
    98  	g.mu.Lock()
    99  	defer g.mu.Unlock()
   100  
   101  	if err := g.cmd("branch", name, head).Run(); err != nil {
   102  		return fmt.Errorf("failed to create branch %q at ref %q: %v", name, head, err)
   103  	}
   104  	return nil
   105  }
   106  
   107  // SHA1 gets the SHA1 hash for the given ref.
   108  func (g *Gateway) SHA1(ref string) (string, error) {
   109  	g.mu.RLock()
   110  	defer g.mu.RUnlock()
   111  
   112  	out, err := g.output("rev-parse", "--verify", "-q", ref)
   113  	if err != nil {
   114  		return "", fmt.Errorf("could not resolve ref %q: %v", ref, err)
   115  	}
   116  	return strings.TrimSpace(out), nil
   117  }
   118  
   119  // DeleteBranch deletes the given branch.
   120  func (g *Gateway) DeleteBranch(name string) error {
   121  	g.mu.Lock()
   122  	defer g.mu.Unlock()
   123  
   124  	if err := g.cmd("branch", "-D", name).Run(); err != nil {
   125  		return fmt.Errorf("failed to delete branch %q: %v", name, err)
   126  	}
   127  	return nil
   128  }
   129  
   130  // DeleteRemoteTrackingBranch deletes the remote tracking branch with the
   131  // given name.
   132  func (g *Gateway) DeleteRemoteTrackingBranch(remote, name string) error {
   133  	g.mu.Lock()
   134  	defer g.mu.Unlock()
   135  
   136  	if err := g.cmd("branch", "-dr", remote+"/"+name).Run(); err != nil {
   137  		return fmt.Errorf("failed to delete remote tracking branch %q: %v", name, err)
   138  	}
   139  	return nil
   140  }
   141  
   142  // Checkout checks the given branch out.
   143  func (g *Gateway) Checkout(name string) error {
   144  	g.mu.Lock()
   145  	defer g.mu.Unlock()
   146  
   147  	if err := g.cmd("checkout", name).Run(); err != nil {
   148  		err = fmt.Errorf("failed to checkout branch %q: %v", name, err)
   149  	}
   150  	return nil
   151  }
   152  
   153  // Fetch a git ref
   154  func (g *Gateway) Fetch(req *gateway.FetchRequest) error {
   155  	ref := req.RemoteRef
   156  	if req.LocalRef != "" {
   157  		ref = ref + ":" + req.LocalRef
   158  	}
   159  
   160  	g.mu.Lock()
   161  	defer g.mu.Unlock()
   162  
   163  	if err := g.cmd("fetch", req.Remote, ref).Run(); err != nil {
   164  		return fmt.Errorf("failed to fetch %q from %q: %v", ref, req.Remote, err)
   165  	}
   166  	return nil
   167  }
   168  
   169  // Push pushes refs to a remote.
   170  func (g *Gateway) Push(req *gateway.PushRequest) error {
   171  	if len(req.Refs) == 0 {
   172  		return nil
   173  	}
   174  
   175  	args := append(make([]string, 0, len(req.Refs)+2), "push")
   176  	if req.Force {
   177  		args = append(args, "-f")
   178  	}
   179  	args = append(args, req.Remote)
   180  
   181  	for ref, remote := range req.Refs {
   182  		if remote != "" {
   183  			ref = ref + ":" + remote
   184  		}
   185  		args = append(args, ref)
   186  	}
   187  
   188  	g.mu.Lock()
   189  	defer g.mu.Unlock()
   190  
   191  	if err := g.cmd(args...).Run(); err != nil {
   192  		return fmt.Errorf("failed to push refs to %q: %v", req.Remote, err)
   193  	}
   194  	return nil
   195  }
   196  
   197  // Pull pulls the given branch.
   198  func (g *Gateway) Pull(remote, name string) error {
   199  	g.mu.Lock()
   200  	defer g.mu.Unlock()
   201  
   202  	if err := g.cmd("pull", remote, name).Run(); err != nil {
   203  		return fmt.Errorf("failed to pull %q from %q: %v", name, remote, err)
   204  	}
   205  	return nil
   206  }
   207  
   208  // Rebase a branch.
   209  func (g *Gateway) Rebase(req *gateway.RebaseRequest) error {
   210  	var _args [5]string
   211  
   212  	args := append(_args[:0], "rebase")
   213  	if req.Onto != "" {
   214  		args = append(args, "--onto", req.Onto)
   215  	}
   216  	if req.From != "" {
   217  		args = append(args, req.From)
   218  	}
   219  	args = append(args, req.Branch)
   220  
   221  	g.mu.Lock()
   222  	defer g.mu.Unlock()
   223  
   224  	if err := g.cmd(args...).Run(); err != nil {
   225  		return multierr.Append(
   226  			fmt.Errorf("failed to rebase %q: %v", req.Branch, err),
   227  			// If this failed, abort the rebase so that we're not left in a
   228  			// bad state.
   229  			g.cmd("rebase", "--abort").Run(),
   230  		)
   231  	}
   232  	return nil
   233  }
   234  
   235  // ResetBranch resets the given branch to the given head.
   236  func (g *Gateway) ResetBranch(branch, head string) error {
   237  	curr, err := g.CurrentBranch()
   238  	if err != nil {
   239  		return fmt.Errorf("could not reset %q to %q: %v", branch, head, err)
   240  	}
   241  
   242  	g.mu.Lock()
   243  	defer g.mu.Unlock()
   244  
   245  	if curr == branch {
   246  		err = g.cmd("reset", "--hard", head).Run()
   247  	} else {
   248  		err = g.cmd("branch", "-f", branch, head).Run()
   249  	}
   250  
   251  	if err != nil {
   252  		err = fmt.Errorf("could not reset %q to %q: %v", branch, head, err)
   253  	}
   254  	return err
   255  }
   256  
   257  // RemoteURL gets the URL for the given remote.
   258  func (g *Gateway) RemoteURL(name string) (string, error) {
   259  	g.mu.RLock()
   260  	defer g.mu.RUnlock()
   261  
   262  	out, err := g.output("remote", "get-url", name)
   263  	if err != nil {
   264  		return "", fmt.Errorf("failed to get URL for remote %q: %v", name, err)
   265  	}
   266  	return strings.TrimSpace(out), nil
   267  }
   268  
   269  // run the given git command.
   270  func (g *Gateway) cmd(args ...string) *exec.Cmd {
   271  	cmd := exec.Command("git", args...)
   272  	cmd.Dir = g.dir
   273  	cmd.Stderr = os.Stderr
   274  	cmd.Stdout = os.Stdout
   275  	return cmd
   276  }
   277  
   278  func (g *Gateway) output(args ...string) (string, error) {
   279  	var stdout bytes.Buffer
   280  	cmd := g.cmd(args...)
   281  	cmd.Stdout = &stdout
   282  	err := cmd.Run()
   283  	return stdout.String(), err
   284  }