github.com/psexton/git-lfs@v2.1.1-0.20170517224304-289a18b2bc53+incompatible/lfs/extension.go (about)

     1  package lfs
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/sha256"
     6  	"encoding/hex"
     7  	"fmt"
     8  	"hash"
     9  	"io"
    10  	"os"
    11  	"os/exec"
    12  	"strings"
    13  
    14  	"github.com/git-lfs/git-lfs/config"
    15  )
    16  
    17  type pipeRequest struct {
    18  	action     string
    19  	reader     io.Reader
    20  	fileName   string
    21  	extensions []config.Extension
    22  }
    23  
    24  type pipeResponse struct {
    25  	file    *os.File
    26  	results []*pipeExtResult
    27  }
    28  
    29  type pipeExtResult struct {
    30  	name   string
    31  	oidIn  string
    32  	oidOut string
    33  }
    34  
    35  type extCommand struct {
    36  	cmd    *exec.Cmd
    37  	out    io.WriteCloser
    38  	err    *bytes.Buffer
    39  	hasher hash.Hash
    40  	result *pipeExtResult
    41  }
    42  
    43  func pipeExtensions(request *pipeRequest) (response pipeResponse, err error) {
    44  	var extcmds []*extCommand
    45  	defer func() {
    46  		// In the case of an early return before the end of this
    47  		// function (in response to an error, etc), kill all running
    48  		// processes. Errors are ignored since the function has already
    49  		// returned.
    50  		//
    51  		// In the happy path, the commands will have already been
    52  		// `Wait()`-ed upon and e.cmd.Process.Kill() will return an
    53  		// error, but we can ignore it.
    54  		for _, e := range extcmds {
    55  			if e.cmd.Process != nil {
    56  				e.cmd.Process.Kill()
    57  			}
    58  		}
    59  	}()
    60  
    61  	for _, e := range request.extensions {
    62  		var pieces []string
    63  		switch request.action {
    64  		case "clean":
    65  			pieces = strings.Split(e.Clean, " ")
    66  		case "smudge":
    67  			pieces = strings.Split(e.Smudge, " ")
    68  		default:
    69  			err = fmt.Errorf("Invalid action: " + request.action)
    70  			return
    71  		}
    72  		name := strings.Trim(pieces[0], " ")
    73  		var args []string
    74  		for _, value := range pieces[1:] {
    75  			arg := strings.Replace(value, "%f", request.fileName, -1)
    76  			args = append(args, arg)
    77  		}
    78  		cmd := exec.Command(name, args...)
    79  		ec := &extCommand{cmd: cmd, result: &pipeExtResult{name: e.Name}}
    80  		extcmds = append(extcmds, ec)
    81  	}
    82  
    83  	hasher := sha256.New()
    84  	pipeReader, pipeWriter := io.Pipe()
    85  	multiWriter := io.MultiWriter(hasher, pipeWriter)
    86  
    87  	var input io.Reader
    88  	var output io.WriteCloser
    89  	input = pipeReader
    90  	extcmds[0].cmd.Stdin = input
    91  	if response.file, err = TempFile(""); err != nil {
    92  		return
    93  	}
    94  	defer response.file.Close()
    95  	output = response.file
    96  
    97  	last := len(extcmds) - 1
    98  	for i, ec := range extcmds {
    99  		ec.hasher = sha256.New()
   100  
   101  		if i == last {
   102  			ec.cmd.Stdout = io.MultiWriter(ec.hasher, output)
   103  			ec.out = output
   104  			continue
   105  		}
   106  
   107  		nextec := extcmds[i+1]
   108  		var nextStdin io.WriteCloser
   109  		var stdout io.ReadCloser
   110  		if nextStdin, err = nextec.cmd.StdinPipe(); err != nil {
   111  			return
   112  		}
   113  		if stdout, err = ec.cmd.StdoutPipe(); err != nil {
   114  			return
   115  		}
   116  
   117  		ec.cmd.Stdin = input
   118  		ec.cmd.Stdout = io.MultiWriter(ec.hasher, nextStdin)
   119  		ec.out = nextStdin
   120  
   121  		input = stdout
   122  
   123  		var errBuff bytes.Buffer
   124  		ec.err = &errBuff
   125  		ec.cmd.Stderr = ec.err
   126  	}
   127  
   128  	for _, ec := range extcmds {
   129  		if err = ec.cmd.Start(); err != nil {
   130  			return
   131  		}
   132  	}
   133  
   134  	if _, err = io.Copy(multiWriter, request.reader); err != nil {
   135  		return
   136  	}
   137  	if err = pipeWriter.Close(); err != nil {
   138  		return
   139  	}
   140  
   141  	for _, ec := range extcmds {
   142  		if err = ec.cmd.Wait(); err != nil {
   143  			if ec.err != nil {
   144  				errStr := ec.err.String()
   145  				err = fmt.Errorf("Extension '%s' failed with: %s", ec.result.name, errStr)
   146  			}
   147  			return
   148  		}
   149  		if err = ec.out.Close(); err != nil {
   150  			return
   151  		}
   152  	}
   153  
   154  	oid := hex.EncodeToString(hasher.Sum(nil))
   155  	for _, ec := range extcmds {
   156  		ec.result.oidIn = oid
   157  		oid = hex.EncodeToString(ec.hasher.Sum(nil))
   158  		ec.result.oidOut = oid
   159  		response.results = append(response.results, ec.result)
   160  	}
   161  	return
   162  }