github.com/datastax/go-cassandra-native-protocol@v0.0.0-20220706104457-5e8aad05cf90/frame/encode.go (about) 1 // Copyright 2020 DataStax 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package frame 16 17 import ( 18 "bytes" 19 "errors" 20 "fmt" 21 "io" 22 23 "github.com/datastax/go-cassandra-native-protocol/primitive" 24 ) 25 26 func (c *codec) EncodeFrame(frame *Frame, dest io.Writer) error { 27 if frame.Header.Flags.Contains(primitive.HeaderFlagCompressed) { 28 return c.encodeFrameCompressed(frame, dest) 29 } else { 30 return c.encodeFrameUncompressed(frame, dest) 31 } 32 } 33 34 func (c *codec) encodeFrameUncompressed(frame *Frame, dest io.Writer) error { 35 if encodedBodyLength, err := c.uncompressedBodyLength(frame.Header, frame.Body); err != nil { 36 return fmt.Errorf("cannot compute length of uncompressed message body: %w", err) 37 } else { 38 frame.Header.BodyLength = int32(encodedBodyLength) 39 } 40 if err := c.EncodeHeader(frame.Header, dest); err != nil { 41 return fmt.Errorf("cannot encode frame header: %w", err) 42 } else if err := c.EncodeBody(frame.Header, frame.Body, dest); err != nil { 43 return fmt.Errorf("cannot encode frame body: %w", err) 44 } 45 return nil 46 } 47 48 func (c *codec) encodeFrameCompressed(frame *Frame, dest io.Writer) error { 49 compressedBody := bytes.Buffer{} 50 if err := c.EncodeBody(frame.Header, frame.Body, &compressedBody); err != nil { 51 return fmt.Errorf("cannot encode frame body: %w", err) 52 } else { 53 frame.Header.BodyLength = int32(compressedBody.Len()) 54 if err := c.EncodeHeader(frame.Header, dest); err != nil { 55 return fmt.Errorf("cannot encode frame header: %w", err) 56 } else if _, err := compressedBody.WriteTo(dest); err != nil { 57 return fmt.Errorf("cannot concat frame body to frame header: %w", err) 58 } 59 } 60 return nil 61 } 62 63 func (c *codec) EncodeRawFrame(frame *RawFrame, dest io.Writer) error { 64 if err := primitive.CheckSupportedProtocolVersion(frame.Header.Version); err != nil { 65 return err 66 } else { 67 frame.Header.BodyLength = int32(len(frame.Body)) 68 if err := c.EncodeHeader(frame.Header, dest); err != nil { 69 return fmt.Errorf("cannot encode raw header: %w", err) 70 } else if _, err := dest.Write(frame.Body); err != nil { 71 return fmt.Errorf("cannot write raw body: %w", err) 72 } 73 } 74 return nil 75 } 76 77 func (c *codec) EncodeHeader(header *Header, dest io.Writer) error { 78 useBetaFlag := header.Flags.Contains(primitive.HeaderFlagUseBeta) 79 if err := primitive.CheckSupportedProtocolVersion(header.Version); err != nil { 80 return NewProtocolVersionErr(err.Error(), header.Version, useBetaFlag) 81 } else if header.Version.IsBeta() && !useBetaFlag { 82 return NewProtocolVersionErr("expected USE_BETA flag to be set", header.Version, useBetaFlag) 83 } 84 85 versionAndDirection := uint8(header.Version) 86 if header.IsResponse { 87 versionAndDirection |= 0b1000_0000 88 } 89 if err := primitive.WriteByte(versionAndDirection, dest); err != nil { 90 return fmt.Errorf("cannot encode header version and direction: %w", err) 91 } else if err := primitive.WriteByte(uint8(header.Flags), dest); err != nil { 92 return fmt.Errorf("cannot encode header flags: %w", err) 93 } else if err = primitive.WriteStreamId(header.StreamId, dest, header.Version); err != nil { 94 return fmt.Errorf("cannot encode header stream id: %w", err) 95 } else if err = primitive.WriteByte(uint8(header.OpCode), dest); err != nil { 96 return fmt.Errorf("cannot encode header opcode: %w", err) 97 } else if err = primitive.WriteInt(header.BodyLength, dest); err != nil { 98 return fmt.Errorf("cannot encode header body length: %w", err) 99 } 100 return nil 101 } 102 103 func (c *codec) EncodeBody(header *Header, body *Body, dest io.Writer) error { 104 if header.OpCode != body.Message.GetOpCode() { 105 return fmt.Errorf("opcode mismatch between header and body: %d != %d", header.OpCode, body.Message.GetOpCode()) 106 } else if header.Flags.Contains(primitive.HeaderFlagCompressed) { 107 if c.compressor == nil { 108 return errors.New("cannot compress body: no compressor available") 109 } else if uncompressedBodyLength, err := c.uncompressedBodyLength(header, body); err != nil { 110 return fmt.Errorf("cannot compute length of uncompressed message body: %w", err) 111 } else { 112 uncompressedBody := bytes.NewBuffer(make([]byte, 0, uncompressedBodyLength)) 113 if err = c.encodeBodyUncompressed(header, body, uncompressedBody); err != nil { 114 return fmt.Errorf("cannot encode body: %w", err) 115 } else if err := c.compressor.CompressWithLength(uncompressedBody, dest); err != nil { 116 return fmt.Errorf("cannot compress body: %w", err) 117 } 118 return nil 119 } 120 } else { 121 return c.encodeBodyUncompressed(header, body, dest) 122 } 123 } 124 125 func (c *codec) encodeBodyUncompressed(header *Header, body *Body, dest io.Writer) (err error) { 126 if header.Flags.Contains(primitive.HeaderFlagTracing) && body.Message.IsResponse() { 127 if err = primitive.WriteUuid(body.TracingId, dest); err != nil { 128 return fmt.Errorf("cannot encode body tracing id: %w", err) 129 } 130 } 131 if header.Flags.Contains(primitive.HeaderFlagCustomPayload) { 132 if header.Version < primitive.ProtocolVersion4 { 133 return fmt.Errorf("custom payloads are not supported in protocol version %v", header.Version) 134 } else if err = primitive.WriteBytesMap(body.CustomPayload, dest); err != nil { 135 return fmt.Errorf("cannot encode body custom payload: %w", err) 136 } 137 } 138 if header.Flags.Contains(primitive.HeaderFlagWarning) { 139 if header.Version < primitive.ProtocolVersion4 && body.Warnings != nil { 140 return fmt.Errorf("warnings are not supported in protocol version %v", header.Version) 141 } else if err = primitive.WriteStringList(body.Warnings, dest); err != nil { 142 return fmt.Errorf("cannot encode body warnings: %w", err) 143 } 144 } 145 if encoder, err := c.findMessageCodec(body.Message.GetOpCode()); err != nil { 146 return err 147 } else if err = encoder.Encode(body.Message, dest, header.Version); err != nil { 148 return fmt.Errorf("cannot encode body message: %w", err) 149 } 150 return nil 151 } 152 153 func (c *codec) uncompressedBodyLength(header *Header, body *Body) (length int, err error) { 154 if encoder, err := c.findMessageCodec(body.Message.GetOpCode()); err != nil { 155 return -1, err 156 } else if length, err = encoder.EncodedLength(body.Message, header.Version); err != nil { 157 return -1, fmt.Errorf("cannot compute message length: %w", err) 158 } 159 if header.Flags.Contains(primitive.HeaderFlagTracing) { 160 length += primitive.LengthOfUuid 161 } 162 if header.Flags.Contains(primitive.HeaderFlagCustomPayload) { 163 length += primitive.LengthOfBytesMap(body.CustomPayload) 164 } 165 if header.Flags.Contains(primitive.HeaderFlagWarning) { 166 length += primitive.LengthOfStringList(body.Warnings) 167 } 168 return length, nil 169 }