github.com/aavshr/aws-sdk-go@v1.41.3/service/s3/body_hash_test.go (about)

     1  package s3
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/md5"
     6  	"fmt"
     7  	"io"
     8  	"io/ioutil"
     9  	"net/http"
    10  	"strings"
    11  	"testing"
    12  	"time"
    13  
    14  	"github.com/aavshr/aws-sdk-go/aws"
    15  	"github.com/aavshr/aws-sdk-go/aws/request"
    16  	"github.com/aavshr/aws-sdk-go/internal/sdkio"
    17  )
    18  
    19  type errorReader struct{}
    20  
    21  func (errorReader) Read([]byte) (int, error) {
    22  	return 0, fmt.Errorf("errorReader error")
    23  }
    24  func (errorReader) Seek(int64, int) (int64, error) {
    25  	return 0, nil
    26  }
    27  
    28  func TestComputeBodyHases(t *testing.T) {
    29  	bodyContent := []byte("bodyContent goes here")
    30  
    31  	cases := []struct {
    32  		Req               *request.Request
    33  		ExpectMD5         string
    34  		ExpectSHA256      string
    35  		Error             string
    36  		DisableContentMD5 bool
    37  		Presigned         bool
    38  	}{
    39  		{
    40  			Req: &request.Request{
    41  				HTTPRequest: &http.Request{
    42  					Header: http.Header{},
    43  				},
    44  				Body: bytes.NewReader(bodyContent),
    45  			},
    46  			ExpectMD5:    "CqD6NNPvoNOBT/5pkjtzOw==",
    47  			ExpectSHA256: "3ff09c8b42a58a905e27835919ede45b61722e7cd400f30101bd9ed1a69a1825",
    48  		},
    49  		{
    50  			Req: &request.Request{
    51  				HTTPRequest: &http.Request{
    52  					Header: func() http.Header {
    53  						h := http.Header{}
    54  						h.Set(contentMD5Header, "MD5AlreadySet")
    55  						return h
    56  					}(),
    57  				},
    58  				Body: bytes.NewReader(bodyContent),
    59  			},
    60  			ExpectMD5:    "MD5AlreadySet",
    61  			ExpectSHA256: "3ff09c8b42a58a905e27835919ede45b61722e7cd400f30101bd9ed1a69a1825",
    62  		},
    63  		{
    64  			Req: &request.Request{
    65  				HTTPRequest: &http.Request{
    66  					Header: func() http.Header {
    67  						h := http.Header{}
    68  						h.Set(contentSha256Header, "SHA256AlreadySet")
    69  						return h
    70  					}(),
    71  				},
    72  				Body: bytes.NewReader(bodyContent),
    73  			},
    74  			ExpectMD5:    "CqD6NNPvoNOBT/5pkjtzOw==",
    75  			ExpectSHA256: "SHA256AlreadySet",
    76  		},
    77  		{
    78  			Req: &request.Request{
    79  				HTTPRequest: &http.Request{
    80  					Header: func() http.Header {
    81  						h := http.Header{}
    82  						h.Set(contentMD5Header, "MD5AlreadySet")
    83  						h.Set(contentSha256Header, "SHA256AlreadySet")
    84  						return h
    85  					}(),
    86  				},
    87  				Body: bytes.NewReader(bodyContent),
    88  			},
    89  			ExpectMD5:    "MD5AlreadySet",
    90  			ExpectSHA256: "SHA256AlreadySet",
    91  		},
    92  		{
    93  			Req: &request.Request{
    94  				HTTPRequest: &http.Request{
    95  					Header: http.Header{},
    96  				},
    97  				// Non-seekable reader
    98  				Body: aws.ReadSeekCloser(bytes.NewBuffer(bodyContent)),
    99  			},
   100  			ExpectMD5:    "",
   101  			ExpectSHA256: "",
   102  		},
   103  		{
   104  			Req: &request.Request{
   105  				HTTPRequest: &http.Request{
   106  					Header: http.Header{},
   107  				},
   108  				// Empty seekable body
   109  				Body: aws.ReadSeekCloser(bytes.NewReader(nil)),
   110  			},
   111  			ExpectMD5:    "1B2M2Y8AsgTpgAmY7PhCfg==",
   112  			ExpectSHA256: "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855",
   113  		},
   114  		{
   115  			Req: &request.Request{
   116  				HTTPRequest: &http.Request{
   117  					Header: http.Header{},
   118  				},
   119  				// failure while reading reader
   120  				Body: errorReader{},
   121  			},
   122  			ExpectMD5:    "",
   123  			ExpectSHA256: "",
   124  			Error:        "errorReader error",
   125  		},
   126  		{
   127  			// Disabled ContentMD5 validation
   128  			Req: &request.Request{
   129  				HTTPRequest: &http.Request{
   130  					Header: http.Header{},
   131  				},
   132  				Body: bytes.NewReader(bodyContent),
   133  			},
   134  			ExpectMD5:         "",
   135  			ExpectSHA256:      "",
   136  			DisableContentMD5: true,
   137  		},
   138  		{
   139  			// Disabled ContentMD5 validation
   140  			Req: &request.Request{
   141  				HTTPRequest: &http.Request{
   142  					Header: http.Header{},
   143  				},
   144  				Body: bytes.NewReader(bodyContent),
   145  			},
   146  			ExpectMD5:    "",
   147  			ExpectSHA256: "",
   148  			Presigned:    true,
   149  		},
   150  	}
   151  
   152  	for i, c := range cases {
   153  		c.Req.Config.S3DisableContentMD5Validation = aws.Bool(c.DisableContentMD5)
   154  
   155  		if c.Presigned {
   156  			c.Req.ExpireTime = 10 * time.Minute
   157  		}
   158  		computeBodyHashes(c.Req)
   159  
   160  		if e, a := c.ExpectMD5, c.Req.HTTPRequest.Header.Get(contentMD5Header); e != a {
   161  			t.Errorf("%d, expect %v md5, got %v", i, e, a)
   162  		}
   163  
   164  		if e, a := c.ExpectSHA256, c.Req.HTTPRequest.Header.Get(contentSha256Header); e != a {
   165  			t.Errorf("%d, expect %v sha256, got %v", i, e, a)
   166  		}
   167  
   168  		if len(c.Error) != 0 {
   169  			if c.Req.Error == nil {
   170  				t.Fatalf("%d, expect error, got none", i)
   171  			}
   172  			if e, a := c.Error, c.Req.Error.Error(); !strings.Contains(a, e) {
   173  				t.Errorf("%d, expect %v error to be in %v", i, e, a)
   174  			}
   175  
   176  		} else if c.Req.Error != nil {
   177  			t.Errorf("%d, expect no error, got %v", i, c.Req.Error)
   178  		}
   179  	}
   180  }
   181  
   182  func BenchmarkComputeBodyHashes(b *testing.B) {
   183  	body := bytes.NewReader(make([]byte, 2*1024))
   184  	req := &request.Request{
   185  		HTTPRequest: &http.Request{
   186  			Header: http.Header{},
   187  		},
   188  		Body: body,
   189  	}
   190  	b.ResetTimer()
   191  
   192  	for i := 0; i < b.N; i++ {
   193  		computeBodyHashes(req)
   194  		if req.Error != nil {
   195  			b.Fatalf("expect no error, got %v", req.Error)
   196  		}
   197  
   198  		req.HTTPRequest.Header = http.Header{}
   199  		body.Seek(0, sdkio.SeekStart)
   200  	}
   201  }
   202  
   203  func TestAskForTxEncodingAppendMD5(t *testing.T) {
   204  	cases := []struct {
   205  		DisableContentMD5 bool
   206  		Presigned         bool
   207  	}{
   208  		{DisableContentMD5: true},
   209  		{DisableContentMD5: false},
   210  		{Presigned: true},
   211  	}
   212  
   213  	for i, c := range cases {
   214  		req := &request.Request{
   215  			HTTPRequest: &http.Request{
   216  				Header: http.Header{},
   217  			},
   218  			Config: aws.Config{
   219  				S3DisableContentMD5Validation: aws.Bool(c.DisableContentMD5),
   220  			},
   221  		}
   222  		if c.Presigned {
   223  			req.ExpireTime = 10 * time.Minute
   224  		}
   225  
   226  		askForTxEncodingAppendMD5(req)
   227  
   228  		v := req.HTTPRequest.Header.Get(amzTeHeader)
   229  
   230  		expectHeader := !(c.DisableContentMD5 || c.Presigned)
   231  
   232  		if e, a := expectHeader, len(v) != 0; e != a {
   233  			t.Errorf("%d, expect %t disable content MD5, got %t, %s", i, e, a, v)
   234  		}
   235  	}
   236  }
   237  
   238  func TestUseMD5ValidationReader(t *testing.T) {
   239  	body := []byte("create a really cool md5 checksum of me")
   240  	bodySum := md5.Sum(body)
   241  	bodyWithSum := append(body, bodySum[:]...)
   242  
   243  	emptyBodySum := md5.Sum([]byte{})
   244  
   245  	cases := []struct {
   246  		Req      *request.Request
   247  		Error    string
   248  		Validate func(outupt interface{}) error
   249  	}{
   250  		{
   251  			// Positive: Use Validation reader
   252  			Req: &request.Request{
   253  				HTTPResponse: &http.Response{
   254  					Header: func() http.Header {
   255  						h := http.Header{}
   256  						h.Set(amzTxEncodingHeader, appendMD5TxEncoding)
   257  						return h
   258  					}(),
   259  				},
   260  				Data: &GetObjectOutput{
   261  					Body:          ioutil.NopCloser(bytes.NewReader(bodyWithSum)),
   262  					ContentLength: aws.Int64(int64(len(bodyWithSum))),
   263  				},
   264  			},
   265  			Validate: func(output interface{}) error {
   266  				getObjOut := output.(*GetObjectOutput)
   267  				reader, ok := getObjOut.Body.(*md5ValidationReader)
   268  				if !ok {
   269  					return fmt.Errorf("expect %T updated body reader, got %T",
   270  						(*md5ValidationReader)(nil), getObjOut.Body)
   271  				}
   272  
   273  				if reader.rawReader == nil {
   274  					return fmt.Errorf("expect rawReader not to be nil")
   275  				}
   276  				if reader.payload == nil {
   277  					return fmt.Errorf("expect payload not to be nil")
   278  				}
   279  				if e, a := int64(len(bodyWithSum)-md5.Size), reader.payloadLen; e != a {
   280  					return fmt.Errorf("expect %v payload len, got %v", e, a)
   281  				}
   282  				if reader.hash == nil {
   283  					return fmt.Errorf("expect hash not to be nil")
   284  				}
   285  
   286  				return nil
   287  			},
   288  		},
   289  		{
   290  			// Positive: Use Validation reader, empty object
   291  			Req: &request.Request{
   292  				HTTPResponse: &http.Response{
   293  					Header: func() http.Header {
   294  						h := http.Header{}
   295  						h.Set(amzTxEncodingHeader, appendMD5TxEncoding)
   296  						return h
   297  					}(),
   298  				},
   299  				Data: &GetObjectOutput{
   300  					Body:          ioutil.NopCloser(bytes.NewReader(emptyBodySum[:])),
   301  					ContentLength: aws.Int64(int64(len(emptyBodySum[:]))),
   302  				},
   303  			},
   304  			Validate: func(output interface{}) error {
   305  				getObjOut := output.(*GetObjectOutput)
   306  				reader, ok := getObjOut.Body.(*md5ValidationReader)
   307  				if !ok {
   308  					return fmt.Errorf("expect %T updated body reader, got %T",
   309  						(*md5ValidationReader)(nil), getObjOut.Body)
   310  				}
   311  
   312  				if reader.rawReader == nil {
   313  					return fmt.Errorf("expect rawReader not to be nil")
   314  				}
   315  				if reader.payload == nil {
   316  					return fmt.Errorf("expect payload not to be nil")
   317  				}
   318  				if e, a := int64(len(emptyBodySum)-md5.Size), reader.payloadLen; e != a {
   319  					return fmt.Errorf("expect %v payload len, got %v", e, a)
   320  				}
   321  				if reader.hash == nil {
   322  					return fmt.Errorf("expect hash not to be nil")
   323  				}
   324  
   325  				return nil
   326  			},
   327  		},
   328  		{
   329  			// Negative: amzTxEncoding header not set
   330  			Req: &request.Request{
   331  				HTTPResponse: &http.Response{
   332  					Header: http.Header{},
   333  				},
   334  				Data: &GetObjectOutput{
   335  					Body:          ioutil.NopCloser(bytes.NewReader(body)),
   336  					ContentLength: aws.Int64(int64(len(body))),
   337  				},
   338  			},
   339  			Validate: func(output interface{}) error {
   340  				getObjOut := output.(*GetObjectOutput)
   341  				reader, ok := getObjOut.Body.(*md5ValidationReader)
   342  				if ok {
   343  					return fmt.Errorf("expect body reader not to be %T",
   344  						reader)
   345  				}
   346  
   347  				return nil
   348  			},
   349  		},
   350  		{
   351  			// Negative: Not GetObjectOutput type.
   352  			Req: &request.Request{
   353  				Operation: &request.Operation{
   354  					Name: "PutObject",
   355  				},
   356  				HTTPResponse: &http.Response{
   357  					Header: func() http.Header {
   358  						h := http.Header{}
   359  						h.Set(amzTxEncodingHeader, appendMD5TxEncoding)
   360  						return h
   361  					}(),
   362  				},
   363  				Data: &PutObjectOutput{},
   364  			},
   365  			Error: "header received on unsupported API",
   366  			Validate: func(output interface{}) error {
   367  				_, ok := output.(*PutObjectOutput)
   368  				if !ok {
   369  					return fmt.Errorf("expect %T output not to change, got %T",
   370  						(*PutObjectOutput)(nil), output)
   371  				}
   372  
   373  				return nil
   374  			},
   375  		},
   376  		{
   377  			// Negative: invalid content length.
   378  			Req: &request.Request{
   379  				HTTPResponse: &http.Response{
   380  					Header: func() http.Header {
   381  						h := http.Header{}
   382  						h.Set(amzTxEncodingHeader, appendMD5TxEncoding)
   383  						return h
   384  					}(),
   385  				},
   386  				Data: &GetObjectOutput{
   387  					Body:          ioutil.NopCloser(bytes.NewReader(bodyWithSum)),
   388  					ContentLength: aws.Int64(-1),
   389  				},
   390  			},
   391  			Error: "invalid Content-Length -1",
   392  			Validate: func(output interface{}) error {
   393  				getObjOut := output.(*GetObjectOutput)
   394  				reader, ok := getObjOut.Body.(*md5ValidationReader)
   395  				if ok {
   396  					return fmt.Errorf("expect body reader not to be %T",
   397  						reader)
   398  				}
   399  				return nil
   400  			},
   401  		},
   402  		{
   403  			// Negative: invalid content length, < md5.Size.
   404  			Req: &request.Request{
   405  				HTTPResponse: &http.Response{
   406  					Header: func() http.Header {
   407  						h := http.Header{}
   408  						h.Set(amzTxEncodingHeader, appendMD5TxEncoding)
   409  						return h
   410  					}(),
   411  				},
   412  				Data: &GetObjectOutput{
   413  					Body:          ioutil.NopCloser(bytes.NewReader(make([]byte, 5))),
   414  					ContentLength: aws.Int64(5),
   415  				},
   416  			},
   417  			Error: "invalid Content-Length 5",
   418  			Validate: func(output interface{}) error {
   419  				getObjOut := output.(*GetObjectOutput)
   420  				reader, ok := getObjOut.Body.(*md5ValidationReader)
   421  				if ok {
   422  					return fmt.Errorf("expect body reader not to be %T",
   423  						reader)
   424  				}
   425  				return nil
   426  			},
   427  		},
   428  	}
   429  
   430  	for i, c := range cases {
   431  		useMD5ValidationReader(c.Req)
   432  		if len(c.Error) != 0 {
   433  			if c.Req.Error == nil {
   434  				t.Fatalf("%d, expect error, got none", i)
   435  			}
   436  			if e, a := c.Error, c.Req.Error.Error(); !strings.Contains(a, e) {
   437  				t.Errorf("%d, expect %v error to be in %v", i, e, a)
   438  			}
   439  		} else if c.Req.Error != nil {
   440  			t.Errorf("%d, expect no error, got %v", i, c.Req.Error)
   441  		}
   442  
   443  		if c.Validate != nil {
   444  			if err := c.Validate(c.Req.Data); err != nil {
   445  				t.Errorf("%d, expect Data to validate, got %v", i, err)
   446  			}
   447  		}
   448  	}
   449  }
   450  
   451  func TestReaderMD5Validation(t *testing.T) {
   452  	body := []byte("create a really cool md5 checksum of me")
   453  	bodySum := md5.Sum(body)
   454  	bodyWithSum := append(body, bodySum[:]...)
   455  	emptyBodySum := md5.Sum([]byte{})
   456  	badBodySum := append(body, emptyBodySum[:]...)
   457  
   458  	cases := []struct {
   459  		Content       []byte
   460  		ContentReader io.ReadCloser
   461  		PayloadLen    int64
   462  		Error         string
   463  	}{
   464  		{
   465  			Content:    bodyWithSum,
   466  			PayloadLen: int64(len(body)),
   467  		},
   468  		{
   469  			Content:    emptyBodySum[:],
   470  			PayloadLen: 0,
   471  		},
   472  		{
   473  			Content:    badBodySum,
   474  			PayloadLen: int64(len(body)),
   475  			Error:      "expected MD5 checksum",
   476  		},
   477  		{
   478  			Content:    emptyBodySum[:len(emptyBodySum)-2],
   479  			PayloadLen: 0,
   480  			Error:      "unexpected EOF",
   481  		},
   482  		{
   483  			Content:    body,
   484  			PayloadLen: int64(len(body) * 2),
   485  			Error:      "unexpected EOF",
   486  		},
   487  		{
   488  			ContentReader: ioutil.NopCloser(errorReader{}),
   489  			PayloadLen:    int64(len(body)),
   490  			Error:         "errorReader error",
   491  		},
   492  	}
   493  
   494  	for i, c := range cases {
   495  		reader := c.ContentReader
   496  		if reader == nil {
   497  			reader = ioutil.NopCloser(bytes.NewReader(c.Content))
   498  		}
   499  		v := newMD5ValidationReader(reader, c.PayloadLen)
   500  
   501  		var actual bytes.Buffer
   502  		n, err := io.Copy(&actual, v)
   503  		if len(c.Error) != 0 {
   504  			if err == nil {
   505  				t.Errorf("%d, expect error, got none", i)
   506  			}
   507  			if e, a := c.Error, err.Error(); !strings.Contains(a, e) {
   508  				t.Errorf("%d, expect %v error to be in %v", i, e, a)
   509  			}
   510  			continue
   511  		} else if err != nil {
   512  			t.Errorf("%d, expect no error, got %v", i, err)
   513  			continue
   514  		}
   515  		if e, a := c.PayloadLen, n; e != a {
   516  			t.Errorf("%d, expect %v len, got %v", i, e, a)
   517  		}
   518  
   519  		if e, a := c.Content[:c.PayloadLen], actual.Bytes(); !bytes.Equal(e, a) {
   520  			t.Errorf("%d, expect:\n%v\nactual:\n%v", i, e, a)
   521  		}
   522  	}
   523  }