github.com/aavshr/aws-sdk-go@v1.41.3/service/sqs/checksums.go (about)

     1  package sqs
     2  
     3  import (
     4  	"crypto/md5"
     5  	"encoding/hex"
     6  	"fmt"
     7  	"strings"
     8  
     9  	"github.com/aavshr/aws-sdk-go/aws"
    10  	"github.com/aavshr/aws-sdk-go/aws/awserr"
    11  	"github.com/aavshr/aws-sdk-go/aws/request"
    12  )
    13  
    14  var (
    15  	errChecksumMissingBody = fmt.Errorf("cannot compute checksum. missing body")
    16  	errChecksumMissingMD5  = fmt.Errorf("cannot verify checksum. missing response MD5")
    17  )
    18  
    19  func setupChecksumValidation(r *request.Request) {
    20  	if aws.BoolValue(r.Config.DisableComputeChecksums) {
    21  		return
    22  	}
    23  
    24  	switch r.Operation.Name {
    25  	case opSendMessage:
    26  		r.Handlers.Unmarshal.PushBack(verifySendMessage)
    27  	case opSendMessageBatch:
    28  		r.Handlers.Unmarshal.PushBack(verifySendMessageBatch)
    29  	case opReceiveMessage:
    30  		r.Handlers.Unmarshal.PushBack(verifyReceiveMessage)
    31  	}
    32  }
    33  
    34  func verifySendMessage(r *request.Request) {
    35  	if r.DataFilled() && r.ParamsFilled() {
    36  		in := r.Params.(*SendMessageInput)
    37  		out := r.Data.(*SendMessageOutput)
    38  		err := checksumsMatch(in.MessageBody, out.MD5OfMessageBody)
    39  		if err != nil {
    40  			setChecksumError(r, err.Error())
    41  		}
    42  	}
    43  }
    44  
    45  func verifySendMessageBatch(r *request.Request) {
    46  	if r.DataFilled() && r.ParamsFilled() {
    47  		entries := map[string]*SendMessageBatchResultEntry{}
    48  		ids := []string{}
    49  
    50  		out := r.Data.(*SendMessageBatchOutput)
    51  		for _, entry := range out.Successful {
    52  			entries[*entry.Id] = entry
    53  		}
    54  
    55  		in := r.Params.(*SendMessageBatchInput)
    56  		for _, entry := range in.Entries {
    57  			if e, ok := entries[*entry.Id]; ok {
    58  				if err := checksumsMatch(entry.MessageBody, e.MD5OfMessageBody); err != nil {
    59  					ids = append(ids, *e.MessageId)
    60  				}
    61  			}
    62  		}
    63  		if len(ids) > 0 {
    64  			setChecksumError(r, "invalid messages: %s", strings.Join(ids, ", "))
    65  		}
    66  	}
    67  }
    68  
    69  func verifyReceiveMessage(r *request.Request) {
    70  	if r.DataFilled() && r.ParamsFilled() {
    71  		ids := []string{}
    72  		out := r.Data.(*ReceiveMessageOutput)
    73  		for i, msg := range out.Messages {
    74  			err := checksumsMatch(msg.Body, msg.MD5OfBody)
    75  			if err != nil {
    76  				if msg.MessageId == nil {
    77  					if r.Config.Logger != nil {
    78  						r.Config.Logger.Log(fmt.Sprintf(
    79  							"WARN: SQS.ReceiveMessage failed checksum request id: %s, message %d has no message ID.",
    80  							r.RequestID, i,
    81  						))
    82  					}
    83  					continue
    84  				}
    85  
    86  				ids = append(ids, *msg.MessageId)
    87  			}
    88  		}
    89  		if len(ids) > 0 {
    90  			setChecksumError(r, "invalid messages: %s", strings.Join(ids, ", "))
    91  		}
    92  	}
    93  }
    94  
    95  func checksumsMatch(body, expectedMD5 *string) error {
    96  	if body == nil {
    97  		return errChecksumMissingBody
    98  	} else if expectedMD5 == nil {
    99  		return errChecksumMissingMD5
   100  	}
   101  
   102  	msum := md5.Sum([]byte(*body))
   103  	sum := hex.EncodeToString(msum[:])
   104  	if sum != *expectedMD5 {
   105  		return fmt.Errorf("expected MD5 checksum '%s', got '%s'", *expectedMD5, sum)
   106  	}
   107  
   108  	return nil
   109  }
   110  
   111  func setChecksumError(r *request.Request, format string, args ...interface{}) {
   112  	r.Retryable = aws.Bool(true)
   113  	r.Error = awserr.New("InvalidChecksum", fmt.Sprintf(format, args...), nil)
   114  }