github.com/datastax/go-cassandra-native-protocol@v0.0.0-20220706104457-5e8aad05cf90/message/batch.go (about)

     1  // Copyright 2020 DataStax
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package message
    16  
    17  import (
    18  	"errors"
    19  	"fmt"
    20  	"io"
    21  
    22  	"github.com/datastax/go-cassandra-native-protocol/primitive"
    23  )
    24  
    25  // Batch is a BATCH request message. The zero value is NOT a valid message; at least one batch child must be provided.
    26  // +k8s:deepcopy-gen=true
    27  // +k8s:deepcopy-gen:interfaces=github.com/datastax/go-cassandra-native-protocol/message.Message
    28  type Batch struct {
    29  	// The batch type: LOGGED, UNLOGGED or COUNTER. This field is mandatory; its default is LOGGED, as the zero value
    30  	// of primitive.BatchType is primitive.BatchTypeLogged. The LOGGED type is equivalent to a regular CQL3 batch
    31  	// statement CREATE BATCH ... APPLY BATCH.
    32  	Type primitive.BatchType
    33  	// This batch children statements. At least one batch child must be provided.
    34  	Children []*BatchChild
    35  	// The consistency level to use to execute the batch statements. This field is mandatory; its default is ANY, as the
    36  	// zero value of primitive.ConsistencyLevel is primitive.ConsistencyLevelAny.
    37  	Consistency primitive.ConsistencyLevel
    38  	// The (optional) serial consistency level to use when executing the query. Serial consistency is available
    39  	// starting with protocol version 3.
    40  	SerialConsistency *primitive.ConsistencyLevel
    41  	// The default timestamp for the query in microseconds (negative values are discouraged but supported for
    42  	// backward compatibility reasons except for the smallest negative value (-2^63) that is forbidden). If provided,
    43  	// this will replace the server-side assigned timestamp as default timestamp. Note that a timestamp in the query
    44  	// itself (that is, if the query has a USING TIMESTAMP clause) will still override this timestamp.
    45  	// Default timestamps are valid for protocol versions 3 and higher.
    46  	DefaultTimestamp *int64
    47  	// The keyspace in which to execute queries, when the target table name is not qualified. Optional. Introduced in
    48  	// Protocol Version 5, also present in DSE protocol v2.
    49  	Keyspace string
    50  	// Introduced in Protocol Version 5, not present in DSE protocol versions.
    51  	NowInSeconds *int32
    52  }
    53  
    54  func (m *Batch) IsResponse() bool {
    55  	return false
    56  }
    57  
    58  func (m *Batch) GetOpCode() primitive.OpCode {
    59  	return primitive.OpCodeBatch
    60  }
    61  
    62  func (m *Batch) String() string {
    63  	return fmt.Sprintf("BATCH (%d statements)", len(m.Children))
    64  }
    65  
    66  // Flags are the flags of this BATCH message. BATCH messages have flags starting with protocol version 3.
    67  func (m *Batch) Flags() primitive.QueryFlag {
    68  	var flags primitive.QueryFlag
    69  	if m.SerialConsistency != nil {
    70  		flags = flags.Add(primitive.QueryFlagSerialConsistency)
    71  	}
    72  	if m.DefaultTimestamp != nil {
    73  		flags = flags.Add(primitive.QueryFlagDefaultTimestamp)
    74  	}
    75  	if m.Keyspace != "" {
    76  		flags = flags.Add(primitive.QueryFlagWithKeyspace)
    77  	}
    78  	if m.NowInSeconds != nil {
    79  		flags = flags.Add(primitive.QueryFlagNowInSeconds)
    80  	}
    81  	// Note: the named values flag is in theory possible, but server-side implementation is
    82  	// broken. See https://issues.apache.org/jira/browse/CASSANDRA-10246
    83  	return flags
    84  }
    85  
    86  // BatchChild represents a BATCH child statement.
    87  // +k8s:deepcopy-gen=true
    88  type BatchChild struct {
    89  	// The CQL statement to execute. Exactly one of Query or Id must be present, never both.
    90  	Query string
    91  	// The prepared id of the statement to execute. Exactly one of Query or Id must be present, never both.
    92  	Id []byte
    93  	// Note: named values are in theory possible, but their server-side implementation is
    94  	// broken. See https://issues.apache.org/jira/browse/CASSANDRA-10246
    95  	Values []*primitive.Value
    96  }
    97  
    98  type batchCodec struct{}
    99  
   100  func (c *batchCodec) Encode(msg Message, dest io.Writer, version primitive.ProtocolVersion) (err error) {
   101  	batch, ok := msg.(*Batch)
   102  	if !ok {
   103  		return errors.New(fmt.Sprintf("expected *message.Batch, got %T", msg))
   104  	}
   105  	if err = primitive.CheckValidBatchType(batch.Type); err != nil {
   106  		return err
   107  	} else if err = primitive.WriteByte(uint8(batch.Type), dest); err != nil {
   108  		return fmt.Errorf("cannot write BATCH type: %w", err)
   109  	}
   110  	childrenCount := len(batch.Children)
   111  	if childrenCount > 0xFFFF {
   112  		return errors.New(fmt.Sprintf("BATCH messages can contain at most %d child queries", 0xFFFF))
   113  	} else if err = primitive.WriteShort(uint16(childrenCount), dest); err != nil {
   114  		return fmt.Errorf("cannot write BATCH query count: %w", err)
   115  	}
   116  	for i, child := range batch.Children {
   117  		if child.Query != "" {
   118  			if err = primitive.WriteByte(uint8(primitive.BatchChildTypeQueryString), dest); err != nil {
   119  				return fmt.Errorf("cannot write BATCH query kind 0 for child #%d: %w", i, err)
   120  			} else if child.Query == "" {
   121  				return fmt.Errorf("cannot write empty BATCH query string for child #%d", i)
   122  			} else if err = primitive.WriteLongString(child.Query, dest); err != nil {
   123  				return fmt.Errorf("cannot write BATCH query string for child #%d: %w", i, err)
   124  			}
   125  		} else {
   126  			if err = primitive.WriteByte(uint8(primitive.BatchChildTypePreparedId), dest); err != nil {
   127  				return fmt.Errorf("cannot write BATCH query kind 1 for child #%d: %w", i, err)
   128  			} else if len(child.Id) == 0 {
   129  				return fmt.Errorf("cannot write empty BATCH query id for child #%d: %w", i, err)
   130  			} else if err = primitive.WriteShortBytes(child.Id, dest); err != nil {
   131  				return fmt.Errorf("cannot write BATCH query id for child #%d: %w", i, err)
   132  			}
   133  		}
   134  		if err = primitive.WritePositionalValues(child.Values, dest, version); err != nil {
   135  			return fmt.Errorf("cannot write BATCH positional values for child #%d: %w", i, err)
   136  		}
   137  	}
   138  	if err = primitive.WriteShort(uint16(batch.Consistency), dest); err != nil {
   139  		return fmt.Errorf("cannot write BATCH consistency: %w", err)
   140  	}
   141  	if version.SupportsBatchQueryFlags() {
   142  		flags := batch.Flags()
   143  		if version.Uses4BytesQueryFlags() {
   144  			err = primitive.WriteInt(int32(flags), dest)
   145  		} else {
   146  			err = primitive.WriteByte(uint8(flags), dest)
   147  		}
   148  		if err != nil {
   149  			return fmt.Errorf("cannot write BATCH query flags: %w", err)
   150  		}
   151  		if version.SupportsQueryFlag(primitive.QueryFlagSerialConsistency) && flags.Contains(primitive.QueryFlagSerialConsistency) {
   152  			if err = primitive.WriteShort(uint16(*batch.SerialConsistency), dest); err != nil {
   153  				return fmt.Errorf("cannot write BATCH serial consistency: %w", err)
   154  			}
   155  		}
   156  		if version.SupportsQueryFlag(primitive.QueryFlagDefaultTimestamp) && flags.Contains(primitive.QueryFlagDefaultTimestamp) {
   157  			if err = primitive.WriteLong(*batch.DefaultTimestamp, dest); err != nil {
   158  				return fmt.Errorf("cannot write BATCH default timestamp: %w", err)
   159  			}
   160  		}
   161  		if version.SupportsQueryFlag(primitive.QueryFlagWithKeyspace) && flags.Contains(primitive.QueryFlagWithKeyspace) {
   162  			if batch.Keyspace == "" {
   163  				return errors.New("cannot write BATCH empty keyspace")
   164  			} else if err = primitive.WriteString(batch.Keyspace, dest); err != nil {
   165  				return fmt.Errorf("cannot write BATCH keyspace: %w", err)
   166  			}
   167  		}
   168  		if version.SupportsQueryFlag(primitive.QueryFlagNowInSeconds) && flags.Contains(primitive.QueryFlagNowInSeconds) {
   169  			if err = primitive.WriteInt(*batch.NowInSeconds, dest); err != nil {
   170  				return fmt.Errorf("cannot write BATCH now-in-seconds: %w", err)
   171  			}
   172  		}
   173  	}
   174  	return nil
   175  }
   176  
   177  func (c *batchCodec) EncodedLength(msg Message, version primitive.ProtocolVersion) (length int, err error) {
   178  	batch, ok := msg.(*Batch)
   179  	if !ok {
   180  		return -1, errors.New(fmt.Sprintf("expected *message.Batch, got %T", msg))
   181  	}
   182  	childrenCount := len(batch.Children)
   183  	if childrenCount > 0xFFFF {
   184  		return -1, errors.New(fmt.Sprintf("BATCH messages can contain at most %d queries", 0xFFFF))
   185  	}
   186  	length += primitive.LengthOfByte  // type
   187  	length += primitive.LengthOfShort // number of queries
   188  	for i, child := range batch.Children {
   189  		length += primitive.LengthOfByte // child type
   190  		if child.Query != "" {
   191  			length += primitive.LengthOfLongString(child.Query)
   192  		} else {
   193  			length += primitive.LengthOfShortBytes(child.Id)
   194  		}
   195  		if valuesLength, err := primitive.LengthOfPositionalValues(child.Values); err != nil {
   196  			return -1, fmt.Errorf("cannot compute length of BATCH positional values for child #%d: %w", i, err)
   197  		} else {
   198  			length += valuesLength
   199  		}
   200  	}
   201  	length += primitive.LengthOfShort // consistency level
   202  	// flags
   203  	if version.SupportsBatchQueryFlags() {
   204  		if version.Uses4BytesQueryFlags() {
   205  			length += primitive.LengthOfInt
   206  		} else {
   207  			length += primitive.LengthOfByte
   208  		}
   209  		flags := batch.Flags()
   210  		if version.SupportsQueryFlag(primitive.QueryFlagSerialConsistency) && flags.Contains(primitive.QueryFlagSerialConsistency) {
   211  			length += primitive.LengthOfShort
   212  		}
   213  		if version.SupportsQueryFlag(primitive.QueryFlagDefaultTimestamp) && flags.Contains(primitive.QueryFlagDefaultTimestamp) {
   214  			length += primitive.LengthOfLong
   215  		}
   216  		if version.SupportsQueryFlag(primitive.QueryFlagWithKeyspace) && flags.Contains(primitive.QueryFlagWithKeyspace) {
   217  			length += primitive.LengthOfString(batch.Keyspace)
   218  		}
   219  		if version.SupportsQueryFlag(primitive.QueryFlagNowInSeconds) && flags.Contains(primitive.QueryFlagNowInSeconds) {
   220  			length += primitive.LengthOfInt
   221  		}
   222  	}
   223  	return length, nil
   224  }
   225  
   226  func (c *batchCodec) Decode(source io.Reader, version primitive.ProtocolVersion) (msg Message, err error) {
   227  	batch := &Batch{}
   228  	var batchType uint8
   229  	if batchType, err = primitive.ReadByte(source); err != nil {
   230  		return nil, fmt.Errorf("cannot read BATCH type: %w", err)
   231  	}
   232  	batch.Type = primitive.BatchType(batchType)
   233  	if err = primitive.CheckValidBatchType(batch.Type); err != nil {
   234  		return nil, err
   235  	}
   236  	var childrenCount uint16
   237  	if childrenCount, err = primitive.ReadShort(source); err != nil {
   238  		return nil, fmt.Errorf("cannot read BATCH query count: %w", err)
   239  	}
   240  	batch.Children = make([]*BatchChild, childrenCount)
   241  	for i := 0; i < int(childrenCount); i++ {
   242  		var childType uint8
   243  		if childType, err = primitive.ReadByte(source); err != nil {
   244  			return nil, fmt.Errorf("cannot read BATCH child type for child #%d: %w", i, err)
   245  		}
   246  		var child = &BatchChild{}
   247  		switch primitive.BatchChildType(childType) {
   248  		case primitive.BatchChildTypeQueryString:
   249  			if child.Query, err = primitive.ReadLongString(source); err != nil {
   250  				return nil, fmt.Errorf("cannot read BATCH query string for child #%d: %w", i, err)
   251  			}
   252  		case primitive.BatchChildTypePreparedId:
   253  			if child.Id, err = primitive.ReadShortBytes(source); err != nil {
   254  				return nil, fmt.Errorf("cannot read BATCH query id for child #%d: %w", i, err)
   255  			}
   256  		default:
   257  			return nil, fmt.Errorf("unsupported BATCH child type for child #%d: %v", i, childType)
   258  		}
   259  		if child.Values, err = primitive.ReadPositionalValues(source, version); err != nil {
   260  			return nil, fmt.Errorf("cannot read BATCH positional values for child #%d: %w", i, err)
   261  		}
   262  		batch.Children[i] = child
   263  	}
   264  	var batchConsistency uint16
   265  	if batchConsistency, err = primitive.ReadShort(source); err != nil {
   266  		return nil, fmt.Errorf("cannot read BATCH consistency: %w", err)
   267  	}
   268  	batch.Consistency = primitive.ConsistencyLevel(batchConsistency)
   269  	if version.SupportsBatchQueryFlags() {
   270  		var flags primitive.QueryFlag
   271  		if version.Uses4BytesQueryFlags() {
   272  			var f int32
   273  			f, err = primitive.ReadInt(source)
   274  			flags = primitive.QueryFlag(f)
   275  		} else {
   276  			var f uint8
   277  			f, err = primitive.ReadByte(source)
   278  			flags = primitive.QueryFlag(f)
   279  		}
   280  		if err != nil {
   281  			return nil, fmt.Errorf("cannot read BATCH query flags: %w", err)
   282  		}
   283  		if flags.Contains(primitive.QueryFlagValueNames) {
   284  			return nil, errors.New("cannot use BATCH with named values, see CASSANDRA-10246")
   285  		}
   286  		if flags.Contains(primitive.QueryFlagSerialConsistency) {
   287  			var batchSerialConsistencyUint uint16
   288  			if batchSerialConsistencyUint, err = primitive.ReadShort(source); err != nil {
   289  				return nil, fmt.Errorf("cannot read BATCH serial consistency: %w", err)
   290  			}
   291  			batchSerialConsistency := primitive.ConsistencyLevel(batchSerialConsistencyUint)
   292  			batch.SerialConsistency = &batchSerialConsistency
   293  		}
   294  		if flags.Contains(primitive.QueryFlagDefaultTimestamp) {
   295  			var batchDefaultTimestamp int64
   296  			if batchDefaultTimestamp, err = primitive.ReadLong(source); err != nil {
   297  				return nil, fmt.Errorf("cannot read BATCH default timestamp: %w", err)
   298  			}
   299  			batch.DefaultTimestamp = &batchDefaultTimestamp
   300  		}
   301  		if version.SupportsQueryFlag(primitive.QueryFlagWithKeyspace) && flags.Contains(primitive.QueryFlagWithKeyspace) {
   302  			if batch.Keyspace, err = primitive.ReadString(source); err != nil {
   303  				return nil, fmt.Errorf("cannot read BATCH keyspace: %w", err)
   304  			}
   305  		}
   306  		if version.SupportsQueryFlag(primitive.QueryFlagNowInSeconds) && flags.Contains(primitive.QueryFlagNowInSeconds) {
   307  			var batchNowInSeconds int32
   308  			if batchNowInSeconds, err = primitive.ReadInt(source); err != nil {
   309  				return nil, fmt.Errorf("cannot read BATCH now-in-seconds: %w", err)
   310  			}
   311  			batch.NowInSeconds = &batchNowInSeconds
   312  		}
   313  	}
   314  	return batch, nil
   315  }
   316  
   317  func (c *batchCodec) GetOpCode() primitive.OpCode {
   318  	return primitive.OpCodeBatch
   319  }