github.com/minio/madmin-go@v1.7.5/estream/stream_test.go (about) 1 // 2 // MinIO Object Storage (c) 2022 MinIO, Inc. 3 // 4 // Licensed under the Apache License, Version 2.0 (the "License"); 5 // you may not use this file except in compliance with the License. 6 // You may obtain a copy of the License at 7 // 8 // http://www.apache.org/licenses/LICENSE-2.0 9 // 10 // Unless required by applicable law or agreed to in writing, software 11 // distributed under the License is distributed on an "AS IS" BASIS, 12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 // See the License for the specific language governing permissions and 14 // limitations under the License. 15 // 16 17 package estream 18 19 import ( 20 "bytes" 21 crand "crypto/rand" 22 "crypto/rsa" 23 "io" 24 "os" 25 "testing" 26 ) 27 28 var testStreams = map[string][]byte{ 29 "stream1": bytes.Repeat([]byte("a"), 2000), 30 "stream2": bytes.Repeat([]byte("b"), 1<<20), 31 "stream3": bytes.Repeat([]byte("b"), 5), 32 "empty": {}, 33 } 34 35 func TestStreamRoundtrip(t *testing.T) { 36 var buf bytes.Buffer 37 w := NewWriter(&buf) 38 if err := w.AddKeyPlain(); err != nil { 39 t.Fatal(err) 40 } 41 wantStreams := 0 42 wantDecStreams := 0 43 for name, value := range testStreams { 44 st, err := w.AddEncryptedStream(name, []byte(name)) 45 if err != nil { 46 t.Fatal(err) 47 } 48 _, err = io.Copy(st, bytes.NewBuffer(value)) 49 if err != nil { 50 t.Fatal(err) 51 } 52 st.Close() 53 st, err = w.AddUnencryptedStream(name, []byte(name)) 54 if err != nil { 55 t.Fatal(err) 56 } 57 _, err = io.Copy(st, bytes.NewBuffer(value)) 58 if err != nil { 59 t.Fatal(err) 60 } 61 st.Close() 62 wantStreams += 2 63 wantDecStreams += 2 64 } 65 66 priv, err := rsa.GenerateKey(crand.Reader, 2048) 67 if err != nil { 68 t.Fatal(err) 69 } 70 err = w.AddKeyEncrypted(&priv.PublicKey) 71 if err != nil { 72 t.Fatal(err) 73 } 74 for name, value := range testStreams { 75 st, err := w.AddEncryptedStream(name, []byte(name)) 76 if err != nil { 77 t.Fatal(err) 78 } 79 _, err = io.Copy(st, bytes.NewBuffer(value)) 80 if err != nil { 81 t.Fatal(err) 82 } 83 st.Close() 84 st, err = w.AddUnencryptedStream(name, []byte(name)) 85 if err != nil { 86 t.Fatal(err) 87 } 88 _, err = io.Copy(st, bytes.NewBuffer(value)) 89 if err != nil { 90 t.Fatal(err) 91 } 92 st.Close() 93 wantStreams += 2 94 wantDecStreams++ 95 } 96 err = w.Close() 97 if err != nil { 98 t.Fatal(err) 99 } 100 101 // Read back... 102 b := buf.Bytes() 103 r, err := NewReader(bytes.NewBuffer(b)) 104 if err != nil { 105 t.Fatal(err) 106 } 107 r.SetPrivateKey(priv) 108 109 var gotStreams int 110 for { 111 st, err := r.NextStream() 112 if err == io.EOF { 113 break 114 } 115 if err != nil { 116 t.Fatalf("stream %d: %v", gotStreams, err) 117 } 118 want, ok := testStreams[st.Name] 119 if !ok { 120 t.Fatal("unexpected stream name", st.Name) 121 } 122 if !bytes.Equal(st.Extra, []byte(st.Name)) { 123 t.Fatal("unexpected stream extra:", st.Extra) 124 } 125 got, err := io.ReadAll(st) 126 if err != nil { 127 t.Fatalf("stream %d: %v", gotStreams, err) 128 } 129 if !bytes.Equal(got, want) { 130 t.Errorf("stream %d: content mismatch (len %d,%d)", gotStreams, len(got), len(want)) 131 } 132 gotStreams++ 133 } 134 if gotStreams != wantStreams { 135 t.Errorf("want %d streams, got %d", wantStreams, gotStreams) 136 } 137 138 // Read back, but skip encrypted streams. 139 r, err = NewReader(bytes.NewBuffer(b)) 140 if err != nil { 141 t.Fatal(err) 142 } 143 r.SkipEncrypted(true) 144 145 gotStreams = 0 146 for { 147 st, err := r.NextStream() 148 if err == io.EOF { 149 break 150 } 151 if err != nil { 152 t.Fatalf("stream %d: %v", gotStreams, err) 153 } 154 want, ok := testStreams[st.Name] 155 if !ok { 156 t.Fatal("unexpected stream name", st.Name) 157 } 158 if !bytes.Equal(st.Extra, []byte(st.Name)) { 159 t.Fatal("unexpected stream extra:", st.Extra) 160 } 161 got, err := io.ReadAll(st) 162 if err != nil { 163 t.Fatalf("stream %d: %v", gotStreams, err) 164 } 165 if !bytes.Equal(got, want) { 166 t.Errorf("stream %d: content mismatch (len %d,%d)", gotStreams, len(got), len(want)) 167 } 168 gotStreams++ 169 } 170 if gotStreams != wantDecStreams { 171 t.Errorf("want %d streams, got %d", wantStreams, gotStreams) 172 } 173 174 gotStreams = 0 175 r, err = NewReader(bytes.NewBuffer(b)) 176 if err != nil { 177 t.Fatal(err) 178 } 179 r.SkipEncrypted(true) 180 for { 181 st, err := r.NextStream() 182 if err == io.EOF { 183 break 184 } 185 if err != nil { 186 t.Fatalf("stream %d: %v", gotStreams, err) 187 } 188 _, ok := testStreams[st.Name] 189 if !ok { 190 t.Fatal("unexpected stream name", st.Name) 191 } 192 if !bytes.Equal(st.Extra, []byte(st.Name)) { 193 t.Fatal("unexpected stream extra:", st.Extra) 194 } 195 err = st.Skip() 196 if err != nil { 197 t.Fatalf("stream %d: %v", gotStreams, err) 198 } 199 gotStreams++ 200 } 201 if gotStreams != wantDecStreams { 202 t.Errorf("want %d streams, got %d", wantDecStreams, gotStreams) 203 } 204 205 if false { 206 r, err = NewReader(bytes.NewBuffer(b)) 207 if err != nil { 208 t.Fatal(err) 209 } 210 r.SkipEncrypted(true) 211 212 err = r.DebugStream(os.Stdout) 213 if err != nil { 214 t.Fatal(err) 215 } 216 } 217 } 218 219 func TestReplaceKeys(t *testing.T) { 220 var buf bytes.Buffer 221 w := NewWriter(&buf) 222 if err := w.AddKeyPlain(); err != nil { 223 t.Fatal(err) 224 } 225 wantStreams := 0 226 for name, value := range testStreams { 227 st, err := w.AddEncryptedStream(name, []byte(name)) 228 if err != nil { 229 t.Fatal(err) 230 } 231 _, err = io.Copy(st, bytes.NewBuffer(value)) 232 if err != nil { 233 t.Fatal(err) 234 } 235 st.Close() 236 st, err = w.AddUnencryptedStream(name, []byte(name)) 237 if err != nil { 238 t.Fatal(err) 239 } 240 _, err = io.Copy(st, bytes.NewBuffer(value)) 241 if err != nil { 242 t.Fatal(err) 243 } 244 st.Close() 245 wantStreams += 2 246 } 247 248 priv, err := rsa.GenerateKey(crand.Reader, 2048) 249 if err != nil { 250 t.Fatal(err) 251 } 252 err = w.AddKeyEncrypted(&priv.PublicKey) 253 if err != nil { 254 t.Fatal(err) 255 } 256 for name, value := range testStreams { 257 st, err := w.AddEncryptedStream(name, []byte(name)) 258 if err != nil { 259 t.Fatal(err) 260 } 261 _, err = io.Copy(st, bytes.NewBuffer(value)) 262 if err != nil { 263 t.Fatal(err) 264 } 265 st.Close() 266 st, err = w.AddUnencryptedStream(name, []byte(name)) 267 if err != nil { 268 t.Fatal(err) 269 } 270 _, err = io.Copy(st, bytes.NewBuffer(value)) 271 if err != nil { 272 t.Fatal(err) 273 } 274 st.Close() 275 wantStreams += 2 276 } 277 err = w.Close() 278 if err != nil { 279 t.Fatal(err) 280 } 281 282 priv2, err := rsa.GenerateKey(crand.Reader, 2048) 283 if err != nil { 284 t.Fatal(err) 285 } 286 287 var replaced bytes.Buffer 288 err = ReplaceKeys(&replaced, &buf, func(key *rsa.PublicKey) (*rsa.PrivateKey, *rsa.PublicKey) { 289 if key == nil { 290 return nil, &priv2.PublicKey 291 } 292 if key.Equal(&priv.PublicKey) { 293 return priv, &priv2.PublicKey 294 } 295 t.Fatal("unknown key\n", *key, "\nwant\n", priv.PublicKey) 296 return nil, nil 297 }, ReplaceKeysOptions{EncryptAll: true}) 298 if err != nil { 299 t.Fatal(err) 300 } 301 302 // Read back... 303 r, err := NewReader(&replaced) 304 if err != nil { 305 t.Fatal(err) 306 } 307 308 // Use key provider. 309 r.PrivateKeyProvider(func(key *rsa.PublicKey) *rsa.PrivateKey { 310 if key.Equal(&priv2.PublicKey) { 311 return priv2 312 } 313 t.Fatal("unexpected public key") 314 return nil 315 }) 316 317 var gotStreams int 318 for { 319 st, err := r.NextStream() 320 if err == io.EOF { 321 break 322 } 323 if err != nil { 324 t.Fatalf("stream %d: %v", gotStreams, err) 325 } 326 want, ok := testStreams[st.Name] 327 if !ok { 328 t.Fatal("unexpected stream name", st.Name) 329 } 330 if st.SentEncrypted != (gotStreams&1 == 0) { 331 t.Errorf("stream %d was sent with unexpected encryption %v", gotStreams, st.SentEncrypted) 332 } 333 if !bytes.Equal(st.Extra, []byte(st.Name)) { 334 t.Fatal("unexpected stream extra:", st.Extra) 335 } 336 got, err := io.ReadAll(st) 337 if err != nil { 338 t.Fatalf("stream %d: %v", gotStreams, err) 339 } 340 if !bytes.Equal(got, want) { 341 t.Errorf("stream %d: content mismatch (len %d,%d)", gotStreams, len(got), len(want)) 342 } 343 gotStreams++ 344 } 345 if gotStreams != wantStreams { 346 t.Errorf("want %d streams, got %d", wantStreams, gotStreams) 347 } 348 } 349 350 func TestError(t *testing.T) { 351 var buf bytes.Buffer 352 w := NewWriter(&buf) 353 if err := w.AddKeyPlain(); err != nil { 354 t.Fatal(err) 355 } 356 want := "an error message!" 357 if err := w.AddError(want); err != nil { 358 t.Fatal(err) 359 } 360 w.Close() 361 362 // Read back... 363 r, err := NewReader(&buf) 364 if err != nil { 365 t.Fatal(err) 366 } 367 st, err := r.NextStream() 368 if err == nil { 369 t.Fatalf("did not receive error, got %v, err: %v", st, err) 370 } 371 if err.Error() != want { 372 t.Errorf("Expected %q, got %q", want, err.Error()) 373 } 374 } 375 376 func TestStreamReturnNonDecryptable(t *testing.T) { 377 var buf bytes.Buffer 378 w := NewWriter(&buf) 379 if err := w.AddKeyPlain(); err != nil { 380 t.Fatal(err) 381 } 382 383 priv, err := rsa.GenerateKey(crand.Reader, 2048) 384 if err != nil { 385 t.Fatal(err) 386 } 387 err = w.AddKeyEncrypted(&priv.PublicKey) 388 if err != nil { 389 t.Fatal(err) 390 } 391 wantStreams := len(testStreams) 392 for name, value := range testStreams { 393 st, err := w.AddEncryptedStream(name, []byte(name)) 394 if err != nil { 395 t.Fatal(err) 396 } 397 _, err = io.Copy(st, bytes.NewBuffer(value)) 398 if err != nil { 399 t.Fatal(err) 400 } 401 st.Close() 402 } 403 err = w.Close() 404 if err != nil { 405 t.Fatal(err) 406 } 407 408 // Read back... 409 b := buf.Bytes() 410 r, err := NewReader(bytes.NewBuffer(b)) 411 if err != nil { 412 t.Fatal(err) 413 } 414 r.ReturnNonDecryptable(true) 415 gotStreams := 0 416 for { 417 st, err := r.NextStream() 418 if err == io.EOF { 419 break 420 } 421 if err != ErrNoKey { 422 t.Fatalf("stream %d: %v", gotStreams, err) 423 } 424 _, ok := testStreams[st.Name] 425 if !ok { 426 t.Fatal("unexpected stream name", st.Name) 427 } 428 if !bytes.Equal(st.Extra, []byte(st.Name)) { 429 t.Fatal("unexpected stream extra:", st.Extra) 430 } 431 if !st.SentEncrypted { 432 t.Fatal("stream not marked as encrypted:", st.SentEncrypted) 433 } 434 err = st.Skip() 435 if err != nil { 436 t.Fatalf("stream %d: %v", gotStreams, err) 437 } 438 gotStreams++ 439 } 440 if gotStreams != wantStreams { 441 t.Errorf("want %d streams, got %d", wantStreams, gotStreams) 442 } 443 }