github.com/minio/madmin-go@v1.7.5/estream/reader.go (about)

     1  //
     2  // MinIO Object Storage (c) 2022 MinIO, Inc.
     3  //
     4  // Licensed under the Apache License, Version 2.0 (the "License");
     5  // you may not use this file except in compliance with the License.
     6  // You may obtain a copy of the License at
     7  //
     8  //      http://www.apache.org/licenses/LICENSE-2.0
     9  //
    10  // Unless required by applicable law or agreed to in writing, software
    11  // distributed under the License is distributed on an "AS IS" BASIS,
    12  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  // See the License for the specific language governing permissions and
    14  // limitations under the License.
    15  //
    16  
    17  package estream
    18  
    19  import (
    20  	"bytes"
    21  	"crypto/rand"
    22  	"crypto/rsa"
    23  	"crypto/sha512"
    24  	"crypto/x509"
    25  	"encoding/hex"
    26  	"errors"
    27  	"fmt"
    28  	"io"
    29  	"runtime"
    30  
    31  	"github.com/cespare/xxhash/v2"
    32  	"github.com/secure-io/sio-go"
    33  	"github.com/tinylib/msgp/msgp"
    34  )
    35  
    36  type Reader struct {
    37  	mr            *msgp.Reader
    38  	majorV        uint8
    39  	minorV        uint8
    40  	err           error
    41  	inStream      bool
    42  	key           *[32]byte
    43  	private       *rsa.PrivateKey
    44  	privateFn     func(key *rsa.PublicKey) *rsa.PrivateKey
    45  	skipEncrypted bool
    46  	returnNonDec  bool
    47  }
    48  
    49  // ErrNoKey is returned when a stream cannot be decrypted.
    50  // The Skip function on the stream can be called to skip to the next.
    51  var ErrNoKey = errors.New("no valid private key found")
    52  
    53  // NewReader will return a Reader that will split streams.
    54  func NewReader(r io.Reader) (*Reader, error) {
    55  	var ver [2]byte
    56  	if _, err := io.ReadFull(r, ver[:]); err != nil {
    57  		return nil, err
    58  	}
    59  	switch ver[0] {
    60  	case 2:
    61  	default:
    62  		return nil, fmt.Errorf("unknown stream version: 0x%x", ver[0])
    63  	}
    64  
    65  	return &Reader{mr: msgp.NewReader(r), majorV: ver[0], minorV: ver[1]}, nil
    66  }
    67  
    68  // SetPrivateKey will set the private key to allow stream decryption.
    69  // This overrides any function set by PrivateKeyProvider.
    70  func (r *Reader) SetPrivateKey(k *rsa.PrivateKey) {
    71  	r.privateFn = nil
    72  	r.private = k
    73  }
    74  
    75  // PrivateKeyProvider will ask for a private key matching the public key.
    76  // If the function returns a nil private key the stream key will not be decrypted
    77  // and if SkipEncrypted has been set any streams with this key will be silently skipped.
    78  // This overrides any key set by SetPrivateKey.
    79  func (r *Reader) PrivateKeyProvider(fn func(key *rsa.PublicKey) *rsa.PrivateKey) {
    80  	r.privateFn = fn
    81  	r.private = nil
    82  }
    83  
    84  // SkipEncrypted will skip encrypted streams if no private key has been set.
    85  func (r *Reader) SkipEncrypted(b bool) {
    86  	r.skipEncrypted = b
    87  }
    88  
    89  // ReturnNonDecryptable will return non-decryptable stream headers.
    90  // Streams are returned with ErrNoKey error.
    91  // Streams with this error cannot be read, but the Skip function can be invoked.
    92  // SkipEncrypted overrides this.
    93  func (r *Reader) ReturnNonDecryptable(b bool) {
    94  	r.returnNonDec = b
    95  }
    96  
    97  // Stream returns the next stream.
    98  type Stream struct {
    99  	io.Reader
   100  	Name          string
   101  	Extra         []byte
   102  	SentEncrypted bool
   103  
   104  	parent *Reader
   105  }
   106  
   107  // NextStream will return the next stream.
   108  // Before calling this the previous stream must be read until EOF,
   109  // or Skip() should have been called.
   110  // Will return nil, io.EOF when there are no more streams.
   111  func (r *Reader) NextStream() (*Stream, error) {
   112  	if r.err != nil {
   113  		return nil, r.err
   114  	}
   115  	if r.inStream {
   116  		return nil, errors.New("previous stream not read until EOF")
   117  	}
   118  
   119  	// Temp storage for blocks.
   120  	block := make([]byte, 1024)
   121  	for {
   122  		// Read block ID.
   123  		n, err := r.mr.ReadInt8()
   124  		if err != nil {
   125  			return nil, r.setErr(err)
   126  		}
   127  		id := blockID(n)
   128  
   129  		// Read block size
   130  		sz, err := r.mr.ReadUint32()
   131  		if err != nil {
   132  			return nil, r.setErr(err)
   133  		}
   134  
   135  		// Read block data
   136  		if cap(block) < int(sz) {
   137  			block = make([]byte, sz)
   138  		}
   139  		block = block[:sz]
   140  		_, err = io.ReadFull(r.mr, block)
   141  		if err != nil {
   142  			return nil, r.setErr(err)
   143  		}
   144  
   145  		// Parse block
   146  		switch id {
   147  		case blockPlainKey:
   148  			// Read plaintext key.
   149  			key, _, err := msgp.ReadBytesBytes(block, make([]byte, 0, 32))
   150  			if err != nil {
   151  				return nil, r.setErr(err)
   152  			}
   153  			if len(key) != 32 {
   154  				return nil, r.setErr(fmt.Errorf("unexpected key length: %d", len(key)))
   155  			}
   156  
   157  			// Set key for following streams.
   158  			r.key = (*[32]byte)(key)
   159  		case blockEncryptedKey:
   160  			// Read public key
   161  			publicKey, block, err := msgp.ReadBytesZC(block)
   162  			if err != nil {
   163  				return nil, r.setErr(err)
   164  			}
   165  
   166  			// Request private key if we have a custom function.
   167  			if r.privateFn != nil {
   168  				pk, err := x509.ParsePKCS1PublicKey(publicKey)
   169  				if err != nil {
   170  					return nil, r.setErr(err)
   171  				}
   172  				r.private = r.privateFn(pk)
   173  				if r.private == nil {
   174  					if r.skipEncrypted || r.returnNonDec {
   175  						r.key = nil
   176  						continue
   177  					}
   178  					return nil, r.setErr(errors.New("nil private key returned"))
   179  				}
   180  			}
   181  
   182  			// Read cipher key
   183  			cipherKey, _, err := msgp.ReadBytesZC(block)
   184  			if err != nil {
   185  				return nil, r.setErr(err)
   186  			}
   187  			if r.private == nil {
   188  				if r.skipEncrypted || r.returnNonDec {
   189  					r.key = nil
   190  					continue
   191  				}
   192  				return nil, r.setErr(errors.New("private key has not been set"))
   193  			}
   194  
   195  			// Decrypt stream key
   196  			key, err := rsa.DecryptOAEP(sha512.New(), rand.Reader, r.private, cipherKey, nil)
   197  			if err != nil {
   198  				if r.returnNonDec {
   199  					r.key = nil
   200  					continue
   201  				}
   202  				return nil, err
   203  			}
   204  
   205  			if len(key) != 32 {
   206  				return nil, r.setErr(fmt.Errorf("unexpected key length: %d", len(key)))
   207  			}
   208  			r.key = (*[32]byte)(key)
   209  
   210  		case blockPlainStream, blockEncStream:
   211  			// Read metadata
   212  			name, block, err := msgp.ReadStringBytes(block)
   213  			if err != nil {
   214  				return nil, r.setErr(err)
   215  			}
   216  			extra, block, err := msgp.ReadBytesBytes(block, nil)
   217  			if err != nil {
   218  				return nil, r.setErr(err)
   219  			}
   220  			c, block, err := msgp.ReadUint8Bytes(block)
   221  			if err != nil {
   222  				return nil, r.setErr(err)
   223  			}
   224  			checksum := checksumType(c)
   225  			if !checksum.valid() {
   226  				return nil, r.setErr(fmt.Errorf("unknown checksum type %d", checksum))
   227  			}
   228  
   229  			// Return plaintext stream
   230  			if id == blockPlainStream {
   231  				return &Stream{
   232  					Reader: r.newStreamReader(checksum),
   233  					Name:   name,
   234  					Extra:  extra,
   235  					parent: r,
   236  				}, nil
   237  			}
   238  
   239  			// Handle encrypted streams.
   240  			if r.key == nil {
   241  				if r.skipEncrypted {
   242  					if err := r.skipDataBlocks(); err != nil {
   243  						return nil, r.setErr(err)
   244  					}
   245  					continue
   246  				}
   247  				return &Stream{
   248  					SentEncrypted: true,
   249  					Reader:        nil,
   250  					Name:          name,
   251  					Extra:         extra,
   252  					parent:        r,
   253  				}, ErrNoKey
   254  			}
   255  			// Read stream nonce
   256  			nonce, _, err := msgp.ReadBytesZC(block)
   257  			if err != nil {
   258  				return nil, r.setErr(err)
   259  			}
   260  
   261  			stream, err := sio.AES_256_GCM.Stream(r.key[:])
   262  			if err != nil {
   263  				return nil, r.setErr(err)
   264  			}
   265  
   266  			// Check if nonce is expected length.
   267  			if len(nonce) != stream.NonceSize() {
   268  				return nil, r.setErr(fmt.Errorf("unexpected nonce length: %d", len(nonce)))
   269  			}
   270  
   271  			encr := stream.DecryptReader(r.newStreamReader(checksum), nonce, nil)
   272  			return &Stream{
   273  				SentEncrypted: true,
   274  				Reader:        encr,
   275  				Name:          name,
   276  				Extra:         extra,
   277  				parent:        r,
   278  			}, nil
   279  		case blockEOS:
   280  			return nil, errors.New("end-of-stream without being in stream")
   281  		case blockEOF:
   282  			return nil, io.EOF
   283  		case blockError:
   284  			msg, _, err := msgp.ReadStringBytes(block)
   285  			if err != nil {
   286  				return nil, r.setErr(err)
   287  			}
   288  			return nil, r.setErr(errors.New(msg))
   289  		default:
   290  			if id >= 0 {
   291  				return nil, fmt.Errorf("unknown block type: %d", id)
   292  			}
   293  		}
   294  	}
   295  }
   296  
   297  // skipDataBlocks reads data blocks until end.
   298  func (r *Reader) skipDataBlocks() error {
   299  	for {
   300  		// Read block ID.
   301  		n, err := r.mr.ReadInt8()
   302  		if err != nil {
   303  			return err
   304  		}
   305  		id := blockID(n)
   306  		sz, err := r.mr.ReadUint32()
   307  		if err != nil {
   308  			return err
   309  		}
   310  		if id == blockError {
   311  			msg, err := r.mr.ReadString()
   312  			if err != nil {
   313  				return err
   314  			}
   315  			return errors.New(msg)
   316  		}
   317  		// Discard data
   318  		_, err = io.CopyN(io.Discard, r.mr, int64(sz))
   319  		if err != nil {
   320  			return err
   321  		}
   322  		switch id {
   323  		case blockDatablock:
   324  			// Skip data
   325  		case blockEOS:
   326  			// Done
   327  			r.inStream = false
   328  			return nil
   329  		default:
   330  			if id >= 0 {
   331  				return fmt.Errorf("unknown block type: %d", id)
   332  			}
   333  		}
   334  	}
   335  }
   336  
   337  // setErr sets a stateful error.
   338  func (r *Reader) setErr(err error) error {
   339  	if r.err != nil {
   340  		return r.err
   341  	}
   342  	if err == nil {
   343  		return err
   344  	}
   345  	if errors.Is(err, io.EOF) {
   346  		r.err = io.ErrUnexpectedEOF
   347  	}
   348  	if false {
   349  		_, file, line, ok := runtime.Caller(1)
   350  		if ok {
   351  			err = fmt.Errorf("%s:%d: %w", file, line, err)
   352  		}
   353  	}
   354  	r.err = err
   355  	return err
   356  }
   357  
   358  type streamReader struct {
   359  	up    *Reader
   360  	h     xxhash.Digest
   361  	buf   bytes.Buffer
   362  	tmp   []byte
   363  	isEOF bool
   364  	check checksumType
   365  }
   366  
   367  // newStreamReader creates a stream reader that can be read to get all data blocks.
   368  func (r *Reader) newStreamReader(ct checksumType) *streamReader {
   369  	sr := &streamReader{up: r, check: ct}
   370  	sr.h.Reset()
   371  	r.inStream = true
   372  	return sr
   373  }
   374  
   375  // Skip the remainder of the stream.
   376  func (s *Stream) Skip() error {
   377  	if sr, ok := s.Reader.(*streamReader); ok {
   378  		sr.isEOF = true
   379  		sr.buf.Reset()
   380  	}
   381  	return s.parent.skipDataBlocks()
   382  }
   383  
   384  // Read will return data blocks as on stream.
   385  func (r *streamReader) Read(b []byte) (int, error) {
   386  	if r.isEOF {
   387  		return 0, io.EOF
   388  	}
   389  	if r.up.err != nil {
   390  		return 0, r.up.err
   391  	}
   392  	for {
   393  		// If we have anything in the buffer return that first.
   394  		if r.buf.Len() > 0 {
   395  			n, err := r.buf.Read(b)
   396  			if err == io.EOF {
   397  				err = nil
   398  			}
   399  			return n, r.up.setErr(err)
   400  		}
   401  
   402  		// Read block
   403  		n, err := r.up.mr.ReadInt8()
   404  		if err != nil {
   405  			return 0, r.up.setErr(err)
   406  		}
   407  		id := blockID(n)
   408  
   409  		// Read size...
   410  		sz, err := r.up.mr.ReadUint32()
   411  		if err != nil {
   412  			return 0, r.up.setErr(err)
   413  		}
   414  
   415  		switch id {
   416  		case blockDatablock:
   417  			// Read block
   418  			buf, err := r.up.mr.ReadBytes(r.tmp[:0])
   419  			if err != nil {
   420  				return 0, r.up.setErr(err)
   421  			}
   422  
   423  			// Write to buffer and checksum
   424  			if r.check == checksumTypeXxhash {
   425  				r.h.Write(buf)
   426  			}
   427  			r.tmp = buf
   428  			r.buf.Write(buf)
   429  		case blockEOS:
   430  			// Verify stream checksum if any.
   431  			hash, err := r.up.mr.ReadBytes(nil)
   432  			if err != nil {
   433  				return 0, r.up.setErr(err)
   434  			}
   435  			switch r.check {
   436  			case checksumTypeXxhash:
   437  				got := r.h.Sum(nil)
   438  				if !bytes.Equal(hash, got) {
   439  					return 0, r.up.setErr(fmt.Errorf("checksum mismatch, want %s, got %s", hex.EncodeToString(hash), hex.EncodeToString(got)))
   440  				}
   441  			case checksumTypeNone:
   442  			default:
   443  				return 0, r.up.setErr(fmt.Errorf("unknown checksum id %d", r.check))
   444  			}
   445  			r.isEOF = true
   446  			r.up.inStream = false
   447  			return 0, io.EOF
   448  		case blockError:
   449  			msg, err := r.up.mr.ReadString()
   450  			if err != nil {
   451  				return 0, r.up.setErr(err)
   452  			}
   453  			return 0, r.up.setErr(errors.New(msg))
   454  		default:
   455  			if id >= 0 {
   456  				return 0, fmt.Errorf("unexpected block type: %d", id)
   457  			}
   458  			// Skip block...
   459  			_, err := io.CopyN(io.Discard, r.up.mr, int64(sz))
   460  			if err != nil {
   461  				return 0, r.up.setErr(err)
   462  			}
   463  		}
   464  	}
   465  }