github.com/minio/madmin-go@v1.7.5/estream/stream.go (about) 1 // 2 // MinIO Object Storage (c) 2022 MinIO, Inc. 3 // 4 // Licensed under the Apache License, Version 2.0 (the "License"); 5 // you may not use this file except in compliance with the License. 6 // You may obtain a copy of the License at 7 // 8 // http://www.apache.org/licenses/LICENSE-2.0 9 // 10 // Unless required by applicable law or agreed to in writing, software 11 // distributed under the License is distributed on an "AS IS" BASIS, 12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 // See the License for the specific language governing permissions and 14 // limitations under the License. 15 // 16 17 package estream 18 19 import ( 20 "crypto/rand" 21 crand "crypto/rand" 22 "crypto/rsa" 23 "crypto/sha512" 24 "crypto/x509" 25 "encoding/hex" 26 "errors" 27 "fmt" 28 "hash" 29 "io" 30 31 "github.com/cespare/xxhash/v2" 32 "github.com/secure-io/sio-go" 33 "github.com/tinylib/msgp/msgp" 34 ) 35 36 // ReplaceFn provides key replacement. 37 // 38 // When a key is found on stream, the function is called with the public key. 39 // The function must then return a private key to decrypt matching the key sent. 40 // The public key must then be specified that should be used to re-encrypt the stream. 41 // 42 // If no private key is sent and the public key matches the one sent to the function 43 // the key will be kept as is. Other returned values will cause an error. 44 // 45 // For encrypting unencrypted keys on stream a nil key will be sent. 46 // If a public key is returned the key will be encrypted with the public key. 47 // No private key should be returned for this. 48 type ReplaceFn func(key *rsa.PublicKey) (*rsa.PrivateKey, *rsa.PublicKey) 49 50 // ReplaceKeysOptions allows passing additional options to ReplaceKeys. 51 type ReplaceKeysOptions struct { 52 // If EncryptAll set all unencrypted keys will be encrypted. 53 EncryptAll bool 54 55 // PassErrors will pass through error an error packet, 56 // and not return an error. 57 PassErrors bool 58 } 59 60 // ReplaceKeys will replace the keys in a stream. 61 // 62 // A replace function must be provided. See ReplaceFn for functionality. 63 // If encryptAll is set. 64 func ReplaceKeys(w io.Writer, r io.Reader, replace ReplaceFn, o ReplaceKeysOptions) error { 65 var ver [2]byte 66 if _, err := io.ReadFull(r, ver[:]); err != nil { 67 return err 68 } 69 switch ver[0] { 70 case 2: 71 default: 72 return fmt.Errorf("unknown stream version: 0x%x", ver[0]) 73 } 74 if _, err := w.Write(ver[:]); err != nil { 75 return err 76 } 77 // Input 78 mr := msgp.NewReader(r) 79 mw := msgp.NewWriter(w) 80 81 // Temporary block storage. 82 block := make([]byte, 1024) 83 84 // Write a block. 85 writeBlock := func(id blockID, sz uint32, content []byte) error { 86 if err := mw.WriteInt8(int8(id)); err != nil { 87 return err 88 } 89 if err := mw.WriteUint32(sz); err != nil { 90 return err 91 } 92 _, err := mw.Write(content) 93 return err 94 } 95 96 for { 97 // Read block ID. 98 n, err := mr.ReadInt8() 99 if err != nil { 100 return err 101 } 102 id := blockID(n) 103 104 // Read size 105 sz, err := mr.ReadUint32() 106 if err != nil { 107 return err 108 } 109 if cap(block) < int(sz) { 110 block = make([]byte, sz) 111 } 112 block = block[:sz] 113 _, err = io.ReadFull(mr, block) 114 if err != nil { 115 return err 116 } 117 118 switch id { 119 case blockEncryptedKey: 120 ogBlock := block 121 // Read public key 122 publicKey, block, err := msgp.ReadBytesZC(block) 123 if err != nil { 124 return err 125 } 126 127 pk, err := x509.ParsePKCS1PublicKey(publicKey) 128 if err != nil { 129 return err 130 } 131 132 private, public := replace(pk) 133 if private == nil && public == pk { 134 if err := writeBlock(id, sz, ogBlock); err != nil { 135 return err 136 } 137 } 138 if private == nil { 139 return errors.New("no private key provided, unable to re-encrypt") 140 } 141 142 // Read cipher key 143 cipherKey, _, err := msgp.ReadBytesZC(block) 144 if err != nil { 145 return err 146 } 147 148 // Decrypt stream key 149 key, err := rsa.DecryptOAEP(sha512.New(), crand.Reader, private, cipherKey, nil) 150 if err != nil { 151 return err 152 } 153 154 if len(key) != 32 { 155 return fmt.Errorf("unexpected key length: %d", len(key)) 156 } 157 158 cipherKey, err = rsa.EncryptOAEP(sha512.New(), crand.Reader, public, key[:], nil) 159 if err != nil { 160 return err 161 } 162 163 // Write Public key 164 tmp := msgp.AppendBytes(nil, x509.MarshalPKCS1PublicKey(public)) 165 // Write encrypted cipher key 166 tmp = msgp.AppendBytes(tmp, cipherKey) 167 if err := writeBlock(blockEncryptedKey, uint32(len(tmp)), tmp); err != nil { 168 return err 169 } 170 case blockPlainKey: 171 if !o.EncryptAll { 172 if err := writeBlock(id, sz, block); err != nil { 173 return err 174 } 175 continue 176 } 177 _, public := replace(nil) 178 if public == nil { 179 if err := writeBlock(id, sz, block); err != nil { 180 return err 181 } 182 continue 183 } 184 key, _, err := msgp.ReadBytesZC(block) 185 if err != nil { 186 return err 187 } 188 if len(key) != 32 { 189 return fmt.Errorf("unexpected key length: %d", len(key)) 190 } 191 cipherKey, err := rsa.EncryptOAEP(sha512.New(), crand.Reader, public, key[:], nil) 192 if err != nil { 193 return err 194 } 195 196 // Write Public key 197 tmp := msgp.AppendBytes(nil, x509.MarshalPKCS1PublicKey(public)) 198 // Write encrypted cipher key 199 tmp = msgp.AppendBytes(tmp, cipherKey) 200 if err := writeBlock(blockEncryptedKey, uint32(len(tmp)), tmp); err != nil { 201 return err 202 } 203 case blockEOF: 204 if err := writeBlock(id, sz, block); err != nil { 205 return err 206 } 207 return mw.Flush() 208 case blockError: 209 if o.PassErrors { 210 if err := writeBlock(id, sz, block); err != nil { 211 return err 212 } 213 return mw.Flush() 214 } 215 // Return error 216 msg, _, err := msgp.ReadStringBytes(block) 217 if err != nil { 218 return err 219 } 220 return errors.New(msg) 221 default: 222 if err := writeBlock(id, sz, block); err != nil { 223 return err 224 } 225 } 226 } 227 } 228 229 // DebugStream will print stream block information to w. 230 func (r *Reader) DebugStream(w io.Writer) error { 231 if r.err != nil { 232 return r.err 233 } 234 if r.inStream { 235 return errors.New("previous stream not read until EOF") 236 } 237 fmt.Fprintf(w, "stream major: %v, minor: %v\n", r.majorV, r.minorV) 238 239 // Temp storage for blocks. 240 block := make([]byte, 1024) 241 hashers := []hash.Hash{nil, xxhash.New()} 242 for { 243 // Read block ID. 244 n, err := r.mr.ReadInt8() 245 if err != nil { 246 return r.setErr(fmt.Errorf("reading block id: %w", err)) 247 } 248 id := blockID(n) 249 250 // Read block size 251 sz, err := r.mr.ReadUint32() 252 if err != nil { 253 return r.setErr(fmt.Errorf("reading block size: %w", err)) 254 } 255 fmt.Fprintf(w, "block type: %v, size: %d bytes, in stream: %v\n", id, sz, r.inStream) 256 257 // Read block data 258 if cap(block) < int(sz) { 259 block = make([]byte, sz) 260 } 261 block = block[:sz] 262 _, err = io.ReadFull(r.mr, block) 263 if err != nil { 264 return r.setErr(fmt.Errorf("reading block data: %w", err)) 265 } 266 267 // Parse block 268 switch id { 269 case blockPlainKey: 270 // Read plaintext key. 271 key, _, err := msgp.ReadBytesBytes(block, make([]byte, 0, 32)) 272 if err != nil { 273 return r.setErr(fmt.Errorf("reading key: %w", err)) 274 } 275 if len(key) != 32 { 276 return r.setErr(fmt.Errorf("unexpected key length: %d", len(key))) 277 } 278 279 // Set key for following streams. 280 r.key = (*[32]byte)(key) 281 fmt.Fprintf(w, "plain key read\n") 282 283 case blockEncryptedKey: 284 // Read public key 285 publicKey, block, err := msgp.ReadBytesZC(block) 286 if err != nil { 287 return r.setErr(fmt.Errorf("reading public key: %w", err)) 288 } 289 290 // Request private key if we have a custom function. 291 if r.privateFn != nil { 292 fmt.Fprintf(w, "requesting private key from privateFn\n") 293 pk, err := x509.ParsePKCS1PublicKey(publicKey) 294 if err != nil { 295 return r.setErr(fmt.Errorf("parse public key: %w", err)) 296 } 297 r.private = r.privateFn(pk) 298 if r.private == nil { 299 fmt.Fprintf(w, "privateFn did not provide private key\n") 300 if r.skipEncrypted || r.returnNonDec { 301 fmt.Fprintf(w, "continuing. skipEncrypted: %v, returnNonDec: %v\n", r.skipEncrypted, r.returnNonDec) 302 r.key = nil 303 continue 304 } 305 return r.setErr(errors.New("nil private key returned")) 306 } 307 } 308 309 // Read cipher key 310 cipherKey, _, err := msgp.ReadBytesZC(block) 311 if err != nil { 312 return r.setErr(fmt.Errorf("reading cipherkey: %w", err)) 313 } 314 if r.private == nil { 315 if r.skipEncrypted || r.returnNonDec { 316 fmt.Fprintf(w, "no private key, continuing due to skipEncrypted: %v, returnNonDec: %v\n", r.skipEncrypted, r.returnNonDec) 317 r.key = nil 318 continue 319 } 320 return r.setErr(errors.New("private key has not been set")) 321 } 322 323 // Decrypt stream key 324 key, err := rsa.DecryptOAEP(sha512.New(), rand.Reader, r.private, cipherKey, nil) 325 if err != nil { 326 if r.returnNonDec { 327 fmt.Fprintf(w, "no private key, continuing due to returnNonDec: %v\n", r.returnNonDec) 328 r.key = nil 329 continue 330 } 331 return fmt.Errorf("decrypting key: %w", err) 332 } 333 334 if len(key) != 32 { 335 return r.setErr(fmt.Errorf("unexpected key length: %d", len(key))) 336 } 337 r.key = (*[32]byte)(key) 338 fmt.Fprintf(w, "stream key decoded\n") 339 340 case blockPlainStream, blockEncStream: 341 // Read metadata 342 name, block, err := msgp.ReadStringBytes(block) 343 if err != nil { 344 return r.setErr(fmt.Errorf("reading name: %w", err)) 345 } 346 extra, block, err := msgp.ReadBytesBytes(block, nil) 347 if err != nil { 348 return r.setErr(fmt.Errorf("reading extra: %w", err)) 349 } 350 c, block, err := msgp.ReadUint8Bytes(block) 351 if err != nil { 352 return r.setErr(fmt.Errorf("reading checksum: %w", err)) 353 } 354 checksum := checksumType(c) 355 if !checksum.valid() { 356 return r.setErr(fmt.Errorf("unknown checksum type %d", checksum)) 357 } 358 fmt.Fprintf(w, "new stream. name: %v, extra size: %v, checksum type: %v\n", name, len(extra), checksum) 359 360 for _, h := range hashers { 361 if h != nil { 362 h.Reset() 363 } 364 } 365 366 // Return plaintext stream 367 if id == blockPlainStream { 368 r.inStream = true 369 continue 370 } 371 372 // Handle encrypted streams. 373 if r.key == nil { 374 if r.skipEncrypted { 375 fmt.Fprintf(w, "nil key, skipEncrypted: %v\n", r.skipEncrypted) 376 r.inStream = true 377 continue 378 } 379 return ErrNoKey 380 } 381 // Read stream nonce 382 nonce, _, err := msgp.ReadBytesZC(block) 383 if err != nil { 384 return r.setErr(fmt.Errorf("reading nonce: %w", err)) 385 } 386 387 stream, err := sio.AES_256_GCM.Stream(r.key[:]) 388 if err != nil { 389 return r.setErr(fmt.Errorf("initializing sio: %w", err)) 390 } 391 392 // Check if nonce is expected length. 393 if len(nonce) != stream.NonceSize() { 394 return r.setErr(fmt.Errorf("unexpected nonce length: %d", len(nonce))) 395 } 396 fmt.Fprintf(w, "nonce: %v\n", nonce) 397 r.inStream = true 398 case blockEOS: 399 if !r.inStream { 400 return errors.New("end-of-stream without being in stream") 401 } 402 h, _, err := msgp.ReadBytesZC(block) 403 if err != nil { 404 return r.setErr(fmt.Errorf("reading block data: %w", err)) 405 } 406 fmt.Fprintf(w, "end-of-stream. stream hash: %s. data hashes: ", hex.EncodeToString(h)) 407 for i, h := range hashers { 408 if h != nil { 409 fmt.Fprintf(w, "%s:%s. ", checksumType(i), hex.EncodeToString(h.Sum(nil))) 410 } 411 } 412 fmt.Fprint(w, "\n") 413 r.inStream = false 414 case blockEOF: 415 if r.inStream { 416 return errors.New("end-of-file without finishing stream") 417 } 418 fmt.Fprintf(w, "end-of-file\n") 419 return nil 420 case blockError: 421 msg, _, err := msgp.ReadStringBytes(block) 422 if err != nil { 423 return r.setErr(fmt.Errorf("reading error string: %w", err)) 424 } 425 fmt.Fprintf(w, "error recorded on stream: %v\n", msg) 426 return nil 427 case blockDatablock: 428 buf, _, err := msgp.ReadBytesZC(block) 429 if err != nil { 430 return r.setErr(fmt.Errorf("reading block data: %w", err)) 431 } 432 for _, h := range hashers { 433 if h != nil { 434 h.Write(buf) 435 } 436 } 437 fmt.Fprintf(w, "data block, length: %v\n", len(buf)) 438 default: 439 fmt.Fprintf(w, "skipping block\n") 440 if id >= 0 { 441 return fmt.Errorf("unknown block type: %d", id) 442 } 443 } 444 } 445 }