github.com/minio/madmin-go/v3@v3.0.51/estream/stream_test.go (about) 1 // 2 // Copyright (c) 2015-2022 MinIO, Inc. 3 // 4 // This file is part of MinIO Object Storage stack 5 // 6 // This program is free software: you can redistribute it and/or modify 7 // it under the terms of the GNU Affero General Public License as 8 // published by the Free Software Foundation, either version 3 of the 9 // License, or (at your option) any later version. 10 // 11 // This program is distributed in the hope that it will be useful, 12 // but WITHOUT ANY WARRANTY; without even the implied warranty of 13 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 // GNU Affero General Public License for more details. 15 // 16 // You should have received a copy of the GNU Affero General Public License 17 // along with this program. If not, see <http://www.gnu.org/licenses/>. 18 // 19 20 package estream 21 22 import ( 23 "bytes" 24 crand "crypto/rand" 25 "crypto/rsa" 26 "io" 27 "os" 28 "testing" 29 ) 30 31 var testStreams = map[string][]byte{ 32 "stream1": bytes.Repeat([]byte("a"), 2000), 33 "stream2": bytes.Repeat([]byte("b"), 1<<20), 34 "stream3": bytes.Repeat([]byte("b"), 5), 35 "empty": {}, 36 } 37 38 func TestStreamRoundtrip(t *testing.T) { 39 var buf bytes.Buffer 40 w := NewWriter(&buf) 41 if err := w.AddKeyPlain(); err != nil { 42 t.Fatal(err) 43 } 44 wantStreams := 0 45 wantDecStreams := 0 46 for name, value := range testStreams { 47 st, err := w.AddEncryptedStream(name, []byte(name)) 48 if err != nil { 49 t.Fatal(err) 50 } 51 _, err = io.Copy(st, bytes.NewBuffer(value)) 52 if err != nil { 53 t.Fatal(err) 54 } 55 st.Close() 56 st, err = w.AddUnencryptedStream(name, []byte(name)) 57 if err != nil { 58 t.Fatal(err) 59 } 60 _, err = io.Copy(st, bytes.NewBuffer(value)) 61 if err != nil { 62 t.Fatal(err) 63 } 64 st.Close() 65 wantStreams += 2 66 wantDecStreams += 2 67 } 68 69 priv, err := rsa.GenerateKey(crand.Reader, 2048) 70 if err != nil { 71 t.Fatal(err) 72 } 73 err = w.AddKeyEncrypted(&priv.PublicKey) 74 if err != nil { 75 t.Fatal(err) 76 } 77 for name, value := range testStreams { 78 st, err := w.AddEncryptedStream(name, []byte(name)) 79 if err != nil { 80 t.Fatal(err) 81 } 82 _, err = io.Copy(st, bytes.NewBuffer(value)) 83 if err != nil { 84 t.Fatal(err) 85 } 86 st.Close() 87 st, err = w.AddUnencryptedStream(name, []byte(name)) 88 if err != nil { 89 t.Fatal(err) 90 } 91 _, err = io.Copy(st, bytes.NewBuffer(value)) 92 if err != nil { 93 t.Fatal(err) 94 } 95 st.Close() 96 wantStreams += 2 97 wantDecStreams++ 98 } 99 err = w.Close() 100 if err != nil { 101 t.Fatal(err) 102 } 103 104 // Read back... 105 b := buf.Bytes() 106 r, err := NewReader(bytes.NewBuffer(b)) 107 if err != nil { 108 t.Fatal(err) 109 } 110 r.SetPrivateKey(priv) 111 112 var gotStreams int 113 for { 114 st, err := r.NextStream() 115 if err == io.EOF { 116 break 117 } 118 if err != nil { 119 t.Fatalf("stream %d: %v", gotStreams, err) 120 } 121 want, ok := testStreams[st.Name] 122 if !ok { 123 t.Fatal("unexpected stream name", st.Name) 124 } 125 if !bytes.Equal(st.Extra, []byte(st.Name)) { 126 t.Fatal("unexpected stream extra:", st.Extra) 127 } 128 got, err := io.ReadAll(st) 129 if err != nil { 130 t.Fatalf("stream %d: %v", gotStreams, err) 131 } 132 if !bytes.Equal(got, want) { 133 t.Errorf("stream %d: content mismatch (len %d,%d)", gotStreams, len(got), len(want)) 134 } 135 gotStreams++ 136 } 137 if gotStreams != wantStreams { 138 t.Errorf("want %d streams, got %d", wantStreams, gotStreams) 139 } 140 141 // Read back, but skip encrypted streams. 142 r, err = NewReader(bytes.NewBuffer(b)) 143 if err != nil { 144 t.Fatal(err) 145 } 146 r.SkipEncrypted(true) 147 148 gotStreams = 0 149 for { 150 st, err := r.NextStream() 151 if err == io.EOF { 152 break 153 } 154 if err != nil { 155 t.Fatalf("stream %d: %v", gotStreams, err) 156 } 157 want, ok := testStreams[st.Name] 158 if !ok { 159 t.Fatal("unexpected stream name", st.Name) 160 } 161 if !bytes.Equal(st.Extra, []byte(st.Name)) { 162 t.Fatal("unexpected stream extra:", st.Extra) 163 } 164 got, err := io.ReadAll(st) 165 if err != nil { 166 t.Fatalf("stream %d: %v", gotStreams, err) 167 } 168 if !bytes.Equal(got, want) { 169 t.Errorf("stream %d: content mismatch (len %d,%d)", gotStreams, len(got), len(want)) 170 } 171 gotStreams++ 172 } 173 if gotStreams != wantDecStreams { 174 t.Errorf("want %d streams, got %d", wantStreams, gotStreams) 175 } 176 177 gotStreams = 0 178 r, err = NewReader(bytes.NewBuffer(b)) 179 if err != nil { 180 t.Fatal(err) 181 } 182 r.SkipEncrypted(true) 183 for { 184 st, err := r.NextStream() 185 if err == io.EOF { 186 break 187 } 188 if err != nil { 189 t.Fatalf("stream %d: %v", gotStreams, err) 190 } 191 _, ok := testStreams[st.Name] 192 if !ok { 193 t.Fatal("unexpected stream name", st.Name) 194 } 195 if !bytes.Equal(st.Extra, []byte(st.Name)) { 196 t.Fatal("unexpected stream extra:", st.Extra) 197 } 198 err = st.Skip() 199 if err != nil { 200 t.Fatalf("stream %d: %v", gotStreams, err) 201 } 202 gotStreams++ 203 } 204 if gotStreams != wantDecStreams { 205 t.Errorf("want %d streams, got %d", wantDecStreams, gotStreams) 206 } 207 208 if false { 209 r, err = NewReader(bytes.NewBuffer(b)) 210 if err != nil { 211 t.Fatal(err) 212 } 213 r.SkipEncrypted(true) 214 215 err = r.DebugStream(os.Stdout) 216 if err != nil { 217 t.Fatal(err) 218 } 219 } 220 } 221 222 func TestReplaceKeys(t *testing.T) { 223 var buf bytes.Buffer 224 w := NewWriter(&buf) 225 if err := w.AddKeyPlain(); err != nil { 226 t.Fatal(err) 227 } 228 wantStreams := 0 229 for name, value := range testStreams { 230 st, err := w.AddEncryptedStream(name, []byte(name)) 231 if err != nil { 232 t.Fatal(err) 233 } 234 _, err = io.Copy(st, bytes.NewBuffer(value)) 235 if err != nil { 236 t.Fatal(err) 237 } 238 st.Close() 239 st, err = w.AddUnencryptedStream(name, []byte(name)) 240 if err != nil { 241 t.Fatal(err) 242 } 243 _, err = io.Copy(st, bytes.NewBuffer(value)) 244 if err != nil { 245 t.Fatal(err) 246 } 247 st.Close() 248 wantStreams += 2 249 } 250 251 priv, err := rsa.GenerateKey(crand.Reader, 2048) 252 if err != nil { 253 t.Fatal(err) 254 } 255 err = w.AddKeyEncrypted(&priv.PublicKey) 256 if err != nil { 257 t.Fatal(err) 258 } 259 for name, value := range testStreams { 260 st, err := w.AddEncryptedStream(name, []byte(name)) 261 if err != nil { 262 t.Fatal(err) 263 } 264 _, err = io.Copy(st, bytes.NewBuffer(value)) 265 if err != nil { 266 t.Fatal(err) 267 } 268 st.Close() 269 st, err = w.AddUnencryptedStream(name, []byte(name)) 270 if err != nil { 271 t.Fatal(err) 272 } 273 _, err = io.Copy(st, bytes.NewBuffer(value)) 274 if err != nil { 275 t.Fatal(err) 276 } 277 st.Close() 278 wantStreams += 2 279 } 280 err = w.Close() 281 if err != nil { 282 t.Fatal(err) 283 } 284 285 priv2, err := rsa.GenerateKey(crand.Reader, 2048) 286 if err != nil { 287 t.Fatal(err) 288 } 289 290 var replaced bytes.Buffer 291 err = ReplaceKeys(&replaced, &buf, func(key *rsa.PublicKey) (*rsa.PrivateKey, *rsa.PublicKey) { 292 if key == nil { 293 return nil, &priv2.PublicKey 294 } 295 if key.Equal(&priv.PublicKey) { 296 return priv, &priv2.PublicKey 297 } 298 t.Fatal("unknown key\n", *key, "\nwant\n", priv.PublicKey) 299 return nil, nil 300 }, ReplaceKeysOptions{EncryptAll: true}) 301 if err != nil { 302 t.Fatal(err) 303 } 304 305 // Read back... 306 r, err := NewReader(&replaced) 307 if err != nil { 308 t.Fatal(err) 309 } 310 311 // Use key provider. 312 r.PrivateKeyProvider(func(key *rsa.PublicKey) *rsa.PrivateKey { 313 if key.Equal(&priv2.PublicKey) { 314 return priv2 315 } 316 t.Fatal("unexpected public key") 317 return nil 318 }) 319 320 var gotStreams int 321 for { 322 st, err := r.NextStream() 323 if err == io.EOF { 324 break 325 } 326 if err != nil { 327 t.Fatalf("stream %d: %v", gotStreams, err) 328 } 329 want, ok := testStreams[st.Name] 330 if !ok { 331 t.Fatal("unexpected stream name", st.Name) 332 } 333 if st.SentEncrypted != (gotStreams&1 == 0) { 334 t.Errorf("stream %d was sent with unexpected encryption %v", gotStreams, st.SentEncrypted) 335 } 336 if !bytes.Equal(st.Extra, []byte(st.Name)) { 337 t.Fatal("unexpected stream extra:", st.Extra) 338 } 339 got, err := io.ReadAll(st) 340 if err != nil { 341 t.Fatalf("stream %d: %v", gotStreams, err) 342 } 343 if !bytes.Equal(got, want) { 344 t.Errorf("stream %d: content mismatch (len %d,%d)", gotStreams, len(got), len(want)) 345 } 346 gotStreams++ 347 } 348 if gotStreams != wantStreams { 349 t.Errorf("want %d streams, got %d", wantStreams, gotStreams) 350 } 351 } 352 353 func TestError(t *testing.T) { 354 var buf bytes.Buffer 355 w := NewWriter(&buf) 356 if err := w.AddKeyPlain(); err != nil { 357 t.Fatal(err) 358 } 359 want := "an error message!" 360 if err := w.AddError(want); err != nil { 361 t.Fatal(err) 362 } 363 w.Close() 364 365 // Read back... 366 r, err := NewReader(&buf) 367 if err != nil { 368 t.Fatal(err) 369 } 370 st, err := r.NextStream() 371 if err == nil { 372 t.Fatalf("did not receive error, got %v, err: %v", st, err) 373 } 374 if err.Error() != want { 375 t.Errorf("Expected %q, got %q", want, err.Error()) 376 } 377 } 378 379 func TestStreamReturnNonDecryptable(t *testing.T) { 380 var buf bytes.Buffer 381 w := NewWriter(&buf) 382 if err := w.AddKeyPlain(); err != nil { 383 t.Fatal(err) 384 } 385 386 priv, err := rsa.GenerateKey(crand.Reader, 2048) 387 if err != nil { 388 t.Fatal(err) 389 } 390 err = w.AddKeyEncrypted(&priv.PublicKey) 391 if err != nil { 392 t.Fatal(err) 393 } 394 wantStreams := len(testStreams) 395 for name, value := range testStreams { 396 st, err := w.AddEncryptedStream(name, []byte(name)) 397 if err != nil { 398 t.Fatal(err) 399 } 400 _, err = io.Copy(st, bytes.NewBuffer(value)) 401 if err != nil { 402 t.Fatal(err) 403 } 404 st.Close() 405 } 406 err = w.Close() 407 if err != nil { 408 t.Fatal(err) 409 } 410 411 // Read back... 412 b := buf.Bytes() 413 r, err := NewReader(bytes.NewBuffer(b)) 414 if err != nil { 415 t.Fatal(err) 416 } 417 r.ReturnNonDecryptable(true) 418 gotStreams := 0 419 for { 420 st, err := r.NextStream() 421 if err == io.EOF { 422 break 423 } 424 if err != ErrNoKey { 425 t.Fatalf("stream %d: %v", gotStreams, err) 426 } 427 _, ok := testStreams[st.Name] 428 if !ok { 429 t.Fatal("unexpected stream name", st.Name) 430 } 431 if !bytes.Equal(st.Extra, []byte(st.Name)) { 432 t.Fatal("unexpected stream extra:", st.Extra) 433 } 434 if !st.SentEncrypted { 435 t.Fatal("stream not marked as encrypted:", st.SentEncrypted) 436 } 437 err = st.Skip() 438 if err != nil { 439 t.Fatalf("stream %d: %v", gotStreams, err) 440 } 441 gotStreams++ 442 } 443 if gotStreams != wantStreams { 444 t.Errorf("want %d streams, got %d", wantStreams, gotStreams) 445 } 446 }