github.com/aavshr/aws-sdk-go@v1.41.3/service/s3/body_hash.go (about)

     1  package s3
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/md5"
     6  	"crypto/sha256"
     7  	"encoding/base64"
     8  	"encoding/hex"
     9  	"fmt"
    10  	"hash"
    11  	"io"
    12  
    13  	"github.com/aavshr/aws-sdk-go/aws"
    14  	"github.com/aavshr/aws-sdk-go/aws/awserr"
    15  	"github.com/aavshr/aws-sdk-go/aws/request"
    16  )
    17  
    18  const (
    19  	contentMD5Header    = "Content-Md5"
    20  	contentSha256Header = "X-Amz-Content-Sha256"
    21  	amzTeHeader         = "X-Amz-Te"
    22  	amzTxEncodingHeader = "X-Amz-Transfer-Encoding"
    23  
    24  	appendMD5TxEncoding = "append-md5"
    25  )
    26  
    27  // computeBodyHashes will add Content MD5 and Content Sha256 hashes to the
    28  // request. If the body is not seekable or S3DisableContentMD5Validation set
    29  // this handler will be ignored.
    30  func computeBodyHashes(r *request.Request) {
    31  	if aws.BoolValue(r.Config.S3DisableContentMD5Validation) {
    32  		return
    33  	}
    34  	if r.IsPresigned() {
    35  		return
    36  	}
    37  	if r.Error != nil || !aws.IsReaderSeekable(r.Body) {
    38  		return
    39  	}
    40  
    41  	var md5Hash, sha256Hash hash.Hash
    42  	hashers := make([]io.Writer, 0, 2)
    43  
    44  	// Determine upfront which hashes can be set without overriding user
    45  	// provide header data.
    46  	if v := r.HTTPRequest.Header.Get(contentMD5Header); len(v) == 0 {
    47  		md5Hash = md5.New()
    48  		hashers = append(hashers, md5Hash)
    49  	}
    50  
    51  	if v := r.HTTPRequest.Header.Get(contentSha256Header); len(v) == 0 {
    52  		sha256Hash = sha256.New()
    53  		hashers = append(hashers, sha256Hash)
    54  	}
    55  
    56  	// Create the destination writer based on the hashes that are not already
    57  	// provided by the user.
    58  	var dst io.Writer
    59  	switch len(hashers) {
    60  	case 0:
    61  		return
    62  	case 1:
    63  		dst = hashers[0]
    64  	default:
    65  		dst = io.MultiWriter(hashers...)
    66  	}
    67  
    68  	if _, err := aws.CopySeekableBody(dst, r.Body); err != nil {
    69  		r.Error = awserr.New("BodyHashError", "failed to compute body hashes", err)
    70  		return
    71  	}
    72  
    73  	// For the hashes created, set the associated headers that the user did not
    74  	// already provide.
    75  	if md5Hash != nil {
    76  		sum := make([]byte, md5.Size)
    77  		encoded := make([]byte, md5Base64EncLen)
    78  
    79  		base64.StdEncoding.Encode(encoded, md5Hash.Sum(sum[0:0]))
    80  		r.HTTPRequest.Header[contentMD5Header] = []string{string(encoded)}
    81  	}
    82  
    83  	if sha256Hash != nil {
    84  		encoded := make([]byte, sha256HexEncLen)
    85  		sum := make([]byte, sha256.Size)
    86  
    87  		hex.Encode(encoded, sha256Hash.Sum(sum[0:0]))
    88  		r.HTTPRequest.Header[contentSha256Header] = []string{string(encoded)}
    89  	}
    90  }
    91  
    92  const (
    93  	md5Base64EncLen = (md5.Size + 2) / 3 * 4 // base64.StdEncoding.EncodedLen
    94  	sha256HexEncLen = sha256.Size * 2        // hex.EncodedLen
    95  )
    96  
    97  // Adds the x-amz-te: append_md5 header to the request. This requests the service
    98  // responds with a trailing MD5 checksum.
    99  //
   100  // Will not ask for append MD5 if disabled, the request is presigned or,
   101  // or the API operation does not support content MD5 validation.
   102  func askForTxEncodingAppendMD5(r *request.Request) {
   103  	if aws.BoolValue(r.Config.S3DisableContentMD5Validation) {
   104  		return
   105  	}
   106  	if r.IsPresigned() {
   107  		return
   108  	}
   109  	r.HTTPRequest.Header.Set(amzTeHeader, appendMD5TxEncoding)
   110  }
   111  
   112  func useMD5ValidationReader(r *request.Request) {
   113  	if r.Error != nil {
   114  		return
   115  	}
   116  
   117  	if v := r.HTTPResponse.Header.Get(amzTxEncodingHeader); v != appendMD5TxEncoding {
   118  		return
   119  	}
   120  
   121  	var bodyReader *io.ReadCloser
   122  	var contentLen int64
   123  	switch tv := r.Data.(type) {
   124  	case *GetObjectOutput:
   125  		bodyReader = &tv.Body
   126  		contentLen = aws.Int64Value(tv.ContentLength)
   127  		// Update ContentLength hiden the trailing MD5 checksum.
   128  		tv.ContentLength = aws.Int64(contentLen - md5.Size)
   129  		tv.ContentRange = aws.String(r.HTTPResponse.Header.Get("X-Amz-Content-Range"))
   130  	default:
   131  		r.Error = awserr.New("ChecksumValidationError",
   132  			fmt.Sprintf("%s: %s header received on unsupported API, %s",
   133  				amzTxEncodingHeader, appendMD5TxEncoding, r.Operation.Name,
   134  			), nil)
   135  		return
   136  	}
   137  
   138  	if contentLen < md5.Size {
   139  		r.Error = awserr.New("ChecksumValidationError",
   140  			fmt.Sprintf("invalid Content-Length %d for %s %s",
   141  				contentLen, appendMD5TxEncoding, amzTxEncodingHeader,
   142  			), nil)
   143  		return
   144  	}
   145  
   146  	// Wrap and swap the response body reader with the validation reader.
   147  	*bodyReader = newMD5ValidationReader(*bodyReader, contentLen-md5.Size)
   148  }
   149  
   150  type md5ValidationReader struct {
   151  	rawReader io.ReadCloser
   152  	payload   io.Reader
   153  	hash      hash.Hash
   154  
   155  	payloadLen int64
   156  	read       int64
   157  }
   158  
   159  func newMD5ValidationReader(reader io.ReadCloser, payloadLen int64) *md5ValidationReader {
   160  	h := md5.New()
   161  	return &md5ValidationReader{
   162  		rawReader:  reader,
   163  		payload:    io.TeeReader(&io.LimitedReader{R: reader, N: payloadLen}, h),
   164  		hash:       h,
   165  		payloadLen: payloadLen,
   166  	}
   167  }
   168  
   169  func (v *md5ValidationReader) Read(p []byte) (n int, err error) {
   170  	n, err = v.payload.Read(p)
   171  	if err != nil && err != io.EOF {
   172  		return n, err
   173  	}
   174  
   175  	v.read += int64(n)
   176  
   177  	if err == io.EOF {
   178  		if v.read != v.payloadLen {
   179  			return n, io.ErrUnexpectedEOF
   180  		}
   181  		expectSum := make([]byte, md5.Size)
   182  		actualSum := make([]byte, md5.Size)
   183  		if _, sumReadErr := io.ReadFull(v.rawReader, expectSum); sumReadErr != nil {
   184  			return n, sumReadErr
   185  		}
   186  		actualSum = v.hash.Sum(actualSum[0:0])
   187  		if !bytes.Equal(expectSum, actualSum) {
   188  			return n, awserr.New("InvalidChecksum",
   189  				fmt.Sprintf("expected MD5 checksum %s, got %s",
   190  					hex.EncodeToString(expectSum),
   191  					hex.EncodeToString(actualSum),
   192  				),
   193  				nil)
   194  		}
   195  	}
   196  
   197  	return n, err
   198  }
   199  
   200  func (v *md5ValidationReader) Close() error {
   201  	return v.rawReader.Close()
   202  }