github.com/minio/madmin-go@v1.7.5/estream/stream.go (about)

     1  //
     2  // MinIO Object Storage (c) 2022 MinIO, Inc.
     3  //
     4  // Licensed under the Apache License, Version 2.0 (the "License");
     5  // you may not use this file except in compliance with the License.
     6  // You may obtain a copy of the License at
     7  //
     8  //      http://www.apache.org/licenses/LICENSE-2.0
     9  //
    10  // Unless required by applicable law or agreed to in writing, software
    11  // distributed under the License is distributed on an "AS IS" BASIS,
    12  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  // See the License for the specific language governing permissions and
    14  // limitations under the License.
    15  //
    16  
    17  package estream
    18  
    19  import (
    20  	"crypto/rand"
    21  	crand "crypto/rand"
    22  	"crypto/rsa"
    23  	"crypto/sha512"
    24  	"crypto/x509"
    25  	"encoding/hex"
    26  	"errors"
    27  	"fmt"
    28  	"hash"
    29  	"io"
    30  
    31  	"github.com/cespare/xxhash/v2"
    32  	"github.com/secure-io/sio-go"
    33  	"github.com/tinylib/msgp/msgp"
    34  )
    35  
    36  // ReplaceFn provides key replacement.
    37  //
    38  // When a key is found on stream, the function is called with the public key.
    39  // The function must then return a private key to decrypt matching the key sent.
    40  // The public key must then be specified that should be used to re-encrypt the stream.
    41  //
    42  // If no private key is sent and the public key matches the one sent to the function
    43  // the key will be kept as is. Other returned values will cause an error.
    44  //
    45  // For encrypting unencrypted keys on stream a nil key will be sent.
    46  // If a public key is returned the key will be encrypted with the public key.
    47  // No private key should be returned for this.
    48  type ReplaceFn func(key *rsa.PublicKey) (*rsa.PrivateKey, *rsa.PublicKey)
    49  
    50  // ReplaceKeysOptions allows passing additional options to ReplaceKeys.
    51  type ReplaceKeysOptions struct {
    52  	// If EncryptAll set all unencrypted keys will be encrypted.
    53  	EncryptAll bool
    54  
    55  	// PassErrors will pass through error an error packet,
    56  	// and not return an error.
    57  	PassErrors bool
    58  }
    59  
    60  // ReplaceKeys will replace the keys in a stream.
    61  //
    62  // A replace function must be provided. See ReplaceFn for functionality.
    63  // If encryptAll is set.
    64  func ReplaceKeys(w io.Writer, r io.Reader, replace ReplaceFn, o ReplaceKeysOptions) error {
    65  	var ver [2]byte
    66  	if _, err := io.ReadFull(r, ver[:]); err != nil {
    67  		return err
    68  	}
    69  	switch ver[0] {
    70  	case 2:
    71  	default:
    72  		return fmt.Errorf("unknown stream version: 0x%x", ver[0])
    73  	}
    74  	if _, err := w.Write(ver[:]); err != nil {
    75  		return err
    76  	}
    77  	// Input
    78  	mr := msgp.NewReader(r)
    79  	mw := msgp.NewWriter(w)
    80  
    81  	// Temporary block storage.
    82  	block := make([]byte, 1024)
    83  
    84  	// Write a block.
    85  	writeBlock := func(id blockID, sz uint32, content []byte) error {
    86  		if err := mw.WriteInt8(int8(id)); err != nil {
    87  			return err
    88  		}
    89  		if err := mw.WriteUint32(sz); err != nil {
    90  			return err
    91  		}
    92  		_, err := mw.Write(content)
    93  		return err
    94  	}
    95  
    96  	for {
    97  		// Read block ID.
    98  		n, err := mr.ReadInt8()
    99  		if err != nil {
   100  			return err
   101  		}
   102  		id := blockID(n)
   103  
   104  		// Read size
   105  		sz, err := mr.ReadUint32()
   106  		if err != nil {
   107  			return err
   108  		}
   109  		if cap(block) < int(sz) {
   110  			block = make([]byte, sz)
   111  		}
   112  		block = block[:sz]
   113  		_, err = io.ReadFull(mr, block)
   114  		if err != nil {
   115  			return err
   116  		}
   117  
   118  		switch id {
   119  		case blockEncryptedKey:
   120  			ogBlock := block
   121  			// Read public key
   122  			publicKey, block, err := msgp.ReadBytesZC(block)
   123  			if err != nil {
   124  				return err
   125  			}
   126  
   127  			pk, err := x509.ParsePKCS1PublicKey(publicKey)
   128  			if err != nil {
   129  				return err
   130  			}
   131  
   132  			private, public := replace(pk)
   133  			if private == nil && public == pk {
   134  				if err := writeBlock(id, sz, ogBlock); err != nil {
   135  					return err
   136  				}
   137  			}
   138  			if private == nil {
   139  				return errors.New("no private key provided, unable to re-encrypt")
   140  			}
   141  
   142  			// Read cipher key
   143  			cipherKey, _, err := msgp.ReadBytesZC(block)
   144  			if err != nil {
   145  				return err
   146  			}
   147  
   148  			// Decrypt stream key
   149  			key, err := rsa.DecryptOAEP(sha512.New(), crand.Reader, private, cipherKey, nil)
   150  			if err != nil {
   151  				return err
   152  			}
   153  
   154  			if len(key) != 32 {
   155  				return fmt.Errorf("unexpected key length: %d", len(key))
   156  			}
   157  
   158  			cipherKey, err = rsa.EncryptOAEP(sha512.New(), crand.Reader, public, key[:], nil)
   159  			if err != nil {
   160  				return err
   161  			}
   162  
   163  			// Write Public key
   164  			tmp := msgp.AppendBytes(nil, x509.MarshalPKCS1PublicKey(public))
   165  			// Write encrypted cipher key
   166  			tmp = msgp.AppendBytes(tmp, cipherKey)
   167  			if err := writeBlock(blockEncryptedKey, uint32(len(tmp)), tmp); err != nil {
   168  				return err
   169  			}
   170  		case blockPlainKey:
   171  			if !o.EncryptAll {
   172  				if err := writeBlock(id, sz, block); err != nil {
   173  					return err
   174  				}
   175  				continue
   176  			}
   177  			_, public := replace(nil)
   178  			if public == nil {
   179  				if err := writeBlock(id, sz, block); err != nil {
   180  					return err
   181  				}
   182  				continue
   183  			}
   184  			key, _, err := msgp.ReadBytesZC(block)
   185  			if err != nil {
   186  				return err
   187  			}
   188  			if len(key) != 32 {
   189  				return fmt.Errorf("unexpected key length: %d", len(key))
   190  			}
   191  			cipherKey, err := rsa.EncryptOAEP(sha512.New(), crand.Reader, public, key[:], nil)
   192  			if err != nil {
   193  				return err
   194  			}
   195  
   196  			// Write Public key
   197  			tmp := msgp.AppendBytes(nil, x509.MarshalPKCS1PublicKey(public))
   198  			// Write encrypted cipher key
   199  			tmp = msgp.AppendBytes(tmp, cipherKey)
   200  			if err := writeBlock(blockEncryptedKey, uint32(len(tmp)), tmp); err != nil {
   201  				return err
   202  			}
   203  		case blockEOF:
   204  			if err := writeBlock(id, sz, block); err != nil {
   205  				return err
   206  			}
   207  			return mw.Flush()
   208  		case blockError:
   209  			if o.PassErrors {
   210  				if err := writeBlock(id, sz, block); err != nil {
   211  					return err
   212  				}
   213  				return mw.Flush()
   214  			}
   215  			// Return error
   216  			msg, _, err := msgp.ReadStringBytes(block)
   217  			if err != nil {
   218  				return err
   219  			}
   220  			return errors.New(msg)
   221  		default:
   222  			if err := writeBlock(id, sz, block); err != nil {
   223  				return err
   224  			}
   225  		}
   226  	}
   227  }
   228  
   229  // DebugStream will print stream block information to w.
   230  func (r *Reader) DebugStream(w io.Writer) error {
   231  	if r.err != nil {
   232  		return r.err
   233  	}
   234  	if r.inStream {
   235  		return errors.New("previous stream not read until EOF")
   236  	}
   237  	fmt.Fprintf(w, "stream major: %v, minor: %v\n", r.majorV, r.minorV)
   238  
   239  	// Temp storage for blocks.
   240  	block := make([]byte, 1024)
   241  	hashers := []hash.Hash{nil, xxhash.New()}
   242  	for {
   243  		// Read block ID.
   244  		n, err := r.mr.ReadInt8()
   245  		if err != nil {
   246  			return r.setErr(fmt.Errorf("reading block id: %w", err))
   247  		}
   248  		id := blockID(n)
   249  
   250  		// Read block size
   251  		sz, err := r.mr.ReadUint32()
   252  		if err != nil {
   253  			return r.setErr(fmt.Errorf("reading block size: %w", err))
   254  		}
   255  		fmt.Fprintf(w, "block type: %v, size: %d bytes, in stream: %v\n", id, sz, r.inStream)
   256  
   257  		// Read block data
   258  		if cap(block) < int(sz) {
   259  			block = make([]byte, sz)
   260  		}
   261  		block = block[:sz]
   262  		_, err = io.ReadFull(r.mr, block)
   263  		if err != nil {
   264  			return r.setErr(fmt.Errorf("reading block data: %w", err))
   265  		}
   266  
   267  		// Parse block
   268  		switch id {
   269  		case blockPlainKey:
   270  			// Read plaintext key.
   271  			key, _, err := msgp.ReadBytesBytes(block, make([]byte, 0, 32))
   272  			if err != nil {
   273  				return r.setErr(fmt.Errorf("reading key: %w", err))
   274  			}
   275  			if len(key) != 32 {
   276  				return r.setErr(fmt.Errorf("unexpected key length: %d", len(key)))
   277  			}
   278  
   279  			// Set key for following streams.
   280  			r.key = (*[32]byte)(key)
   281  			fmt.Fprintf(w, "plain key read\n")
   282  
   283  		case blockEncryptedKey:
   284  			// Read public key
   285  			publicKey, block, err := msgp.ReadBytesZC(block)
   286  			if err != nil {
   287  				return r.setErr(fmt.Errorf("reading public key: %w", err))
   288  			}
   289  
   290  			// Request private key if we have a custom function.
   291  			if r.privateFn != nil {
   292  				fmt.Fprintf(w, "requesting private key from privateFn\n")
   293  				pk, err := x509.ParsePKCS1PublicKey(publicKey)
   294  				if err != nil {
   295  					return r.setErr(fmt.Errorf("parse public key: %w", err))
   296  				}
   297  				r.private = r.privateFn(pk)
   298  				if r.private == nil {
   299  					fmt.Fprintf(w, "privateFn did not provide private key\n")
   300  					if r.skipEncrypted || r.returnNonDec {
   301  						fmt.Fprintf(w, "continuing. skipEncrypted: %v, returnNonDec: %v\n", r.skipEncrypted, r.returnNonDec)
   302  						r.key = nil
   303  						continue
   304  					}
   305  					return r.setErr(errors.New("nil private key returned"))
   306  				}
   307  			}
   308  
   309  			// Read cipher key
   310  			cipherKey, _, err := msgp.ReadBytesZC(block)
   311  			if err != nil {
   312  				return r.setErr(fmt.Errorf("reading cipherkey: %w", err))
   313  			}
   314  			if r.private == nil {
   315  				if r.skipEncrypted || r.returnNonDec {
   316  					fmt.Fprintf(w, "no private key, continuing due to skipEncrypted: %v, returnNonDec: %v\n", r.skipEncrypted, r.returnNonDec)
   317  					r.key = nil
   318  					continue
   319  				}
   320  				return r.setErr(errors.New("private key has not been set"))
   321  			}
   322  
   323  			// Decrypt stream key
   324  			key, err := rsa.DecryptOAEP(sha512.New(), rand.Reader, r.private, cipherKey, nil)
   325  			if err != nil {
   326  				if r.returnNonDec {
   327  					fmt.Fprintf(w, "no private key, continuing due to returnNonDec: %v\n", r.returnNonDec)
   328  					r.key = nil
   329  					continue
   330  				}
   331  				return fmt.Errorf("decrypting key: %w", err)
   332  			}
   333  
   334  			if len(key) != 32 {
   335  				return r.setErr(fmt.Errorf("unexpected key length: %d", len(key)))
   336  			}
   337  			r.key = (*[32]byte)(key)
   338  			fmt.Fprintf(w, "stream key decoded\n")
   339  
   340  		case blockPlainStream, blockEncStream:
   341  			// Read metadata
   342  			name, block, err := msgp.ReadStringBytes(block)
   343  			if err != nil {
   344  				return r.setErr(fmt.Errorf("reading name: %w", err))
   345  			}
   346  			extra, block, err := msgp.ReadBytesBytes(block, nil)
   347  			if err != nil {
   348  				return r.setErr(fmt.Errorf("reading extra: %w", err))
   349  			}
   350  			c, block, err := msgp.ReadUint8Bytes(block)
   351  			if err != nil {
   352  				return r.setErr(fmt.Errorf("reading checksum: %w", err))
   353  			}
   354  			checksum := checksumType(c)
   355  			if !checksum.valid() {
   356  				return r.setErr(fmt.Errorf("unknown checksum type %d", checksum))
   357  			}
   358  			fmt.Fprintf(w, "new stream. name: %v, extra size: %v, checksum type: %v\n", name, len(extra), checksum)
   359  
   360  			for _, h := range hashers {
   361  				if h != nil {
   362  					h.Reset()
   363  				}
   364  			}
   365  
   366  			// Return plaintext stream
   367  			if id == blockPlainStream {
   368  				r.inStream = true
   369  				continue
   370  			}
   371  
   372  			// Handle encrypted streams.
   373  			if r.key == nil {
   374  				if r.skipEncrypted {
   375  					fmt.Fprintf(w, "nil key, skipEncrypted: %v\n", r.skipEncrypted)
   376  					r.inStream = true
   377  					continue
   378  				}
   379  				return ErrNoKey
   380  			}
   381  			// Read stream nonce
   382  			nonce, _, err := msgp.ReadBytesZC(block)
   383  			if err != nil {
   384  				return r.setErr(fmt.Errorf("reading nonce: %w", err))
   385  			}
   386  
   387  			stream, err := sio.AES_256_GCM.Stream(r.key[:])
   388  			if err != nil {
   389  				return r.setErr(fmt.Errorf("initializing sio: %w", err))
   390  			}
   391  
   392  			// Check if nonce is expected length.
   393  			if len(nonce) != stream.NonceSize() {
   394  				return r.setErr(fmt.Errorf("unexpected nonce length: %d", len(nonce)))
   395  			}
   396  			fmt.Fprintf(w, "nonce: %v\n", nonce)
   397  			r.inStream = true
   398  		case blockEOS:
   399  			if !r.inStream {
   400  				return errors.New("end-of-stream without being in stream")
   401  			}
   402  			h, _, err := msgp.ReadBytesZC(block)
   403  			if err != nil {
   404  				return r.setErr(fmt.Errorf("reading block data: %w", err))
   405  			}
   406  			fmt.Fprintf(w, "end-of-stream. stream hash: %s. data hashes: ", hex.EncodeToString(h))
   407  			for i, h := range hashers {
   408  				if h != nil {
   409  					fmt.Fprintf(w, "%s:%s. ", checksumType(i), hex.EncodeToString(h.Sum(nil)))
   410  				}
   411  			}
   412  			fmt.Fprint(w, "\n")
   413  			r.inStream = false
   414  		case blockEOF:
   415  			if r.inStream {
   416  				return errors.New("end-of-file without finishing stream")
   417  			}
   418  			fmt.Fprintf(w, "end-of-file\n")
   419  			return nil
   420  		case blockError:
   421  			msg, _, err := msgp.ReadStringBytes(block)
   422  			if err != nil {
   423  				return r.setErr(fmt.Errorf("reading error string: %w", err))
   424  			}
   425  			fmt.Fprintf(w, "error recorded on stream: %v\n", msg)
   426  			return nil
   427  		case blockDatablock:
   428  			buf, _, err := msgp.ReadBytesZC(block)
   429  			if err != nil {
   430  				return r.setErr(fmt.Errorf("reading block data: %w", err))
   431  			}
   432  			for _, h := range hashers {
   433  				if h != nil {
   434  					h.Write(buf)
   435  				}
   436  			}
   437  			fmt.Fprintf(w, "data block, length: %v\n", len(buf))
   438  		default:
   439  			fmt.Fprintf(w, "skipping block\n")
   440  			if id >= 0 {
   441  				return fmt.Errorf("unknown block type: %d", id)
   442  			}
   443  		}
   444  	}
   445  }