github.com/datastax/go-cassandra-native-protocol@v0.0.0-20220706104457-5e8aad05cf90/message/result_metadata.go (about)

     1  // Copyright 2020 DataStax
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package message
    16  
    17  import (
    18  	"fmt"
    19  	"io"
    20  
    21  	"github.com/datastax/go-cassandra-native-protocol/datatype"
    22  	"github.com/datastax/go-cassandra-native-protocol/primitive"
    23  )
    24  
    25  // ColumnMetadata represents a column in a PreparedResult message.
    26  // +k8s:deepcopy-gen=true
    27  type ColumnMetadata struct {
    28  	Keyspace string
    29  	Table    string
    30  	Name     string
    31  	Index    int32
    32  	Type     datatype.DataType
    33  }
    34  
    35  // VariablesMetadata is used in PreparedResult to indicate metadata about the prepared statement's bound variables.
    36  // +k8s:deepcopy-gen=true
    37  type VariablesMetadata struct {
    38  	// The indices of variables belonging to the table's partition key, if any. Valid from protocol version 4 onwards;
    39  	// will be nil for protocol versions lesser than 4.
    40  	PkIndices []uint16
    41  	Columns   []*ColumnMetadata
    42  }
    43  
    44  func (rm *VariablesMetadata) Flags() (flag primitive.VariablesFlag) {
    45  	if len(rm.Columns) > 0 && haveSameTable(rm.Columns) {
    46  		flag |= primitive.VariablesFlagGlobalTablesSpec
    47  	}
    48  	return flag
    49  }
    50  
    51  // RowsMetadata is used in RowsResult to indicate metadata about the result set present in the result response;
    52  // and in PreparedResult, to indicate metadata about the result set that the prepared statement will produce once
    53  // executed.
    54  // +k8s:deepcopy-gen=true
    55  type RowsMetadata struct {
    56  	// Must be always present, even when Columns is nil. If Columns is non-nil, the value of ColumnCount must match
    57  	// len(Columns), otherwise an error is returned when encoding.
    58  	ColumnCount int32
    59  	// PagingState is a [bytes] value. If provided, this means that this page of results is not the last page..
    60  	PagingState []byte
    61  	// Valid for protocol version 5 and DSE protocol version 2 only.
    62  	NewResultMetadataId []byte
    63  	// Valid for DSE protocol versions only.
    64  	ContinuousPageNumber int32
    65  	// Valid for DSE protocol versions only.
    66  	LastContinuousPage bool
    67  	// If nil, the NO_METADATA flag is set. In a PreparedResult, will be non-nil if the statement is a SELECT.
    68  	Columns []*ColumnMetadata
    69  }
    70  
    71  func (rm *RowsMetadata) Flags() (flag primitive.RowsFlag) {
    72  	if len(rm.Columns) == 0 {
    73  		flag |= primitive.RowsFlagNoMetadata
    74  	} else if haveSameTable(rm.Columns) {
    75  		flag |= primitive.RowsFlagGlobalTablesSpec
    76  	}
    77  	if rm.PagingState != nil {
    78  		flag |= primitive.RowsFlagHasMorePages
    79  	}
    80  	if rm.NewResultMetadataId != nil {
    81  		flag |= primitive.RowsFlagMetadataChanged
    82  	}
    83  	if rm.ContinuousPageNumber > 0 {
    84  		flag |= primitive.RowsFlagDseContinuousPaging
    85  		if rm.LastContinuousPage {
    86  			flag |= primitive.RowsFlagDseLastContinuousPage
    87  		}
    88  	}
    89  	return flag
    90  }
    91  
    92  func encodeVariablesMetadata(metadata *VariablesMetadata, dest io.Writer, version primitive.ProtocolVersion) (err error) {
    93  	if metadata == nil {
    94  		metadata = &VariablesMetadata{}
    95  	}
    96  	flags := metadata.Flags()
    97  	if err = primitive.WriteInt(int32(flags), dest); err != nil {
    98  		return fmt.Errorf("cannot write RESULT Prepared variables metadata flags: %w", err)
    99  	}
   100  	if err = primitive.WriteInt(int32(len(metadata.Columns)), dest); err != nil {
   101  		return fmt.Errorf("cannot write RESULT Prepared variables metadata column count: %w", err)
   102  	}
   103  	if version >= primitive.ProtocolVersion4 {
   104  		if err = primitive.WriteInt(int32(len(metadata.PkIndices)), dest); err != nil {
   105  			return fmt.Errorf("cannot write RESULT Prepared variables metadata pk indices length: %w", err)
   106  		}
   107  		for i, idx := range metadata.PkIndices {
   108  			if err = primitive.WriteShort(idx, dest); err != nil {
   109  				return fmt.Errorf("cannot write RESULT Prepared variables metadata pk indices element %d: %w", i, err)
   110  			}
   111  		}
   112  	}
   113  	if len(metadata.Columns) > 0 {
   114  		globalTableSpec := flags.Contains(primitive.VariablesFlagGlobalTablesSpec)
   115  		if err = encodeColumnsMetadata(globalTableSpec, metadata.Columns, dest, version); err != nil {
   116  			return fmt.Errorf("cannot write RESULT Prepared variables metadata column cols: %w", err)
   117  		}
   118  	}
   119  	return nil
   120  }
   121  
   122  func lengthOfVariablesMetadata(metadata *VariablesMetadata, version primitive.ProtocolVersion) (length int, err error) {
   123  	if metadata == nil {
   124  		metadata = &VariablesMetadata{}
   125  	}
   126  	length += primitive.LengthOfInt // flags
   127  	length += primitive.LengthOfInt // column count
   128  	if version >= primitive.ProtocolVersion4 {
   129  		length += primitive.LengthOfInt // pk count
   130  		length += primitive.LengthOfShort * len(metadata.PkIndices)
   131  	}
   132  	if len(metadata.Columns) > 0 {
   133  		globalTableSpec := metadata.Flags()&primitive.VariablesFlagGlobalTablesSpec > 0
   134  		var lcs int
   135  		if lcs, err = lengthOfColumnsMetadata(globalTableSpec, metadata.Columns, version); err != nil {
   136  			return -1, fmt.Errorf("cannot compute length of RESULT Prepared variables metadata column cols: %w", err)
   137  		}
   138  		length += lcs
   139  	}
   140  	return length, nil
   141  }
   142  
   143  func decodeVariablesMetadata(source io.Reader, version primitive.ProtocolVersion) (metadata *VariablesMetadata, err error) {
   144  	metadata = &VariablesMetadata{}
   145  	var f int32
   146  	if f, err = primitive.ReadInt(source); err != nil {
   147  		return nil, fmt.Errorf("cannot read RESULT Prepared variables metadata flags: %w", err)
   148  	}
   149  	var flags = primitive.VariablesFlag(f)
   150  	var columnCount int32
   151  	if columnCount, err = primitive.ReadInt(source); err != nil {
   152  		return nil, fmt.Errorf("cannot read RESULT Prepared variables metadata column count: %w", err)
   153  	}
   154  	if version >= primitive.ProtocolVersion4 {
   155  		var pkCount int32
   156  		if pkCount, err = primitive.ReadInt(source); err != nil {
   157  			return nil, fmt.Errorf("cannot read RESULT Prepared variables metadata pk indices length: %w", err)
   158  		}
   159  		if pkCount > 0 {
   160  			metadata.PkIndices = make([]uint16, pkCount)
   161  			for i := 0; i < int(pkCount); i++ {
   162  				if metadata.PkIndices[i], err = primitive.ReadShort(source); err != nil {
   163  					return nil, fmt.Errorf("cannot read RESULT Prepared variables metadata pk index element %d: %w", i, err)
   164  				}
   165  			}
   166  		}
   167  	}
   168  	if columnCount > 0 {
   169  		globalTableSpec := flags.Contains(primitive.VariablesFlagGlobalTablesSpec)
   170  		if metadata.Columns, err = decodeColumnsMetadata(globalTableSpec, columnCount, source, version); err != nil {
   171  			return nil, fmt.Errorf("cannot read RESULT Prepared variables metadata column cols: %w", err)
   172  		}
   173  	}
   174  	return metadata, nil
   175  }
   176  
   177  func encodeRowsMetadata(metadata *RowsMetadata, dest io.Writer, version primitive.ProtocolVersion) (err error) {
   178  	if metadata == nil {
   179  		metadata = &RowsMetadata{}
   180  	}
   181  	flags := metadata.Flags()
   182  	if err = primitive.WriteInt(int32(flags), dest); err != nil {
   183  		return fmt.Errorf("cannot write RESULT Rows metadata flags: %w", err)
   184  	}
   185  	columnSpecsLength := len(metadata.Columns)
   186  	if columnSpecsLength > 0 && int(metadata.ColumnCount) != columnSpecsLength {
   187  		return fmt.Errorf(
   188  			"invalid RESULT Rows metadata: metadata.ColumnCount %d != len(metadata.ColumnSpecs) %d",
   189  			metadata.ColumnCount,
   190  			columnSpecsLength,
   191  		)
   192  	}
   193  	if err = primitive.WriteInt(metadata.ColumnCount, dest); err != nil {
   194  		return fmt.Errorf("cannot write RESULT Rows metadata column count: %w", err)
   195  	}
   196  	if flags.Contains(primitive.RowsFlagHasMorePages) {
   197  		if err = primitive.WriteBytes(metadata.PagingState, dest); err != nil {
   198  			return fmt.Errorf("cannot write RESULT Rows metadata paging state: %w", err)
   199  		}
   200  	}
   201  	if flags.Contains(primitive.RowsFlagMetadataChanged) {
   202  		if err = primitive.WriteShortBytes(metadata.NewResultMetadataId, dest); err != nil {
   203  			return fmt.Errorf("cannot write RESULT Rows metadata new result metadata id: %w", err)
   204  		}
   205  	}
   206  	if flags.Contains(primitive.RowsFlagDseContinuousPaging) {
   207  		if err = primitive.WriteInt(metadata.ContinuousPageNumber, dest); err != nil {
   208  			return fmt.Errorf("cannot write RESULT Rows metadata continuous page number: %w", err)
   209  		}
   210  	}
   211  	if flags&primitive.RowsFlagNoMetadata == 0 && columnSpecsLength > 0 {
   212  		globalTableSpec := flags.Contains(primitive.RowsFlagGlobalTablesSpec)
   213  		if err = encodeColumnsMetadata(globalTableSpec, metadata.Columns, dest, version); err != nil {
   214  			return fmt.Errorf("cannot write RESULT Rows metadata column specs: %w", err)
   215  		}
   216  	}
   217  	return nil
   218  }
   219  
   220  func lengthOfRowsMetadata(metadata *RowsMetadata, version primitive.ProtocolVersion) (length int, err error) {
   221  	if metadata == nil {
   222  		metadata = &RowsMetadata{}
   223  	}
   224  	length += primitive.LengthOfInt // flags
   225  	length += primitive.LengthOfInt // column count
   226  	flags := metadata.Flags()
   227  	if flags.Contains(primitive.RowsFlagHasMorePages) {
   228  		length += primitive.LengthOfBytes(metadata.PagingState)
   229  	}
   230  	if flags.Contains(primitive.RowsFlagMetadataChanged) {
   231  		length += primitive.LengthOfShortBytes(metadata.NewResultMetadataId)
   232  	}
   233  	if flags.Contains(primitive.RowsFlagDseContinuousPaging) {
   234  		length += primitive.LengthOfInt // continuous page number
   235  	}
   236  	if flags&primitive.RowsFlagNoMetadata == 0 && len(metadata.Columns) > 0 {
   237  		globalTableSpec := flags.Contains(primitive.RowsFlagGlobalTablesSpec)
   238  		var lengthOfCols int
   239  		if lengthOfCols, err = lengthOfColumnsMetadata(globalTableSpec, metadata.Columns, version); err != nil {
   240  			return -1, fmt.Errorf("cannot compute length of RESULT Rows metadata column cols: %w", err)
   241  		}
   242  		length += lengthOfCols
   243  	}
   244  	return length, nil
   245  }
   246  
   247  func decodeRowsMetadata(source io.Reader, version primitive.ProtocolVersion) (metadata *RowsMetadata, err error) {
   248  	metadata = &RowsMetadata{}
   249  	var f int32
   250  	if f, err = primitive.ReadInt(source); err != nil {
   251  		return nil, fmt.Errorf("cannot read RESULT Rows metadata flags: %w", err)
   252  	}
   253  	var flags = primitive.RowsFlag(f)
   254  	if metadata.ColumnCount, err = primitive.ReadInt(source); err != nil {
   255  		return nil, fmt.Errorf("cannot read RESULT Rows metadata column count: %w", err)
   256  	}
   257  	if flags.Contains(primitive.RowsFlagHasMorePages) {
   258  		if metadata.PagingState, err = primitive.ReadBytes(source); err != nil {
   259  			return nil, fmt.Errorf("cannot read RESULT Rows metadata paging state: %w", err)
   260  		}
   261  	}
   262  	if flags.Contains(primitive.RowsFlagMetadataChanged) {
   263  		if metadata.NewResultMetadataId, err = primitive.ReadShortBytes(source); err != nil {
   264  			return nil, fmt.Errorf("cannot read RESULT Rows metadata new result metadata id: %w", err)
   265  		}
   266  	}
   267  	if flags.Contains(primitive.RowsFlagDseContinuousPaging) {
   268  		if metadata.ContinuousPageNumber, err = primitive.ReadInt(source); err != nil {
   269  			return nil, fmt.Errorf("cannot read RESULT Rows metadata continuous paging number: %w", err)
   270  		}
   271  		metadata.LastContinuousPage = flags.Contains(primitive.RowsFlagDseLastContinuousPage)
   272  	}
   273  	if flags&primitive.RowsFlagNoMetadata == 0 {
   274  		globalTableSpec := flags.Contains(primitive.RowsFlagGlobalTablesSpec)
   275  		if metadata.Columns, err = decodeColumnsMetadata(globalTableSpec, metadata.ColumnCount, source, version); err != nil {
   276  			return nil, fmt.Errorf("cannot read RESULT Rows metadata column cols: %w", err)
   277  		}
   278  	}
   279  	return metadata, nil
   280  }
   281  
   282  func encodeColumnsMetadata(globalTableSpec bool, cols []*ColumnMetadata, dest io.Writer, version primitive.ProtocolVersion) (err error) {
   283  	if globalTableSpec {
   284  		firstCol := cols[0]
   285  		if err = primitive.WriteString(firstCol.Keyspace, dest); err != nil {
   286  			return fmt.Errorf("cannot write column col global keyspace: %w", err)
   287  		}
   288  		if err = primitive.WriteString(firstCol.Table, dest); err != nil {
   289  			return fmt.Errorf("cannot write column col global table: %w", err)
   290  		}
   291  	}
   292  	for i, col := range cols {
   293  		if !globalTableSpec {
   294  			if err = primitive.WriteString(col.Keyspace, dest); err != nil {
   295  				return fmt.Errorf("cannot write column col %d keyspace: %w", i, err)
   296  			}
   297  			if err = primitive.WriteString(col.Table, dest); err != nil {
   298  				return fmt.Errorf("cannot write column col %d table: %w", i, err)
   299  			}
   300  		}
   301  		if err = primitive.WriteString(col.Name, dest); err != nil {
   302  			return fmt.Errorf("cannot write column col %d name: %w", i, err)
   303  		}
   304  		if err = datatype.WriteDataType(col.Type, dest, version); err != nil {
   305  			return fmt.Errorf("cannot write column col %d type: %w", i, err)
   306  		}
   307  	}
   308  	return nil
   309  }
   310  
   311  func lengthOfColumnsMetadata(globalTableSpec bool, cols []*ColumnMetadata, version primitive.ProtocolVersion) (length int, err error) {
   312  	if globalTableSpec {
   313  		firstCol := cols[0]
   314  		length += primitive.LengthOfString(firstCol.Keyspace)
   315  		length += primitive.LengthOfString(firstCol.Table)
   316  	}
   317  	for i, col := range cols {
   318  		if !globalTableSpec {
   319  			length += primitive.LengthOfString(col.Keyspace)
   320  			length += primitive.LengthOfString(col.Table)
   321  		}
   322  		length += primitive.LengthOfString(col.Name)
   323  		if lengthOfDataType, err := datatype.LengthOfDataType(col.Type, version); err != nil {
   324  			return -1, fmt.Errorf("cannot compute length column col %d type: %w", i, err)
   325  		} else {
   326  			length += lengthOfDataType
   327  		}
   328  	}
   329  	return
   330  }
   331  
   332  func decodeColumnsMetadata(globalTableSpec bool, columnCount int32, source io.Reader, version primitive.ProtocolVersion) (cols []*ColumnMetadata, err error) {
   333  	var globalKsName string
   334  	var globalTableName string
   335  	if globalTableSpec {
   336  		if globalKsName, err = primitive.ReadString(source); err != nil {
   337  			return nil, fmt.Errorf("cannot read column col global keyspace: %w", err)
   338  		}
   339  		if globalTableName, err = primitive.ReadString(source); err != nil {
   340  			return nil, fmt.Errorf("cannot read column col global table: %w", err)
   341  		}
   342  	}
   343  	cols = make([]*ColumnMetadata, columnCount)
   344  	for i := 0; i < int(columnCount); i++ {
   345  		cols[i] = &ColumnMetadata{}
   346  		if globalTableSpec {
   347  			cols[i].Keyspace = globalKsName
   348  		} else {
   349  			if cols[i].Keyspace, err = primitive.ReadString(source); err != nil {
   350  				return nil, fmt.Errorf("cannot read column col %d keyspace: %w", i, err)
   351  			}
   352  		}
   353  		if globalTableSpec {
   354  			cols[i].Table = globalTableName
   355  		} else {
   356  			if cols[i].Table, err = primitive.ReadString(source); err != nil {
   357  				return nil, fmt.Errorf("cannot read column col %d table: %w", i, err)
   358  			}
   359  		}
   360  		if cols[i].Name, err = primitive.ReadString(source); err != nil {
   361  			return nil, fmt.Errorf("cannot read column col %d name: %w", i, err)
   362  		}
   363  		if cols[i].Type, err = datatype.ReadDataType(source, version); err != nil {
   364  			return nil, fmt.Errorf("cannot read column col %d type: %w", i, err)
   365  		}
   366  	}
   367  	return cols, nil
   368  }
   369  
   370  func haveSameTable(cols []*ColumnMetadata) bool {
   371  	if cols == nil || len(cols) == 0 {
   372  		return false
   373  	}
   374  	first := true
   375  	var ksName string
   376  	var tableName string
   377  	for _, col := range cols {
   378  		if first {
   379  			first = false
   380  			ksName = col.Keyspace
   381  			tableName = col.Table
   382  		} else if col.Keyspace != ksName || col.Table != tableName {
   383  			return false
   384  		}
   385  	}
   386  	return true
   387  }