github.com/aavshr/aws-sdk-go@v1.41.3/private/protocol/eventstream/encode.go (about)

     1  package eventstream
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/binary"
     6  	"encoding/hex"
     7  	"encoding/json"
     8  	"fmt"
     9  	"hash"
    10  	"hash/crc32"
    11  	"io"
    12  
    13  	"github.com/aavshr/aws-sdk-go/aws"
    14  )
    15  
    16  // Encoder provides EventStream message encoding.
    17  type Encoder struct {
    18  	w      io.Writer
    19  	logger aws.Logger
    20  
    21  	headersBuf *bytes.Buffer
    22  }
    23  
    24  // NewEncoder initializes and returns an Encoder to encode Event Stream
    25  // messages to an io.Writer.
    26  func NewEncoder(w io.Writer, opts ...func(*Encoder)) *Encoder {
    27  	e := &Encoder{
    28  		w:          w,
    29  		headersBuf: bytes.NewBuffer(nil),
    30  	}
    31  
    32  	for _, opt := range opts {
    33  		opt(e)
    34  	}
    35  
    36  	return e
    37  }
    38  
    39  // EncodeWithLogger adds a logger to be used by the encode when decoding
    40  // stream events.
    41  func EncodeWithLogger(logger aws.Logger) func(*Encoder) {
    42  	return func(d *Encoder) {
    43  		d.logger = logger
    44  	}
    45  }
    46  
    47  // Encode encodes a single EventStream message to the io.Writer the Encoder
    48  // was created with. An error is returned if writing the message fails.
    49  func (e *Encoder) Encode(msg Message) (err error) {
    50  	e.headersBuf.Reset()
    51  
    52  	writer := e.w
    53  	if e.logger != nil {
    54  		encodeMsgBuf := bytes.NewBuffer(nil)
    55  		writer = io.MultiWriter(writer, encodeMsgBuf)
    56  		defer func() {
    57  			logMessageEncode(e.logger, encodeMsgBuf, msg, err)
    58  		}()
    59  	}
    60  
    61  	if err = EncodeHeaders(e.headersBuf, msg.Headers); err != nil {
    62  		return err
    63  	}
    64  
    65  	crc := crc32.New(crc32IEEETable)
    66  	hashWriter := io.MultiWriter(writer, crc)
    67  
    68  	headersLen := uint32(e.headersBuf.Len())
    69  	payloadLen := uint32(len(msg.Payload))
    70  
    71  	if err = encodePrelude(hashWriter, crc, headersLen, payloadLen); err != nil {
    72  		return err
    73  	}
    74  
    75  	if headersLen > 0 {
    76  		if _, err = io.Copy(hashWriter, e.headersBuf); err != nil {
    77  			return err
    78  		}
    79  	}
    80  
    81  	if payloadLen > 0 {
    82  		if _, err = hashWriter.Write(msg.Payload); err != nil {
    83  			return err
    84  		}
    85  	}
    86  
    87  	msgCRC := crc.Sum32()
    88  	return binary.Write(writer, binary.BigEndian, msgCRC)
    89  }
    90  
    91  func logMessageEncode(logger aws.Logger, msgBuf *bytes.Buffer, msg Message, encodeErr error) {
    92  	w := bytes.NewBuffer(nil)
    93  	defer func() { logger.Log(w.String()) }()
    94  
    95  	fmt.Fprintf(w, "Message to encode:\n")
    96  	encoder := json.NewEncoder(w)
    97  	if err := encoder.Encode(msg); err != nil {
    98  		fmt.Fprintf(w, "Failed to get encoded message, %v\n", err)
    99  	}
   100  
   101  	if encodeErr != nil {
   102  		fmt.Fprintf(w, "Encode error: %v\n", encodeErr)
   103  		return
   104  	}
   105  
   106  	fmt.Fprintf(w, "Raw message:\n%s\n", hex.Dump(msgBuf.Bytes()))
   107  }
   108  
   109  func encodePrelude(w io.Writer, crc hash.Hash32, headersLen, payloadLen uint32) error {
   110  	p := messagePrelude{
   111  		Length:     minMsgLen + headersLen + payloadLen,
   112  		HeadersLen: headersLen,
   113  	}
   114  	if err := p.ValidateLens(); err != nil {
   115  		return err
   116  	}
   117  
   118  	err := binaryWriteFields(w, binary.BigEndian,
   119  		p.Length,
   120  		p.HeadersLen,
   121  	)
   122  	if err != nil {
   123  		return err
   124  	}
   125  
   126  	p.PreludeCRC = crc.Sum32()
   127  	err = binary.Write(w, binary.BigEndian, p.PreludeCRC)
   128  	if err != nil {
   129  		return err
   130  	}
   131  
   132  	return nil
   133  }
   134  
   135  // EncodeHeaders writes the header values to the writer encoded in the event
   136  // stream format. Returns an error if a header fails to encode.
   137  func EncodeHeaders(w io.Writer, headers Headers) error {
   138  	for _, h := range headers {
   139  		hn := headerName{
   140  			Len: uint8(len(h.Name)),
   141  		}
   142  		copy(hn.Name[:hn.Len], h.Name)
   143  		if err := hn.encode(w); err != nil {
   144  			return err
   145  		}
   146  
   147  		if err := h.Value.encode(w); err != nil {
   148  			return err
   149  		}
   150  	}
   151  
   152  	return nil
   153  }
   154  
   155  func binaryWriteFields(w io.Writer, order binary.ByteOrder, vs ...interface{}) error {
   156  	for _, v := range vs {
   157  		if err := binary.Write(w, order, v); err != nil {
   158  			return err
   159  		}
   160  	}
   161  	return nil
   162  }