github.com/datastax/go-cassandra-native-protocol@v0.0.0-20220706104457-5e8aad05cf90/datacodec/varint.go (about) 1 // Copyright 2020 DataStax 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package datacodec 16 17 import ( 18 "math/big" 19 "strconv" 20 21 "github.com/datastax/go-cassandra-native-protocol/datatype" 22 "github.com/datastax/go-cassandra-native-protocol/primitive" 23 ) 24 25 // Varint is a codec for the CQL varint type, a type that can handle arbitrary-length integers. Its preferred 26 // Go type is big.Int, but it can encode from and decode to most numeric types. 27 var Varint Codec = &varintCodec{} 28 29 type varintCodec struct{} 30 31 func (c *varintCodec) DataType() datatype.DataType { 32 return datatype.Varint 33 } 34 35 func (c *varintCodec) Encode(source interface{}, version primitive.ProtocolVersion) (dest []byte, err error) { 36 var val *big.Int 37 if val, err = convertToBigInt(source); err == nil && val != nil { 38 dest = val.Bytes() 39 } 40 if err != nil { 41 err = errCannotEncode(source, c.DataType(), version, err) 42 } 43 return 44 } 45 46 func (c *varintCodec) Decode(source []byte, dest interface{}, version primitive.ProtocolVersion) (wasNull bool, err error) { 47 val := readBigInt(source) 48 wasNull = val == nil 49 if err = convertFromBigInt(val, wasNull, dest); err != nil { 50 err = errCannotDecode(dest, c.DataType(), version, err) 51 } 52 return 53 } 54 55 func convertToBigInt(source interface{}) (val *big.Int, err error) { 56 switch s := source.(type) { 57 case int64: 58 val = big.NewInt(s) 59 case int: 60 val = big.NewInt(int64(s)) 61 case int32: 62 val = big.NewInt(int64(s)) 63 case int16: 64 val = big.NewInt(int64(s)) 65 case int8: 66 val = big.NewInt(int64(s)) 67 case uint64: 68 val = new(big.Int).SetUint64(s) 69 case uint: 70 val = new(big.Int).SetUint64(uint64(s)) 71 case uint32: 72 val = new(big.Int).SetUint64(uint64(s)) 73 case uint16: 74 val = new(big.Int).SetUint64(uint64(s)) 75 case uint8: 76 val = new(big.Int).SetUint64(uint64(s)) 77 case string: 78 val, err = stringToBigInt(s) 79 case *int64: 80 if s != nil { 81 val = big.NewInt(*s) 82 } 83 case *int: 84 if s != nil { 85 val = big.NewInt(int64(*s)) 86 } 87 case *int32: 88 if s != nil { 89 val = big.NewInt(int64(*s)) 90 } 91 case *int16: 92 if s != nil { 93 val = big.NewInt(int64(*s)) 94 } 95 case *int8: 96 if s != nil { 97 val = big.NewInt(int64(*s)) 98 } 99 case *uint64: 100 if s != nil { 101 val = new(big.Int).SetUint64(*s) 102 } 103 case *uint: 104 if s != nil { 105 val = new(big.Int).SetUint64(uint64(*s)) 106 } 107 case *uint32: 108 if s != nil { 109 val = new(big.Int).SetUint64(uint64(*s)) 110 } 111 case *uint16: 112 if s != nil { 113 val = new(big.Int).SetUint64(uint64(*s)) 114 } 115 case *uint8: 116 if s != nil { 117 val = new(big.Int).SetUint64(uint64(*s)) 118 } 119 case *big.Int: 120 // Note: non-pointer big.Int is not supported as per its docs, it should always be a pointer. 121 val = s 122 case *string: 123 if s != nil { 124 val, err = stringToBigInt(*s) 125 } 126 case nil: 127 default: 128 err = ErrConversionNotSupported 129 } 130 if err != nil { 131 err = errSourceConversionFailed(source, val, err) 132 } 133 return 134 } 135 136 func convertFromBigInt(val *big.Int, wasNull bool, dest interface{}) (err error) { 137 switch d := dest.(type) { 138 case *interface{}: 139 if d == nil { 140 err = ErrNilDestination 141 } else if wasNull { 142 *d = nil 143 } else { 144 *d = val 145 } 146 case *int64: 147 if d == nil { 148 err = ErrNilDestination 149 } else if wasNull { 150 *d = 0 151 } else { 152 *d, err = bigIntToInt64(val) 153 } 154 case *int: 155 if d == nil { 156 err = ErrNilDestination 157 } else if wasNull { 158 *d = 0 159 } else { 160 *d, err = bigIntToInt(val, strconv.IntSize) 161 } 162 case *int32: 163 if d == nil { 164 err = ErrNilDestination 165 } else if wasNull { 166 *d = 0 167 } else { 168 *d, err = bigIntToInt32(val) 169 } 170 case *int16: 171 if d == nil { 172 err = ErrNilDestination 173 } else if wasNull { 174 *d = 0 175 } else { 176 *d, err = bigIntToInt16(val) 177 } 178 case *int8: 179 if d == nil { 180 err = ErrNilDestination 181 } else if wasNull { 182 *d = 0 183 } else { 184 *d, err = bigIntToInt8(val) 185 } 186 case *uint64: 187 if d == nil { 188 err = ErrNilDestination 189 } else if wasNull { 190 *d = 0 191 } else { 192 *d, err = bigIntToUint64(val) 193 } 194 case *uint: 195 if d == nil { 196 err = ErrNilDestination 197 } else if wasNull { 198 *d = 0 199 } else { 200 *d, err = bigIntToUint(val, strconv.IntSize) 201 } 202 case *uint32: 203 if d == nil { 204 err = ErrNilDestination 205 } else if wasNull { 206 *d = 0 207 } else { 208 *d, err = bigIntToUint32(val) 209 } 210 case *uint16: 211 if d == nil { 212 err = ErrNilDestination 213 } else if wasNull { 214 *d = 0 215 } else { 216 *d, err = bigIntToUint16(val) 217 } 218 case *uint8: 219 if d == nil { 220 err = ErrNilDestination 221 } else if wasNull { 222 *d = 0 223 } else { 224 *d, err = bigIntToUint8(val) 225 } 226 case *big.Int: 227 if d == nil { 228 err = ErrNilDestination 229 } else if wasNull { 230 *d = big.Int{} 231 } else { 232 *d = *val 233 } 234 case *string: 235 if d == nil { 236 err = ErrNilDestination 237 } else if wasNull { 238 *d = "" 239 } else { 240 *d = val.Text(10) 241 } 242 default: 243 err = errDestinationInvalid(dest) 244 } 245 if err != nil { 246 err = errDestinationConversionFailed(val, dest, err) 247 } 248 return 249 } 250 251 var ( 252 zeroBigInt = big.NewInt(0) 253 oneBigInt = big.NewInt(1) 254 ) 255 256 // Implementation note: the encoding scheme used for CQL varint is dictated by Java's implementation of 257 // BigInteger.toByteArray(). This scheme has nothing to do with the "Varint" functions declared in Go's binary package. 258 // Relevant readings for varint encoding in Go: 259 // https://groups.google.com/g/golang-nuts/c/TV4bRVrHZUw 260 // https://github.com/gocql/gocql/blob/go1.2/marshal.go#L729-L767 261 262 func writeBigInt(n *big.Int) []byte { 263 if n == nil { 264 return nil 265 } 266 switch n.Sign() { 267 case 1: 268 b := n.Bytes() 269 if b[0]&0x80 > 0 { 270 b = append([]byte{0}, b...) 271 } 272 return b 273 case -1: 274 length := uint(n.BitLen()/8+1) * 8 275 b := new(big.Int).Add(n, new(big.Int).Lsh(oneBigInt, length)).Bytes() 276 // When the most significant bit is on a byte 277 // boundary, we can get some extra significant 278 // bits, so strip them off when that happens. 279 if len(b) >= 2 && b[0] == 0xff && b[1]&0x80 != 0 { 280 b = b[1:] 281 } 282 return b 283 default: 284 return []byte{0} 285 } 286 } 287 288 func readBigInt(source []byte) (val *big.Int) { 289 length := len(source) 290 if length > 0 { 291 val = new(big.Int).SetBytes(source) 292 if source[0]&0x80 > 0 { 293 val.Sub(val, new(big.Int).Lsh(oneBigInt, uint(length)*8)) 294 } 295 } 296 return 297 }