github.com/aavshr/aws-sdk-go@v1.41.3/service/sqs/checksums.go (about) 1 package sqs 2 3 import ( 4 "crypto/md5" 5 "encoding/hex" 6 "fmt" 7 "strings" 8 9 "github.com/aavshr/aws-sdk-go/aws" 10 "github.com/aavshr/aws-sdk-go/aws/awserr" 11 "github.com/aavshr/aws-sdk-go/aws/request" 12 ) 13 14 var ( 15 errChecksumMissingBody = fmt.Errorf("cannot compute checksum. missing body") 16 errChecksumMissingMD5 = fmt.Errorf("cannot verify checksum. missing response MD5") 17 ) 18 19 func setupChecksumValidation(r *request.Request) { 20 if aws.BoolValue(r.Config.DisableComputeChecksums) { 21 return 22 } 23 24 switch r.Operation.Name { 25 case opSendMessage: 26 r.Handlers.Unmarshal.PushBack(verifySendMessage) 27 case opSendMessageBatch: 28 r.Handlers.Unmarshal.PushBack(verifySendMessageBatch) 29 case opReceiveMessage: 30 r.Handlers.Unmarshal.PushBack(verifyReceiveMessage) 31 } 32 } 33 34 func verifySendMessage(r *request.Request) { 35 if r.DataFilled() && r.ParamsFilled() { 36 in := r.Params.(*SendMessageInput) 37 out := r.Data.(*SendMessageOutput) 38 err := checksumsMatch(in.MessageBody, out.MD5OfMessageBody) 39 if err != nil { 40 setChecksumError(r, err.Error()) 41 } 42 } 43 } 44 45 func verifySendMessageBatch(r *request.Request) { 46 if r.DataFilled() && r.ParamsFilled() { 47 entries := map[string]*SendMessageBatchResultEntry{} 48 ids := []string{} 49 50 out := r.Data.(*SendMessageBatchOutput) 51 for _, entry := range out.Successful { 52 entries[*entry.Id] = entry 53 } 54 55 in := r.Params.(*SendMessageBatchInput) 56 for _, entry := range in.Entries { 57 if e, ok := entries[*entry.Id]; ok { 58 if err := checksumsMatch(entry.MessageBody, e.MD5OfMessageBody); err != nil { 59 ids = append(ids, *e.MessageId) 60 } 61 } 62 } 63 if len(ids) > 0 { 64 setChecksumError(r, "invalid messages: %s", strings.Join(ids, ", ")) 65 } 66 } 67 } 68 69 func verifyReceiveMessage(r *request.Request) { 70 if r.DataFilled() && r.ParamsFilled() { 71 ids := []string{} 72 out := r.Data.(*ReceiveMessageOutput) 73 for i, msg := range out.Messages { 74 err := checksumsMatch(msg.Body, msg.MD5OfBody) 75 if err != nil { 76 if msg.MessageId == nil { 77 if r.Config.Logger != nil { 78 r.Config.Logger.Log(fmt.Sprintf( 79 "WARN: SQS.ReceiveMessage failed checksum request id: %s, message %d has no message ID.", 80 r.RequestID, i, 81 )) 82 } 83 continue 84 } 85 86 ids = append(ids, *msg.MessageId) 87 } 88 } 89 if len(ids) > 0 { 90 setChecksumError(r, "invalid messages: %s", strings.Join(ids, ", ")) 91 } 92 } 93 } 94 95 func checksumsMatch(body, expectedMD5 *string) error { 96 if body == nil { 97 return errChecksumMissingBody 98 } else if expectedMD5 == nil { 99 return errChecksumMissingMD5 100 } 101 102 msum := md5.Sum([]byte(*body)) 103 sum := hex.EncodeToString(msum[:]) 104 if sum != *expectedMD5 { 105 return fmt.Errorf("expected MD5 checksum '%s', got '%s'", *expectedMD5, sum) 106 } 107 108 return nil 109 } 110 111 func setChecksumError(r *request.Request, format string, args ...interface{}) { 112 r.Retryable = aws.Bool(true) 113 r.Error = awserr.New("InvalidChecksum", fmt.Sprintf(format, args...), nil) 114 }