github.com/portworx/docker@v1.12.1/pkg/stdcopy/stdcopy.go (about)

     1  package stdcopy
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/binary"
     6  	"errors"
     7  	"fmt"
     8  	"io"
     9  	"sync"
    10  
    11  	"github.com/Sirupsen/logrus"
    12  )
    13  
    14  // StdType is the type of standard stream
    15  // a writer can multiplex to.
    16  type StdType byte
    17  
    18  const (
    19  	// Stdin represents standard input stream type.
    20  	Stdin StdType = iota
    21  	// Stdout represents standard output stream type.
    22  	Stdout
    23  	// Stderr represents standard error steam type.
    24  	Stderr
    25  
    26  	stdWriterPrefixLen = 8
    27  	stdWriterFdIndex   = 0
    28  	stdWriterSizeIndex = 4
    29  
    30  	startingBufLen = 32*1024 + stdWriterPrefixLen + 1
    31  )
    32  
    33  var bufPool = &sync.Pool{New: func() interface{} { return bytes.NewBuffer(nil) }}
    34  
    35  // stdWriter is wrapper of io.Writer with extra customized info.
    36  type stdWriter struct {
    37  	io.Writer
    38  	prefix byte
    39  }
    40  
    41  // Write sends the buffer to the underneath writer.
    42  // It inserts the prefix header before the buffer,
    43  // so stdcopy.StdCopy knows where to multiplex the output.
    44  // It makes stdWriter to implement io.Writer.
    45  func (w *stdWriter) Write(p []byte) (n int, err error) {
    46  	if w == nil || w.Writer == nil {
    47  		return 0, errors.New("Writer not instantiated")
    48  	}
    49  	if p == nil {
    50  		return 0, nil
    51  	}
    52  
    53  	header := [stdWriterPrefixLen]byte{stdWriterFdIndex: w.prefix}
    54  	binary.BigEndian.PutUint32(header[stdWriterSizeIndex:], uint32(len(p)))
    55  	buf := bufPool.Get().(*bytes.Buffer)
    56  	buf.Write(header[:])
    57  	buf.Write(p)
    58  
    59  	n, err = w.Writer.Write(buf.Bytes())
    60  	n -= stdWriterPrefixLen
    61  	if n < 0 {
    62  		n = 0
    63  	}
    64  
    65  	buf.Reset()
    66  	bufPool.Put(buf)
    67  	return
    68  }
    69  
    70  // NewStdWriter instantiates a new Writer.
    71  // Everything written to it will be encapsulated using a custom format,
    72  // and written to the underlying `w` stream.
    73  // This allows multiple write streams (e.g. stdout and stderr) to be muxed into a single connection.
    74  // `t` indicates the id of the stream to encapsulate.
    75  // It can be stdcopy.Stdin, stdcopy.Stdout, stdcopy.Stderr.
    76  func NewStdWriter(w io.Writer, t StdType) io.Writer {
    77  	return &stdWriter{
    78  		Writer: w,
    79  		prefix: byte(t),
    80  	}
    81  }
    82  
    83  // StdCopy is a modified version of io.Copy.
    84  //
    85  // StdCopy will demultiplex `src`, assuming that it contains two streams,
    86  // previously multiplexed together using a StdWriter instance.
    87  // As it reads from `src`, StdCopy will write to `dstout` and `dsterr`.
    88  //
    89  // StdCopy will read until it hits EOF on `src`. It will then return a nil error.
    90  // In other words: if `err` is non nil, it indicates a real underlying error.
    91  //
    92  // `written` will hold the total number of bytes written to `dstout` and `dsterr`.
    93  func StdCopy(dstout, dsterr io.Writer, src io.Reader) (written int64, err error) {
    94  	var (
    95  		buf       = make([]byte, startingBufLen)
    96  		bufLen    = len(buf)
    97  		nr, nw    int
    98  		er, ew    error
    99  		out       io.Writer
   100  		frameSize int
   101  	)
   102  
   103  	for {
   104  		// Make sure we have at least a full header
   105  		for nr < stdWriterPrefixLen {
   106  			var nr2 int
   107  			nr2, er = src.Read(buf[nr:])
   108  			nr += nr2
   109  			if er == io.EOF {
   110  				if nr < stdWriterPrefixLen {
   111  					logrus.Debugf("Corrupted prefix: %v", buf[:nr])
   112  					return written, nil
   113  				}
   114  				break
   115  			}
   116  			if er != nil {
   117  				logrus.Debugf("Error reading header: %s", er)
   118  				return 0, er
   119  			}
   120  		}
   121  
   122  		// Check the first byte to know where to write
   123  		switch StdType(buf[stdWriterFdIndex]) {
   124  		case Stdin:
   125  			fallthrough
   126  		case Stdout:
   127  			// Write on stdout
   128  			out = dstout
   129  		case Stderr:
   130  			// Write on stderr
   131  			out = dsterr
   132  		default:
   133  			logrus.Debugf("Error selecting output fd: (%d)", buf[stdWriterFdIndex])
   134  			return 0, fmt.Errorf("Unrecognized input header: %d", buf[stdWriterFdIndex])
   135  		}
   136  
   137  		// Retrieve the size of the frame
   138  		frameSize = int(binary.BigEndian.Uint32(buf[stdWriterSizeIndex : stdWriterSizeIndex+4]))
   139  		logrus.Debugf("framesize: %d", frameSize)
   140  
   141  		// Check if the buffer is big enough to read the frame.
   142  		// Extend it if necessary.
   143  		if frameSize+stdWriterPrefixLen > bufLen {
   144  			logrus.Debugf("Extending buffer cap by %d (was %d)", frameSize+stdWriterPrefixLen-bufLen+1, len(buf))
   145  			buf = append(buf, make([]byte, frameSize+stdWriterPrefixLen-bufLen+1)...)
   146  			bufLen = len(buf)
   147  		}
   148  
   149  		// While the amount of bytes read is less than the size of the frame + header, we keep reading
   150  		for nr < frameSize+stdWriterPrefixLen {
   151  			var nr2 int
   152  			nr2, er = src.Read(buf[nr:])
   153  			nr += nr2
   154  			if er == io.EOF {
   155  				if nr < frameSize+stdWriterPrefixLen {
   156  					logrus.Debugf("Corrupted frame: %v", buf[stdWriterPrefixLen:nr])
   157  					return written, nil
   158  				}
   159  				break
   160  			}
   161  			if er != nil {
   162  				logrus.Debugf("Error reading frame: %s", er)
   163  				return 0, er
   164  			}
   165  		}
   166  
   167  		// Write the retrieved frame (without header)
   168  		nw, ew = out.Write(buf[stdWriterPrefixLen : frameSize+stdWriterPrefixLen])
   169  		if ew != nil {
   170  			logrus.Debugf("Error writing frame: %s", ew)
   171  			return 0, ew
   172  		}
   173  		// If the frame has not been fully written: error
   174  		if nw != frameSize {
   175  			logrus.Debugf("Error Short Write: (%d on %d)", nw, frameSize)
   176  			return 0, io.ErrShortWrite
   177  		}
   178  		written += int64(nw)
   179  
   180  		// Move the rest of the buffer to the beginning
   181  		copy(buf, buf[frameSize+stdWriterPrefixLen:])
   182  		// Move the index
   183  		nr -= frameSize + stdWriterPrefixLen
   184  	}
   185  }