github.com/grailbio/base@v0.0.11/digest/digest.go (about)

     1  // Copyright 2017 GRAIL, Inc. All rights reserved.
     2  // Use of this source code is governed by the Apache 2.0
     3  // license that can be found in the LICENSE file.
     4  
     5  // Package digest provides a generalized representation for digests
     6  // computed with cryptographic hash functions. It provides an efficient
     7  // in-memory representation as well as serialization.
     8  package digest
     9  
    10  import (
    11  	"bufio"
    12  	"bytes"
    13  	"crypto"
    14  	"crypto/rand"
    15  	"encoding/binary"
    16  	"encoding/gob"
    17  	"encoding/hex"
    18  	"encoding/json"
    19  	"errors"
    20  	"fmt"
    21  	"hash"
    22  	"io"
    23  	mathrand "math/rand"
    24  	"strings"
    25  )
    26  
    27  const maxSize = 64 // To support SHA-512
    28  
    29  // Define digestHash constants to be used during (de)serialization of Digests.
    30  // crypto.Hash values are not guaranteed to be stable over releases.
    31  // The order of digestHashes should never be changed to maintain compatibility
    32  // over time and new values should always be appended. The initial set has been
    33  // ordered to match the order in crypto.Hash at this change to maintain backward
    34  // compatibility.
    35  type digestHash uint
    36  
    37  const (
    38  	MD4         digestHash = 1 + iota // crypto.MD4
    39  	MD5                               // crypto.MD5
    40  	SHA1                              // crypto.SHA1
    41  	SHA224                            // crypto.SHA224
    42  	SHA256                            // crypto.SHA256
    43  	SHA384                            // crypto.SHA384
    44  	SHA512                            // crypto.SHA512
    45  	MD5SHA1                           // crypto.MD5SHA1
    46  	RIPEMD160                         // crypto.RIPEMD160
    47  	SHA3_224                          // crypto.SHA3_224
    48  	SHA3_256                          // crypto.SHA3_256
    49  	SHA3_384                          // crypto.SHA3_384
    50  	SHA3_512                          // crypto.SHA3_512
    51  	SHA512_224                        // crypto.SHA512_224
    52  	SHA512_256                        // crypto.SHA512_256
    53  	BLAKE2s_256                       // crypto.BLAKE2s_256
    54  	BLAKE2b_256                       // crypto.BLAKE2b_256
    55  	BLAKE2b_384                       // crypto.BLAKE2b_384
    56  	BLAKE2b_512                       // crypto.BLAKE2b_512
    57  
    58  	zeroString = "<zero>"
    59  )
    60  
    61  var (
    62  	digestToCryptoHashes = map[digestHash]crypto.Hash{
    63  		MD4:         crypto.MD4,
    64  		MD5:         crypto.MD5,
    65  		SHA1:        crypto.SHA1,
    66  		SHA224:      crypto.SHA224,
    67  		SHA256:      crypto.SHA256,
    68  		SHA384:      crypto.SHA384,
    69  		SHA512:      crypto.SHA512,
    70  		MD5SHA1:     crypto.MD5SHA1,
    71  		RIPEMD160:   crypto.RIPEMD160,
    72  		SHA3_224:    crypto.SHA3_224,
    73  		SHA3_256:    crypto.SHA3_256,
    74  		SHA3_384:    crypto.SHA3_384,
    75  		SHA3_512:    crypto.SHA3_512,
    76  		SHA512_224:  crypto.SHA512_224,
    77  		SHA512_256:  crypto.SHA512_256,
    78  		BLAKE2s_256: crypto.BLAKE2s_256,
    79  		BLAKE2b_256: crypto.BLAKE2b_256,
    80  		BLAKE2b_384: crypto.BLAKE2b_384,
    81  		BLAKE2b_512: crypto.BLAKE2b_512,
    82  	}
    83  	cryptoToDigestHashes = map[crypto.Hash]digestHash{} // populated by init()
    84  )
    85  
    86  var (
    87  	shortSuffix = [maxSize - 4]byte{}
    88  	zeros       = [maxSize]byte{}
    89  )
    90  
    91  var (
    92  	name = map[crypto.Hash]string{
    93  		crypto.MD4:        "md4",
    94  		crypto.MD5:        "md5",
    95  		crypto.SHA1:       "sha1",
    96  		crypto.SHA224:     "sha224",
    97  		crypto.SHA256:     "sha256",
    98  		crypto.SHA384:     "sha384",
    99  		crypto.SHA512:     "sha512",
   100  		crypto.SHA512_224: "sha512_224",
   101  		crypto.SHA512_256: "sha512_256",
   102  		crypto.SHA3_224:   "sha3_224",
   103  		crypto.SHA3_256:   "sha3_256",
   104  		crypto.SHA3_384:   "sha3_384",
   105  		crypto.SHA3_512:   "sha3_512",
   106  		crypto.MD5SHA1:    "md5sha1",
   107  		crypto.RIPEMD160:  "ripemd160",
   108  	}
   109  	hashes = map[string]crypto.Hash{} // populated by init()
   110  )
   111  
   112  var (
   113  	// An attempt was made to parse an invalid digest
   114  	ErrInvalidDigest = errors.New("invalid digest")
   115  	// A Digest's hash function was not imported.
   116  	ErrHashUnavailable = errors.New("the requested hash function is not available")
   117  	// The Digest's hash did not match the hash of the Digester.
   118  	ErrWrongHash = errors.New("wrong hash")
   119  	// An EOF was encountered while attempting to read a Digest.
   120  	ErrShortRead = errors.New("short read")
   121  )
   122  
   123  func init() {
   124  	for h, name := range name {
   125  		hashes[name] = h
   126  	}
   127  	for dh, ch := range digestToCryptoHashes {
   128  		cryptoToDigestHashes[ch] = dh
   129  	}
   130  }
   131  
   132  // Digest represents a digest computed with a cryptographic hash
   133  // function. It uses a fixed-size representation and is directly
   134  // comparable.
   135  type Digest struct {
   136  	h crypto.Hash
   137  	b [maxSize]byte
   138  }
   139  
   140  var _ gob.GobEncoder = Digest{}
   141  var _ gob.GobDecoder = (*Digest)(nil)
   142  
   143  // GobEncode implements Gob encoding for digests.
   144  func (d Digest) GobEncode() ([]byte, error) {
   145  	b := make([]byte, binary.MaxVarintLen64+d.h.Size())
   146  	n := binary.PutUvarint(b, uint64(d.h))
   147  	copy(b[n:], d.b[:d.h.Size()])
   148  	return b[:n+d.h.Size()], nil
   149  }
   150  
   151  // GobDecode implements Gob decoding for digests.
   152  func (d *Digest) GobDecode(p []byte) error {
   153  	h, n := binary.Uvarint(p)
   154  	if n == 0 {
   155  		return errors.New("short buffer")
   156  	}
   157  	if n < 0 {
   158  		return errors.New("invalid hash")
   159  	}
   160  	d.h = crypto.Hash(h)
   161  	if len(p)-n != d.h.Size() {
   162  		return errors.New("invalid digest")
   163  	}
   164  	copy(d.b[:], p[n:])
   165  	return nil
   166  }
   167  
   168  // Parse parses a string representation of Digest, as defined by
   169  // Digest.String().
   170  func Parse(s string) (Digest, error) {
   171  	if s == "" || s == zeroString {
   172  		return Digest{}, nil
   173  	}
   174  	parts := strings.Split(s, ":")
   175  	if len(parts) != 2 {
   176  		return Digest{}, ErrInvalidDigest
   177  	}
   178  	name, hex := parts[0], parts[1]
   179  	h, ok := hashes[name]
   180  	if !ok {
   181  		return Digest{}, ErrInvalidDigest
   182  	}
   183  	return ParseHash(h, hex)
   184  }
   185  
   186  // ParseHash parses hex string hx produced by the the hash h into a
   187  // Digest.
   188  func ParseHash(h crypto.Hash, hx string) (Digest, error) {
   189  	if !h.Available() {
   190  		return Digest{}, ErrHashUnavailable
   191  	}
   192  	b, err := hex.DecodeString(hx)
   193  	if err != nil {
   194  		return Digest{}, err
   195  	}
   196  	d := Digest{h: h}
   197  	copy(d.b[:], b)
   198  	if !d.valid() {
   199  		return Digest{}, ErrInvalidDigest
   200  	}
   201  	return d, nil
   202  }
   203  
   204  // New returns a new literal digest with the provided hash and
   205  // value.
   206  func New(h crypto.Hash, b []byte) Digest {
   207  	d := Digest{h: h}
   208  	copy(d.b[:], b)
   209  	return d
   210  }
   211  
   212  // IsZero returns whether the digest is the zero digest.
   213  func (d Digest) IsZero() bool { return d.h == 0 }
   214  
   215  // Hash returns the cryptographic hash used to produce this Digest.
   216  func (d Digest) Hash() crypto.Hash { return d.h }
   217  
   218  // Hex returns the padded hexadecimal representation of the Digest.
   219  func (d Digest) Hex() string {
   220  	n := d.h.Size()
   221  	return fmt.Sprintf("%0*x", 2*n, d.b[:n])
   222  }
   223  
   224  // HexN returns the padded hexadecimal representation of the digest's
   225  // first n bytes. N must be smaller or equal to the digest's size, or
   226  // else it panics.
   227  func (d Digest) HexN(n int) string {
   228  	if d.h.Size() < n {
   229  		panic("n is too large")
   230  	}
   231  	return fmt.Sprintf("%0*x", 2*n, d.b[:n])
   232  }
   233  
   234  // Short returns a short (prefix) version of the Digest's hexadecimal
   235  // representation.
   236  func (d Digest) Short() string {
   237  	return d.Hex()[0:8]
   238  }
   239  
   240  // Name returns the name of the digest's hash.
   241  func (d Digest) Name() string {
   242  	return name[d.h]
   243  }
   244  
   245  // Less defines an order of digests of the same hash. panics if two
   246  // Less digests with different hashes are compared.
   247  func (d Digest) Less(e Digest) bool {
   248  	if d.h != e.h {
   249  		panic("incompatible hashes")
   250  	}
   251  	return bytes.Compare(d.b[:], e.b[:]) < 0
   252  }
   253  
   254  // Bytes returns the byte representation for this digest.
   255  func (d Digest) Bytes() []byte {
   256  	var b bytes.Buffer
   257  	if _, err := WriteDigest(&b, d); err != nil {
   258  		panic("failed to write file digest " + d.String() + ": " + err.Error())
   259  	}
   260  	return b.Bytes()
   261  }
   262  
   263  // Mix mixes digests d and e with XOR.
   264  func (d *Digest) Mix(e Digest) {
   265  	if d.h == 0 {
   266  		*d = e
   267  		return
   268  	}
   269  	if d.h != e.h {
   270  		panic("mismatched hashes")
   271  	}
   272  	for i := range d.b {
   273  		d.b[i] ^= e.b[i]
   274  	}
   275  }
   276  
   277  // Truncate truncates digest d to n bytes. Truncate
   278  // panics if n is greater than the digest's hash size.
   279  func (d *Digest) Truncate(n int) {
   280  	if d.h.Size() < n {
   281  		panic("n is too large")
   282  	}
   283  	copy(d.b[n:], zeros[:])
   284  }
   285  
   286  // IsShort tells whether d is a "short" digest, comprising
   287  // only the initial 4 bytes.
   288  func (d Digest) IsShort() bool {
   289  	return bytes.HasSuffix(d.b[:], shortSuffix[:])
   290  }
   291  
   292  // IsAbbrev tells whether d is an "abbreviated" digest, comprising
   293  // no more than half of the digest bytes.
   294  func (d Digest) IsAbbrev() bool {
   295  	return bytes.HasSuffix(d.b[:], zeros[d.h.Size()/2:])
   296  }
   297  
   298  // NPrefix returns the number of nonzero leading bytes in the
   299  // digest, after which the remaining bytes are zero.
   300  func (d Digest) NPrefix() int {
   301  	for i := d.h.Size() - 1; i >= 0; i-- {
   302  		if d.b[i] != 0 {
   303  			return i + 1
   304  		}
   305  	}
   306  	return 0
   307  }
   308  
   309  // Expands tells whether digest d expands digest e.
   310  func (d Digest) Expands(e Digest) bool {
   311  	n := e.NPrefix()
   312  	return bytes.HasPrefix(d.b[:], e.b[:n])
   313  }
   314  
   315  // String returns the full string representation of the digest: the digest
   316  // name, followed by ":", followed by its hexadecimal value.
   317  func (d Digest) String() string {
   318  	if d.IsZero() {
   319  		return zeroString
   320  	}
   321  	return fmt.Sprintf("%s:%s", name[d.h], d.Hex())
   322  }
   323  
   324  // ShortString returns a short representation of the digest, comprising
   325  // the digest name and its first n bytes.
   326  func (d Digest) ShortString(n int) string {
   327  	if d.IsZero() {
   328  		return zeroString
   329  	}
   330  	return fmt.Sprintf("%s:%s", name[d.h], d.HexN(n))
   331  }
   332  
   333  func (d Digest) valid() bool {
   334  	return d.h.Available() && len(d.b) >= d.h.Size()
   335  }
   336  
   337  // MarshalJSON marshals the Digest into JSON format.
   338  func (d Digest) MarshalJSON() ([]byte, error) {
   339  	return json.Marshal(d.String())
   340  }
   341  
   342  // UnmarshalJSON unmarshals a digest from JSON data.
   343  func (d *Digest) UnmarshalJSON(b []byte) error {
   344  	var s string
   345  	if err := json.Unmarshal(b, &s); err != nil {
   346  		return err
   347  	}
   348  	var err error
   349  	*d, err = Parse(s)
   350  	return err
   351  }
   352  
   353  // Digester computes digests based on a cryptographic hash function.
   354  type Digester crypto.Hash
   355  
   356  // New returns a new digest with the provided literal contents. New
   357  // panics if the digest size does not match the hash function's length.
   358  func (d Digester) New(b []byte) Digest {
   359  	if crypto.Hash(d).Size() != len(b) {
   360  		panic("digest: bad digest length")
   361  	}
   362  	return New(crypto.Hash(d), b)
   363  }
   364  
   365  // Parse parses a string into a Digest with the cryptographic hash of
   366  // Digester. The input string is in the form of Digest.String, except
   367  // that the hash name may be omitted--it is then instead assumed to
   368  // be the hash function associated with the Digester.
   369  func (d Digester) Parse(s string) (Digest, error) {
   370  	if s == "" || s == zeroString {
   371  		return Digest{h: crypto.Hash(d)}, nil
   372  	}
   373  	parts := strings.Split(s, ":")
   374  	switch len(parts) {
   375  	default:
   376  		return Digest{}, ErrInvalidDigest
   377  	case 1:
   378  		return ParseHash(crypto.Hash(d), s)
   379  	case 2:
   380  		dgst, err := Parse(s)
   381  		if err != nil {
   382  			return Digest{}, err
   383  		}
   384  		if dgst.h != crypto.Hash(d) {
   385  			return Digest{}, ErrWrongHash
   386  		}
   387  		return dgst, nil
   388  	}
   389  }
   390  
   391  // FromBytes computes a Digest from a slice of bytes.
   392  func (d Digester) FromBytes(p []byte) Digest {
   393  	w := crypto.Hash(d).New()
   394  	if _, err := w.Write(p); err != nil {
   395  		panic("hash returned error " + err.Error())
   396  	}
   397  	return New(crypto.Hash(d), w.Sum(nil))
   398  }
   399  
   400  // FromString computes a Digest from a string.
   401  func (d Digester) FromString(s string) Digest {
   402  	return d.FromBytes([]byte(s))
   403  }
   404  
   405  // FromDigests computes a Digest over other Digests.
   406  func (d Digester) FromDigests(digests ...Digest) Digest {
   407  	w := crypto.Hash(d).New()
   408  	for _, d := range digests {
   409  		// TODO(saito,pknudsgaaard,schandra)
   410  		//
   411  		// grail.com/pipeline/release/internal/reference passes an empty Digest and
   412  		// fails here. We need to be more principled about the values passed here,
   413  		// so we intentionally drop errors here.
   414  		WriteDigest(w, d)
   415  	}
   416  	return New(crypto.Hash(d), w.Sum(nil))
   417  }
   418  
   419  // Rand returns a random digest generated by the random
   420  // provided generator. If no generator is provided (r is nil),
   421  // Rand uses the system's cryptographically secure random
   422  // number generator.
   423  func (d Digester) Rand(r *mathrand.Rand) Digest {
   424  	dg := Digest{h: crypto.Hash(d)}
   425  	var (
   426  		err error
   427  		p   = dg.b[:dg.h.Size()]
   428  	)
   429  	if r != nil {
   430  		_, err = r.Read(p)
   431  	} else {
   432  		_, err = rand.Read(p)
   433  	}
   434  	if err != nil {
   435  		panic(err)
   436  	}
   437  	return dg
   438  }
   439  
   440  // NewWriter returns a Writer that can be used to compute Digests of long inputs.
   441  func (d Digester) NewWriter() Writer {
   442  	hw := crypto.Hash(d).New()
   443  	return Writer{h: crypto.Hash(d), hw: hw, w: bufio.NewWriter(hw)}
   444  }
   445  
   446  const digesterBufferSize = 256
   447  
   448  // NewWriterShort returns a Writer that can be used to compute Digests of short inputs (ie, order of KBs)
   449  func (d Digester) NewWriterShort() Writer {
   450  	hw := crypto.Hash(d).New()
   451  	return Writer{h: crypto.Hash(d), hw: hw, w: bufio.NewWriterSize(hw, digesterBufferSize)}
   452  }
   453  
   454  // Writer provides an io.Writer to which digested bytes are
   455  // written and from which a Digest is produced.
   456  type Writer struct {
   457  	h crypto.Hash
   458  	hw hash.Hash
   459  	w *bufio.Writer
   460  }
   461  
   462  func (d Writer) Write(p []byte) (n int, err error) {
   463  	return d.w.Write(p)
   464  }
   465  
   466  func (d Writer) WriteString(s string) (n int, err error) {
   467  	return d.w.WriteString(s)
   468  }
   469  
   470  // Digest produces the current Digest of the Writer.
   471  // It does not reset its internal state.
   472  func (d Writer) Digest() Digest {
   473  	if err := d.w.Flush(); err != nil {
   474  		panic(fmt.Sprintf("digest.Digest.Flush: %v", err))
   475  	}
   476  	return New(d.h, d.hw.Sum(nil))
   477  }
   478  
   479  // WriteDigest is a convenience function to write a (binary)
   480  // Digest to an io.Writer. Its format is two bytes representing
   481  // the hash function, followed by the hash value itself.
   482  //
   483  // Writing a zero digest is disallowed; WriteDigest panics in
   484  // this case.
   485  func WriteDigest(w io.Writer, d Digest) (n int, err error) {
   486  	if d.IsZero() {
   487  		panic("digest.WriteDigest: attempted to write a zero digest")
   488  	}
   489  	digestHash, ok := cryptoToDigestHashes[d.h]
   490  	if !ok {
   491  		return n, fmt.Errorf("cannot convert %v to a digestHash", d.h)
   492  	}
   493  	b := [2]byte{byte(digestHash >> 8), byte(digestHash & 0xff)}
   494  	n, err = w.Write(b[:])
   495  	if err != nil {
   496  		return n, err
   497  	}
   498  	m, err := w.Write(d.b[:d.h.Size()])
   499  	return n + m, err
   500  }
   501  
   502  // ReadDigest is a convenience function to read a (binary)
   503  // Digest from an io.Reader, as written by WriteDigest.
   504  func ReadDigest(r io.Reader) (Digest, error) {
   505  	var d Digest
   506  	n, err := r.Read(d.b[0:2])
   507  	if err != nil {
   508  		return Digest{}, err
   509  	}
   510  	if n < 2 {
   511  		return Digest{}, ErrShortRead
   512  	}
   513  	d.h = digestToCryptoHashes[digestHash(d.b[0])<<8|digestHash(d.b[1])]
   514  	if !d.h.Available() {
   515  		return Digest{}, ErrHashUnavailable
   516  	}
   517  	n, err = r.Read(d.b[0:d.h.Size()])
   518  	if err != nil {
   519  		return Digest{}, err
   520  	}
   521  	if n < d.h.Size() {
   522  		return Digest{}, ErrShortRead
   523  	}
   524  	return d, nil
   525  }
   526  
   527  // MarshalJSON generates a JSON format byte slice from a Digester.
   528  func (d Digester) MarshalJSON() ([]byte, error) {
   529  	txt, ok := name[crypto.Hash(d)]
   530  
   531  	if !ok {
   532  		return nil, fmt.Errorf("Cannot convert %v to string", d)
   533  	}
   534  
   535  	return []byte(fmt.Sprintf(`"%s"`, txt)), nil
   536  }
   537  
   538  // UnmarshalJSON converts from a JSON format byte slice to a Digester.
   539  func (d *Digester) UnmarshalJSON(b []byte) error {
   540  	str := string(b)
   541  
   542  	val, ok := hashes[strings.Trim(str, `"`)]
   543  
   544  	if !ok {
   545  		return fmt.Errorf("Cannot convert %s to digest.Digester", string(b))
   546  	}
   547  
   548  	*d = Digester(val)
   549  	return nil
   550  }