github.com/datastax/go-cassandra-native-protocol@v0.0.0-20220706104457-5e8aad05cf90/message/batch.go (about) 1 // Copyright 2020 DataStax 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package message 16 17 import ( 18 "errors" 19 "fmt" 20 "io" 21 22 "github.com/datastax/go-cassandra-native-protocol/primitive" 23 ) 24 25 // Batch is a BATCH request message. The zero value is NOT a valid message; at least one batch child must be provided. 26 // +k8s:deepcopy-gen=true 27 // +k8s:deepcopy-gen:interfaces=github.com/datastax/go-cassandra-native-protocol/message.Message 28 type Batch struct { 29 // The batch type: LOGGED, UNLOGGED or COUNTER. This field is mandatory; its default is LOGGED, as the zero value 30 // of primitive.BatchType is primitive.BatchTypeLogged. The LOGGED type is equivalent to a regular CQL3 batch 31 // statement CREATE BATCH ... APPLY BATCH. 32 Type primitive.BatchType 33 // This batch children statements. At least one batch child must be provided. 34 Children []*BatchChild 35 // The consistency level to use to execute the batch statements. This field is mandatory; its default is ANY, as the 36 // zero value of primitive.ConsistencyLevel is primitive.ConsistencyLevelAny. 37 Consistency primitive.ConsistencyLevel 38 // The (optional) serial consistency level to use when executing the query. Serial consistency is available 39 // starting with protocol version 3. 40 SerialConsistency *primitive.ConsistencyLevel 41 // The default timestamp for the query in microseconds (negative values are discouraged but supported for 42 // backward compatibility reasons except for the smallest negative value (-2^63) that is forbidden). If provided, 43 // this will replace the server-side assigned timestamp as default timestamp. Note that a timestamp in the query 44 // itself (that is, if the query has a USING TIMESTAMP clause) will still override this timestamp. 45 // Default timestamps are valid for protocol versions 3 and higher. 46 DefaultTimestamp *int64 47 // The keyspace in which to execute queries, when the target table name is not qualified. Optional. Introduced in 48 // Protocol Version 5, also present in DSE protocol v2. 49 Keyspace string 50 // Introduced in Protocol Version 5, not present in DSE protocol versions. 51 NowInSeconds *int32 52 } 53 54 func (m *Batch) IsResponse() bool { 55 return false 56 } 57 58 func (m *Batch) GetOpCode() primitive.OpCode { 59 return primitive.OpCodeBatch 60 } 61 62 func (m *Batch) String() string { 63 return fmt.Sprintf("BATCH (%d statements)", len(m.Children)) 64 } 65 66 // Flags are the flags of this BATCH message. BATCH messages have flags starting with protocol version 3. 67 func (m *Batch) Flags() primitive.QueryFlag { 68 var flags primitive.QueryFlag 69 if m.SerialConsistency != nil { 70 flags = flags.Add(primitive.QueryFlagSerialConsistency) 71 } 72 if m.DefaultTimestamp != nil { 73 flags = flags.Add(primitive.QueryFlagDefaultTimestamp) 74 } 75 if m.Keyspace != "" { 76 flags = flags.Add(primitive.QueryFlagWithKeyspace) 77 } 78 if m.NowInSeconds != nil { 79 flags = flags.Add(primitive.QueryFlagNowInSeconds) 80 } 81 // Note: the named values flag is in theory possible, but server-side implementation is 82 // broken. See https://issues.apache.org/jira/browse/CASSANDRA-10246 83 return flags 84 } 85 86 // BatchChild represents a BATCH child statement. 87 // +k8s:deepcopy-gen=true 88 type BatchChild struct { 89 // The CQL statement to execute. Exactly one of Query or Id must be present, never both. 90 Query string 91 // The prepared id of the statement to execute. Exactly one of Query or Id must be present, never both. 92 Id []byte 93 // Note: named values are in theory possible, but their server-side implementation is 94 // broken. See https://issues.apache.org/jira/browse/CASSANDRA-10246 95 Values []*primitive.Value 96 } 97 98 type batchCodec struct{} 99 100 func (c *batchCodec) Encode(msg Message, dest io.Writer, version primitive.ProtocolVersion) (err error) { 101 batch, ok := msg.(*Batch) 102 if !ok { 103 return errors.New(fmt.Sprintf("expected *message.Batch, got %T", msg)) 104 } 105 if err = primitive.CheckValidBatchType(batch.Type); err != nil { 106 return err 107 } else if err = primitive.WriteByte(uint8(batch.Type), dest); err != nil { 108 return fmt.Errorf("cannot write BATCH type: %w", err) 109 } 110 childrenCount := len(batch.Children) 111 if childrenCount > 0xFFFF { 112 return errors.New(fmt.Sprintf("BATCH messages can contain at most %d child queries", 0xFFFF)) 113 } else if err = primitive.WriteShort(uint16(childrenCount), dest); err != nil { 114 return fmt.Errorf("cannot write BATCH query count: %w", err) 115 } 116 for i, child := range batch.Children { 117 if child.Query != "" { 118 if err = primitive.WriteByte(uint8(primitive.BatchChildTypeQueryString), dest); err != nil { 119 return fmt.Errorf("cannot write BATCH query kind 0 for child #%d: %w", i, err) 120 } else if child.Query == "" { 121 return fmt.Errorf("cannot write empty BATCH query string for child #%d", i) 122 } else if err = primitive.WriteLongString(child.Query, dest); err != nil { 123 return fmt.Errorf("cannot write BATCH query string for child #%d: %w", i, err) 124 } 125 } else { 126 if err = primitive.WriteByte(uint8(primitive.BatchChildTypePreparedId), dest); err != nil { 127 return fmt.Errorf("cannot write BATCH query kind 1 for child #%d: %w", i, err) 128 } else if len(child.Id) == 0 { 129 return fmt.Errorf("cannot write empty BATCH query id for child #%d: %w", i, err) 130 } else if err = primitive.WriteShortBytes(child.Id, dest); err != nil { 131 return fmt.Errorf("cannot write BATCH query id for child #%d: %w", i, err) 132 } 133 } 134 if err = primitive.WritePositionalValues(child.Values, dest, version); err != nil { 135 return fmt.Errorf("cannot write BATCH positional values for child #%d: %w", i, err) 136 } 137 } 138 if err = primitive.WriteShort(uint16(batch.Consistency), dest); err != nil { 139 return fmt.Errorf("cannot write BATCH consistency: %w", err) 140 } 141 if version.SupportsBatchQueryFlags() { 142 flags := batch.Flags() 143 if version.Uses4BytesQueryFlags() { 144 err = primitive.WriteInt(int32(flags), dest) 145 } else { 146 err = primitive.WriteByte(uint8(flags), dest) 147 } 148 if err != nil { 149 return fmt.Errorf("cannot write BATCH query flags: %w", err) 150 } 151 if version.SupportsQueryFlag(primitive.QueryFlagSerialConsistency) && flags.Contains(primitive.QueryFlagSerialConsistency) { 152 if err = primitive.WriteShort(uint16(*batch.SerialConsistency), dest); err != nil { 153 return fmt.Errorf("cannot write BATCH serial consistency: %w", err) 154 } 155 } 156 if version.SupportsQueryFlag(primitive.QueryFlagDefaultTimestamp) && flags.Contains(primitive.QueryFlagDefaultTimestamp) { 157 if err = primitive.WriteLong(*batch.DefaultTimestamp, dest); err != nil { 158 return fmt.Errorf("cannot write BATCH default timestamp: %w", err) 159 } 160 } 161 if version.SupportsQueryFlag(primitive.QueryFlagWithKeyspace) && flags.Contains(primitive.QueryFlagWithKeyspace) { 162 if batch.Keyspace == "" { 163 return errors.New("cannot write BATCH empty keyspace") 164 } else if err = primitive.WriteString(batch.Keyspace, dest); err != nil { 165 return fmt.Errorf("cannot write BATCH keyspace: %w", err) 166 } 167 } 168 if version.SupportsQueryFlag(primitive.QueryFlagNowInSeconds) && flags.Contains(primitive.QueryFlagNowInSeconds) { 169 if err = primitive.WriteInt(*batch.NowInSeconds, dest); err != nil { 170 return fmt.Errorf("cannot write BATCH now-in-seconds: %w", err) 171 } 172 } 173 } 174 return nil 175 } 176 177 func (c *batchCodec) EncodedLength(msg Message, version primitive.ProtocolVersion) (length int, err error) { 178 batch, ok := msg.(*Batch) 179 if !ok { 180 return -1, errors.New(fmt.Sprintf("expected *message.Batch, got %T", msg)) 181 } 182 childrenCount := len(batch.Children) 183 if childrenCount > 0xFFFF { 184 return -1, errors.New(fmt.Sprintf("BATCH messages can contain at most %d queries", 0xFFFF)) 185 } 186 length += primitive.LengthOfByte // type 187 length += primitive.LengthOfShort // number of queries 188 for i, child := range batch.Children { 189 length += primitive.LengthOfByte // child type 190 if child.Query != "" { 191 length += primitive.LengthOfLongString(child.Query) 192 } else { 193 length += primitive.LengthOfShortBytes(child.Id) 194 } 195 if valuesLength, err := primitive.LengthOfPositionalValues(child.Values); err != nil { 196 return -1, fmt.Errorf("cannot compute length of BATCH positional values for child #%d: %w", i, err) 197 } else { 198 length += valuesLength 199 } 200 } 201 length += primitive.LengthOfShort // consistency level 202 // flags 203 if version.SupportsBatchQueryFlags() { 204 if version.Uses4BytesQueryFlags() { 205 length += primitive.LengthOfInt 206 } else { 207 length += primitive.LengthOfByte 208 } 209 flags := batch.Flags() 210 if version.SupportsQueryFlag(primitive.QueryFlagSerialConsistency) && flags.Contains(primitive.QueryFlagSerialConsistency) { 211 length += primitive.LengthOfShort 212 } 213 if version.SupportsQueryFlag(primitive.QueryFlagDefaultTimestamp) && flags.Contains(primitive.QueryFlagDefaultTimestamp) { 214 length += primitive.LengthOfLong 215 } 216 if version.SupportsQueryFlag(primitive.QueryFlagWithKeyspace) && flags.Contains(primitive.QueryFlagWithKeyspace) { 217 length += primitive.LengthOfString(batch.Keyspace) 218 } 219 if version.SupportsQueryFlag(primitive.QueryFlagNowInSeconds) && flags.Contains(primitive.QueryFlagNowInSeconds) { 220 length += primitive.LengthOfInt 221 } 222 } 223 return length, nil 224 } 225 226 func (c *batchCodec) Decode(source io.Reader, version primitive.ProtocolVersion) (msg Message, err error) { 227 batch := &Batch{} 228 var batchType uint8 229 if batchType, err = primitive.ReadByte(source); err != nil { 230 return nil, fmt.Errorf("cannot read BATCH type: %w", err) 231 } 232 batch.Type = primitive.BatchType(batchType) 233 if err = primitive.CheckValidBatchType(batch.Type); err != nil { 234 return nil, err 235 } 236 var childrenCount uint16 237 if childrenCount, err = primitive.ReadShort(source); err != nil { 238 return nil, fmt.Errorf("cannot read BATCH query count: %w", err) 239 } 240 batch.Children = make([]*BatchChild, childrenCount) 241 for i := 0; i < int(childrenCount); i++ { 242 var childType uint8 243 if childType, err = primitive.ReadByte(source); err != nil { 244 return nil, fmt.Errorf("cannot read BATCH child type for child #%d: %w", i, err) 245 } 246 var child = &BatchChild{} 247 switch primitive.BatchChildType(childType) { 248 case primitive.BatchChildTypeQueryString: 249 if child.Query, err = primitive.ReadLongString(source); err != nil { 250 return nil, fmt.Errorf("cannot read BATCH query string for child #%d: %w", i, err) 251 } 252 case primitive.BatchChildTypePreparedId: 253 if child.Id, err = primitive.ReadShortBytes(source); err != nil { 254 return nil, fmt.Errorf("cannot read BATCH query id for child #%d: %w", i, err) 255 } 256 default: 257 return nil, fmt.Errorf("unsupported BATCH child type for child #%d: %v", i, childType) 258 } 259 if child.Values, err = primitive.ReadPositionalValues(source, version); err != nil { 260 return nil, fmt.Errorf("cannot read BATCH positional values for child #%d: %w", i, err) 261 } 262 batch.Children[i] = child 263 } 264 var batchConsistency uint16 265 if batchConsistency, err = primitive.ReadShort(source); err != nil { 266 return nil, fmt.Errorf("cannot read BATCH consistency: %w", err) 267 } 268 batch.Consistency = primitive.ConsistencyLevel(batchConsistency) 269 if version.SupportsBatchQueryFlags() { 270 var flags primitive.QueryFlag 271 if version.Uses4BytesQueryFlags() { 272 var f int32 273 f, err = primitive.ReadInt(source) 274 flags = primitive.QueryFlag(f) 275 } else { 276 var f uint8 277 f, err = primitive.ReadByte(source) 278 flags = primitive.QueryFlag(f) 279 } 280 if err != nil { 281 return nil, fmt.Errorf("cannot read BATCH query flags: %w", err) 282 } 283 if flags.Contains(primitive.QueryFlagValueNames) { 284 return nil, errors.New("cannot use BATCH with named values, see CASSANDRA-10246") 285 } 286 if flags.Contains(primitive.QueryFlagSerialConsistency) { 287 var batchSerialConsistencyUint uint16 288 if batchSerialConsistencyUint, err = primitive.ReadShort(source); err != nil { 289 return nil, fmt.Errorf("cannot read BATCH serial consistency: %w", err) 290 } 291 batchSerialConsistency := primitive.ConsistencyLevel(batchSerialConsistencyUint) 292 batch.SerialConsistency = &batchSerialConsistency 293 } 294 if flags.Contains(primitive.QueryFlagDefaultTimestamp) { 295 var batchDefaultTimestamp int64 296 if batchDefaultTimestamp, err = primitive.ReadLong(source); err != nil { 297 return nil, fmt.Errorf("cannot read BATCH default timestamp: %w", err) 298 } 299 batch.DefaultTimestamp = &batchDefaultTimestamp 300 } 301 if version.SupportsQueryFlag(primitive.QueryFlagWithKeyspace) && flags.Contains(primitive.QueryFlagWithKeyspace) { 302 if batch.Keyspace, err = primitive.ReadString(source); err != nil { 303 return nil, fmt.Errorf("cannot read BATCH keyspace: %w", err) 304 } 305 } 306 if version.SupportsQueryFlag(primitive.QueryFlagNowInSeconds) && flags.Contains(primitive.QueryFlagNowInSeconds) { 307 var batchNowInSeconds int32 308 if batchNowInSeconds, err = primitive.ReadInt(source); err != nil { 309 return nil, fmt.Errorf("cannot read BATCH now-in-seconds: %w", err) 310 } 311 batch.NowInSeconds = &batchNowInSeconds 312 } 313 } 314 return batch, nil 315 } 316 317 func (c *batchCodec) GetOpCode() primitive.OpCode { 318 return primitive.OpCodeBatch 319 }