github.com/minio/madmin-go/v3@v3.0.51/estream/stream.go (about)

     1  //
     2  // Copyright (c) 2015-2022 MinIO, Inc.
     3  //
     4  // This file is part of MinIO Object Storage stack
     5  //
     6  // This program is free software: you can redistribute it and/or modify
     7  // it under the terms of the GNU Affero General Public License as
     8  // published by the Free Software Foundation, either version 3 of the
     9  // License, or (at your option) any later version.
    10  //
    11  // This program is distributed in the hope that it will be useful,
    12  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    13  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    14  // GNU Affero General Public License for more details.
    15  //
    16  // You should have received a copy of the GNU Affero General Public License
    17  // along with this program. If not, see <http://www.gnu.org/licenses/>.
    18  //
    19  
    20  package estream
    21  
    22  import (
    23  	"crypto/rand"
    24  	crand "crypto/rand"
    25  	"crypto/rsa"
    26  	"crypto/sha512"
    27  	"crypto/x509"
    28  	"encoding/hex"
    29  	"errors"
    30  	"fmt"
    31  	"hash"
    32  	"io"
    33  
    34  	"github.com/cespare/xxhash/v2"
    35  	"github.com/secure-io/sio-go"
    36  	"github.com/tinylib/msgp/msgp"
    37  )
    38  
    39  // ReplaceFn provides key replacement.
    40  //
    41  // When a key is found on stream, the function is called with the public key.
    42  // The function must then return a private key to decrypt matching the key sent.
    43  // The public key must then be specified that should be used to re-encrypt the stream.
    44  //
    45  // If no private key is sent and the public key matches the one sent to the function
    46  // the key will be kept as is. Other returned values will cause an error.
    47  //
    48  // For encrypting unencrypted keys on stream a nil key will be sent.
    49  // If a public key is returned the key will be encrypted with the public key.
    50  // No private key should be returned for this.
    51  type ReplaceFn func(key *rsa.PublicKey) (*rsa.PrivateKey, *rsa.PublicKey)
    52  
    53  // ReplaceKeysOptions allows passing additional options to ReplaceKeys.
    54  type ReplaceKeysOptions struct {
    55  	// If EncryptAll set all unencrypted keys will be encrypted.
    56  	EncryptAll bool
    57  
    58  	// PassErrors will pass through error an error packet,
    59  	// and not return an error.
    60  	PassErrors bool
    61  }
    62  
    63  // ReplaceKeys will replace the keys in a stream.
    64  //
    65  // A replace function must be provided. See ReplaceFn for functionality.
    66  // If encryptAll is set.
    67  func ReplaceKeys(w io.Writer, r io.Reader, replace ReplaceFn, o ReplaceKeysOptions) error {
    68  	var ver [2]byte
    69  	if _, err := io.ReadFull(r, ver[:]); err != nil {
    70  		return err
    71  	}
    72  	switch ver[0] {
    73  	case 2:
    74  	default:
    75  		return fmt.Errorf("unknown stream version: 0x%x", ver[0])
    76  	}
    77  	if _, err := w.Write(ver[:]); err != nil {
    78  		return err
    79  	}
    80  	// Input
    81  	mr := msgp.NewReader(r)
    82  	mw := msgp.NewWriter(w)
    83  
    84  	// Temporary block storage.
    85  	block := make([]byte, 1024)
    86  
    87  	// Write a block.
    88  	writeBlock := func(id blockID, sz uint32, content []byte) error {
    89  		if err := mw.WriteInt8(int8(id)); err != nil {
    90  			return err
    91  		}
    92  		if err := mw.WriteUint32(sz); err != nil {
    93  			return err
    94  		}
    95  		_, err := mw.Write(content)
    96  		return err
    97  	}
    98  
    99  	for {
   100  		// Read block ID.
   101  		n, err := mr.ReadInt8()
   102  		if err != nil {
   103  			return err
   104  		}
   105  		id := blockID(n)
   106  
   107  		// Read size
   108  		sz, err := mr.ReadUint32()
   109  		if err != nil {
   110  			return err
   111  		}
   112  		if cap(block) < int(sz) {
   113  			block = make([]byte, sz)
   114  		}
   115  		block = block[:sz]
   116  		_, err = io.ReadFull(mr, block)
   117  		if err != nil {
   118  			return err
   119  		}
   120  
   121  		switch id {
   122  		case blockEncryptedKey:
   123  			ogBlock := block
   124  			// Read public key
   125  			publicKey, block, err := msgp.ReadBytesZC(block)
   126  			if err != nil {
   127  				return err
   128  			}
   129  
   130  			pk, err := x509.ParsePKCS1PublicKey(publicKey)
   131  			if err != nil {
   132  				return err
   133  			}
   134  
   135  			private, public := replace(pk)
   136  			if private == nil && public == pk {
   137  				if err := writeBlock(id, sz, ogBlock); err != nil {
   138  					return err
   139  				}
   140  			}
   141  			if private == nil {
   142  				return errors.New("no private key provided, unable to re-encrypt")
   143  			}
   144  
   145  			// Read cipher key
   146  			cipherKey, _, err := msgp.ReadBytesZC(block)
   147  			if err != nil {
   148  				return err
   149  			}
   150  
   151  			// Decrypt stream key
   152  			key, err := rsa.DecryptOAEP(sha512.New(), crand.Reader, private, cipherKey, nil)
   153  			if err != nil {
   154  				return err
   155  			}
   156  
   157  			if len(key) != 32 {
   158  				return fmt.Errorf("unexpected key length: %d", len(key))
   159  			}
   160  
   161  			cipherKey, err = rsa.EncryptOAEP(sha512.New(), crand.Reader, public, key[:], nil)
   162  			if err != nil {
   163  				return err
   164  			}
   165  
   166  			// Write Public key
   167  			tmp := msgp.AppendBytes(nil, x509.MarshalPKCS1PublicKey(public))
   168  			// Write encrypted cipher key
   169  			tmp = msgp.AppendBytes(tmp, cipherKey)
   170  			if err := writeBlock(blockEncryptedKey, uint32(len(tmp)), tmp); err != nil {
   171  				return err
   172  			}
   173  		case blockPlainKey:
   174  			if !o.EncryptAll {
   175  				if err := writeBlock(id, sz, block); err != nil {
   176  					return err
   177  				}
   178  				continue
   179  			}
   180  			_, public := replace(nil)
   181  			if public == nil {
   182  				if err := writeBlock(id, sz, block); err != nil {
   183  					return err
   184  				}
   185  				continue
   186  			}
   187  			key, _, err := msgp.ReadBytesZC(block)
   188  			if err != nil {
   189  				return err
   190  			}
   191  			if len(key) != 32 {
   192  				return fmt.Errorf("unexpected key length: %d", len(key))
   193  			}
   194  			cipherKey, err := rsa.EncryptOAEP(sha512.New(), crand.Reader, public, key[:], nil)
   195  			if err != nil {
   196  				return err
   197  			}
   198  
   199  			// Write Public key
   200  			tmp := msgp.AppendBytes(nil, x509.MarshalPKCS1PublicKey(public))
   201  			// Write encrypted cipher key
   202  			tmp = msgp.AppendBytes(tmp, cipherKey)
   203  			if err := writeBlock(blockEncryptedKey, uint32(len(tmp)), tmp); err != nil {
   204  				return err
   205  			}
   206  		case blockEOF:
   207  			if err := writeBlock(id, sz, block); err != nil {
   208  				return err
   209  			}
   210  			return mw.Flush()
   211  		case blockError:
   212  			if o.PassErrors {
   213  				if err := writeBlock(id, sz, block); err != nil {
   214  					return err
   215  				}
   216  				return mw.Flush()
   217  			}
   218  			// Return error
   219  			msg, _, err := msgp.ReadStringBytes(block)
   220  			if err != nil {
   221  				return err
   222  			}
   223  			return errors.New(msg)
   224  		default:
   225  			if err := writeBlock(id, sz, block); err != nil {
   226  				return err
   227  			}
   228  		}
   229  	}
   230  }
   231  
   232  // DebugStream will print stream block information to w.
   233  func (r *Reader) DebugStream(w io.Writer) error {
   234  	if r.err != nil {
   235  		return r.err
   236  	}
   237  	if r.inStream {
   238  		return errors.New("previous stream not read until EOF")
   239  	}
   240  	fmt.Fprintf(w, "stream major: %v, minor: %v\n", r.majorV, r.minorV)
   241  
   242  	// Temp storage for blocks.
   243  	block := make([]byte, 1024)
   244  	hashers := []hash.Hash{nil, xxhash.New()}
   245  	for {
   246  		// Read block ID.
   247  		n, err := r.mr.ReadInt8()
   248  		if err != nil {
   249  			return r.setErr(fmt.Errorf("reading block id: %w", err))
   250  		}
   251  		id := blockID(n)
   252  
   253  		// Read block size
   254  		sz, err := r.mr.ReadUint32()
   255  		if err != nil {
   256  			return r.setErr(fmt.Errorf("reading block size: %w", err))
   257  		}
   258  		fmt.Fprintf(w, "block type: %v, size: %d bytes, in stream: %v\n", id, sz, r.inStream)
   259  
   260  		// Read block data
   261  		if cap(block) < int(sz) {
   262  			block = make([]byte, sz)
   263  		}
   264  		block = block[:sz]
   265  		_, err = io.ReadFull(r.mr, block)
   266  		if err != nil {
   267  			return r.setErr(fmt.Errorf("reading block data: %w", err))
   268  		}
   269  
   270  		// Parse block
   271  		switch id {
   272  		case blockPlainKey:
   273  			// Read plaintext key.
   274  			key, _, err := msgp.ReadBytesBytes(block, make([]byte, 0, 32))
   275  			if err != nil {
   276  				return r.setErr(fmt.Errorf("reading key: %w", err))
   277  			}
   278  			if len(key) != 32 {
   279  				return r.setErr(fmt.Errorf("unexpected key length: %d", len(key)))
   280  			}
   281  
   282  			// Set key for following streams.
   283  			r.key = (*[32]byte)(key)
   284  			fmt.Fprintf(w, "plain key read\n")
   285  
   286  		case blockEncryptedKey:
   287  			// Read public key
   288  			publicKey, block, err := msgp.ReadBytesZC(block)
   289  			if err != nil {
   290  				return r.setErr(fmt.Errorf("reading public key: %w", err))
   291  			}
   292  
   293  			// Request private key if we have a custom function.
   294  			if r.privateFn != nil {
   295  				fmt.Fprintf(w, "requesting private key from privateFn\n")
   296  				pk, err := x509.ParsePKCS1PublicKey(publicKey)
   297  				if err != nil {
   298  					return r.setErr(fmt.Errorf("parse public key: %w", err))
   299  				}
   300  				r.private = r.privateFn(pk)
   301  				if r.private == nil {
   302  					fmt.Fprintf(w, "privateFn did not provide private key\n")
   303  					if r.skipEncrypted || r.returnNonDec {
   304  						fmt.Fprintf(w, "continuing. skipEncrypted: %v, returnNonDec: %v\n", r.skipEncrypted, r.returnNonDec)
   305  						r.key = nil
   306  						continue
   307  					}
   308  					return r.setErr(errors.New("nil private key returned"))
   309  				}
   310  			}
   311  
   312  			// Read cipher key
   313  			cipherKey, _, err := msgp.ReadBytesZC(block)
   314  			if err != nil {
   315  				return r.setErr(fmt.Errorf("reading cipherkey: %w", err))
   316  			}
   317  			if r.private == nil {
   318  				if r.skipEncrypted || r.returnNonDec {
   319  					fmt.Fprintf(w, "no private key, continuing due to skipEncrypted: %v, returnNonDec: %v\n", r.skipEncrypted, r.returnNonDec)
   320  					r.key = nil
   321  					continue
   322  				}
   323  				return r.setErr(errors.New("private key has not been set"))
   324  			}
   325  
   326  			// Decrypt stream key
   327  			key, err := rsa.DecryptOAEP(sha512.New(), rand.Reader, r.private, cipherKey, nil)
   328  			if err != nil {
   329  				if r.returnNonDec {
   330  					fmt.Fprintf(w, "no private key, continuing due to returnNonDec: %v\n", r.returnNonDec)
   331  					r.key = nil
   332  					continue
   333  				}
   334  				return fmt.Errorf("decrypting key: %w", err)
   335  			}
   336  
   337  			if len(key) != 32 {
   338  				return r.setErr(fmt.Errorf("unexpected key length: %d", len(key)))
   339  			}
   340  			r.key = (*[32]byte)(key)
   341  			fmt.Fprintf(w, "stream key decoded\n")
   342  
   343  		case blockPlainStream, blockEncStream:
   344  			// Read metadata
   345  			name, block, err := msgp.ReadStringBytes(block)
   346  			if err != nil {
   347  				return r.setErr(fmt.Errorf("reading name: %w", err))
   348  			}
   349  			extra, block, err := msgp.ReadBytesBytes(block, nil)
   350  			if err != nil {
   351  				return r.setErr(fmt.Errorf("reading extra: %w", err))
   352  			}
   353  			c, block, err := msgp.ReadUint8Bytes(block)
   354  			if err != nil {
   355  				return r.setErr(fmt.Errorf("reading checksum: %w", err))
   356  			}
   357  			checksum := checksumType(c)
   358  			if !checksum.valid() {
   359  				return r.setErr(fmt.Errorf("unknown checksum type %d", checksum))
   360  			}
   361  			fmt.Fprintf(w, "new stream. name: %v, extra size: %v, checksum type: %v\n", name, len(extra), checksum)
   362  
   363  			for _, h := range hashers {
   364  				if h != nil {
   365  					h.Reset()
   366  				}
   367  			}
   368  
   369  			// Return plaintext stream
   370  			if id == blockPlainStream {
   371  				r.inStream = true
   372  				continue
   373  			}
   374  
   375  			// Handle encrypted streams.
   376  			if r.key == nil {
   377  				if r.skipEncrypted {
   378  					fmt.Fprintf(w, "nil key, skipEncrypted: %v\n", r.skipEncrypted)
   379  					r.inStream = true
   380  					continue
   381  				}
   382  				return ErrNoKey
   383  			}
   384  			// Read stream nonce
   385  			nonce, _, err := msgp.ReadBytesZC(block)
   386  			if err != nil {
   387  				return r.setErr(fmt.Errorf("reading nonce: %w", err))
   388  			}
   389  
   390  			stream, err := sio.AES_256_GCM.Stream(r.key[:])
   391  			if err != nil {
   392  				return r.setErr(fmt.Errorf("initializing sio: %w", err))
   393  			}
   394  
   395  			// Check if nonce is expected length.
   396  			if len(nonce) != stream.NonceSize() {
   397  				return r.setErr(fmt.Errorf("unexpected nonce length: %d", len(nonce)))
   398  			}
   399  			fmt.Fprintf(w, "nonce: %v\n", nonce)
   400  			r.inStream = true
   401  		case blockEOS:
   402  			if !r.inStream {
   403  				return errors.New("end-of-stream without being in stream")
   404  			}
   405  			h, _, err := msgp.ReadBytesZC(block)
   406  			if err != nil {
   407  				return r.setErr(fmt.Errorf("reading block data: %w", err))
   408  			}
   409  			fmt.Fprintf(w, "end-of-stream. stream hash: %s. data hashes: ", hex.EncodeToString(h))
   410  			for i, h := range hashers {
   411  				if h != nil {
   412  					fmt.Fprintf(w, "%s:%s. ", checksumType(i), hex.EncodeToString(h.Sum(nil)))
   413  				}
   414  			}
   415  			fmt.Fprint(w, "\n")
   416  			r.inStream = false
   417  		case blockEOF:
   418  			if r.inStream {
   419  				return errors.New("end-of-file without finishing stream")
   420  			}
   421  			fmt.Fprintf(w, "end-of-file\n")
   422  			return nil
   423  		case blockError:
   424  			msg, _, err := msgp.ReadStringBytes(block)
   425  			if err != nil {
   426  				return r.setErr(fmt.Errorf("reading error string: %w", err))
   427  			}
   428  			fmt.Fprintf(w, "error recorded on stream: %v\n", msg)
   429  			return nil
   430  		case blockDatablock:
   431  			buf, _, err := msgp.ReadBytesZC(block)
   432  			if err != nil {
   433  				return r.setErr(fmt.Errorf("reading block data: %w", err))
   434  			}
   435  			for _, h := range hashers {
   436  				if h != nil {
   437  					h.Write(buf)
   438  				}
   439  			}
   440  			fmt.Fprintf(w, "data block, length: %v\n", len(buf))
   441  		default:
   442  			fmt.Fprintf(w, "skipping block\n")
   443  			if id >= 0 {
   444  				return fmt.Errorf("unknown block type: %d", id)
   445  			}
   446  		}
   447  	}
   448  }