github.com/minio/madmin-go/v2@v2.2.1/estream/stream_test.go (about)

     1  //
     2  // Copyright (c) 2015-2022 MinIO, Inc.
     3  //
     4  // This file is part of MinIO Object Storage stack
     5  //
     6  // This program is free software: you can redistribute it and/or modify
     7  // it under the terms of the GNU Affero General Public License as
     8  // published by the Free Software Foundation, either version 3 of the
     9  // License, or (at your option) any later version.
    10  //
    11  // This program is distributed in the hope that it will be useful,
    12  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    13  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    14  // GNU Affero General Public License for more details.
    15  //
    16  // You should have received a copy of the GNU Affero General Public License
    17  // along with this program. If not, see <http://www.gnu.org/licenses/>.
    18  //
    19  
    20  package estream
    21  
    22  import (
    23  	"bytes"
    24  	crand "crypto/rand"
    25  	"crypto/rsa"
    26  	"io"
    27  	"os"
    28  	"testing"
    29  )
    30  
    31  var testStreams = map[string][]byte{
    32  	"stream1": bytes.Repeat([]byte("a"), 2000),
    33  	"stream2": bytes.Repeat([]byte("b"), 1<<20),
    34  	"stream3": bytes.Repeat([]byte("b"), 5),
    35  	"empty":   {},
    36  }
    37  
    38  func TestStreamRoundtrip(t *testing.T) {
    39  	var buf bytes.Buffer
    40  	w := NewWriter(&buf)
    41  	if err := w.AddKeyPlain(); err != nil {
    42  		t.Fatal(err)
    43  	}
    44  	wantStreams := 0
    45  	wantDecStreams := 0
    46  	for name, value := range testStreams {
    47  		st, err := w.AddEncryptedStream(name, []byte(name))
    48  		if err != nil {
    49  			t.Fatal(err)
    50  		}
    51  		_, err = io.Copy(st, bytes.NewBuffer(value))
    52  		if err != nil {
    53  			t.Fatal(err)
    54  		}
    55  		st.Close()
    56  		st, err = w.AddUnencryptedStream(name, []byte(name))
    57  		if err != nil {
    58  			t.Fatal(err)
    59  		}
    60  		_, err = io.Copy(st, bytes.NewBuffer(value))
    61  		if err != nil {
    62  			t.Fatal(err)
    63  		}
    64  		st.Close()
    65  		wantStreams += 2
    66  		wantDecStreams += 2
    67  	}
    68  
    69  	priv, err := rsa.GenerateKey(crand.Reader, 2048)
    70  	if err != nil {
    71  		t.Fatal(err)
    72  	}
    73  	err = w.AddKeyEncrypted(&priv.PublicKey)
    74  	if err != nil {
    75  		t.Fatal(err)
    76  	}
    77  	for name, value := range testStreams {
    78  		st, err := w.AddEncryptedStream(name, []byte(name))
    79  		if err != nil {
    80  			t.Fatal(err)
    81  		}
    82  		_, err = io.Copy(st, bytes.NewBuffer(value))
    83  		if err != nil {
    84  			t.Fatal(err)
    85  		}
    86  		st.Close()
    87  		st, err = w.AddUnencryptedStream(name, []byte(name))
    88  		if err != nil {
    89  			t.Fatal(err)
    90  		}
    91  		_, err = io.Copy(st, bytes.NewBuffer(value))
    92  		if err != nil {
    93  			t.Fatal(err)
    94  		}
    95  		st.Close()
    96  		wantStreams += 2
    97  		wantDecStreams++
    98  	}
    99  	err = w.Close()
   100  	if err != nil {
   101  		t.Fatal(err)
   102  	}
   103  
   104  	// Read back...
   105  	b := buf.Bytes()
   106  	r, err := NewReader(bytes.NewBuffer(b))
   107  	if err != nil {
   108  		t.Fatal(err)
   109  	}
   110  	r.SetPrivateKey(priv)
   111  
   112  	var gotStreams int
   113  	for {
   114  		st, err := r.NextStream()
   115  		if err == io.EOF {
   116  			break
   117  		}
   118  		if err != nil {
   119  			t.Fatalf("stream %d: %v", gotStreams, err)
   120  		}
   121  		want, ok := testStreams[st.Name]
   122  		if !ok {
   123  			t.Fatal("unexpected stream name", st.Name)
   124  		}
   125  		if !bytes.Equal(st.Extra, []byte(st.Name)) {
   126  			t.Fatal("unexpected stream extra:", st.Extra)
   127  		}
   128  		got, err := io.ReadAll(st)
   129  		if err != nil {
   130  			t.Fatalf("stream %d: %v", gotStreams, err)
   131  		}
   132  		if !bytes.Equal(got, want) {
   133  			t.Errorf("stream %d: content mismatch (len %d,%d)", gotStreams, len(got), len(want))
   134  		}
   135  		gotStreams++
   136  	}
   137  	if gotStreams != wantStreams {
   138  		t.Errorf("want %d streams, got %d", wantStreams, gotStreams)
   139  	}
   140  
   141  	// Read back, but skip encrypted streams.
   142  	r, err = NewReader(bytes.NewBuffer(b))
   143  	if err != nil {
   144  		t.Fatal(err)
   145  	}
   146  	r.SkipEncrypted(true)
   147  
   148  	gotStreams = 0
   149  	for {
   150  		st, err := r.NextStream()
   151  		if err == io.EOF {
   152  			break
   153  		}
   154  		if err != nil {
   155  			t.Fatalf("stream %d: %v", gotStreams, err)
   156  		}
   157  		want, ok := testStreams[st.Name]
   158  		if !ok {
   159  			t.Fatal("unexpected stream name", st.Name)
   160  		}
   161  		if !bytes.Equal(st.Extra, []byte(st.Name)) {
   162  			t.Fatal("unexpected stream extra:", st.Extra)
   163  		}
   164  		got, err := io.ReadAll(st)
   165  		if err != nil {
   166  			t.Fatalf("stream %d: %v", gotStreams, err)
   167  		}
   168  		if !bytes.Equal(got, want) {
   169  			t.Errorf("stream %d: content mismatch (len %d,%d)", gotStreams, len(got), len(want))
   170  		}
   171  		gotStreams++
   172  	}
   173  	if gotStreams != wantDecStreams {
   174  		t.Errorf("want %d streams, got %d", wantStreams, gotStreams)
   175  	}
   176  
   177  	gotStreams = 0
   178  	r, err = NewReader(bytes.NewBuffer(b))
   179  	if err != nil {
   180  		t.Fatal(err)
   181  	}
   182  	r.SkipEncrypted(true)
   183  	for {
   184  		st, err := r.NextStream()
   185  		if err == io.EOF {
   186  			break
   187  		}
   188  		if err != nil {
   189  			t.Fatalf("stream %d: %v", gotStreams, err)
   190  		}
   191  		_, ok := testStreams[st.Name]
   192  		if !ok {
   193  			t.Fatal("unexpected stream name", st.Name)
   194  		}
   195  		if !bytes.Equal(st.Extra, []byte(st.Name)) {
   196  			t.Fatal("unexpected stream extra:", st.Extra)
   197  		}
   198  		err = st.Skip()
   199  		if err != nil {
   200  			t.Fatalf("stream %d: %v", gotStreams, err)
   201  		}
   202  		gotStreams++
   203  	}
   204  	if gotStreams != wantDecStreams {
   205  		t.Errorf("want %d streams, got %d", wantDecStreams, gotStreams)
   206  	}
   207  
   208  	if false {
   209  		r, err = NewReader(bytes.NewBuffer(b))
   210  		if err != nil {
   211  			t.Fatal(err)
   212  		}
   213  		r.SkipEncrypted(true)
   214  
   215  		err = r.DebugStream(os.Stdout)
   216  		if err != nil {
   217  			t.Fatal(err)
   218  		}
   219  	}
   220  }
   221  
   222  func TestReplaceKeys(t *testing.T) {
   223  	var buf bytes.Buffer
   224  	w := NewWriter(&buf)
   225  	if err := w.AddKeyPlain(); err != nil {
   226  		t.Fatal(err)
   227  	}
   228  	wantStreams := 0
   229  	for name, value := range testStreams {
   230  		st, err := w.AddEncryptedStream(name, []byte(name))
   231  		if err != nil {
   232  			t.Fatal(err)
   233  		}
   234  		_, err = io.Copy(st, bytes.NewBuffer(value))
   235  		if err != nil {
   236  			t.Fatal(err)
   237  		}
   238  		st.Close()
   239  		st, err = w.AddUnencryptedStream(name, []byte(name))
   240  		if err != nil {
   241  			t.Fatal(err)
   242  		}
   243  		_, err = io.Copy(st, bytes.NewBuffer(value))
   244  		if err != nil {
   245  			t.Fatal(err)
   246  		}
   247  		st.Close()
   248  		wantStreams += 2
   249  	}
   250  
   251  	priv, err := rsa.GenerateKey(crand.Reader, 2048)
   252  	if err != nil {
   253  		t.Fatal(err)
   254  	}
   255  	err = w.AddKeyEncrypted(&priv.PublicKey)
   256  	if err != nil {
   257  		t.Fatal(err)
   258  	}
   259  	for name, value := range testStreams {
   260  		st, err := w.AddEncryptedStream(name, []byte(name))
   261  		if err != nil {
   262  			t.Fatal(err)
   263  		}
   264  		_, err = io.Copy(st, bytes.NewBuffer(value))
   265  		if err != nil {
   266  			t.Fatal(err)
   267  		}
   268  		st.Close()
   269  		st, err = w.AddUnencryptedStream(name, []byte(name))
   270  		if err != nil {
   271  			t.Fatal(err)
   272  		}
   273  		_, err = io.Copy(st, bytes.NewBuffer(value))
   274  		if err != nil {
   275  			t.Fatal(err)
   276  		}
   277  		st.Close()
   278  		wantStreams += 2
   279  	}
   280  	err = w.Close()
   281  	if err != nil {
   282  		t.Fatal(err)
   283  	}
   284  
   285  	priv2, err := rsa.GenerateKey(crand.Reader, 2048)
   286  	if err != nil {
   287  		t.Fatal(err)
   288  	}
   289  
   290  	var replaced bytes.Buffer
   291  	err = ReplaceKeys(&replaced, &buf, func(key *rsa.PublicKey) (*rsa.PrivateKey, *rsa.PublicKey) {
   292  		if key == nil {
   293  			return nil, &priv2.PublicKey
   294  		}
   295  		if key.Equal(&priv.PublicKey) {
   296  			return priv, &priv2.PublicKey
   297  		}
   298  		t.Fatal("unknown key\n", *key, "\nwant\n", priv.PublicKey)
   299  		return nil, nil
   300  	}, ReplaceKeysOptions{EncryptAll: true})
   301  	if err != nil {
   302  		t.Fatal(err)
   303  	}
   304  
   305  	// Read back...
   306  	r, err := NewReader(&replaced)
   307  	if err != nil {
   308  		t.Fatal(err)
   309  	}
   310  
   311  	// Use key provider.
   312  	r.PrivateKeyProvider(func(key *rsa.PublicKey) *rsa.PrivateKey {
   313  		if key.Equal(&priv2.PublicKey) {
   314  			return priv2
   315  		}
   316  		t.Fatal("unexpected public key")
   317  		return nil
   318  	})
   319  
   320  	var gotStreams int
   321  	for {
   322  		st, err := r.NextStream()
   323  		if err == io.EOF {
   324  			break
   325  		}
   326  		if err != nil {
   327  			t.Fatalf("stream %d: %v", gotStreams, err)
   328  		}
   329  		want, ok := testStreams[st.Name]
   330  		if !ok {
   331  			t.Fatal("unexpected stream name", st.Name)
   332  		}
   333  		if st.SentEncrypted != (gotStreams&1 == 0) {
   334  			t.Errorf("stream %d was sent with unexpected encryption %v", gotStreams, st.SentEncrypted)
   335  		}
   336  		if !bytes.Equal(st.Extra, []byte(st.Name)) {
   337  			t.Fatal("unexpected stream extra:", st.Extra)
   338  		}
   339  		got, err := io.ReadAll(st)
   340  		if err != nil {
   341  			t.Fatalf("stream %d: %v", gotStreams, err)
   342  		}
   343  		if !bytes.Equal(got, want) {
   344  			t.Errorf("stream %d: content mismatch (len %d,%d)", gotStreams, len(got), len(want))
   345  		}
   346  		gotStreams++
   347  	}
   348  	if gotStreams != wantStreams {
   349  		t.Errorf("want %d streams, got %d", wantStreams, gotStreams)
   350  	}
   351  }
   352  
   353  func TestError(t *testing.T) {
   354  	var buf bytes.Buffer
   355  	w := NewWriter(&buf)
   356  	if err := w.AddKeyPlain(); err != nil {
   357  		t.Fatal(err)
   358  	}
   359  	want := "an error message!"
   360  	if err := w.AddError(want); err != nil {
   361  		t.Fatal(err)
   362  	}
   363  	w.Close()
   364  
   365  	// Read back...
   366  	r, err := NewReader(&buf)
   367  	if err != nil {
   368  		t.Fatal(err)
   369  	}
   370  	st, err := r.NextStream()
   371  	if err == nil {
   372  		t.Fatalf("did not receive error, got %v, err: %v", st, err)
   373  	}
   374  	if err.Error() != want {
   375  		t.Errorf("Expected %q, got %q", want, err.Error())
   376  	}
   377  }
   378  
   379  func TestStreamReturnNonDecryptable(t *testing.T) {
   380  	var buf bytes.Buffer
   381  	w := NewWriter(&buf)
   382  	if err := w.AddKeyPlain(); err != nil {
   383  		t.Fatal(err)
   384  	}
   385  
   386  	priv, err := rsa.GenerateKey(crand.Reader, 2048)
   387  	if err != nil {
   388  		t.Fatal(err)
   389  	}
   390  	err = w.AddKeyEncrypted(&priv.PublicKey)
   391  	if err != nil {
   392  		t.Fatal(err)
   393  	}
   394  	wantStreams := len(testStreams)
   395  	for name, value := range testStreams {
   396  		st, err := w.AddEncryptedStream(name, []byte(name))
   397  		if err != nil {
   398  			t.Fatal(err)
   399  		}
   400  		_, err = io.Copy(st, bytes.NewBuffer(value))
   401  		if err != nil {
   402  			t.Fatal(err)
   403  		}
   404  		st.Close()
   405  	}
   406  	err = w.Close()
   407  	if err != nil {
   408  		t.Fatal(err)
   409  	}
   410  
   411  	// Read back...
   412  	b := buf.Bytes()
   413  	r, err := NewReader(bytes.NewBuffer(b))
   414  	if err != nil {
   415  		t.Fatal(err)
   416  	}
   417  	r.ReturnNonDecryptable(true)
   418  	gotStreams := 0
   419  	for {
   420  		st, err := r.NextStream()
   421  		if err == io.EOF {
   422  			break
   423  		}
   424  		if err != ErrNoKey {
   425  			t.Fatalf("stream %d: %v", gotStreams, err)
   426  		}
   427  		_, ok := testStreams[st.Name]
   428  		if !ok {
   429  			t.Fatal("unexpected stream name", st.Name)
   430  		}
   431  		if !bytes.Equal(st.Extra, []byte(st.Name)) {
   432  			t.Fatal("unexpected stream extra:", st.Extra)
   433  		}
   434  		if !st.SentEncrypted {
   435  			t.Fatal("stream not marked as encrypted:", st.SentEncrypted)
   436  		}
   437  		err = st.Skip()
   438  		if err != nil {
   439  			t.Fatalf("stream %d: %v", gotStreams, err)
   440  		}
   441  		gotStreams++
   442  	}
   443  	if gotStreams != wantStreams {
   444  		t.Errorf("want %d streams, got %d", wantStreams, gotStreams)
   445  	}
   446  }