github.com/minio/madmin-go/v2@v2.2.1/estream/reader.go (about) 1 // 2 // Copyright (c) 2015-2022 MinIO, Inc. 3 // 4 // This file is part of MinIO Object Storage stack 5 // 6 // This program is free software: you can redistribute it and/or modify 7 // it under the terms of the GNU Affero General Public License as 8 // published by the Free Software Foundation, either version 3 of the 9 // License, or (at your option) any later version. 10 // 11 // This program is distributed in the hope that it will be useful, 12 // but WITHOUT ANY WARRANTY; without even the implied warranty of 13 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 // GNU Affero General Public License for more details. 15 // 16 // You should have received a copy of the GNU Affero General Public License 17 // along with this program. If not, see <http://www.gnu.org/licenses/>. 18 // 19 20 package estream 21 22 import ( 23 "bytes" 24 "crypto/rand" 25 "crypto/rsa" 26 "crypto/sha512" 27 "crypto/x509" 28 "encoding/hex" 29 "errors" 30 "fmt" 31 "io" 32 "runtime" 33 34 "github.com/cespare/xxhash/v2" 35 "github.com/secure-io/sio-go" 36 "github.com/tinylib/msgp/msgp" 37 ) 38 39 type Reader struct { 40 mr *msgp.Reader 41 majorV uint8 42 minorV uint8 43 err error 44 inStream bool 45 key *[32]byte 46 private *rsa.PrivateKey 47 privateFn func(key *rsa.PublicKey) *rsa.PrivateKey 48 skipEncrypted bool 49 returnNonDec bool 50 } 51 52 // ErrNoKey is returned when a stream cannot be decrypted. 53 // The Skip function on the stream can be called to skip to the next. 54 var ErrNoKey = errors.New("no valid private key found") 55 56 // NewReader will return a Reader that will split streams. 57 func NewReader(r io.Reader) (*Reader, error) { 58 var ver [2]byte 59 if _, err := io.ReadFull(r, ver[:]); err != nil { 60 return nil, err 61 } 62 switch ver[0] { 63 case 2: 64 default: 65 return nil, fmt.Errorf("unknown stream version: 0x%x", ver[0]) 66 } 67 68 return &Reader{mr: msgp.NewReader(r), majorV: ver[0], minorV: ver[1]}, nil 69 } 70 71 // SetPrivateKey will set the private key to allow stream decryption. 72 // This overrides any function set by PrivateKeyProvider. 73 func (r *Reader) SetPrivateKey(k *rsa.PrivateKey) { 74 r.privateFn = nil 75 r.private = k 76 } 77 78 // PrivateKeyProvider will ask for a private key matching the public key. 79 // If the function returns a nil private key the stream key will not be decrypted 80 // and if SkipEncrypted has been set any streams with this key will be silently skipped. 81 // This overrides any key set by SetPrivateKey. 82 func (r *Reader) PrivateKeyProvider(fn func(key *rsa.PublicKey) *rsa.PrivateKey) { 83 r.privateFn = fn 84 r.private = nil 85 } 86 87 // SkipEncrypted will skip encrypted streams if no private key has been set. 88 func (r *Reader) SkipEncrypted(b bool) { 89 r.skipEncrypted = b 90 } 91 92 // ReturnNonDecryptable will return non-decryptable stream headers. 93 // Streams are returned with ErrNoKey error. 94 // Streams with this error cannot be read, but the Skip function can be invoked. 95 // SkipEncrypted overrides this. 96 func (r *Reader) ReturnNonDecryptable(b bool) { 97 r.returnNonDec = b 98 } 99 100 // Stream returns the next stream. 101 type Stream struct { 102 io.Reader 103 Name string 104 Extra []byte 105 SentEncrypted bool 106 107 parent *Reader 108 } 109 110 // NextStream will return the next stream. 111 // Before calling this the previous stream must be read until EOF, 112 // or Skip() should have been called. 113 // Will return nil, io.EOF when there are no more streams. 114 func (r *Reader) NextStream() (*Stream, error) { 115 if r.err != nil { 116 return nil, r.err 117 } 118 if r.inStream { 119 return nil, errors.New("previous stream not read until EOF") 120 } 121 122 // Temp storage for blocks. 123 block := make([]byte, 1024) 124 for { 125 // Read block ID. 126 n, err := r.mr.ReadInt8() 127 if err != nil { 128 return nil, r.setErr(err) 129 } 130 id := blockID(n) 131 132 // Read block size 133 sz, err := r.mr.ReadUint32() 134 if err != nil { 135 return nil, r.setErr(err) 136 } 137 138 // Read block data 139 if cap(block) < int(sz) { 140 block = make([]byte, sz) 141 } 142 block = block[:sz] 143 _, err = io.ReadFull(r.mr, block) 144 if err != nil { 145 return nil, r.setErr(err) 146 } 147 148 // Parse block 149 switch id { 150 case blockPlainKey: 151 // Read plaintext key. 152 key, _, err := msgp.ReadBytesBytes(block, make([]byte, 0, 32)) 153 if err != nil { 154 return nil, r.setErr(err) 155 } 156 if len(key) != 32 { 157 return nil, r.setErr(fmt.Errorf("unexpected key length: %d", len(key))) 158 } 159 160 // Set key for following streams. 161 r.key = (*[32]byte)(key) 162 case blockEncryptedKey: 163 // Read public key 164 publicKey, block, err := msgp.ReadBytesZC(block) 165 if err != nil { 166 return nil, r.setErr(err) 167 } 168 169 // Request private key if we have a custom function. 170 if r.privateFn != nil { 171 pk, err := x509.ParsePKCS1PublicKey(publicKey) 172 if err != nil { 173 return nil, r.setErr(err) 174 } 175 r.private = r.privateFn(pk) 176 if r.private == nil { 177 if r.skipEncrypted || r.returnNonDec { 178 r.key = nil 179 continue 180 } 181 return nil, r.setErr(errors.New("nil private key returned")) 182 } 183 } 184 185 // Read cipher key 186 cipherKey, _, err := msgp.ReadBytesZC(block) 187 if err != nil { 188 return nil, r.setErr(err) 189 } 190 if r.private == nil { 191 if r.skipEncrypted || r.returnNonDec { 192 r.key = nil 193 continue 194 } 195 return nil, r.setErr(errors.New("private key has not been set")) 196 } 197 198 // Decrypt stream key 199 key, err := rsa.DecryptOAEP(sha512.New(), rand.Reader, r.private, cipherKey, nil) 200 if err != nil { 201 if r.returnNonDec { 202 r.key = nil 203 continue 204 } 205 return nil, err 206 } 207 208 if len(key) != 32 { 209 return nil, r.setErr(fmt.Errorf("unexpected key length: %d", len(key))) 210 } 211 r.key = (*[32]byte)(key) 212 213 case blockPlainStream, blockEncStream: 214 // Read metadata 215 name, block, err := msgp.ReadStringBytes(block) 216 if err != nil { 217 return nil, r.setErr(err) 218 } 219 extra, block, err := msgp.ReadBytesBytes(block, nil) 220 if err != nil { 221 return nil, r.setErr(err) 222 } 223 c, block, err := msgp.ReadUint8Bytes(block) 224 if err != nil { 225 return nil, r.setErr(err) 226 } 227 checksum := checksumType(c) 228 if !checksum.valid() { 229 return nil, r.setErr(fmt.Errorf("unknown checksum type %d", checksum)) 230 } 231 232 // Return plaintext stream 233 if id == blockPlainStream { 234 return &Stream{ 235 Reader: r.newStreamReader(checksum), 236 Name: name, 237 Extra: extra, 238 parent: r, 239 }, nil 240 } 241 242 // Handle encrypted streams. 243 if r.key == nil { 244 if r.skipEncrypted { 245 if err := r.skipDataBlocks(); err != nil { 246 return nil, r.setErr(err) 247 } 248 continue 249 } 250 return &Stream{ 251 SentEncrypted: true, 252 Reader: nil, 253 Name: name, 254 Extra: extra, 255 parent: r, 256 }, ErrNoKey 257 } 258 // Read stream nonce 259 nonce, _, err := msgp.ReadBytesZC(block) 260 if err != nil { 261 return nil, r.setErr(err) 262 } 263 264 stream, err := sio.AES_256_GCM.Stream(r.key[:]) 265 if err != nil { 266 return nil, r.setErr(err) 267 } 268 269 // Check if nonce is expected length. 270 if len(nonce) != stream.NonceSize() { 271 return nil, r.setErr(fmt.Errorf("unexpected nonce length: %d", len(nonce))) 272 } 273 274 encr := stream.DecryptReader(r.newStreamReader(checksum), nonce, nil) 275 return &Stream{ 276 SentEncrypted: true, 277 Reader: encr, 278 Name: name, 279 Extra: extra, 280 parent: r, 281 }, nil 282 case blockEOS: 283 return nil, errors.New("end-of-stream without being in stream") 284 case blockEOF: 285 return nil, io.EOF 286 case blockError: 287 msg, _, err := msgp.ReadStringBytes(block) 288 if err != nil { 289 return nil, r.setErr(err) 290 } 291 return nil, r.setErr(errors.New(msg)) 292 default: 293 if id >= 0 { 294 return nil, fmt.Errorf("unknown block type: %d", id) 295 } 296 } 297 } 298 } 299 300 // skipDataBlocks reads data blocks until end. 301 func (r *Reader) skipDataBlocks() error { 302 for { 303 // Read block ID. 304 n, err := r.mr.ReadInt8() 305 if err != nil { 306 return err 307 } 308 id := blockID(n) 309 sz, err := r.mr.ReadUint32() 310 if err != nil { 311 return err 312 } 313 if id == blockError { 314 msg, err := r.mr.ReadString() 315 if err != nil { 316 return err 317 } 318 return errors.New(msg) 319 } 320 // Discard data 321 _, err = io.CopyN(io.Discard, r.mr, int64(sz)) 322 if err != nil { 323 return err 324 } 325 switch id { 326 case blockDatablock: 327 // Skip data 328 case blockEOS: 329 // Done 330 r.inStream = false 331 return nil 332 default: 333 if id >= 0 { 334 return fmt.Errorf("unknown block type: %d", id) 335 } 336 } 337 } 338 } 339 340 // setErr sets a stateful error. 341 func (r *Reader) setErr(err error) error { 342 if r.err != nil { 343 return r.err 344 } 345 if err == nil { 346 return err 347 } 348 if errors.Is(err, io.EOF) { 349 r.err = io.ErrUnexpectedEOF 350 } 351 if false { 352 _, file, line, ok := runtime.Caller(1) 353 if ok { 354 err = fmt.Errorf("%s:%d: %w", file, line, err) 355 } 356 } 357 r.err = err 358 return err 359 } 360 361 type streamReader struct { 362 up *Reader 363 h xxhash.Digest 364 buf bytes.Buffer 365 tmp []byte 366 isEOF bool 367 check checksumType 368 } 369 370 // newStreamReader creates a stream reader that can be read to get all data blocks. 371 func (r *Reader) newStreamReader(ct checksumType) *streamReader { 372 sr := &streamReader{up: r, check: ct} 373 sr.h.Reset() 374 r.inStream = true 375 return sr 376 } 377 378 // Skip the remainder of the stream. 379 func (s *Stream) Skip() error { 380 if sr, ok := s.Reader.(*streamReader); ok { 381 sr.isEOF = true 382 sr.buf.Reset() 383 } 384 return s.parent.skipDataBlocks() 385 } 386 387 // Read will return data blocks as on stream. 388 func (r *streamReader) Read(b []byte) (int, error) { 389 if r.isEOF { 390 return 0, io.EOF 391 } 392 if r.up.err != nil { 393 return 0, r.up.err 394 } 395 for { 396 // If we have anything in the buffer return that first. 397 if r.buf.Len() > 0 { 398 n, err := r.buf.Read(b) 399 if err == io.EOF { 400 err = nil 401 } 402 return n, r.up.setErr(err) 403 } 404 405 // Read block 406 n, err := r.up.mr.ReadInt8() 407 if err != nil { 408 return 0, r.up.setErr(err) 409 } 410 id := blockID(n) 411 412 // Read size... 413 sz, err := r.up.mr.ReadUint32() 414 if err != nil { 415 return 0, r.up.setErr(err) 416 } 417 418 switch id { 419 case blockDatablock: 420 // Read block 421 buf, err := r.up.mr.ReadBytes(r.tmp[:0]) 422 if err != nil { 423 return 0, r.up.setErr(err) 424 } 425 426 // Write to buffer and checksum 427 if r.check == checksumTypeXxhash { 428 r.h.Write(buf) 429 } 430 r.tmp = buf 431 r.buf.Write(buf) 432 case blockEOS: 433 // Verify stream checksum if any. 434 hash, err := r.up.mr.ReadBytes(nil) 435 if err != nil { 436 return 0, r.up.setErr(err) 437 } 438 switch r.check { 439 case checksumTypeXxhash: 440 got := r.h.Sum(nil) 441 if !bytes.Equal(hash, got) { 442 return 0, r.up.setErr(fmt.Errorf("checksum mismatch, want %s, got %s", hex.EncodeToString(hash), hex.EncodeToString(got))) 443 } 444 case checksumTypeNone: 445 default: 446 return 0, r.up.setErr(fmt.Errorf("unknown checksum id %d", r.check)) 447 } 448 r.isEOF = true 449 r.up.inStream = false 450 return 0, io.EOF 451 case blockError: 452 msg, err := r.up.mr.ReadString() 453 if err != nil { 454 return 0, r.up.setErr(err) 455 } 456 return 0, r.up.setErr(errors.New(msg)) 457 default: 458 if id >= 0 { 459 return 0, fmt.Errorf("unexpected block type: %d", id) 460 } 461 // Skip block... 462 _, err := io.CopyN(io.Discard, r.up.mr, int64(sz)) 463 if err != nil { 464 return 0, r.up.setErr(err) 465 } 466 } 467 } 468 }