github.com/datastax/go-cassandra-native-protocol@v0.0.0-20220706104457-5e8aad05cf90/frame/decode.go (about)

     1  // Copyright 2020 DataStax
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package frame
    16  
    17  import (
    18  	"bytes"
    19  	"errors"
    20  	"fmt"
    21  	"io"
    22  	"io/ioutil"
    23  
    24  	"github.com/datastax/go-cassandra-native-protocol/primitive"
    25  )
    26  
    27  func (c *codec) DecodeFrame(source io.Reader) (*Frame, error) {
    28  	if header, err := c.DecodeHeader(source); err != nil {
    29  		return nil, fmt.Errorf("cannot decode frame header: %w", err)
    30  	} else if body, err := c.DecodeBody(header, source); err != nil {
    31  		return nil, fmt.Errorf("cannot decode frame body: %w", err)
    32  	} else {
    33  		return &Frame{Header: header, Body: body}, nil
    34  	}
    35  }
    36  
    37  func (c *codec) DecodeRawFrame(source io.Reader) (*RawFrame, error) {
    38  	if header, err := c.DecodeHeader(source); err != nil {
    39  		return nil, fmt.Errorf("cannot decode frame header: %w", err)
    40  	} else if body, err := c.DecodeRawBody(header, source); err != nil {
    41  		return nil, fmt.Errorf("cannot read frame body: %w", err)
    42  	} else {
    43  		return &RawFrame{Header: header, Body: body}, nil
    44  	}
    45  }
    46  
    47  func (c *codec) DecodeHeader(source io.Reader) (*Header, error) {
    48  	if versionAndDirection, err := primitive.ReadByte(source); err != nil {
    49  		return nil, fmt.Errorf("cannot decode header version and direction: %w", err)
    50  	} else {
    51  		isResponse := (versionAndDirection & 0b1000_0000) > 0
    52  		version := primitive.ProtocolVersion(versionAndDirection & 0b0111_1111)
    53  		header := &Header{
    54  			IsResponse: isResponse,
    55  			Version:    version,
    56  		}
    57  
    58  		var flags uint8
    59  		var err error
    60  		if flags, err = primitive.ReadByte(source); err != nil {
    61  			return nil, fmt.Errorf("cannot decode header flags: %w", err)
    62  		}
    63  		useBetaFlag := primitive.HeaderFlag(flags).Contains(primitive.HeaderFlagUseBeta)
    64  
    65  		var opCode uint8
    66  		if err = primitive.CheckSupportedProtocolVersion(version); err != nil {
    67  			return nil, NewProtocolVersionErr(err.Error(), version, useBetaFlag)
    68  		} else if version.IsBeta() && !useBetaFlag {
    69  			return nil, NewProtocolVersionErr("expected USE_BETA flag to be set", version, useBetaFlag)
    70  		} else if header.StreamId, err = primitive.ReadStreamId(source, version); err != nil {
    71  			return nil, fmt.Errorf("cannot decode header stream id: %w", err)
    72  		} else if opCode, err = primitive.ReadByte(source); err != nil {
    73  			return nil, fmt.Errorf("cannot decode header opcode: %w", err)
    74  		} else if header.BodyLength, err = primitive.ReadInt(source); err != nil {
    75  			return nil, fmt.Errorf("cannot decode header body length: %w", err)
    76  		}
    77  		header.Flags = primitive.HeaderFlag(flags)
    78  		header.OpCode = primitive.OpCode(opCode)
    79  		if err := primitive.CheckValidOpCode(header.OpCode); err != nil {
    80  			return nil, err
    81  		} else if isResponse {
    82  			if err := primitive.CheckResponseOpCode(header.OpCode); err != nil {
    83  				return nil, err
    84  			}
    85  		} else {
    86  			if err := primitive.CheckRequestOpCode(header.OpCode); err != nil {
    87  				return nil, err
    88  			}
    89  		}
    90  		return header, err
    91  	}
    92  }
    93  
    94  func (c *codec) DecodeBody(header *Header, source io.Reader) (body *Body, err error) {
    95  	if compressed := header.Flags.Contains(primitive.HeaderFlagCompressed); compressed {
    96  		if c.compressor == nil {
    97  			return nil, errors.New("cannot decompress body: no compressor available")
    98  		} else {
    99  			decompressedBody := &bytes.Buffer{}
   100  			if err := c.compressor.DecompressWithLength(io.LimitReader(source, int64(header.BodyLength)), decompressedBody); err != nil {
   101  				return nil, fmt.Errorf("cannot decompress body: %w", err)
   102  			} else {
   103  				source = decompressedBody
   104  			}
   105  		}
   106  	}
   107  	body = &Body{}
   108  	if header.IsResponse && header.Flags.Contains(primitive.HeaderFlagTracing) {
   109  		if body.TracingId, err = primitive.ReadUuid(source); err != nil {
   110  			return nil, fmt.Errorf("cannot decode body tracing id: %w", err)
   111  		}
   112  	}
   113  	if header.Flags.Contains(primitive.HeaderFlagCustomPayload) {
   114  		if body.CustomPayload, err = primitive.ReadBytesMap(source); err != nil {
   115  			return nil, fmt.Errorf("cannot decode body custom payload: %w", err)
   116  		}
   117  	}
   118  	if header.IsResponse && header.Flags.Contains(primitive.HeaderFlagWarning) {
   119  		if body.Warnings, err = primitive.ReadStringList(source); err != nil {
   120  			return nil, fmt.Errorf("cannot decode body warnings: %w", err)
   121  		}
   122  	}
   123  	if decoder, err := c.findMessageCodec(header.OpCode); err != nil {
   124  		return nil, err
   125  	} else if body.Message, err = decoder.Decode(source, header.Version); err != nil {
   126  		return nil, fmt.Errorf("cannot decode body message: %w", err)
   127  	}
   128  	return body, err
   129  }
   130  
   131  func (c *codec) DecodeRawBody(header *Header, source io.Reader) (body []byte, err error) {
   132  	if header.BodyLength < 0 {
   133  		return nil, fmt.Errorf("invalid body length: %d", header.BodyLength)
   134  	} else if header.BodyLength == 0 {
   135  		return []byte{}, nil
   136  	}
   137  	count := int64(header.BodyLength)
   138  	buf := bytes.NewBuffer(make([]byte, 0, count))
   139  	if _, err := io.CopyN(buf, source, count); err != nil {
   140  		return nil, fmt.Errorf("cannot decode raw body: %w", err)
   141  	}
   142  	return buf.Bytes(), nil
   143  }
   144  
   145  func (c *codec) DiscardBody(header *Header, source io.Reader) (err error) {
   146  	if header.BodyLength < 0 {
   147  		return fmt.Errorf("invalid body length: %d", header.BodyLength)
   148  	} else if header.BodyLength == 0 {
   149  		return nil
   150  	}
   151  	count := int64(header.BodyLength)
   152  	switch s := source.(type) {
   153  	case io.Seeker:
   154  		_, err = s.Seek(count, io.SeekCurrent)
   155  	default:
   156  		_, err = io.CopyN(ioutil.Discard, s, count)
   157  	}
   158  	if err != nil {
   159  		err = fmt.Errorf("cannot discard body; %w", err)
   160  	}
   161  	return err
   162  }