github.com/datastax/go-cassandra-native-protocol@v0.0.0-20220706104457-5e8aad05cf90/frame/decode.go (about) 1 // Copyright 2020 DataStax 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package frame 16 17 import ( 18 "bytes" 19 "errors" 20 "fmt" 21 "io" 22 "io/ioutil" 23 24 "github.com/datastax/go-cassandra-native-protocol/primitive" 25 ) 26 27 func (c *codec) DecodeFrame(source io.Reader) (*Frame, error) { 28 if header, err := c.DecodeHeader(source); err != nil { 29 return nil, fmt.Errorf("cannot decode frame header: %w", err) 30 } else if body, err := c.DecodeBody(header, source); err != nil { 31 return nil, fmt.Errorf("cannot decode frame body: %w", err) 32 } else { 33 return &Frame{Header: header, Body: body}, nil 34 } 35 } 36 37 func (c *codec) DecodeRawFrame(source io.Reader) (*RawFrame, error) { 38 if header, err := c.DecodeHeader(source); err != nil { 39 return nil, fmt.Errorf("cannot decode frame header: %w", err) 40 } else if body, err := c.DecodeRawBody(header, source); err != nil { 41 return nil, fmt.Errorf("cannot read frame body: %w", err) 42 } else { 43 return &RawFrame{Header: header, Body: body}, nil 44 } 45 } 46 47 func (c *codec) DecodeHeader(source io.Reader) (*Header, error) { 48 if versionAndDirection, err := primitive.ReadByte(source); err != nil { 49 return nil, fmt.Errorf("cannot decode header version and direction: %w", err) 50 } else { 51 isResponse := (versionAndDirection & 0b1000_0000) > 0 52 version := primitive.ProtocolVersion(versionAndDirection & 0b0111_1111) 53 header := &Header{ 54 IsResponse: isResponse, 55 Version: version, 56 } 57 58 var flags uint8 59 var err error 60 if flags, err = primitive.ReadByte(source); err != nil { 61 return nil, fmt.Errorf("cannot decode header flags: %w", err) 62 } 63 useBetaFlag := primitive.HeaderFlag(flags).Contains(primitive.HeaderFlagUseBeta) 64 65 var opCode uint8 66 if err = primitive.CheckSupportedProtocolVersion(version); err != nil { 67 return nil, NewProtocolVersionErr(err.Error(), version, useBetaFlag) 68 } else if version.IsBeta() && !useBetaFlag { 69 return nil, NewProtocolVersionErr("expected USE_BETA flag to be set", version, useBetaFlag) 70 } else if header.StreamId, err = primitive.ReadStreamId(source, version); err != nil { 71 return nil, fmt.Errorf("cannot decode header stream id: %w", err) 72 } else if opCode, err = primitive.ReadByte(source); err != nil { 73 return nil, fmt.Errorf("cannot decode header opcode: %w", err) 74 } else if header.BodyLength, err = primitive.ReadInt(source); err != nil { 75 return nil, fmt.Errorf("cannot decode header body length: %w", err) 76 } 77 header.Flags = primitive.HeaderFlag(flags) 78 header.OpCode = primitive.OpCode(opCode) 79 if err := primitive.CheckValidOpCode(header.OpCode); err != nil { 80 return nil, err 81 } else if isResponse { 82 if err := primitive.CheckResponseOpCode(header.OpCode); err != nil { 83 return nil, err 84 } 85 } else { 86 if err := primitive.CheckRequestOpCode(header.OpCode); err != nil { 87 return nil, err 88 } 89 } 90 return header, err 91 } 92 } 93 94 func (c *codec) DecodeBody(header *Header, source io.Reader) (body *Body, err error) { 95 if compressed := header.Flags.Contains(primitive.HeaderFlagCompressed); compressed { 96 if c.compressor == nil { 97 return nil, errors.New("cannot decompress body: no compressor available") 98 } else { 99 decompressedBody := &bytes.Buffer{} 100 if err := c.compressor.DecompressWithLength(io.LimitReader(source, int64(header.BodyLength)), decompressedBody); err != nil { 101 return nil, fmt.Errorf("cannot decompress body: %w", err) 102 } else { 103 source = decompressedBody 104 } 105 } 106 } 107 body = &Body{} 108 if header.IsResponse && header.Flags.Contains(primitive.HeaderFlagTracing) { 109 if body.TracingId, err = primitive.ReadUuid(source); err != nil { 110 return nil, fmt.Errorf("cannot decode body tracing id: %w", err) 111 } 112 } 113 if header.Flags.Contains(primitive.HeaderFlagCustomPayload) { 114 if body.CustomPayload, err = primitive.ReadBytesMap(source); err != nil { 115 return nil, fmt.Errorf("cannot decode body custom payload: %w", err) 116 } 117 } 118 if header.IsResponse && header.Flags.Contains(primitive.HeaderFlagWarning) { 119 if body.Warnings, err = primitive.ReadStringList(source); err != nil { 120 return nil, fmt.Errorf("cannot decode body warnings: %w", err) 121 } 122 } 123 if decoder, err := c.findMessageCodec(header.OpCode); err != nil { 124 return nil, err 125 } else if body.Message, err = decoder.Decode(source, header.Version); err != nil { 126 return nil, fmt.Errorf("cannot decode body message: %w", err) 127 } 128 return body, err 129 } 130 131 func (c *codec) DecodeRawBody(header *Header, source io.Reader) (body []byte, err error) { 132 if header.BodyLength < 0 { 133 return nil, fmt.Errorf("invalid body length: %d", header.BodyLength) 134 } else if header.BodyLength == 0 { 135 return []byte{}, nil 136 } 137 count := int64(header.BodyLength) 138 buf := bytes.NewBuffer(make([]byte, 0, count)) 139 if _, err := io.CopyN(buf, source, count); err != nil { 140 return nil, fmt.Errorf("cannot decode raw body: %w", err) 141 } 142 return buf.Bytes(), nil 143 } 144 145 func (c *codec) DiscardBody(header *Header, source io.Reader) (err error) { 146 if header.BodyLength < 0 { 147 return fmt.Errorf("invalid body length: %d", header.BodyLength) 148 } else if header.BodyLength == 0 { 149 return nil 150 } 151 count := int64(header.BodyLength) 152 switch s := source.(type) { 153 case io.Seeker: 154 _, err = s.Seek(count, io.SeekCurrent) 155 default: 156 _, err = io.CopyN(ioutil.Discard, s, count) 157 } 158 if err != nil { 159 err = fmt.Errorf("cannot discard body; %w", err) 160 } 161 return err 162 }