github.com/minio/madmin-go@v1.7.5/estream/stream_test.go (about)

     1  //
     2  // MinIO Object Storage (c) 2022 MinIO, Inc.
     3  //
     4  // Licensed under the Apache License, Version 2.0 (the "License");
     5  // you may not use this file except in compliance with the License.
     6  // You may obtain a copy of the License at
     7  //
     8  //      http://www.apache.org/licenses/LICENSE-2.0
     9  //
    10  // Unless required by applicable law or agreed to in writing, software
    11  // distributed under the License is distributed on an "AS IS" BASIS,
    12  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  // See the License for the specific language governing permissions and
    14  // limitations under the License.
    15  //
    16  
    17  package estream
    18  
    19  import (
    20  	"bytes"
    21  	crand "crypto/rand"
    22  	"crypto/rsa"
    23  	"io"
    24  	"os"
    25  	"testing"
    26  )
    27  
    28  var testStreams = map[string][]byte{
    29  	"stream1": bytes.Repeat([]byte("a"), 2000),
    30  	"stream2": bytes.Repeat([]byte("b"), 1<<20),
    31  	"stream3": bytes.Repeat([]byte("b"), 5),
    32  	"empty":   {},
    33  }
    34  
    35  func TestStreamRoundtrip(t *testing.T) {
    36  	var buf bytes.Buffer
    37  	w := NewWriter(&buf)
    38  	if err := w.AddKeyPlain(); err != nil {
    39  		t.Fatal(err)
    40  	}
    41  	wantStreams := 0
    42  	wantDecStreams := 0
    43  	for name, value := range testStreams {
    44  		st, err := w.AddEncryptedStream(name, []byte(name))
    45  		if err != nil {
    46  			t.Fatal(err)
    47  		}
    48  		_, err = io.Copy(st, bytes.NewBuffer(value))
    49  		if err != nil {
    50  			t.Fatal(err)
    51  		}
    52  		st.Close()
    53  		st, err = w.AddUnencryptedStream(name, []byte(name))
    54  		if err != nil {
    55  			t.Fatal(err)
    56  		}
    57  		_, err = io.Copy(st, bytes.NewBuffer(value))
    58  		if err != nil {
    59  			t.Fatal(err)
    60  		}
    61  		st.Close()
    62  		wantStreams += 2
    63  		wantDecStreams += 2
    64  	}
    65  
    66  	priv, err := rsa.GenerateKey(crand.Reader, 2048)
    67  	if err != nil {
    68  		t.Fatal(err)
    69  	}
    70  	err = w.AddKeyEncrypted(&priv.PublicKey)
    71  	if err != nil {
    72  		t.Fatal(err)
    73  	}
    74  	for name, value := range testStreams {
    75  		st, err := w.AddEncryptedStream(name, []byte(name))
    76  		if err != nil {
    77  			t.Fatal(err)
    78  		}
    79  		_, err = io.Copy(st, bytes.NewBuffer(value))
    80  		if err != nil {
    81  			t.Fatal(err)
    82  		}
    83  		st.Close()
    84  		st, err = w.AddUnencryptedStream(name, []byte(name))
    85  		if err != nil {
    86  			t.Fatal(err)
    87  		}
    88  		_, err = io.Copy(st, bytes.NewBuffer(value))
    89  		if err != nil {
    90  			t.Fatal(err)
    91  		}
    92  		st.Close()
    93  		wantStreams += 2
    94  		wantDecStreams++
    95  	}
    96  	err = w.Close()
    97  	if err != nil {
    98  		t.Fatal(err)
    99  	}
   100  
   101  	// Read back...
   102  	b := buf.Bytes()
   103  	r, err := NewReader(bytes.NewBuffer(b))
   104  	if err != nil {
   105  		t.Fatal(err)
   106  	}
   107  	r.SetPrivateKey(priv)
   108  
   109  	var gotStreams int
   110  	for {
   111  		st, err := r.NextStream()
   112  		if err == io.EOF {
   113  			break
   114  		}
   115  		if err != nil {
   116  			t.Fatalf("stream %d: %v", gotStreams, err)
   117  		}
   118  		want, ok := testStreams[st.Name]
   119  		if !ok {
   120  			t.Fatal("unexpected stream name", st.Name)
   121  		}
   122  		if !bytes.Equal(st.Extra, []byte(st.Name)) {
   123  			t.Fatal("unexpected stream extra:", st.Extra)
   124  		}
   125  		got, err := io.ReadAll(st)
   126  		if err != nil {
   127  			t.Fatalf("stream %d: %v", gotStreams, err)
   128  		}
   129  		if !bytes.Equal(got, want) {
   130  			t.Errorf("stream %d: content mismatch (len %d,%d)", gotStreams, len(got), len(want))
   131  		}
   132  		gotStreams++
   133  	}
   134  	if gotStreams != wantStreams {
   135  		t.Errorf("want %d streams, got %d", wantStreams, gotStreams)
   136  	}
   137  
   138  	// Read back, but skip encrypted streams.
   139  	r, err = NewReader(bytes.NewBuffer(b))
   140  	if err != nil {
   141  		t.Fatal(err)
   142  	}
   143  	r.SkipEncrypted(true)
   144  
   145  	gotStreams = 0
   146  	for {
   147  		st, err := r.NextStream()
   148  		if err == io.EOF {
   149  			break
   150  		}
   151  		if err != nil {
   152  			t.Fatalf("stream %d: %v", gotStreams, err)
   153  		}
   154  		want, ok := testStreams[st.Name]
   155  		if !ok {
   156  			t.Fatal("unexpected stream name", st.Name)
   157  		}
   158  		if !bytes.Equal(st.Extra, []byte(st.Name)) {
   159  			t.Fatal("unexpected stream extra:", st.Extra)
   160  		}
   161  		got, err := io.ReadAll(st)
   162  		if err != nil {
   163  			t.Fatalf("stream %d: %v", gotStreams, err)
   164  		}
   165  		if !bytes.Equal(got, want) {
   166  			t.Errorf("stream %d: content mismatch (len %d,%d)", gotStreams, len(got), len(want))
   167  		}
   168  		gotStreams++
   169  	}
   170  	if gotStreams != wantDecStreams {
   171  		t.Errorf("want %d streams, got %d", wantStreams, gotStreams)
   172  	}
   173  
   174  	gotStreams = 0
   175  	r, err = NewReader(bytes.NewBuffer(b))
   176  	if err != nil {
   177  		t.Fatal(err)
   178  	}
   179  	r.SkipEncrypted(true)
   180  	for {
   181  		st, err := r.NextStream()
   182  		if err == io.EOF {
   183  			break
   184  		}
   185  		if err != nil {
   186  			t.Fatalf("stream %d: %v", gotStreams, err)
   187  		}
   188  		_, ok := testStreams[st.Name]
   189  		if !ok {
   190  			t.Fatal("unexpected stream name", st.Name)
   191  		}
   192  		if !bytes.Equal(st.Extra, []byte(st.Name)) {
   193  			t.Fatal("unexpected stream extra:", st.Extra)
   194  		}
   195  		err = st.Skip()
   196  		if err != nil {
   197  			t.Fatalf("stream %d: %v", gotStreams, err)
   198  		}
   199  		gotStreams++
   200  	}
   201  	if gotStreams != wantDecStreams {
   202  		t.Errorf("want %d streams, got %d", wantDecStreams, gotStreams)
   203  	}
   204  
   205  	if false {
   206  		r, err = NewReader(bytes.NewBuffer(b))
   207  		if err != nil {
   208  			t.Fatal(err)
   209  		}
   210  		r.SkipEncrypted(true)
   211  
   212  		err = r.DebugStream(os.Stdout)
   213  		if err != nil {
   214  			t.Fatal(err)
   215  		}
   216  	}
   217  }
   218  
   219  func TestReplaceKeys(t *testing.T) {
   220  	var buf bytes.Buffer
   221  	w := NewWriter(&buf)
   222  	if err := w.AddKeyPlain(); err != nil {
   223  		t.Fatal(err)
   224  	}
   225  	wantStreams := 0
   226  	for name, value := range testStreams {
   227  		st, err := w.AddEncryptedStream(name, []byte(name))
   228  		if err != nil {
   229  			t.Fatal(err)
   230  		}
   231  		_, err = io.Copy(st, bytes.NewBuffer(value))
   232  		if err != nil {
   233  			t.Fatal(err)
   234  		}
   235  		st.Close()
   236  		st, err = w.AddUnencryptedStream(name, []byte(name))
   237  		if err != nil {
   238  			t.Fatal(err)
   239  		}
   240  		_, err = io.Copy(st, bytes.NewBuffer(value))
   241  		if err != nil {
   242  			t.Fatal(err)
   243  		}
   244  		st.Close()
   245  		wantStreams += 2
   246  	}
   247  
   248  	priv, err := rsa.GenerateKey(crand.Reader, 2048)
   249  	if err != nil {
   250  		t.Fatal(err)
   251  	}
   252  	err = w.AddKeyEncrypted(&priv.PublicKey)
   253  	if err != nil {
   254  		t.Fatal(err)
   255  	}
   256  	for name, value := range testStreams {
   257  		st, err := w.AddEncryptedStream(name, []byte(name))
   258  		if err != nil {
   259  			t.Fatal(err)
   260  		}
   261  		_, err = io.Copy(st, bytes.NewBuffer(value))
   262  		if err != nil {
   263  			t.Fatal(err)
   264  		}
   265  		st.Close()
   266  		st, err = w.AddUnencryptedStream(name, []byte(name))
   267  		if err != nil {
   268  			t.Fatal(err)
   269  		}
   270  		_, err = io.Copy(st, bytes.NewBuffer(value))
   271  		if err != nil {
   272  			t.Fatal(err)
   273  		}
   274  		st.Close()
   275  		wantStreams += 2
   276  	}
   277  	err = w.Close()
   278  	if err != nil {
   279  		t.Fatal(err)
   280  	}
   281  
   282  	priv2, err := rsa.GenerateKey(crand.Reader, 2048)
   283  	if err != nil {
   284  		t.Fatal(err)
   285  	}
   286  
   287  	var replaced bytes.Buffer
   288  	err = ReplaceKeys(&replaced, &buf, func(key *rsa.PublicKey) (*rsa.PrivateKey, *rsa.PublicKey) {
   289  		if key == nil {
   290  			return nil, &priv2.PublicKey
   291  		}
   292  		if key.Equal(&priv.PublicKey) {
   293  			return priv, &priv2.PublicKey
   294  		}
   295  		t.Fatal("unknown key\n", *key, "\nwant\n", priv.PublicKey)
   296  		return nil, nil
   297  	}, ReplaceKeysOptions{EncryptAll: true})
   298  	if err != nil {
   299  		t.Fatal(err)
   300  	}
   301  
   302  	// Read back...
   303  	r, err := NewReader(&replaced)
   304  	if err != nil {
   305  		t.Fatal(err)
   306  	}
   307  
   308  	// Use key provider.
   309  	r.PrivateKeyProvider(func(key *rsa.PublicKey) *rsa.PrivateKey {
   310  		if key.Equal(&priv2.PublicKey) {
   311  			return priv2
   312  		}
   313  		t.Fatal("unexpected public key")
   314  		return nil
   315  	})
   316  
   317  	var gotStreams int
   318  	for {
   319  		st, err := r.NextStream()
   320  		if err == io.EOF {
   321  			break
   322  		}
   323  		if err != nil {
   324  			t.Fatalf("stream %d: %v", gotStreams, err)
   325  		}
   326  		want, ok := testStreams[st.Name]
   327  		if !ok {
   328  			t.Fatal("unexpected stream name", st.Name)
   329  		}
   330  		if st.SentEncrypted != (gotStreams&1 == 0) {
   331  			t.Errorf("stream %d was sent with unexpected encryption %v", gotStreams, st.SentEncrypted)
   332  		}
   333  		if !bytes.Equal(st.Extra, []byte(st.Name)) {
   334  			t.Fatal("unexpected stream extra:", st.Extra)
   335  		}
   336  		got, err := io.ReadAll(st)
   337  		if err != nil {
   338  			t.Fatalf("stream %d: %v", gotStreams, err)
   339  		}
   340  		if !bytes.Equal(got, want) {
   341  			t.Errorf("stream %d: content mismatch (len %d,%d)", gotStreams, len(got), len(want))
   342  		}
   343  		gotStreams++
   344  	}
   345  	if gotStreams != wantStreams {
   346  		t.Errorf("want %d streams, got %d", wantStreams, gotStreams)
   347  	}
   348  }
   349  
   350  func TestError(t *testing.T) {
   351  	var buf bytes.Buffer
   352  	w := NewWriter(&buf)
   353  	if err := w.AddKeyPlain(); err != nil {
   354  		t.Fatal(err)
   355  	}
   356  	want := "an error message!"
   357  	if err := w.AddError(want); err != nil {
   358  		t.Fatal(err)
   359  	}
   360  	w.Close()
   361  
   362  	// Read back...
   363  	r, err := NewReader(&buf)
   364  	if err != nil {
   365  		t.Fatal(err)
   366  	}
   367  	st, err := r.NextStream()
   368  	if err == nil {
   369  		t.Fatalf("did not receive error, got %v, err: %v", st, err)
   370  	}
   371  	if err.Error() != want {
   372  		t.Errorf("Expected %q, got %q", want, err.Error())
   373  	}
   374  }
   375  
   376  func TestStreamReturnNonDecryptable(t *testing.T) {
   377  	var buf bytes.Buffer
   378  	w := NewWriter(&buf)
   379  	if err := w.AddKeyPlain(); err != nil {
   380  		t.Fatal(err)
   381  	}
   382  
   383  	priv, err := rsa.GenerateKey(crand.Reader, 2048)
   384  	if err != nil {
   385  		t.Fatal(err)
   386  	}
   387  	err = w.AddKeyEncrypted(&priv.PublicKey)
   388  	if err != nil {
   389  		t.Fatal(err)
   390  	}
   391  	wantStreams := len(testStreams)
   392  	for name, value := range testStreams {
   393  		st, err := w.AddEncryptedStream(name, []byte(name))
   394  		if err != nil {
   395  			t.Fatal(err)
   396  		}
   397  		_, err = io.Copy(st, bytes.NewBuffer(value))
   398  		if err != nil {
   399  			t.Fatal(err)
   400  		}
   401  		st.Close()
   402  	}
   403  	err = w.Close()
   404  	if err != nil {
   405  		t.Fatal(err)
   406  	}
   407  
   408  	// Read back...
   409  	b := buf.Bytes()
   410  	r, err := NewReader(bytes.NewBuffer(b))
   411  	if err != nil {
   412  		t.Fatal(err)
   413  	}
   414  	r.ReturnNonDecryptable(true)
   415  	gotStreams := 0
   416  	for {
   417  		st, err := r.NextStream()
   418  		if err == io.EOF {
   419  			break
   420  		}
   421  		if err != ErrNoKey {
   422  			t.Fatalf("stream %d: %v", gotStreams, err)
   423  		}
   424  		_, ok := testStreams[st.Name]
   425  		if !ok {
   426  			t.Fatal("unexpected stream name", st.Name)
   427  		}
   428  		if !bytes.Equal(st.Extra, []byte(st.Name)) {
   429  			t.Fatal("unexpected stream extra:", st.Extra)
   430  		}
   431  		if !st.SentEncrypted {
   432  			t.Fatal("stream not marked as encrypted:", st.SentEncrypted)
   433  		}
   434  		err = st.Skip()
   435  		if err != nil {
   436  			t.Fatalf("stream %d: %v", gotStreams, err)
   437  		}
   438  		gotStreams++
   439  	}
   440  	if gotStreams != wantStreams {
   441  		t.Errorf("want %d streams, got %d", wantStreams, gotStreams)
   442  	}
   443  }