github.com/datastax/go-cassandra-native-protocol@v0.0.0-20220706104457-5e8aad05cf90/datacodec/duration.go (about)

     1  // Copyright 2020 DataStax
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package datacodec
    16  
    17  import (
    18  	"bytes"
    19  	"fmt"
    20  	"io"
    21  	"time"
    22  
    23  	"github.com/datastax/go-cassandra-native-protocol/datatype"
    24  	"github.com/datastax/go-cassandra-native-protocol/primitive"
    25  )
    26  
    27  // CqlDuration is a CQL type introduced in protocol v5. A duration can either be positive or negative. If a duration is
    28  // positive all the integers must be positive or zero. If a duration is negative all the numbers must be negative or
    29  // zero.
    30  type CqlDuration struct {
    31  	Months int32
    32  	Days   int32
    33  	Nanos  time.Duration
    34  }
    35  
    36  // Duration is a codec for the CQL duration type, introduced in protocol v5. There is no built-in representation of
    37  // arbitrary-precision duration values in Go's standard library. This is why this codec can only encode from and decode
    38  // to CqlDuration.
    39  var Duration Codec = &durationCodec{}
    40  
    41  type durationCodec struct {
    42  }
    43  
    44  func (c *durationCodec) DataType() datatype.DataType {
    45  	return datatype.Duration
    46  }
    47  
    48  func (c *durationCodec) Encode(source interface{}, version primitive.ProtocolVersion) (dest []byte, err error) {
    49  	if !version.SupportsDataType(c.DataType().Code()) {
    50  		err = errDataTypeNotSupported(c.DataType(), version)
    51  	} else {
    52  		var val CqlDuration
    53  		var wasNil bool
    54  		if val, wasNil, err = convertToDuration(source); err == nil && !wasNil {
    55  			dest = writeDuration(val)
    56  		}
    57  	}
    58  	if err != nil {
    59  		err = errCannotEncode(source, c.DataType(), version, err)
    60  	}
    61  	return
    62  }
    63  
    64  func (c *durationCodec) Decode(source []byte, dest interface{}, version primitive.ProtocolVersion) (wasNull bool, err error) {
    65  	if !version.SupportsDataType(c.DataType().Code()) {
    66  		wasNull = len(source) == 0
    67  		err = errDataTypeNotSupported(c.DataType(), version)
    68  	} else {
    69  		var val CqlDuration
    70  		if val, wasNull, err = readDuration(source); err == nil {
    71  			err = convertFromDuration(val, wasNull, dest)
    72  		}
    73  	}
    74  	if err != nil {
    75  		err = errCannotDecode(dest, c.DataType(), version, err)
    76  	}
    77  	return
    78  }
    79  
    80  func convertToDuration(source interface{}) (val CqlDuration, wasNil bool, err error) {
    81  	switch s := source.(type) {
    82  	case CqlDuration:
    83  		val = s
    84  	case *CqlDuration:
    85  		if wasNil = s == nil; !wasNil {
    86  			val = *s
    87  		}
    88  	case nil:
    89  		wasNil = true
    90  	default:
    91  		err = ErrConversionNotSupported
    92  	}
    93  	if err != nil {
    94  		err = errSourceConversionFailed(source, val, err)
    95  	}
    96  	return
    97  }
    98  
    99  func convertFromDuration(val CqlDuration, wasNull bool, dest interface{}) (err error) {
   100  	switch d := dest.(type) {
   101  	case *interface{}:
   102  		if d == nil {
   103  			err = ErrNilDestination
   104  		} else if wasNull {
   105  			*d = nil
   106  		} else {
   107  			*d = val
   108  		}
   109  	case *CqlDuration:
   110  		if d == nil {
   111  			err = ErrNilDestination
   112  		} else if wasNull {
   113  			*d = CqlDuration{}
   114  		} else {
   115  			*d = val
   116  		}
   117  	default:
   118  		err = errDestinationInvalid(dest)
   119  	}
   120  	if err != nil {
   121  		err = errDestinationConversionFailed(val, dest, err)
   122  	}
   123  	return
   124  }
   125  
   126  // Implementation notes from the protocol specs:
   127  // A duration is composed of 3 signed variable length integers ([vint]s).
   128  // The first [vint] represents a number of months, the second [vint] represents
   129  // a number of days, and the last [vint] represents a number of nanoseconds.
   130  // The number of months and days must be valid 32 bits integers whereas the
   131  // number of nanoseconds must be a valid 64 bits integer.
   132  
   133  func writeDuration(val CqlDuration) []byte {
   134  	writer := &bytes.Buffer{}
   135  	_, _ = primitive.WriteVint(int64(val.Months), writer)
   136  	_, _ = primitive.WriteVint(int64(val.Days), writer)
   137  	_, _ = primitive.WriteVint(int64(val.Nanos), writer)
   138  	return writer.Bytes()
   139  }
   140  
   141  func readDuration(source []byte) (val CqlDuration, wasNull bool, err error) {
   142  	length := len(source)
   143  	wasNull = length == 0
   144  	if !wasNull {
   145  		var months, days, nanos int64
   146  		var rm, rd, rn int
   147  		reader := bytes.NewReader(source)
   148  		months, rm, err = primitive.ReadVint(reader)
   149  		if err == nil {
   150  			_, _ = reader.Seek(int64(rm), io.SeekStart)
   151  			days, rd, err = primitive.ReadVint(reader)
   152  			if err == nil {
   153  				_, _ = reader.Seek(int64(rm+rd), io.SeekStart)
   154  				nanos, rn, err = primitive.ReadVint(reader)
   155  				if err == nil {
   156  					read := rm + rd + rn
   157  					if length == read {
   158  						val.Months = int32(months)
   159  						val.Days = int32(days)
   160  						val.Nanos = time.Duration(nanos)
   161  					} else {
   162  						err = errBytesRemaining(length, length-read)
   163  					}
   164  				} else {
   165  					err = fmt.Errorf("cannot read duration nanos: %w", err)
   166  				}
   167  			} else {
   168  				err = fmt.Errorf("cannot read duration days: %w", err)
   169  			}
   170  		} else {
   171  			err = fmt.Errorf("cannot read duration months: %w", err)
   172  		}
   173  	}
   174  	if err != nil {
   175  		err = errCannotRead(val, err)
   176  	}
   177  	return
   178  }