github.com/minio/madmin-go@v1.7.5/estream/reader.go (about) 1 // 2 // MinIO Object Storage (c) 2022 MinIO, Inc. 3 // 4 // Licensed under the Apache License, Version 2.0 (the "License"); 5 // you may not use this file except in compliance with the License. 6 // You may obtain a copy of the License at 7 // 8 // http://www.apache.org/licenses/LICENSE-2.0 9 // 10 // Unless required by applicable law or agreed to in writing, software 11 // distributed under the License is distributed on an "AS IS" BASIS, 12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 // See the License for the specific language governing permissions and 14 // limitations under the License. 15 // 16 17 package estream 18 19 import ( 20 "bytes" 21 "crypto/rand" 22 "crypto/rsa" 23 "crypto/sha512" 24 "crypto/x509" 25 "encoding/hex" 26 "errors" 27 "fmt" 28 "io" 29 "runtime" 30 31 "github.com/cespare/xxhash/v2" 32 "github.com/secure-io/sio-go" 33 "github.com/tinylib/msgp/msgp" 34 ) 35 36 type Reader struct { 37 mr *msgp.Reader 38 majorV uint8 39 minorV uint8 40 err error 41 inStream bool 42 key *[32]byte 43 private *rsa.PrivateKey 44 privateFn func(key *rsa.PublicKey) *rsa.PrivateKey 45 skipEncrypted bool 46 returnNonDec bool 47 } 48 49 // ErrNoKey is returned when a stream cannot be decrypted. 50 // The Skip function on the stream can be called to skip to the next. 51 var ErrNoKey = errors.New("no valid private key found") 52 53 // NewReader will return a Reader that will split streams. 54 func NewReader(r io.Reader) (*Reader, error) { 55 var ver [2]byte 56 if _, err := io.ReadFull(r, ver[:]); err != nil { 57 return nil, err 58 } 59 switch ver[0] { 60 case 2: 61 default: 62 return nil, fmt.Errorf("unknown stream version: 0x%x", ver[0]) 63 } 64 65 return &Reader{mr: msgp.NewReader(r), majorV: ver[0], minorV: ver[1]}, nil 66 } 67 68 // SetPrivateKey will set the private key to allow stream decryption. 69 // This overrides any function set by PrivateKeyProvider. 70 func (r *Reader) SetPrivateKey(k *rsa.PrivateKey) { 71 r.privateFn = nil 72 r.private = k 73 } 74 75 // PrivateKeyProvider will ask for a private key matching the public key. 76 // If the function returns a nil private key the stream key will not be decrypted 77 // and if SkipEncrypted has been set any streams with this key will be silently skipped. 78 // This overrides any key set by SetPrivateKey. 79 func (r *Reader) PrivateKeyProvider(fn func(key *rsa.PublicKey) *rsa.PrivateKey) { 80 r.privateFn = fn 81 r.private = nil 82 } 83 84 // SkipEncrypted will skip encrypted streams if no private key has been set. 85 func (r *Reader) SkipEncrypted(b bool) { 86 r.skipEncrypted = b 87 } 88 89 // ReturnNonDecryptable will return non-decryptable stream headers. 90 // Streams are returned with ErrNoKey error. 91 // Streams with this error cannot be read, but the Skip function can be invoked. 92 // SkipEncrypted overrides this. 93 func (r *Reader) ReturnNonDecryptable(b bool) { 94 r.returnNonDec = b 95 } 96 97 // Stream returns the next stream. 98 type Stream struct { 99 io.Reader 100 Name string 101 Extra []byte 102 SentEncrypted bool 103 104 parent *Reader 105 } 106 107 // NextStream will return the next stream. 108 // Before calling this the previous stream must be read until EOF, 109 // or Skip() should have been called. 110 // Will return nil, io.EOF when there are no more streams. 111 func (r *Reader) NextStream() (*Stream, error) { 112 if r.err != nil { 113 return nil, r.err 114 } 115 if r.inStream { 116 return nil, errors.New("previous stream not read until EOF") 117 } 118 119 // Temp storage for blocks. 120 block := make([]byte, 1024) 121 for { 122 // Read block ID. 123 n, err := r.mr.ReadInt8() 124 if err != nil { 125 return nil, r.setErr(err) 126 } 127 id := blockID(n) 128 129 // Read block size 130 sz, err := r.mr.ReadUint32() 131 if err != nil { 132 return nil, r.setErr(err) 133 } 134 135 // Read block data 136 if cap(block) < int(sz) { 137 block = make([]byte, sz) 138 } 139 block = block[:sz] 140 _, err = io.ReadFull(r.mr, block) 141 if err != nil { 142 return nil, r.setErr(err) 143 } 144 145 // Parse block 146 switch id { 147 case blockPlainKey: 148 // Read plaintext key. 149 key, _, err := msgp.ReadBytesBytes(block, make([]byte, 0, 32)) 150 if err != nil { 151 return nil, r.setErr(err) 152 } 153 if len(key) != 32 { 154 return nil, r.setErr(fmt.Errorf("unexpected key length: %d", len(key))) 155 } 156 157 // Set key for following streams. 158 r.key = (*[32]byte)(key) 159 case blockEncryptedKey: 160 // Read public key 161 publicKey, block, err := msgp.ReadBytesZC(block) 162 if err != nil { 163 return nil, r.setErr(err) 164 } 165 166 // Request private key if we have a custom function. 167 if r.privateFn != nil { 168 pk, err := x509.ParsePKCS1PublicKey(publicKey) 169 if err != nil { 170 return nil, r.setErr(err) 171 } 172 r.private = r.privateFn(pk) 173 if r.private == nil { 174 if r.skipEncrypted || r.returnNonDec { 175 r.key = nil 176 continue 177 } 178 return nil, r.setErr(errors.New("nil private key returned")) 179 } 180 } 181 182 // Read cipher key 183 cipherKey, _, err := msgp.ReadBytesZC(block) 184 if err != nil { 185 return nil, r.setErr(err) 186 } 187 if r.private == nil { 188 if r.skipEncrypted || r.returnNonDec { 189 r.key = nil 190 continue 191 } 192 return nil, r.setErr(errors.New("private key has not been set")) 193 } 194 195 // Decrypt stream key 196 key, err := rsa.DecryptOAEP(sha512.New(), rand.Reader, r.private, cipherKey, nil) 197 if err != nil { 198 if r.returnNonDec { 199 r.key = nil 200 continue 201 } 202 return nil, err 203 } 204 205 if len(key) != 32 { 206 return nil, r.setErr(fmt.Errorf("unexpected key length: %d", len(key))) 207 } 208 r.key = (*[32]byte)(key) 209 210 case blockPlainStream, blockEncStream: 211 // Read metadata 212 name, block, err := msgp.ReadStringBytes(block) 213 if err != nil { 214 return nil, r.setErr(err) 215 } 216 extra, block, err := msgp.ReadBytesBytes(block, nil) 217 if err != nil { 218 return nil, r.setErr(err) 219 } 220 c, block, err := msgp.ReadUint8Bytes(block) 221 if err != nil { 222 return nil, r.setErr(err) 223 } 224 checksum := checksumType(c) 225 if !checksum.valid() { 226 return nil, r.setErr(fmt.Errorf("unknown checksum type %d", checksum)) 227 } 228 229 // Return plaintext stream 230 if id == blockPlainStream { 231 return &Stream{ 232 Reader: r.newStreamReader(checksum), 233 Name: name, 234 Extra: extra, 235 parent: r, 236 }, nil 237 } 238 239 // Handle encrypted streams. 240 if r.key == nil { 241 if r.skipEncrypted { 242 if err := r.skipDataBlocks(); err != nil { 243 return nil, r.setErr(err) 244 } 245 continue 246 } 247 return &Stream{ 248 SentEncrypted: true, 249 Reader: nil, 250 Name: name, 251 Extra: extra, 252 parent: r, 253 }, ErrNoKey 254 } 255 // Read stream nonce 256 nonce, _, err := msgp.ReadBytesZC(block) 257 if err != nil { 258 return nil, r.setErr(err) 259 } 260 261 stream, err := sio.AES_256_GCM.Stream(r.key[:]) 262 if err != nil { 263 return nil, r.setErr(err) 264 } 265 266 // Check if nonce is expected length. 267 if len(nonce) != stream.NonceSize() { 268 return nil, r.setErr(fmt.Errorf("unexpected nonce length: %d", len(nonce))) 269 } 270 271 encr := stream.DecryptReader(r.newStreamReader(checksum), nonce, nil) 272 return &Stream{ 273 SentEncrypted: true, 274 Reader: encr, 275 Name: name, 276 Extra: extra, 277 parent: r, 278 }, nil 279 case blockEOS: 280 return nil, errors.New("end-of-stream without being in stream") 281 case blockEOF: 282 return nil, io.EOF 283 case blockError: 284 msg, _, err := msgp.ReadStringBytes(block) 285 if err != nil { 286 return nil, r.setErr(err) 287 } 288 return nil, r.setErr(errors.New(msg)) 289 default: 290 if id >= 0 { 291 return nil, fmt.Errorf("unknown block type: %d", id) 292 } 293 } 294 } 295 } 296 297 // skipDataBlocks reads data blocks until end. 298 func (r *Reader) skipDataBlocks() error { 299 for { 300 // Read block ID. 301 n, err := r.mr.ReadInt8() 302 if err != nil { 303 return err 304 } 305 id := blockID(n) 306 sz, err := r.mr.ReadUint32() 307 if err != nil { 308 return err 309 } 310 if id == blockError { 311 msg, err := r.mr.ReadString() 312 if err != nil { 313 return err 314 } 315 return errors.New(msg) 316 } 317 // Discard data 318 _, err = io.CopyN(io.Discard, r.mr, int64(sz)) 319 if err != nil { 320 return err 321 } 322 switch id { 323 case blockDatablock: 324 // Skip data 325 case blockEOS: 326 // Done 327 r.inStream = false 328 return nil 329 default: 330 if id >= 0 { 331 return fmt.Errorf("unknown block type: %d", id) 332 } 333 } 334 } 335 } 336 337 // setErr sets a stateful error. 338 func (r *Reader) setErr(err error) error { 339 if r.err != nil { 340 return r.err 341 } 342 if err == nil { 343 return err 344 } 345 if errors.Is(err, io.EOF) { 346 r.err = io.ErrUnexpectedEOF 347 } 348 if false { 349 _, file, line, ok := runtime.Caller(1) 350 if ok { 351 err = fmt.Errorf("%s:%d: %w", file, line, err) 352 } 353 } 354 r.err = err 355 return err 356 } 357 358 type streamReader struct { 359 up *Reader 360 h xxhash.Digest 361 buf bytes.Buffer 362 tmp []byte 363 isEOF bool 364 check checksumType 365 } 366 367 // newStreamReader creates a stream reader that can be read to get all data blocks. 368 func (r *Reader) newStreamReader(ct checksumType) *streamReader { 369 sr := &streamReader{up: r, check: ct} 370 sr.h.Reset() 371 r.inStream = true 372 return sr 373 } 374 375 // Skip the remainder of the stream. 376 func (s *Stream) Skip() error { 377 if sr, ok := s.Reader.(*streamReader); ok { 378 sr.isEOF = true 379 sr.buf.Reset() 380 } 381 return s.parent.skipDataBlocks() 382 } 383 384 // Read will return data blocks as on stream. 385 func (r *streamReader) Read(b []byte) (int, error) { 386 if r.isEOF { 387 return 0, io.EOF 388 } 389 if r.up.err != nil { 390 return 0, r.up.err 391 } 392 for { 393 // If we have anything in the buffer return that first. 394 if r.buf.Len() > 0 { 395 n, err := r.buf.Read(b) 396 if err == io.EOF { 397 err = nil 398 } 399 return n, r.up.setErr(err) 400 } 401 402 // Read block 403 n, err := r.up.mr.ReadInt8() 404 if err != nil { 405 return 0, r.up.setErr(err) 406 } 407 id := blockID(n) 408 409 // Read size... 410 sz, err := r.up.mr.ReadUint32() 411 if err != nil { 412 return 0, r.up.setErr(err) 413 } 414 415 switch id { 416 case blockDatablock: 417 // Read block 418 buf, err := r.up.mr.ReadBytes(r.tmp[:0]) 419 if err != nil { 420 return 0, r.up.setErr(err) 421 } 422 423 // Write to buffer and checksum 424 if r.check == checksumTypeXxhash { 425 r.h.Write(buf) 426 } 427 r.tmp = buf 428 r.buf.Write(buf) 429 case blockEOS: 430 // Verify stream checksum if any. 431 hash, err := r.up.mr.ReadBytes(nil) 432 if err != nil { 433 return 0, r.up.setErr(err) 434 } 435 switch r.check { 436 case checksumTypeXxhash: 437 got := r.h.Sum(nil) 438 if !bytes.Equal(hash, got) { 439 return 0, r.up.setErr(fmt.Errorf("checksum mismatch, want %s, got %s", hex.EncodeToString(hash), hex.EncodeToString(got))) 440 } 441 case checksumTypeNone: 442 default: 443 return 0, r.up.setErr(fmt.Errorf("unknown checksum id %d", r.check)) 444 } 445 r.isEOF = true 446 r.up.inStream = false 447 return 0, io.EOF 448 case blockError: 449 msg, err := r.up.mr.ReadString() 450 if err != nil { 451 return 0, r.up.setErr(err) 452 } 453 return 0, r.up.setErr(errors.New(msg)) 454 default: 455 if id >= 0 { 456 return 0, fmt.Errorf("unexpected block type: %d", id) 457 } 458 // Skip block... 459 _, err := io.CopyN(io.Discard, r.up.mr, int64(sz)) 460 if err != nil { 461 return 0, r.up.setErr(err) 462 } 463 } 464 } 465 }