github.com/aavshr/aws-sdk-go@v1.41.3/private/protocol/eventstream/decode.go (about)

     1  package eventstream
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/binary"
     6  	"encoding/hex"
     7  	"encoding/json"
     8  	"fmt"
     9  	"hash"
    10  	"hash/crc32"
    11  	"io"
    12  
    13  	"github.com/aavshr/aws-sdk-go/aws"
    14  )
    15  
    16  // Decoder provides decoding of an Event Stream messages.
    17  type Decoder struct {
    18  	r      io.Reader
    19  	logger aws.Logger
    20  }
    21  
    22  // NewDecoder initializes and returns a Decoder for decoding event
    23  // stream messages from the reader provided.
    24  func NewDecoder(r io.Reader, opts ...func(*Decoder)) *Decoder {
    25  	d := &Decoder{
    26  		r: r,
    27  	}
    28  
    29  	for _, opt := range opts {
    30  		opt(d)
    31  	}
    32  
    33  	return d
    34  }
    35  
    36  // DecodeWithLogger adds a logger to be used by the decoder when decoding
    37  // stream events.
    38  func DecodeWithLogger(logger aws.Logger) func(*Decoder) {
    39  	return func(d *Decoder) {
    40  		d.logger = logger
    41  	}
    42  }
    43  
    44  // Decode attempts to decode a single message from the event stream reader.
    45  // Will return the event stream message, or error if Decode fails to read
    46  // the message from the stream.
    47  func (d *Decoder) Decode(payloadBuf []byte) (m Message, err error) {
    48  	reader := d.r
    49  	if d.logger != nil {
    50  		debugMsgBuf := bytes.NewBuffer(nil)
    51  		reader = io.TeeReader(reader, debugMsgBuf)
    52  		defer func() {
    53  			logMessageDecode(d.logger, debugMsgBuf, m, err)
    54  		}()
    55  	}
    56  
    57  	m, err = Decode(reader, payloadBuf)
    58  
    59  	return m, err
    60  }
    61  
    62  // Decode attempts to decode a single message from the event stream reader.
    63  // Will return the event stream message, or error if Decode fails to read
    64  // the message from the reader.
    65  func Decode(reader io.Reader, payloadBuf []byte) (m Message, err error) {
    66  	crc := crc32.New(crc32IEEETable)
    67  	hashReader := io.TeeReader(reader, crc)
    68  
    69  	prelude, err := decodePrelude(hashReader, crc)
    70  	if err != nil {
    71  		return Message{}, err
    72  	}
    73  
    74  	if prelude.HeadersLen > 0 {
    75  		lr := io.LimitReader(hashReader, int64(prelude.HeadersLen))
    76  		m.Headers, err = decodeHeaders(lr)
    77  		if err != nil {
    78  			return Message{}, err
    79  		}
    80  	}
    81  
    82  	if payloadLen := prelude.PayloadLen(); payloadLen > 0 {
    83  		buf, err := decodePayload(payloadBuf, io.LimitReader(hashReader, int64(payloadLen)))
    84  		if err != nil {
    85  			return Message{}, err
    86  		}
    87  		m.Payload = buf
    88  	}
    89  
    90  	msgCRC := crc.Sum32()
    91  	if err := validateCRC(reader, msgCRC); err != nil {
    92  		return Message{}, err
    93  	}
    94  
    95  	return m, nil
    96  }
    97  
    98  func logMessageDecode(logger aws.Logger, msgBuf *bytes.Buffer, msg Message, decodeErr error) {
    99  	w := bytes.NewBuffer(nil)
   100  	defer func() { logger.Log(w.String()) }()
   101  
   102  	fmt.Fprintf(w, "Raw message:\n%s\n",
   103  		hex.Dump(msgBuf.Bytes()))
   104  
   105  	if decodeErr != nil {
   106  		fmt.Fprintf(w, "Decode error: %v\n", decodeErr)
   107  		return
   108  	}
   109  
   110  	rawMsg, err := msg.rawMessage()
   111  	if err != nil {
   112  		fmt.Fprintf(w, "failed to create raw message, %v\n", err)
   113  		return
   114  	}
   115  
   116  	decodedMsg := decodedMessage{
   117  		rawMessage: rawMsg,
   118  		Headers:    decodedHeaders(msg.Headers),
   119  	}
   120  
   121  	fmt.Fprintf(w, "Decoded message:\n")
   122  	encoder := json.NewEncoder(w)
   123  	if err := encoder.Encode(decodedMsg); err != nil {
   124  		fmt.Fprintf(w, "failed to generate decoded message, %v\n", err)
   125  	}
   126  }
   127  
   128  func decodePrelude(r io.Reader, crc hash.Hash32) (messagePrelude, error) {
   129  	var p messagePrelude
   130  
   131  	var err error
   132  	p.Length, err = decodeUint32(r)
   133  	if err != nil {
   134  		return messagePrelude{}, err
   135  	}
   136  
   137  	p.HeadersLen, err = decodeUint32(r)
   138  	if err != nil {
   139  		return messagePrelude{}, err
   140  	}
   141  
   142  	if err := p.ValidateLens(); err != nil {
   143  		return messagePrelude{}, err
   144  	}
   145  
   146  	preludeCRC := crc.Sum32()
   147  	if err := validateCRC(r, preludeCRC); err != nil {
   148  		return messagePrelude{}, err
   149  	}
   150  
   151  	p.PreludeCRC = preludeCRC
   152  
   153  	return p, nil
   154  }
   155  
   156  func decodePayload(buf []byte, r io.Reader) ([]byte, error) {
   157  	w := bytes.NewBuffer(buf[0:0])
   158  
   159  	_, err := io.Copy(w, r)
   160  	return w.Bytes(), err
   161  }
   162  
   163  func decodeUint8(r io.Reader) (uint8, error) {
   164  	type byteReader interface {
   165  		ReadByte() (byte, error)
   166  	}
   167  
   168  	if br, ok := r.(byteReader); ok {
   169  		v, err := br.ReadByte()
   170  		return uint8(v), err
   171  	}
   172  
   173  	var b [1]byte
   174  	_, err := io.ReadFull(r, b[:])
   175  	return uint8(b[0]), err
   176  }
   177  func decodeUint16(r io.Reader) (uint16, error) {
   178  	var b [2]byte
   179  	bs := b[:]
   180  	_, err := io.ReadFull(r, bs)
   181  	if err != nil {
   182  		return 0, err
   183  	}
   184  	return binary.BigEndian.Uint16(bs), nil
   185  }
   186  func decodeUint32(r io.Reader) (uint32, error) {
   187  	var b [4]byte
   188  	bs := b[:]
   189  	_, err := io.ReadFull(r, bs)
   190  	if err != nil {
   191  		return 0, err
   192  	}
   193  	return binary.BigEndian.Uint32(bs), nil
   194  }
   195  func decodeUint64(r io.Reader) (uint64, error) {
   196  	var b [8]byte
   197  	bs := b[:]
   198  	_, err := io.ReadFull(r, bs)
   199  	if err != nil {
   200  		return 0, err
   201  	}
   202  	return binary.BigEndian.Uint64(bs), nil
   203  }
   204  
   205  func validateCRC(r io.Reader, expect uint32) error {
   206  	msgCRC, err := decodeUint32(r)
   207  	if err != nil {
   208  		return err
   209  	}
   210  
   211  	if msgCRC != expect {
   212  		return ChecksumError{}
   213  	}
   214  
   215  	return nil
   216  }