github.com/ydb-platform/ydb-go-sdk/v3@v3.57.0/testutil/compare.go (about)

     1  package testutil
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"math/big"
     7  	"strings"
     8  
     9  	"github.com/ydb-platform/ydb-go-genproto/protos/Ydb"
    10  
    11  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/allocator"
    12  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/value"
    13  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors"
    14  )
    15  
    16  var ErrNotComparable = xerrors.Wrap(fmt.Errorf("not comparable"))
    17  
    18  // Compare compares its operands.
    19  // It returns -1, 0, 1 if l < r, l == r, l > r. Returns error if types are not comparable.
    20  // Comparable types are all integer types, UUID, DyNumber, Float, Double, String, UTF8,
    21  // Date, Datetime, Timestamp, Tuples and Lists.
    22  // Primitive arguments are comparable if their types are the same.
    23  // Optional types is comparable to underlying types, e.g. Optional<Optional<Float>> is comparable to Float.
    24  // Null value is comparable to non-null value of the same types and is considered less than any non-null value.
    25  // Tuples and Lists are comparable if their elements are comparable.
    26  // Tuples and Lists are compared lexicographically. If tuples (lists) have different length and elements of the
    27  // shorter tuple (list) are all equal to corresponding elements of the other tuple (list), than the shorter tuple (list)
    28  // is considered less than the longer one.
    29  func Compare(l, r value.Value) (int, error) {
    30  	a := allocator.New()
    31  	defer a.Free()
    32  
    33  	return compare(unwrapTypedValue(value.ToYDB(l, a)), unwrapTypedValue(value.ToYDB(r, a)))
    34  }
    35  
    36  func unwrapTypedValue(v *Ydb.TypedValue) *Ydb.TypedValue {
    37  	typ := v.GetType()
    38  	val := v.GetValue()
    39  	for opt := typ.GetOptionalType(); opt != nil; opt = typ.GetOptionalType() {
    40  		typ = opt.GetItem()
    41  		if nested := val.GetNestedValue(); nested != nil {
    42  			val = nested
    43  		}
    44  	}
    45  
    46  	return &Ydb.TypedValue{Type: typ, Value: val}
    47  }
    48  
    49  func compare(lhs, rhs *Ydb.TypedValue) (int, error) {
    50  	lTypeID := lhs.GetType().GetTypeId()
    51  	rTypeID := rhs.GetType().GetTypeId()
    52  	switch {
    53  	case lTypeID != rTypeID:
    54  		return 0, notComparableError(lhs, rhs)
    55  	case lTypeID != Ydb.Type_PRIMITIVE_TYPE_ID_UNSPECIFIED:
    56  		return comparePrimitives(lTypeID, lhs.GetValue(), rhs.GetValue())
    57  	case lhs.GetType().GetTupleType() != nil && rhs.GetType().GetTupleType() != nil:
    58  		return compareTuplesOrLists(expandTuple(lhs), expandTuple(rhs))
    59  	case lhs.GetType().GetListType() != nil && rhs.GetType().GetListType() != nil:
    60  		return compareTuplesOrLists(expandList(lhs), expandList(rhs))
    61  	case lhs.GetType().GetStructType() != nil && rhs.GetType().GetStructType() != nil:
    62  		return compareStructs(expandStruct(lhs), expandStruct(rhs))
    63  	default:
    64  		return 0, notComparableError(lhs, rhs)
    65  	}
    66  }
    67  
    68  func expandItems(v *Ydb.TypedValue, itemType func(i int) *Ydb.Type) []*Ydb.TypedValue {
    69  	size := len(v.GetValue().GetItems())
    70  	values := make([]*Ydb.TypedValue, 0, size)
    71  	for i, val := range v.GetValue().GetItems() {
    72  		values = append(values, unwrapTypedValue(&Ydb.TypedValue{Type: itemType(i), Value: val}))
    73  	}
    74  
    75  	return values
    76  }
    77  
    78  func expandList(v *Ydb.TypedValue) []*Ydb.TypedValue {
    79  	return expandItems(v, func(i int) *Ydb.Type {
    80  		return v.GetType().GetListType().GetItem()
    81  	})
    82  }
    83  
    84  func expandStruct(v *Ydb.TypedValue) []*Ydb.TypedValue {
    85  	return expandItems(v, func(i int) *Ydb.Type {
    86  		return v.GetType().GetStructType().GetMembers()[i].GetType()
    87  	})
    88  }
    89  
    90  func expandTuple(v *Ydb.TypedValue) []*Ydb.TypedValue {
    91  	tuple := v.GetType().GetTupleType()
    92  	size := len(tuple.GetElements())
    93  	values := make([]*Ydb.TypedValue, 0, size)
    94  	for idx, typ := range tuple.GetElements() {
    95  		values = append(values, unwrapTypedValue(&Ydb.TypedValue{Type: typ, Value: v.GetValue().GetItems()[idx]}))
    96  	}
    97  
    98  	return values
    99  }
   100  
   101  func notComparableError(lhs, rhs interface{}) error {
   102  	return xerrors.WithStackTrace(fmt.Errorf("%w: %v and %v", ErrNotComparable, lhs, rhs), xerrors.WithSkipDepth(1))
   103  }
   104  
   105  func comparePrimitives(t Ydb.Type_PrimitiveTypeId, lhs, rhs *Ydb.Value) (int, error) {
   106  	_, lIsNull := lhs.GetValue().(*Ydb.Value_NullFlagValue)
   107  	_, rIsNull := rhs.GetValue().(*Ydb.Value_NullFlagValue)
   108  	if lIsNull {
   109  		if rIsNull {
   110  			return 0, nil
   111  		}
   112  
   113  		return -1, nil
   114  	}
   115  	if rIsNull {
   116  		return 1, nil
   117  	}
   118  
   119  	if compare, found := comparators[t]; found {
   120  		return compare(lhs, rhs), nil
   121  	}
   122  	// special cases
   123  	switch t {
   124  	case Ydb.Type_DYNUMBER:
   125  		return compareDyNumber(lhs, rhs)
   126  	default:
   127  		return 0, notComparableError(lhs, rhs)
   128  	}
   129  }
   130  
   131  func compareTuplesOrLists(lhs, rhs []*Ydb.TypedValue) (int, error) {
   132  	for i, lval := range lhs {
   133  		if i >= len(rhs) {
   134  			// lhs is longer than rhs, first len(rhs) elements equal
   135  			return 1, nil
   136  		}
   137  		rval := rhs[i]
   138  		cmp, err := compare(lval, rval)
   139  		if err != nil {
   140  			return 0, xerrors.WithStackTrace(err)
   141  		}
   142  		if cmp != 0 {
   143  			return cmp, nil
   144  		}
   145  	}
   146  	// len(lhs) elements equal
   147  	if len(rhs) > len(lhs) {
   148  		return -1, nil
   149  	}
   150  
   151  	return 0, nil
   152  }
   153  
   154  func compareStructs(lhs, rhs []*Ydb.TypedValue) (int, error) {
   155  	for i, lval := range lhs {
   156  		if i >= len(rhs) {
   157  			// lhs is longer than rhs, first len(rhs) elements equal
   158  			return 1, nil
   159  		}
   160  		rval := rhs[i]
   161  		cmp, err := compare(lval, rval)
   162  		if err != nil {
   163  			return 0, xerrors.WithStackTrace(err)
   164  		}
   165  		if cmp != 0 {
   166  			return cmp, nil
   167  		}
   168  	}
   169  	// len(lhs) elements equal
   170  	if len(rhs) > len(lhs) {
   171  		return -1, nil
   172  	}
   173  
   174  	return 0, nil
   175  }
   176  
   177  type comparator func(l, r *Ydb.Value) int
   178  
   179  var comparators = map[Ydb.Type_PrimitiveTypeId]comparator{
   180  	Ydb.Type_BOOL:      compareBool,
   181  	Ydb.Type_INT8:      compareInt32,
   182  	Ydb.Type_UINT8:     compareUint32,
   183  	Ydb.Type_INT16:     compareInt32,
   184  	Ydb.Type_UINT16:    compareUint32,
   185  	Ydb.Type_INT32:     compareInt32,
   186  	Ydb.Type_UINT32:    compareUint32,
   187  	Ydb.Type_INT64:     compareInt64,
   188  	Ydb.Type_UINT64:    compareUint64,
   189  	Ydb.Type_FLOAT:     compareFloat,
   190  	Ydb.Type_DOUBLE:    compareDouble,
   191  	Ydb.Type_DATE:      compareUint32,
   192  	Ydb.Type_DATETIME:  compareUint32,
   193  	Ydb.Type_TIMESTAMP: compareUint64,
   194  	Ydb.Type_INTERVAL:  compareInt64,
   195  	Ydb.Type_STRING:    compareBytes,
   196  	Ydb.Type_UTF8:      compareText,
   197  	Ydb.Type_UUID:      compareUUID,
   198  }
   199  
   200  func compareUint32(l, r *Ydb.Value) int {
   201  	ll := l.GetUint32Value()
   202  	rr := r.GetUint32Value()
   203  	switch {
   204  	case ll < rr:
   205  		return -1
   206  	case ll > rr:
   207  		return 1
   208  	default:
   209  		return 0
   210  	}
   211  }
   212  
   213  func compareInt32(l, r *Ydb.Value) int {
   214  	ll := l.GetInt32Value()
   215  	rr := r.GetInt32Value()
   216  	switch {
   217  	case ll < rr:
   218  		return -1
   219  	case ll > rr:
   220  		return 1
   221  	default:
   222  		return 0
   223  	}
   224  }
   225  
   226  func compareUint64(l, r *Ydb.Value) int {
   227  	ll := l.GetUint64Value()
   228  	rr := r.GetUint64Value()
   229  	switch {
   230  	case ll < rr:
   231  		return -1
   232  	case ll > rr:
   233  		return 1
   234  	default:
   235  		return 0
   236  	}
   237  }
   238  
   239  func compareInt64(l, r *Ydb.Value) int {
   240  	ll := l.GetInt64Value()
   241  	rr := r.GetInt64Value()
   242  	switch {
   243  	case ll < rr:
   244  		return -1
   245  	case ll > rr:
   246  		return 1
   247  	default:
   248  		return 0
   249  	}
   250  }
   251  
   252  func compareFloat(l, r *Ydb.Value) int {
   253  	ll := l.GetFloatValue()
   254  	rr := r.GetFloatValue()
   255  	switch {
   256  	case ll < rr:
   257  		return -1
   258  	case ll > rr:
   259  		return 1
   260  	default:
   261  		return 0
   262  	}
   263  }
   264  
   265  func compareDouble(l, r *Ydb.Value) int {
   266  	ll := l.GetDoubleValue()
   267  	rr := r.GetDoubleValue()
   268  	switch {
   269  	case ll < rr:
   270  		return -1
   271  	case ll > rr:
   272  		return 1
   273  	default:
   274  		return 0
   275  	}
   276  }
   277  
   278  func compareText(l, r *Ydb.Value) int {
   279  	ll := l.GetTextValue()
   280  	rr := r.GetTextValue()
   281  
   282  	return strings.Compare(ll, rr)
   283  }
   284  
   285  func compareBytes(l, r *Ydb.Value) int {
   286  	ll := l.GetBytesValue()
   287  	rr := r.GetBytesValue()
   288  
   289  	return bytes.Compare(ll, rr)
   290  }
   291  
   292  func compareBool(l, r *Ydb.Value) int {
   293  	ll := l.GetBoolValue()
   294  	rr := r.GetBoolValue()
   295  	if ll {
   296  		if rr {
   297  			return 0
   298  		}
   299  
   300  		return 1
   301  	}
   302  	if rr {
   303  		return -1
   304  	}
   305  
   306  	return 0
   307  }
   308  
   309  func compareDyNumber(l, r *Ydb.Value) (int, error) {
   310  	ll := l.GetTextValue()
   311  	rr := r.GetTextValue()
   312  	lf, _, err := big.ParseFloat(ll, 10, 127, big.ToNearestEven)
   313  	if err != nil {
   314  		return 0, xerrors.WithStackTrace(err)
   315  	}
   316  	rf, _, err := big.ParseFloat(rr, 10, 127, big.ToNearestEven)
   317  	if err != nil {
   318  		return 0, err
   319  	}
   320  
   321  	return lf.Cmp(rf), nil
   322  }
   323  
   324  func compareUUID(l, r *Ydb.Value) int {
   325  	lh := l.GetHigh_128()
   326  	rh := r.GetHigh_128()
   327  	switch {
   328  	case lh > rh:
   329  		return 1
   330  	case lh < rh:
   331  		return -1
   332  	}
   333  	ll := l.GetLow_128()
   334  	rl := r.GetLow_128()
   335  	switch {
   336  	case ll < rl:
   337  		return -1
   338  	case ll > rl:
   339  		return 1
   340  	default:
   341  		return 0
   342  	}
   343  }