github.com/dlintw/docker@v1.5.0-rc4/pkg/tarsum/tarsum.go (about)

     1  package tarsum
     2  
     3  import (
     4  	"bytes"
     5  	"compress/gzip"
     6  	"crypto"
     7  	"crypto/sha256"
     8  	"encoding/hex"
     9  	"errors"
    10  	"fmt"
    11  	"hash"
    12  	"io"
    13  	"strings"
    14  
    15  	"github.com/docker/docker/vendor/src/code.google.com/p/go/src/pkg/archive/tar"
    16  )
    17  
    18  const (
    19  	buf8K  = 8 * 1024
    20  	buf16K = 16 * 1024
    21  	buf32K = 32 * 1024
    22  )
    23  
    24  // NewTarSum creates a new interface for calculating a fixed time checksum of a
    25  // tar archive.
    26  //
    27  // This is used for calculating checksums of layers of an image, in some cases
    28  // including the byte payload of the image's json metadata as well, and for
    29  // calculating the checksums for buildcache.
    30  func NewTarSum(r io.Reader, dc bool, v Version) (TarSum, error) {
    31  	return NewTarSumHash(r, dc, v, DefaultTHash)
    32  }
    33  
    34  // Create a new TarSum, providing a THash to use rather than the DefaultTHash
    35  func NewTarSumHash(r io.Reader, dc bool, v Version, tHash THash) (TarSum, error) {
    36  	headerSelector, err := getTarHeaderSelector(v)
    37  	if err != nil {
    38  		return nil, err
    39  	}
    40  	ts := &tarSum{Reader: r, DisableCompression: dc, tarSumVersion: v, headerSelector: headerSelector, tHash: tHash}
    41  	err = ts.initTarSum()
    42  	return ts, err
    43  }
    44  
    45  // Create a new TarSum using the provided TarSum version+hash label.
    46  func NewTarSumForLabel(r io.Reader, disableCompression bool, label string) (TarSum, error) {
    47  	parts := strings.SplitN(label, "+", 2)
    48  	if len(parts) != 2 {
    49  		return nil, errors.New("tarsum label string should be of the form: {tarsum_version}+{hash_name}")
    50  	}
    51  
    52  	versionName, hashName := parts[0], parts[1]
    53  
    54  	version, ok := tarSumVersionsByName[versionName]
    55  	if !ok {
    56  		return nil, fmt.Errorf("unknown TarSum version name: %q", versionName)
    57  	}
    58  
    59  	hashConfig, ok := standardHashConfigs[hashName]
    60  	if !ok {
    61  		return nil, fmt.Errorf("unknown TarSum hash name: %q", hashName)
    62  	}
    63  
    64  	tHash := NewTHash(hashConfig.name, hashConfig.hash.New)
    65  
    66  	return NewTarSumHash(r, disableCompression, version, tHash)
    67  }
    68  
    69  // TarSum is the generic interface for calculating fixed time
    70  // checksums of a tar archive
    71  type TarSum interface {
    72  	io.Reader
    73  	GetSums() FileInfoSums
    74  	Sum([]byte) string
    75  	Version() Version
    76  	Hash() THash
    77  }
    78  
    79  // tarSum struct is the structure for a Version0 checksum calculation
    80  type tarSum struct {
    81  	io.Reader
    82  	tarR               *tar.Reader
    83  	tarW               *tar.Writer
    84  	writer             writeCloseFlusher
    85  	bufTar             *bytes.Buffer
    86  	bufWriter          *bytes.Buffer
    87  	bufData            []byte
    88  	h                  hash.Hash
    89  	tHash              THash
    90  	sums               FileInfoSums
    91  	fileCounter        int64
    92  	currentFile        string
    93  	finished           bool
    94  	first              bool
    95  	DisableCompression bool              // false by default. When false, the output gzip compressed.
    96  	tarSumVersion      Version           // this field is not exported so it can not be mutated during use
    97  	headerSelector     tarHeaderSelector // handles selecting and ordering headers for files in the archive
    98  }
    99  
   100  func (ts tarSum) Hash() THash {
   101  	return ts.tHash
   102  }
   103  
   104  func (ts tarSum) Version() Version {
   105  	return ts.tarSumVersion
   106  }
   107  
   108  // A hash.Hash type generator and its name
   109  type THash interface {
   110  	Hash() hash.Hash
   111  	Name() string
   112  }
   113  
   114  // Convenience method for creating a THash
   115  func NewTHash(name string, h func() hash.Hash) THash {
   116  	return simpleTHash{n: name, h: h}
   117  }
   118  
   119  type tHashConfig struct {
   120  	name string
   121  	hash crypto.Hash
   122  }
   123  
   124  var (
   125  	// NOTE: DO NOT include MD5 or SHA1, which are considered insecure.
   126  	standardHashConfigs = map[string]tHashConfig{
   127  		"sha256": {name: "sha256", hash: crypto.SHA256},
   128  		"sha512": {name: "sha512", hash: crypto.SHA512},
   129  	}
   130  )
   131  
   132  // TarSum default is "sha256"
   133  var DefaultTHash = NewTHash("sha256", sha256.New)
   134  
   135  type simpleTHash struct {
   136  	n string
   137  	h func() hash.Hash
   138  }
   139  
   140  func (sth simpleTHash) Name() string    { return sth.n }
   141  func (sth simpleTHash) Hash() hash.Hash { return sth.h() }
   142  
   143  func (ts *tarSum) encodeHeader(h *tar.Header) error {
   144  	for _, elem := range ts.headerSelector.selectHeaders(h) {
   145  		if _, err := ts.h.Write([]byte(elem[0] + elem[1])); err != nil {
   146  			return err
   147  		}
   148  	}
   149  	return nil
   150  }
   151  
   152  func (ts *tarSum) initTarSum() error {
   153  	ts.bufTar = bytes.NewBuffer([]byte{})
   154  	ts.bufWriter = bytes.NewBuffer([]byte{})
   155  	ts.tarR = tar.NewReader(ts.Reader)
   156  	ts.tarW = tar.NewWriter(ts.bufTar)
   157  	if !ts.DisableCompression {
   158  		ts.writer = gzip.NewWriter(ts.bufWriter)
   159  	} else {
   160  		ts.writer = &nopCloseFlusher{Writer: ts.bufWriter}
   161  	}
   162  	if ts.tHash == nil {
   163  		ts.tHash = DefaultTHash
   164  	}
   165  	ts.h = ts.tHash.Hash()
   166  	ts.h.Reset()
   167  	ts.first = true
   168  	ts.sums = FileInfoSums{}
   169  	return nil
   170  }
   171  
   172  func (ts *tarSum) Read(buf []byte) (int, error) {
   173  	if ts.finished {
   174  		return ts.bufWriter.Read(buf)
   175  	}
   176  	if len(ts.bufData) < len(buf) {
   177  		switch {
   178  		case len(buf) <= buf8K:
   179  			ts.bufData = make([]byte, buf8K)
   180  		case len(buf) <= buf16K:
   181  			ts.bufData = make([]byte, buf16K)
   182  		case len(buf) <= buf32K:
   183  			ts.bufData = make([]byte, buf32K)
   184  		default:
   185  			ts.bufData = make([]byte, len(buf))
   186  		}
   187  	}
   188  	buf2 := ts.bufData[:len(buf)]
   189  
   190  	n, err := ts.tarR.Read(buf2)
   191  	if err != nil {
   192  		if err == io.EOF {
   193  			if _, err := ts.h.Write(buf2[:n]); err != nil {
   194  				return 0, err
   195  			}
   196  			if !ts.first {
   197  				ts.sums = append(ts.sums, fileInfoSum{name: ts.currentFile, sum: hex.EncodeToString(ts.h.Sum(nil)), pos: ts.fileCounter})
   198  				ts.fileCounter++
   199  				ts.h.Reset()
   200  			} else {
   201  				ts.first = false
   202  			}
   203  
   204  			currentHeader, err := ts.tarR.Next()
   205  			if err != nil {
   206  				if err == io.EOF {
   207  					if err := ts.tarW.Close(); err != nil {
   208  						return 0, err
   209  					}
   210  					if _, err := io.Copy(ts.writer, ts.bufTar); err != nil {
   211  						return 0, err
   212  					}
   213  					if err := ts.writer.Close(); err != nil {
   214  						return 0, err
   215  					}
   216  					ts.finished = true
   217  					return n, nil
   218  				}
   219  				return n, err
   220  			}
   221  			ts.currentFile = strings.TrimSuffix(strings.TrimPrefix(currentHeader.Name, "./"), "/")
   222  			if err := ts.encodeHeader(currentHeader); err != nil {
   223  				return 0, err
   224  			}
   225  			if err := ts.tarW.WriteHeader(currentHeader); err != nil {
   226  				return 0, err
   227  			}
   228  			if _, err := ts.tarW.Write(buf2[:n]); err != nil {
   229  				return 0, err
   230  			}
   231  			ts.tarW.Flush()
   232  			if _, err := io.Copy(ts.writer, ts.bufTar); err != nil {
   233  				return 0, err
   234  			}
   235  			ts.writer.Flush()
   236  
   237  			return ts.bufWriter.Read(buf)
   238  		}
   239  		return n, err
   240  	}
   241  
   242  	// Filling the hash buffer
   243  	if _, err = ts.h.Write(buf2[:n]); err != nil {
   244  		return 0, err
   245  	}
   246  
   247  	// Filling the tar writter
   248  	if _, err = ts.tarW.Write(buf2[:n]); err != nil {
   249  		return 0, err
   250  	}
   251  	ts.tarW.Flush()
   252  
   253  	// Filling the output writer
   254  	if _, err = io.Copy(ts.writer, ts.bufTar); err != nil {
   255  		return 0, err
   256  	}
   257  	ts.writer.Flush()
   258  
   259  	return ts.bufWriter.Read(buf)
   260  }
   261  
   262  func (ts *tarSum) Sum(extra []byte) string {
   263  	ts.sums.SortBySums()
   264  	h := ts.tHash.Hash()
   265  	if extra != nil {
   266  		h.Write(extra)
   267  	}
   268  	for _, fis := range ts.sums {
   269  		h.Write([]byte(fis.Sum()))
   270  	}
   271  	checksum := ts.Version().String() + "+" + ts.tHash.Name() + ":" + hex.EncodeToString(h.Sum(nil))
   272  	return checksum
   273  }
   274  
   275  func (ts *tarSum) GetSums() FileInfoSums {
   276  	return ts.sums
   277  }