github.com/minio/madmin-go/v3@v3.0.51/estream/reader.go (about)

     1  //
     2  // Copyright (c) 2015-2022 MinIO, Inc.
     3  //
     4  // This file is part of MinIO Object Storage stack
     5  //
     6  // This program is free software: you can redistribute it and/or modify
     7  // it under the terms of the GNU Affero General Public License as
     8  // published by the Free Software Foundation, either version 3 of the
     9  // License, or (at your option) any later version.
    10  //
    11  // This program is distributed in the hope that it will be useful,
    12  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    13  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    14  // GNU Affero General Public License for more details.
    15  //
    16  // You should have received a copy of the GNU Affero General Public License
    17  // along with this program. If not, see <http://www.gnu.org/licenses/>.
    18  //
    19  
    20  package estream
    21  
    22  import (
    23  	"bytes"
    24  	"crypto/rand"
    25  	"crypto/rsa"
    26  	"crypto/sha512"
    27  	"crypto/x509"
    28  	"encoding/hex"
    29  	"errors"
    30  	"fmt"
    31  	"io"
    32  	"runtime"
    33  
    34  	"github.com/cespare/xxhash/v2"
    35  	"github.com/secure-io/sio-go"
    36  	"github.com/tinylib/msgp/msgp"
    37  )
    38  
    39  type Reader struct {
    40  	mr            *msgp.Reader
    41  	majorV        uint8
    42  	minorV        uint8
    43  	err           error
    44  	inStream      bool
    45  	key           *[32]byte
    46  	private       *rsa.PrivateKey
    47  	privateFn     func(key *rsa.PublicKey) *rsa.PrivateKey
    48  	skipEncrypted bool
    49  	returnNonDec  bool
    50  }
    51  
    52  // ErrNoKey is returned when a stream cannot be decrypted.
    53  // The Skip function on the stream can be called to skip to the next.
    54  var ErrNoKey = errors.New("no valid private key found")
    55  
    56  // NewReader will return a Reader that will split streams.
    57  func NewReader(r io.Reader) (*Reader, error) {
    58  	var ver [2]byte
    59  	if _, err := io.ReadFull(r, ver[:]); err != nil {
    60  		return nil, err
    61  	}
    62  	switch ver[0] {
    63  	case 2:
    64  	default:
    65  		return nil, fmt.Errorf("unknown stream version: 0x%x", ver[0])
    66  	}
    67  
    68  	return &Reader{mr: msgp.NewReader(r), majorV: ver[0], minorV: ver[1]}, nil
    69  }
    70  
    71  // SetPrivateKey will set the private key to allow stream decryption.
    72  // This overrides any function set by PrivateKeyProvider.
    73  func (r *Reader) SetPrivateKey(k *rsa.PrivateKey) {
    74  	r.privateFn = nil
    75  	r.private = k
    76  }
    77  
    78  // PrivateKeyProvider will ask for a private key matching the public key.
    79  // If the function returns a nil private key the stream key will not be decrypted
    80  // and if SkipEncrypted has been set any streams with this key will be silently skipped.
    81  // This overrides any key set by SetPrivateKey.
    82  func (r *Reader) PrivateKeyProvider(fn func(key *rsa.PublicKey) *rsa.PrivateKey) {
    83  	r.privateFn = fn
    84  	r.private = nil
    85  }
    86  
    87  // SkipEncrypted will skip encrypted streams if no private key has been set.
    88  func (r *Reader) SkipEncrypted(b bool) {
    89  	r.skipEncrypted = b
    90  }
    91  
    92  // ReturnNonDecryptable will return non-decryptable stream headers.
    93  // Streams are returned with ErrNoKey error.
    94  // Streams with this error cannot be read, but the Skip function can be invoked.
    95  // SkipEncrypted overrides this.
    96  func (r *Reader) ReturnNonDecryptable(b bool) {
    97  	r.returnNonDec = b
    98  }
    99  
   100  // Stream returns the next stream.
   101  type Stream struct {
   102  	io.Reader
   103  	Name          string
   104  	Extra         []byte
   105  	SentEncrypted bool
   106  
   107  	parent *Reader
   108  }
   109  
   110  // NextStream will return the next stream.
   111  // Before calling this the previous stream must be read until EOF,
   112  // or Skip() should have been called.
   113  // Will return nil, io.EOF when there are no more streams.
   114  func (r *Reader) NextStream() (*Stream, error) {
   115  	if r.err != nil {
   116  		return nil, r.err
   117  	}
   118  	if r.inStream {
   119  		return nil, errors.New("previous stream not read until EOF")
   120  	}
   121  
   122  	// Temp storage for blocks.
   123  	block := make([]byte, 1024)
   124  	for {
   125  		// Read block ID.
   126  		n, err := r.mr.ReadInt8()
   127  		if err != nil {
   128  			return nil, r.setErr(err)
   129  		}
   130  		id := blockID(n)
   131  
   132  		// Read block size
   133  		sz, err := r.mr.ReadUint32()
   134  		if err != nil {
   135  			return nil, r.setErr(err)
   136  		}
   137  
   138  		// Read block data
   139  		if cap(block) < int(sz) {
   140  			block = make([]byte, sz)
   141  		}
   142  		block = block[:sz]
   143  		_, err = io.ReadFull(r.mr, block)
   144  		if err != nil {
   145  			return nil, r.setErr(err)
   146  		}
   147  
   148  		// Parse block
   149  		switch id {
   150  		case blockPlainKey:
   151  			// Read plaintext key.
   152  			key, _, err := msgp.ReadBytesBytes(block, make([]byte, 0, 32))
   153  			if err != nil {
   154  				return nil, r.setErr(err)
   155  			}
   156  			if len(key) != 32 {
   157  				return nil, r.setErr(fmt.Errorf("unexpected key length: %d", len(key)))
   158  			}
   159  
   160  			// Set key for following streams.
   161  			r.key = (*[32]byte)(key)
   162  		case blockEncryptedKey:
   163  			// Read public key
   164  			publicKey, block, err := msgp.ReadBytesZC(block)
   165  			if err != nil {
   166  				return nil, r.setErr(err)
   167  			}
   168  
   169  			// Request private key if we have a custom function.
   170  			if r.privateFn != nil {
   171  				pk, err := x509.ParsePKCS1PublicKey(publicKey)
   172  				if err != nil {
   173  					return nil, r.setErr(err)
   174  				}
   175  				r.private = r.privateFn(pk)
   176  				if r.private == nil {
   177  					if r.skipEncrypted || r.returnNonDec {
   178  						r.key = nil
   179  						continue
   180  					}
   181  					return nil, r.setErr(errors.New("nil private key returned"))
   182  				}
   183  			}
   184  
   185  			// Read cipher key
   186  			cipherKey, _, err := msgp.ReadBytesZC(block)
   187  			if err != nil {
   188  				return nil, r.setErr(err)
   189  			}
   190  			if r.private == nil {
   191  				if r.skipEncrypted || r.returnNonDec {
   192  					r.key = nil
   193  					continue
   194  				}
   195  				return nil, r.setErr(errors.New("private key has not been set"))
   196  			}
   197  
   198  			// Decrypt stream key
   199  			key, err := rsa.DecryptOAEP(sha512.New(), rand.Reader, r.private, cipherKey, nil)
   200  			if err != nil {
   201  				if r.returnNonDec {
   202  					r.key = nil
   203  					continue
   204  				}
   205  				return nil, err
   206  			}
   207  
   208  			if len(key) != 32 {
   209  				return nil, r.setErr(fmt.Errorf("unexpected key length: %d", len(key)))
   210  			}
   211  			r.key = (*[32]byte)(key)
   212  
   213  		case blockPlainStream, blockEncStream:
   214  			// Read metadata
   215  			name, block, err := msgp.ReadStringBytes(block)
   216  			if err != nil {
   217  				return nil, r.setErr(err)
   218  			}
   219  			extra, block, err := msgp.ReadBytesBytes(block, nil)
   220  			if err != nil {
   221  				return nil, r.setErr(err)
   222  			}
   223  			c, block, err := msgp.ReadUint8Bytes(block)
   224  			if err != nil {
   225  				return nil, r.setErr(err)
   226  			}
   227  			checksum := checksumType(c)
   228  			if !checksum.valid() {
   229  				return nil, r.setErr(fmt.Errorf("unknown checksum type %d", checksum))
   230  			}
   231  
   232  			// Return plaintext stream
   233  			if id == blockPlainStream {
   234  				return &Stream{
   235  					Reader: r.newStreamReader(checksum),
   236  					Name:   name,
   237  					Extra:  extra,
   238  					parent: r,
   239  				}, nil
   240  			}
   241  
   242  			// Handle encrypted streams.
   243  			if r.key == nil {
   244  				if r.skipEncrypted {
   245  					if err := r.skipDataBlocks(); err != nil {
   246  						return nil, r.setErr(err)
   247  					}
   248  					continue
   249  				}
   250  				return &Stream{
   251  					SentEncrypted: true,
   252  					Reader:        nil,
   253  					Name:          name,
   254  					Extra:         extra,
   255  					parent:        r,
   256  				}, ErrNoKey
   257  			}
   258  			// Read stream nonce
   259  			nonce, _, err := msgp.ReadBytesZC(block)
   260  			if err != nil {
   261  				return nil, r.setErr(err)
   262  			}
   263  
   264  			stream, err := sio.AES_256_GCM.Stream(r.key[:])
   265  			if err != nil {
   266  				return nil, r.setErr(err)
   267  			}
   268  
   269  			// Check if nonce is expected length.
   270  			if len(nonce) != stream.NonceSize() {
   271  				return nil, r.setErr(fmt.Errorf("unexpected nonce length: %d", len(nonce)))
   272  			}
   273  
   274  			encr := stream.DecryptReader(r.newStreamReader(checksum), nonce, nil)
   275  			return &Stream{
   276  				SentEncrypted: true,
   277  				Reader:        encr,
   278  				Name:          name,
   279  				Extra:         extra,
   280  				parent:        r,
   281  			}, nil
   282  		case blockEOS:
   283  			return nil, errors.New("end-of-stream without being in stream")
   284  		case blockEOF:
   285  			return nil, io.EOF
   286  		case blockError:
   287  			msg, _, err := msgp.ReadStringBytes(block)
   288  			if err != nil {
   289  				return nil, r.setErr(err)
   290  			}
   291  			return nil, r.setErr(errors.New(msg))
   292  		default:
   293  			if id >= 0 {
   294  				return nil, fmt.Errorf("unknown block type: %d", id)
   295  			}
   296  		}
   297  	}
   298  }
   299  
   300  // skipDataBlocks reads data blocks until end.
   301  func (r *Reader) skipDataBlocks() error {
   302  	for {
   303  		// Read block ID.
   304  		n, err := r.mr.ReadInt8()
   305  		if err != nil {
   306  			return err
   307  		}
   308  		id := blockID(n)
   309  		sz, err := r.mr.ReadUint32()
   310  		if err != nil {
   311  			return err
   312  		}
   313  		if id == blockError {
   314  			msg, err := r.mr.ReadString()
   315  			if err != nil {
   316  				return err
   317  			}
   318  			return errors.New(msg)
   319  		}
   320  		// Discard data
   321  		_, err = io.CopyN(io.Discard, r.mr, int64(sz))
   322  		if err != nil {
   323  			return err
   324  		}
   325  		switch id {
   326  		case blockDatablock:
   327  			// Skip data
   328  		case blockEOS:
   329  			// Done
   330  			r.inStream = false
   331  			return nil
   332  		default:
   333  			if id >= 0 {
   334  				return fmt.Errorf("unknown block type: %d", id)
   335  			}
   336  		}
   337  	}
   338  }
   339  
   340  // setErr sets a stateful error.
   341  func (r *Reader) setErr(err error) error {
   342  	if r.err != nil {
   343  		return r.err
   344  	}
   345  	if err == nil {
   346  		return err
   347  	}
   348  	if errors.Is(err, io.EOF) {
   349  		r.err = io.ErrUnexpectedEOF
   350  	}
   351  	if false {
   352  		_, file, line, ok := runtime.Caller(1)
   353  		if ok {
   354  			err = fmt.Errorf("%s:%d: %w", file, line, err)
   355  		}
   356  	}
   357  	r.err = err
   358  	return err
   359  }
   360  
   361  type streamReader struct {
   362  	up    *Reader
   363  	h     xxhash.Digest
   364  	buf   bytes.Buffer
   365  	tmp   []byte
   366  	isEOF bool
   367  	check checksumType
   368  }
   369  
   370  // newStreamReader creates a stream reader that can be read to get all data blocks.
   371  func (r *Reader) newStreamReader(ct checksumType) *streamReader {
   372  	sr := &streamReader{up: r, check: ct}
   373  	sr.h.Reset()
   374  	r.inStream = true
   375  	return sr
   376  }
   377  
   378  // Skip the remainder of the stream.
   379  func (s *Stream) Skip() error {
   380  	if sr, ok := s.Reader.(*streamReader); ok {
   381  		sr.isEOF = true
   382  		sr.buf.Reset()
   383  	}
   384  	return s.parent.skipDataBlocks()
   385  }
   386  
   387  // Read will return data blocks as on stream.
   388  func (r *streamReader) Read(b []byte) (int, error) {
   389  	if r.isEOF {
   390  		return 0, io.EOF
   391  	}
   392  	if r.up.err != nil {
   393  		return 0, r.up.err
   394  	}
   395  	for {
   396  		// If we have anything in the buffer return that first.
   397  		if r.buf.Len() > 0 {
   398  			n, err := r.buf.Read(b)
   399  			if err == io.EOF {
   400  				err = nil
   401  			}
   402  			return n, r.up.setErr(err)
   403  		}
   404  
   405  		// Read block
   406  		n, err := r.up.mr.ReadInt8()
   407  		if err != nil {
   408  			return 0, r.up.setErr(err)
   409  		}
   410  		id := blockID(n)
   411  
   412  		// Read size...
   413  		sz, err := r.up.mr.ReadUint32()
   414  		if err != nil {
   415  			return 0, r.up.setErr(err)
   416  		}
   417  
   418  		switch id {
   419  		case blockDatablock:
   420  			// Read block
   421  			buf, err := r.up.mr.ReadBytes(r.tmp[:0])
   422  			if err != nil {
   423  				return 0, r.up.setErr(err)
   424  			}
   425  
   426  			// Write to buffer and checksum
   427  			if r.check == checksumTypeXxhash {
   428  				r.h.Write(buf)
   429  			}
   430  			r.tmp = buf
   431  			r.buf.Write(buf)
   432  		case blockEOS:
   433  			// Verify stream checksum if any.
   434  			hash, err := r.up.mr.ReadBytes(nil)
   435  			if err != nil {
   436  				return 0, r.up.setErr(err)
   437  			}
   438  			switch r.check {
   439  			case checksumTypeXxhash:
   440  				got := r.h.Sum(nil)
   441  				if !bytes.Equal(hash, got) {
   442  					return 0, r.up.setErr(fmt.Errorf("checksum mismatch, want %s, got %s", hex.EncodeToString(hash), hex.EncodeToString(got)))
   443  				}
   444  			case checksumTypeNone:
   445  			default:
   446  				return 0, r.up.setErr(fmt.Errorf("unknown checksum id %d", r.check))
   447  			}
   448  			r.isEOF = true
   449  			r.up.inStream = false
   450  			return 0, io.EOF
   451  		case blockError:
   452  			msg, err := r.up.mr.ReadString()
   453  			if err != nil {
   454  				return 0, r.up.setErr(err)
   455  			}
   456  			return 0, r.up.setErr(errors.New(msg))
   457  		default:
   458  			if id >= 0 {
   459  				return 0, fmt.Errorf("unexpected block type: %d", id)
   460  			}
   461  			// Skip block...
   462  			_, err := io.CopyN(io.Discard, r.up.mr, int64(sz))
   463  			if err != nil {
   464  				return 0, r.up.setErr(err)
   465  			}
   466  		}
   467  	}
   468  }