github.com/datastax/go-cassandra-native-protocol@v0.0.0-20220706104457-5e8aad05cf90/frame/encode.go (about)

     1  // Copyright 2020 DataStax
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package frame
    16  
    17  import (
    18  	"bytes"
    19  	"errors"
    20  	"fmt"
    21  	"io"
    22  
    23  	"github.com/datastax/go-cassandra-native-protocol/primitive"
    24  )
    25  
    26  func (c *codec) EncodeFrame(frame *Frame, dest io.Writer) error {
    27  	if frame.Header.Flags.Contains(primitive.HeaderFlagCompressed) {
    28  		return c.encodeFrameCompressed(frame, dest)
    29  	} else {
    30  		return c.encodeFrameUncompressed(frame, dest)
    31  	}
    32  }
    33  
    34  func (c *codec) encodeFrameUncompressed(frame *Frame, dest io.Writer) error {
    35  	if encodedBodyLength, err := c.uncompressedBodyLength(frame.Header, frame.Body); err != nil {
    36  		return fmt.Errorf("cannot compute length of uncompressed message body: %w", err)
    37  	} else {
    38  		frame.Header.BodyLength = int32(encodedBodyLength)
    39  	}
    40  	if err := c.EncodeHeader(frame.Header, dest); err != nil {
    41  		return fmt.Errorf("cannot encode frame header: %w", err)
    42  	} else if err := c.EncodeBody(frame.Header, frame.Body, dest); err != nil {
    43  		return fmt.Errorf("cannot encode frame body: %w", err)
    44  	}
    45  	return nil
    46  }
    47  
    48  func (c *codec) encodeFrameCompressed(frame *Frame, dest io.Writer) error {
    49  	compressedBody := bytes.Buffer{}
    50  	if err := c.EncodeBody(frame.Header, frame.Body, &compressedBody); err != nil {
    51  		return fmt.Errorf("cannot encode frame body: %w", err)
    52  	} else {
    53  		frame.Header.BodyLength = int32(compressedBody.Len())
    54  		if err := c.EncodeHeader(frame.Header, dest); err != nil {
    55  			return fmt.Errorf("cannot encode frame header: %w", err)
    56  		} else if _, err := compressedBody.WriteTo(dest); err != nil {
    57  			return fmt.Errorf("cannot concat frame body to frame header: %w", err)
    58  		}
    59  	}
    60  	return nil
    61  }
    62  
    63  func (c *codec) EncodeRawFrame(frame *RawFrame, dest io.Writer) error {
    64  	if err := primitive.CheckSupportedProtocolVersion(frame.Header.Version); err != nil {
    65  		return err
    66  	} else {
    67  		frame.Header.BodyLength = int32(len(frame.Body))
    68  		if err := c.EncodeHeader(frame.Header, dest); err != nil {
    69  			return fmt.Errorf("cannot encode raw header: %w", err)
    70  		} else if _, err := dest.Write(frame.Body); err != nil {
    71  			return fmt.Errorf("cannot write raw body: %w", err)
    72  		}
    73  	}
    74  	return nil
    75  }
    76  
    77  func (c *codec) EncodeHeader(header *Header, dest io.Writer) error {
    78  	useBetaFlag := header.Flags.Contains(primitive.HeaderFlagUseBeta)
    79  	if err := primitive.CheckSupportedProtocolVersion(header.Version); err != nil {
    80  		return NewProtocolVersionErr(err.Error(), header.Version, useBetaFlag)
    81  	} else if header.Version.IsBeta() && !useBetaFlag {
    82  		return NewProtocolVersionErr("expected USE_BETA flag to be set", header.Version, useBetaFlag)
    83  	}
    84  
    85  	versionAndDirection := uint8(header.Version)
    86  	if header.IsResponse {
    87  		versionAndDirection |= 0b1000_0000
    88  	}
    89  	if err := primitive.WriteByte(versionAndDirection, dest); err != nil {
    90  		return fmt.Errorf("cannot encode header version and direction: %w", err)
    91  	} else if err := primitive.WriteByte(uint8(header.Flags), dest); err != nil {
    92  		return fmt.Errorf("cannot encode header flags: %w", err)
    93  	} else if err = primitive.WriteStreamId(header.StreamId, dest, header.Version); err != nil {
    94  		return fmt.Errorf("cannot encode header stream id: %w", err)
    95  	} else if err = primitive.WriteByte(uint8(header.OpCode), dest); err != nil {
    96  		return fmt.Errorf("cannot encode header opcode: %w", err)
    97  	} else if err = primitive.WriteInt(header.BodyLength, dest); err != nil {
    98  		return fmt.Errorf("cannot encode header body length: %w", err)
    99  	}
   100  	return nil
   101  }
   102  
   103  func (c *codec) EncodeBody(header *Header, body *Body, dest io.Writer) error {
   104  	if header.OpCode != body.Message.GetOpCode() {
   105  		return fmt.Errorf("opcode mismatch between header and body: %d != %d", header.OpCode, body.Message.GetOpCode())
   106  	} else if header.Flags.Contains(primitive.HeaderFlagCompressed) {
   107  		if c.compressor == nil {
   108  			return errors.New("cannot compress body: no compressor available")
   109  		} else if uncompressedBodyLength, err := c.uncompressedBodyLength(header, body); err != nil {
   110  			return fmt.Errorf("cannot compute length of uncompressed message body: %w", err)
   111  		} else {
   112  			uncompressedBody := bytes.NewBuffer(make([]byte, 0, uncompressedBodyLength))
   113  			if err = c.encodeBodyUncompressed(header, body, uncompressedBody); err != nil {
   114  				return fmt.Errorf("cannot encode body: %w", err)
   115  			} else if err := c.compressor.CompressWithLength(uncompressedBody, dest); err != nil {
   116  				return fmt.Errorf("cannot compress body: %w", err)
   117  			}
   118  			return nil
   119  		}
   120  	} else {
   121  		return c.encodeBodyUncompressed(header, body, dest)
   122  	}
   123  }
   124  
   125  func (c *codec) encodeBodyUncompressed(header *Header, body *Body, dest io.Writer) (err error) {
   126  	if header.Flags.Contains(primitive.HeaderFlagTracing) && body.Message.IsResponse() {
   127  		if err = primitive.WriteUuid(body.TracingId, dest); err != nil {
   128  			return fmt.Errorf("cannot encode body tracing id: %w", err)
   129  		}
   130  	}
   131  	if header.Flags.Contains(primitive.HeaderFlagCustomPayload) {
   132  		if header.Version < primitive.ProtocolVersion4 {
   133  			return fmt.Errorf("custom payloads are not supported in protocol version %v", header.Version)
   134  		} else if err = primitive.WriteBytesMap(body.CustomPayload, dest); err != nil {
   135  			return fmt.Errorf("cannot encode body custom payload: %w", err)
   136  		}
   137  	}
   138  	if header.Flags.Contains(primitive.HeaderFlagWarning) {
   139  		if header.Version < primitive.ProtocolVersion4 && body.Warnings != nil {
   140  			return fmt.Errorf("warnings are not supported in protocol version %v", header.Version)
   141  		} else if err = primitive.WriteStringList(body.Warnings, dest); err != nil {
   142  			return fmt.Errorf("cannot encode body warnings: %w", err)
   143  		}
   144  	}
   145  	if encoder, err := c.findMessageCodec(body.Message.GetOpCode()); err != nil {
   146  		return err
   147  	} else if err = encoder.Encode(body.Message, dest, header.Version); err != nil {
   148  		return fmt.Errorf("cannot encode body message: %w", err)
   149  	}
   150  	return nil
   151  }
   152  
   153  func (c *codec) uncompressedBodyLength(header *Header, body *Body) (length int, err error) {
   154  	if encoder, err := c.findMessageCodec(body.Message.GetOpCode()); err != nil {
   155  		return -1, err
   156  	} else if length, err = encoder.EncodedLength(body.Message, header.Version); err != nil {
   157  		return -1, fmt.Errorf("cannot compute message length: %w", err)
   158  	}
   159  	if header.Flags.Contains(primitive.HeaderFlagTracing) {
   160  		length += primitive.LengthOfUuid
   161  	}
   162  	if header.Flags.Contains(primitive.HeaderFlagCustomPayload) {
   163  		length += primitive.LengthOfBytesMap(body.CustomPayload)
   164  	}
   165  	if header.Flags.Contains(primitive.HeaderFlagWarning) {
   166  		length += primitive.LengthOfStringList(body.Warnings)
   167  	}
   168  	return length, nil
   169  }