github.com/minio/madmin-go/v2@v2.2.1/estream/stream.go (about) 1 // 2 // Copyright (c) 2015-2022 MinIO, Inc. 3 // 4 // This file is part of MinIO Object Storage stack 5 // 6 // This program is free software: you can redistribute it and/or modify 7 // it under the terms of the GNU Affero General Public License as 8 // published by the Free Software Foundation, either version 3 of the 9 // License, or (at your option) any later version. 10 // 11 // This program is distributed in the hope that it will be useful, 12 // but WITHOUT ANY WARRANTY; without even the implied warranty of 13 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 // GNU Affero General Public License for more details. 15 // 16 // You should have received a copy of the GNU Affero General Public License 17 // along with this program. If not, see <http://www.gnu.org/licenses/>. 18 // 19 20 package estream 21 22 import ( 23 "crypto/rand" 24 crand "crypto/rand" 25 "crypto/rsa" 26 "crypto/sha512" 27 "crypto/x509" 28 "encoding/hex" 29 "errors" 30 "fmt" 31 "hash" 32 "io" 33 34 "github.com/cespare/xxhash/v2" 35 "github.com/secure-io/sio-go" 36 "github.com/tinylib/msgp/msgp" 37 ) 38 39 // ReplaceFn provides key replacement. 40 // 41 // When a key is found on stream, the function is called with the public key. 42 // The function must then return a private key to decrypt matching the key sent. 43 // The public key must then be specified that should be used to re-encrypt the stream. 44 // 45 // If no private key is sent and the public key matches the one sent to the function 46 // the key will be kept as is. Other returned values will cause an error. 47 // 48 // For encrypting unencrypted keys on stream a nil key will be sent. 49 // If a public key is returned the key will be encrypted with the public key. 50 // No private key should be returned for this. 51 type ReplaceFn func(key *rsa.PublicKey) (*rsa.PrivateKey, *rsa.PublicKey) 52 53 // ReplaceKeysOptions allows passing additional options to ReplaceKeys. 54 type ReplaceKeysOptions struct { 55 // If EncryptAll set all unencrypted keys will be encrypted. 56 EncryptAll bool 57 58 // PassErrors will pass through error an error packet, 59 // and not return an error. 60 PassErrors bool 61 } 62 63 // ReplaceKeys will replace the keys in a stream. 64 // 65 // A replace function must be provided. See ReplaceFn for functionality. 66 // If encryptAll is set. 67 func ReplaceKeys(w io.Writer, r io.Reader, replace ReplaceFn, o ReplaceKeysOptions) error { 68 var ver [2]byte 69 if _, err := io.ReadFull(r, ver[:]); err != nil { 70 return err 71 } 72 switch ver[0] { 73 case 2: 74 default: 75 return fmt.Errorf("unknown stream version: 0x%x", ver[0]) 76 } 77 if _, err := w.Write(ver[:]); err != nil { 78 return err 79 } 80 // Input 81 mr := msgp.NewReader(r) 82 mw := msgp.NewWriter(w) 83 84 // Temporary block storage. 85 block := make([]byte, 1024) 86 87 // Write a block. 88 writeBlock := func(id blockID, sz uint32, content []byte) error { 89 if err := mw.WriteInt8(int8(id)); err != nil { 90 return err 91 } 92 if err := mw.WriteUint32(sz); err != nil { 93 return err 94 } 95 _, err := mw.Write(content) 96 return err 97 } 98 99 for { 100 // Read block ID. 101 n, err := mr.ReadInt8() 102 if err != nil { 103 return err 104 } 105 id := blockID(n) 106 107 // Read size 108 sz, err := mr.ReadUint32() 109 if err != nil { 110 return err 111 } 112 if cap(block) < int(sz) { 113 block = make([]byte, sz) 114 } 115 block = block[:sz] 116 _, err = io.ReadFull(mr, block) 117 if err != nil { 118 return err 119 } 120 121 switch id { 122 case blockEncryptedKey: 123 ogBlock := block 124 // Read public key 125 publicKey, block, err := msgp.ReadBytesZC(block) 126 if err != nil { 127 return err 128 } 129 130 pk, err := x509.ParsePKCS1PublicKey(publicKey) 131 if err != nil { 132 return err 133 } 134 135 private, public := replace(pk) 136 if private == nil && public == pk { 137 if err := writeBlock(id, sz, ogBlock); err != nil { 138 return err 139 } 140 } 141 if private == nil { 142 return errors.New("no private key provided, unable to re-encrypt") 143 } 144 145 // Read cipher key 146 cipherKey, _, err := msgp.ReadBytesZC(block) 147 if err != nil { 148 return err 149 } 150 151 // Decrypt stream key 152 key, err := rsa.DecryptOAEP(sha512.New(), crand.Reader, private, cipherKey, nil) 153 if err != nil { 154 return err 155 } 156 157 if len(key) != 32 { 158 return fmt.Errorf("unexpected key length: %d", len(key)) 159 } 160 161 cipherKey, err = rsa.EncryptOAEP(sha512.New(), crand.Reader, public, key[:], nil) 162 if err != nil { 163 return err 164 } 165 166 // Write Public key 167 tmp := msgp.AppendBytes(nil, x509.MarshalPKCS1PublicKey(public)) 168 // Write encrypted cipher key 169 tmp = msgp.AppendBytes(tmp, cipherKey) 170 if err := writeBlock(blockEncryptedKey, uint32(len(tmp)), tmp); err != nil { 171 return err 172 } 173 case blockPlainKey: 174 if !o.EncryptAll { 175 if err := writeBlock(id, sz, block); err != nil { 176 return err 177 } 178 continue 179 } 180 _, public := replace(nil) 181 if public == nil { 182 if err := writeBlock(id, sz, block); err != nil { 183 return err 184 } 185 continue 186 } 187 key, _, err := msgp.ReadBytesZC(block) 188 if err != nil { 189 return err 190 } 191 if len(key) != 32 { 192 return fmt.Errorf("unexpected key length: %d", len(key)) 193 } 194 cipherKey, err := rsa.EncryptOAEP(sha512.New(), crand.Reader, public, key[:], nil) 195 if err != nil { 196 return err 197 } 198 199 // Write Public key 200 tmp := msgp.AppendBytes(nil, x509.MarshalPKCS1PublicKey(public)) 201 // Write encrypted cipher key 202 tmp = msgp.AppendBytes(tmp, cipherKey) 203 if err := writeBlock(blockEncryptedKey, uint32(len(tmp)), tmp); err != nil { 204 return err 205 } 206 case blockEOF: 207 if err := writeBlock(id, sz, block); err != nil { 208 return err 209 } 210 return mw.Flush() 211 case blockError: 212 if o.PassErrors { 213 if err := writeBlock(id, sz, block); err != nil { 214 return err 215 } 216 return mw.Flush() 217 } 218 // Return error 219 msg, _, err := msgp.ReadStringBytes(block) 220 if err != nil { 221 return err 222 } 223 return errors.New(msg) 224 default: 225 if err := writeBlock(id, sz, block); err != nil { 226 return err 227 } 228 } 229 } 230 } 231 232 // DebugStream will print stream block information to w. 233 func (r *Reader) DebugStream(w io.Writer) error { 234 if r.err != nil { 235 return r.err 236 } 237 if r.inStream { 238 return errors.New("previous stream not read until EOF") 239 } 240 fmt.Fprintf(w, "stream major: %v, minor: %v\n", r.majorV, r.minorV) 241 242 // Temp storage for blocks. 243 block := make([]byte, 1024) 244 hashers := []hash.Hash{nil, xxhash.New()} 245 for { 246 // Read block ID. 247 n, err := r.mr.ReadInt8() 248 if err != nil { 249 return r.setErr(fmt.Errorf("reading block id: %w", err)) 250 } 251 id := blockID(n) 252 253 // Read block size 254 sz, err := r.mr.ReadUint32() 255 if err != nil { 256 return r.setErr(fmt.Errorf("reading block size: %w", err)) 257 } 258 fmt.Fprintf(w, "block type: %v, size: %d bytes, in stream: %v\n", id, sz, r.inStream) 259 260 // Read block data 261 if cap(block) < int(sz) { 262 block = make([]byte, sz) 263 } 264 block = block[:sz] 265 _, err = io.ReadFull(r.mr, block) 266 if err != nil { 267 return r.setErr(fmt.Errorf("reading block data: %w", err)) 268 } 269 270 // Parse block 271 switch id { 272 case blockPlainKey: 273 // Read plaintext key. 274 key, _, err := msgp.ReadBytesBytes(block, make([]byte, 0, 32)) 275 if err != nil { 276 return r.setErr(fmt.Errorf("reading key: %w", err)) 277 } 278 if len(key) != 32 { 279 return r.setErr(fmt.Errorf("unexpected key length: %d", len(key))) 280 } 281 282 // Set key for following streams. 283 r.key = (*[32]byte)(key) 284 fmt.Fprintf(w, "plain key read\n") 285 286 case blockEncryptedKey: 287 // Read public key 288 publicKey, block, err := msgp.ReadBytesZC(block) 289 if err != nil { 290 return r.setErr(fmt.Errorf("reading public key: %w", err)) 291 } 292 293 // Request private key if we have a custom function. 294 if r.privateFn != nil { 295 fmt.Fprintf(w, "requesting private key from privateFn\n") 296 pk, err := x509.ParsePKCS1PublicKey(publicKey) 297 if err != nil { 298 return r.setErr(fmt.Errorf("parse public key: %w", err)) 299 } 300 r.private = r.privateFn(pk) 301 if r.private == nil { 302 fmt.Fprintf(w, "privateFn did not provide private key\n") 303 if r.skipEncrypted || r.returnNonDec { 304 fmt.Fprintf(w, "continuing. skipEncrypted: %v, returnNonDec: %v\n", r.skipEncrypted, r.returnNonDec) 305 r.key = nil 306 continue 307 } 308 return r.setErr(errors.New("nil private key returned")) 309 } 310 } 311 312 // Read cipher key 313 cipherKey, _, err := msgp.ReadBytesZC(block) 314 if err != nil { 315 return r.setErr(fmt.Errorf("reading cipherkey: %w", err)) 316 } 317 if r.private == nil { 318 if r.skipEncrypted || r.returnNonDec { 319 fmt.Fprintf(w, "no private key, continuing due to skipEncrypted: %v, returnNonDec: %v\n", r.skipEncrypted, r.returnNonDec) 320 r.key = nil 321 continue 322 } 323 return r.setErr(errors.New("private key has not been set")) 324 } 325 326 // Decrypt stream key 327 key, err := rsa.DecryptOAEP(sha512.New(), rand.Reader, r.private, cipherKey, nil) 328 if err != nil { 329 if r.returnNonDec { 330 fmt.Fprintf(w, "no private key, continuing due to returnNonDec: %v\n", r.returnNonDec) 331 r.key = nil 332 continue 333 } 334 return fmt.Errorf("decrypting key: %w", err) 335 } 336 337 if len(key) != 32 { 338 return r.setErr(fmt.Errorf("unexpected key length: %d", len(key))) 339 } 340 r.key = (*[32]byte)(key) 341 fmt.Fprintf(w, "stream key decoded\n") 342 343 case blockPlainStream, blockEncStream: 344 // Read metadata 345 name, block, err := msgp.ReadStringBytes(block) 346 if err != nil { 347 return r.setErr(fmt.Errorf("reading name: %w", err)) 348 } 349 extra, block, err := msgp.ReadBytesBytes(block, nil) 350 if err != nil { 351 return r.setErr(fmt.Errorf("reading extra: %w", err)) 352 } 353 c, block, err := msgp.ReadUint8Bytes(block) 354 if err != nil { 355 return r.setErr(fmt.Errorf("reading checksum: %w", err)) 356 } 357 checksum := checksumType(c) 358 if !checksum.valid() { 359 return r.setErr(fmt.Errorf("unknown checksum type %d", checksum)) 360 } 361 fmt.Fprintf(w, "new stream. name: %v, extra size: %v, checksum type: %v\n", name, len(extra), checksum) 362 363 for _, h := range hashers { 364 if h != nil { 365 h.Reset() 366 } 367 } 368 369 // Return plaintext stream 370 if id == blockPlainStream { 371 r.inStream = true 372 continue 373 } 374 375 // Handle encrypted streams. 376 if r.key == nil { 377 if r.skipEncrypted { 378 fmt.Fprintf(w, "nil key, skipEncrypted: %v\n", r.skipEncrypted) 379 r.inStream = true 380 continue 381 } 382 return ErrNoKey 383 } 384 // Read stream nonce 385 nonce, _, err := msgp.ReadBytesZC(block) 386 if err != nil { 387 return r.setErr(fmt.Errorf("reading nonce: %w", err)) 388 } 389 390 stream, err := sio.AES_256_GCM.Stream(r.key[:]) 391 if err != nil { 392 return r.setErr(fmt.Errorf("initializing sio: %w", err)) 393 } 394 395 // Check if nonce is expected length. 396 if len(nonce) != stream.NonceSize() { 397 return r.setErr(fmt.Errorf("unexpected nonce length: %d", len(nonce))) 398 } 399 fmt.Fprintf(w, "nonce: %v\n", nonce) 400 r.inStream = true 401 case blockEOS: 402 if !r.inStream { 403 return errors.New("end-of-stream without being in stream") 404 } 405 h, _, err := msgp.ReadBytesZC(block) 406 if err != nil { 407 return r.setErr(fmt.Errorf("reading block data: %w", err)) 408 } 409 fmt.Fprintf(w, "end-of-stream. stream hash: %s. data hashes: ", hex.EncodeToString(h)) 410 for i, h := range hashers { 411 if h != nil { 412 fmt.Fprintf(w, "%s:%s. ", checksumType(i), hex.EncodeToString(h.Sum(nil))) 413 } 414 } 415 fmt.Fprint(w, "\n") 416 r.inStream = false 417 case blockEOF: 418 if r.inStream { 419 return errors.New("end-of-file without finishing stream") 420 } 421 fmt.Fprintf(w, "end-of-file\n") 422 return nil 423 case blockError: 424 msg, _, err := msgp.ReadStringBytes(block) 425 if err != nil { 426 return r.setErr(fmt.Errorf("reading error string: %w", err)) 427 } 428 fmt.Fprintf(w, "error recorded on stream: %v\n", msg) 429 return nil 430 case blockDatablock: 431 buf, _, err := msgp.ReadBytesZC(block) 432 if err != nil { 433 return r.setErr(fmt.Errorf("reading block data: %w", err)) 434 } 435 for _, h := range hashers { 436 if h != nil { 437 h.Write(buf) 438 } 439 } 440 fmt.Fprintf(w, "data block, length: %v\n", len(buf)) 441 default: 442 fmt.Fprintf(w, "skipping block\n") 443 if id >= 0 { 444 return fmt.Errorf("unknown block type: %d", id) 445 } 446 } 447 } 448 }