github.com/datastax/go-cassandra-native-protocol@v0.0.0-20220706104457-5e8aad05cf90/datacodec/duration.go (about) 1 // Copyright 2020 DataStax 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package datacodec 16 17 import ( 18 "bytes" 19 "fmt" 20 "io" 21 "time" 22 23 "github.com/datastax/go-cassandra-native-protocol/datatype" 24 "github.com/datastax/go-cassandra-native-protocol/primitive" 25 ) 26 27 // CqlDuration is a CQL type introduced in protocol v5. A duration can either be positive or negative. If a duration is 28 // positive all the integers must be positive or zero. If a duration is negative all the numbers must be negative or 29 // zero. 30 type CqlDuration struct { 31 Months int32 32 Days int32 33 Nanos time.Duration 34 } 35 36 // Duration is a codec for the CQL duration type, introduced in protocol v5. There is no built-in representation of 37 // arbitrary-precision duration values in Go's standard library. This is why this codec can only encode from and decode 38 // to CqlDuration. 39 var Duration Codec = &durationCodec{} 40 41 type durationCodec struct { 42 } 43 44 func (c *durationCodec) DataType() datatype.DataType { 45 return datatype.Duration 46 } 47 48 func (c *durationCodec) Encode(source interface{}, version primitive.ProtocolVersion) (dest []byte, err error) { 49 if !version.SupportsDataType(c.DataType().Code()) { 50 err = errDataTypeNotSupported(c.DataType(), version) 51 } else { 52 var val CqlDuration 53 var wasNil bool 54 if val, wasNil, err = convertToDuration(source); err == nil && !wasNil { 55 dest = writeDuration(val) 56 } 57 } 58 if err != nil { 59 err = errCannotEncode(source, c.DataType(), version, err) 60 } 61 return 62 } 63 64 func (c *durationCodec) Decode(source []byte, dest interface{}, version primitive.ProtocolVersion) (wasNull bool, err error) { 65 if !version.SupportsDataType(c.DataType().Code()) { 66 wasNull = len(source) == 0 67 err = errDataTypeNotSupported(c.DataType(), version) 68 } else { 69 var val CqlDuration 70 if val, wasNull, err = readDuration(source); err == nil { 71 err = convertFromDuration(val, wasNull, dest) 72 } 73 } 74 if err != nil { 75 err = errCannotDecode(dest, c.DataType(), version, err) 76 } 77 return 78 } 79 80 func convertToDuration(source interface{}) (val CqlDuration, wasNil bool, err error) { 81 switch s := source.(type) { 82 case CqlDuration: 83 val = s 84 case *CqlDuration: 85 if wasNil = s == nil; !wasNil { 86 val = *s 87 } 88 case nil: 89 wasNil = true 90 default: 91 err = ErrConversionNotSupported 92 } 93 if err != nil { 94 err = errSourceConversionFailed(source, val, err) 95 } 96 return 97 } 98 99 func convertFromDuration(val CqlDuration, wasNull bool, dest interface{}) (err error) { 100 switch d := dest.(type) { 101 case *interface{}: 102 if d == nil { 103 err = ErrNilDestination 104 } else if wasNull { 105 *d = nil 106 } else { 107 *d = val 108 } 109 case *CqlDuration: 110 if d == nil { 111 err = ErrNilDestination 112 } else if wasNull { 113 *d = CqlDuration{} 114 } else { 115 *d = val 116 } 117 default: 118 err = errDestinationInvalid(dest) 119 } 120 if err != nil { 121 err = errDestinationConversionFailed(val, dest, err) 122 } 123 return 124 } 125 126 // Implementation notes from the protocol specs: 127 // A duration is composed of 3 signed variable length integers ([vint]s). 128 // The first [vint] represents a number of months, the second [vint] represents 129 // a number of days, and the last [vint] represents a number of nanoseconds. 130 // The number of months and days must be valid 32 bits integers whereas the 131 // number of nanoseconds must be a valid 64 bits integer. 132 133 func writeDuration(val CqlDuration) []byte { 134 writer := &bytes.Buffer{} 135 _, _ = primitive.WriteVint(int64(val.Months), writer) 136 _, _ = primitive.WriteVint(int64(val.Days), writer) 137 _, _ = primitive.WriteVint(int64(val.Nanos), writer) 138 return writer.Bytes() 139 } 140 141 func readDuration(source []byte) (val CqlDuration, wasNull bool, err error) { 142 length := len(source) 143 wasNull = length == 0 144 if !wasNull { 145 var months, days, nanos int64 146 var rm, rd, rn int 147 reader := bytes.NewReader(source) 148 months, rm, err = primitive.ReadVint(reader) 149 if err == nil { 150 _, _ = reader.Seek(int64(rm), io.SeekStart) 151 days, rd, err = primitive.ReadVint(reader) 152 if err == nil { 153 _, _ = reader.Seek(int64(rm+rd), io.SeekStart) 154 nanos, rn, err = primitive.ReadVint(reader) 155 if err == nil { 156 read := rm + rd + rn 157 if length == read { 158 val.Months = int32(months) 159 val.Days = int32(days) 160 val.Nanos = time.Duration(nanos) 161 } else { 162 err = errBytesRemaining(length, length-read) 163 } 164 } else { 165 err = fmt.Errorf("cannot read duration nanos: %w", err) 166 } 167 } else { 168 err = fmt.Errorf("cannot read duration days: %w", err) 169 } 170 } else { 171 err = fmt.Errorf("cannot read duration months: %w", err) 172 } 173 } 174 if err != nil { 175 err = errCannotRead(val, err) 176 } 177 return 178 }