github.com/datastax/go-cassandra-native-protocol@v0.0.0-20220706104457-5e8aad05cf90/primitive/values.go (about) 1 // Copyright 2020 DataStax 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package primitive 16 17 import ( 18 "errors" 19 "fmt" 20 "io" 21 ) 22 23 type ValueType = int32 24 25 const ( 26 ValueTypeRegular = ValueType(0) 27 ValueTypeNull = ValueType(-1) 28 ValueTypeUnset = ValueType(-2) 29 ) 30 31 // Value models the [value] protocol primitive structure. 32 // +k8s:deepcopy-gen=true 33 type Value struct { 34 Type ValueType 35 Contents []byte 36 } 37 38 func NewValue(contents []byte) *Value { 39 if contents == nil { 40 return &Value{Type: ValueTypeNull} 41 } else { 42 return &Value{Type: ValueTypeRegular, Contents: contents} 43 } 44 } 45 46 func NewNullValue() *Value { 47 return NewValue(nil) 48 } 49 50 func NewUnsetValue() *Value { 51 return &Value{Type: ValueTypeUnset} 52 } 53 54 // [value] 55 56 func ReadValue(source io.Reader, version ProtocolVersion) (*Value, error) { 57 if length, err := ReadInt(source); err != nil { 58 return nil, fmt.Errorf("cannot read [value] length: %w", err) 59 } else if length == ValueTypeNull { 60 return NewNullValue(), nil 61 } else if length == ValueTypeUnset { 62 if version < ProtocolVersion4 { 63 return nil, fmt.Errorf("cannot use unset value with %v", version) 64 } 65 return NewUnsetValue(), nil 66 } else if length < 0 { 67 return nil, fmt.Errorf("invalid [value] length: %v", length) 68 } else if length == 0 { 69 return NewValue([]byte{}), nil 70 } else { 71 decoded := make([]byte, length) 72 if _, err := io.ReadFull(source, decoded); err != nil { 73 return nil, fmt.Errorf("cannot read [value] content: %w", err) 74 } 75 return NewValue(decoded), nil 76 } 77 } 78 79 func WriteValue(value *Value, dest io.Writer, version ProtocolVersion) error { 80 if value == nil { 81 return errors.New("cannot write a nil [value]") 82 } 83 switch value.Type { 84 case ValueTypeNull: 85 return WriteInt(ValueTypeNull, dest) 86 case ValueTypeUnset: 87 if !version.SupportsUnsetValues() { 88 return fmt.Errorf("cannot use unset value with %v", version) 89 } 90 return WriteInt(ValueTypeUnset, dest) 91 case ValueTypeRegular: 92 if value.Contents == nil { 93 return WriteInt(ValueTypeNull, dest) 94 } else { 95 length := len(value.Contents) 96 if err := WriteInt(int32(length), dest); err != nil { 97 return fmt.Errorf("cannot write [value] length: %w", err) 98 } else if n, err := dest.Write(value.Contents); err != nil { 99 return fmt.Errorf("cannot write [value] content: %w", err) 100 } else if n < length { 101 return errors.New("not enough capacity to write [value] content") 102 } 103 return nil 104 } 105 default: 106 return fmt.Errorf("unknown [value] type: %v", value.Type) 107 } 108 } 109 110 func LengthOfValue(value *Value) (int, error) { 111 if value == nil { 112 return -1, errors.New("cannot compute length of a nil [value]") 113 } 114 switch value.Type { 115 case ValueTypeNull: 116 return LengthOfInt, nil 117 case ValueTypeUnset: 118 return LengthOfInt, nil 119 case ValueTypeRegular: 120 return LengthOfInt + len(value.Contents), nil 121 default: 122 return -1, fmt.Errorf("unknown [value] type: %v", value.Type) 123 } 124 } 125 126 // positional [value]s 127 128 func ReadPositionalValues(source io.Reader, version ProtocolVersion) ([]*Value, error) { 129 if length, err := ReadShort(source); err != nil { 130 return nil, fmt.Errorf("cannot read positional [value]s length: %w", err) 131 } else { 132 decoded := make([]*Value, length) 133 for i := uint16(0); i < length; i++ { 134 if value, err := ReadValue(source, version); err != nil { 135 return nil, fmt.Errorf("cannot read positional [value]s element %d content: %w", i, err) 136 } else { 137 decoded[i] = value 138 } 139 } 140 return decoded, nil 141 } 142 } 143 144 func WritePositionalValues(values []*Value, dest io.Writer, version ProtocolVersion) error { 145 length := len(values) 146 if err := WriteShort(uint16(length), dest); err != nil { 147 return fmt.Errorf("cannot write positional [value]s length: %w", err) 148 } 149 for i, value := range values { 150 if err := WriteValue(value, dest, version); err != nil { 151 return fmt.Errorf("cannot write positional [value]s element %d content: %w", i, err) 152 } 153 } 154 return nil 155 } 156 157 func LengthOfPositionalValues(values []*Value) (length int, err error) { 158 length += LengthOfShort 159 for i, value := range values { 160 var valueLength int 161 valueLength, err = LengthOfValue(value) 162 if err != nil { 163 return -1, fmt.Errorf("cannot compute length of positional [value] %d: %w", i, err) 164 } 165 length += valueLength 166 } 167 return length, nil 168 } 169 170 // named [value]s 171 172 func ReadNamedValues(source io.Reader, version ProtocolVersion) (map[string]*Value, error) { 173 if length, err := ReadShort(source); err != nil { 174 return nil, fmt.Errorf("cannot read named [value]s length: %w", err) 175 } else { 176 decoded := make(map[string]*Value, length) 177 for i := uint16(0); i < length; i++ { 178 if name, err := ReadString(source); err != nil { 179 return nil, fmt.Errorf("cannot read named [value]s entry %d name: %w", i, err) 180 } else if value, err := ReadValue(source, version); err != nil { 181 return nil, fmt.Errorf("cannot read named [value]s entry %d content: %w", i, err) 182 } else { 183 decoded[name] = value 184 } 185 } 186 return decoded, nil 187 } 188 } 189 190 func WriteNamedValues(values map[string]*Value, dest io.Writer, version ProtocolVersion) error { 191 length := len(values) 192 if err := WriteShort(uint16(length), dest); err != nil { 193 return fmt.Errorf("cannot write named [value]s length: %w", err) 194 } 195 for name, value := range values { 196 if err := WriteString(name, dest); err != nil { 197 return fmt.Errorf("cannot write named [value]s entry '%v' name: %w", name, err) 198 } 199 if err := WriteValue(value, dest, version); err != nil { 200 return fmt.Errorf("cannot write named [value]s entry '%v' content: %w", name, err) 201 } 202 } 203 return nil 204 } 205 206 func LengthOfNamedValues(values map[string]*Value) (length int, err error) { 207 length += LengthOfShort 208 for name, value := range values { 209 var nameLength = LengthOfString(name) 210 var valueLength int 211 valueLength, err = LengthOfValue(value) 212 if err != nil { 213 return -1, fmt.Errorf("cannot compute length of named [value]s: %w", err) 214 } 215 length += nameLength 216 length += valueLength 217 } 218 return length, nil 219 }