github.com/datastax/go-cassandra-native-protocol@v0.0.0-20220706104457-5e8aad05cf90/segment/decode.go (about)

     1  // Copyright 2021 DataStax
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package segment
    16  
    17  import (
    18  	"bytes"
    19  	"encoding/binary"
    20  	"fmt"
    21  	"io"
    22  
    23  	"github.com/datastax/go-cassandra-native-protocol/crc"
    24  	"github.com/datastax/go-cassandra-native-protocol/primitive"
    25  )
    26  
    27  func (c *codec) DecodeSegment(source io.Reader) (*Segment, error) {
    28  	if header, err := c.decodeSegmentHeader(source); err != nil {
    29  		return nil, fmt.Errorf("cannot decode segment header: %w", err)
    30  	} else if payload, err := c.decodeSegmentPayload(header, source); err != nil {
    31  		return nil, fmt.Errorf("cannot decode segment payload: %w", err)
    32  	} else {
    33  		return &Segment{
    34  			Header:  header,
    35  			Payload: payload,
    36  		}, nil
    37  	}
    38  }
    39  
    40  func (c *codec) decodeSegmentHeader(source io.Reader) (*Header, error) {
    41  	// Read header data (little endian)
    42  	var headerData uint64
    43  	headerLength := c.headerLength()
    44  	for i := 0; i < headerLength; i++ {
    45  		if b, err := primitive.ReadByte(source); err != nil {
    46  			return nil, fmt.Errorf("cannot read segment header at byte %v: %w", i, err)
    47  		} else {
    48  			headerData |= uint64(b) << (8 * i)
    49  		}
    50  	}
    51  	// Read CRC (little endian) and check it
    52  	var expectedHeaderCrc uint32
    53  	for i := 0; i < Crc24Length; i++ {
    54  		if b, err := primitive.ReadByte(source); err != nil {
    55  			return nil, fmt.Errorf("cannot read segment header CRC: %w", err)
    56  		} else {
    57  			expectedHeaderCrc |= uint32(b) << (8 * i)
    58  		}
    59  	}
    60  	actualHeaderCrc := crc.ChecksumKoopman(headerData, headerLength)
    61  	if actualHeaderCrc != expectedHeaderCrc {
    62  		return nil, fmt.Errorf(
    63  			"crc mismatch on header %x: received %x, computed %x",
    64  			headerData,
    65  			expectedHeaderCrc,
    66  			actualHeaderCrc)
    67  	}
    68  	header := &Header{Crc24: actualHeaderCrc}
    69  	if c.compressor == nil {
    70  		header.CompressedPayloadLength = 0
    71  		header.UncompressedPayloadLength = int32(headerData & MaxPayloadLength)
    72  	} else {
    73  		header.CompressedPayloadLength = int32(headerData & MaxPayloadLength)
    74  		headerData >>= 17
    75  		header.UncompressedPayloadLength = int32(headerData & MaxPayloadLength)
    76  		if header.UncompressedPayloadLength == 0 {
    77  			// the server chose not to compress
    78  			header.UncompressedPayloadLength = header.CompressedPayloadLength
    79  			header.CompressedPayloadLength = 0
    80  		}
    81  	}
    82  	headerData >>= 17
    83  	header.IsSelfContained = (headerData & 1) == 1
    84  	return header, nil
    85  }
    86  
    87  func (c *codec) decodeSegmentPayload(header *Header, source io.Reader) (*Payload, error) {
    88  	// Extract payload
    89  	var length int32
    90  	if c.compressor == nil || header.CompressedPayloadLength == 0 {
    91  		length = header.UncompressedPayloadLength
    92  	} else {
    93  		length = header.CompressedPayloadLength
    94  	}
    95  	encodedPayload := make([]byte, length)
    96  	if _, err := io.ReadFull(source, encodedPayload); err != nil {
    97  		return nil, fmt.Errorf("cannot read encoded payload: %w", err)
    98  	}
    99  	// Read and check CRC
   100  	var expectedPayloadCrc uint32
   101  	if err := binary.Read(source, binary.LittleEndian, &expectedPayloadCrc); err != nil {
   102  		return nil, fmt.Errorf("cannot read segment payload CRC: %w", err)
   103  	}
   104  	actualPayloadCrc := crc.ChecksumIEEE(encodedPayload)
   105  	if actualPayloadCrc != expectedPayloadCrc {
   106  		return nil, fmt.Errorf(
   107  			"crc mismatch on payload: received %x, computed %x",
   108  			expectedPayloadCrc, actualPayloadCrc)
   109  	}
   110  	payload := &Payload{Crc32: actualPayloadCrc}
   111  	// Decompress payload if needed
   112  	if c.compressor == nil || header.CompressedPayloadLength == 0 {
   113  		payload.UncompressedData = encodedPayload
   114  	} else {
   115  		rawData := bytes.NewBuffer(make([]byte, 0, length))
   116  		if err := c.compressor.Decompress(bytes.NewReader(encodedPayload), rawData); err != nil {
   117  			return nil, fmt.Errorf("cannot decompress segment payload: %w", err)
   118  		}
   119  		payload.UncompressedData = rawData.Bytes()
   120  	}
   121  	return payload, nil
   122  }
   123  
   124  func (c *codec) headerLength() int {
   125  	if c.compressor == nil {
   126  		return UncompressedHeaderLength
   127  	}
   128  	return CompressedHeaderLength
   129  }