github.com/ydb-platform/ydb-go-sdk/v3@v3.57.0/testutil/compare.go (about) 1 package testutil 2 3 import ( 4 "bytes" 5 "fmt" 6 "math/big" 7 "strings" 8 9 "github.com/ydb-platform/ydb-go-genproto/protos/Ydb" 10 11 "github.com/ydb-platform/ydb-go-sdk/v3/internal/allocator" 12 "github.com/ydb-platform/ydb-go-sdk/v3/internal/value" 13 "github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors" 14 ) 15 16 var ErrNotComparable = xerrors.Wrap(fmt.Errorf("not comparable")) 17 18 // Compare compares its operands. 19 // It returns -1, 0, 1 if l < r, l == r, l > r. Returns error if types are not comparable. 20 // Comparable types are all integer types, UUID, DyNumber, Float, Double, String, UTF8, 21 // Date, Datetime, Timestamp, Tuples and Lists. 22 // Primitive arguments are comparable if their types are the same. 23 // Optional types is comparable to underlying types, e.g. Optional<Optional<Float>> is comparable to Float. 24 // Null value is comparable to non-null value of the same types and is considered less than any non-null value. 25 // Tuples and Lists are comparable if their elements are comparable. 26 // Tuples and Lists are compared lexicographically. If tuples (lists) have different length and elements of the 27 // shorter tuple (list) are all equal to corresponding elements of the other tuple (list), than the shorter tuple (list) 28 // is considered less than the longer one. 29 func Compare(l, r value.Value) (int, error) { 30 a := allocator.New() 31 defer a.Free() 32 33 return compare(unwrapTypedValue(value.ToYDB(l, a)), unwrapTypedValue(value.ToYDB(r, a))) 34 } 35 36 func unwrapTypedValue(v *Ydb.TypedValue) *Ydb.TypedValue { 37 typ := v.GetType() 38 val := v.GetValue() 39 for opt := typ.GetOptionalType(); opt != nil; opt = typ.GetOptionalType() { 40 typ = opt.GetItem() 41 if nested := val.GetNestedValue(); nested != nil { 42 val = nested 43 } 44 } 45 46 return &Ydb.TypedValue{Type: typ, Value: val} 47 } 48 49 func compare(lhs, rhs *Ydb.TypedValue) (int, error) { 50 lTypeID := lhs.GetType().GetTypeId() 51 rTypeID := rhs.GetType().GetTypeId() 52 switch { 53 case lTypeID != rTypeID: 54 return 0, notComparableError(lhs, rhs) 55 case lTypeID != Ydb.Type_PRIMITIVE_TYPE_ID_UNSPECIFIED: 56 return comparePrimitives(lTypeID, lhs.GetValue(), rhs.GetValue()) 57 case lhs.GetType().GetTupleType() != nil && rhs.GetType().GetTupleType() != nil: 58 return compareTuplesOrLists(expandTuple(lhs), expandTuple(rhs)) 59 case lhs.GetType().GetListType() != nil && rhs.GetType().GetListType() != nil: 60 return compareTuplesOrLists(expandList(lhs), expandList(rhs)) 61 case lhs.GetType().GetStructType() != nil && rhs.GetType().GetStructType() != nil: 62 return compareStructs(expandStruct(lhs), expandStruct(rhs)) 63 default: 64 return 0, notComparableError(lhs, rhs) 65 } 66 } 67 68 func expandItems(v *Ydb.TypedValue, itemType func(i int) *Ydb.Type) []*Ydb.TypedValue { 69 size := len(v.GetValue().GetItems()) 70 values := make([]*Ydb.TypedValue, 0, size) 71 for i, val := range v.GetValue().GetItems() { 72 values = append(values, unwrapTypedValue(&Ydb.TypedValue{Type: itemType(i), Value: val})) 73 } 74 75 return values 76 } 77 78 func expandList(v *Ydb.TypedValue) []*Ydb.TypedValue { 79 return expandItems(v, func(i int) *Ydb.Type { 80 return v.GetType().GetListType().GetItem() 81 }) 82 } 83 84 func expandStruct(v *Ydb.TypedValue) []*Ydb.TypedValue { 85 return expandItems(v, func(i int) *Ydb.Type { 86 return v.GetType().GetStructType().GetMembers()[i].GetType() 87 }) 88 } 89 90 func expandTuple(v *Ydb.TypedValue) []*Ydb.TypedValue { 91 tuple := v.GetType().GetTupleType() 92 size := len(tuple.GetElements()) 93 values := make([]*Ydb.TypedValue, 0, size) 94 for idx, typ := range tuple.GetElements() { 95 values = append(values, unwrapTypedValue(&Ydb.TypedValue{Type: typ, Value: v.GetValue().GetItems()[idx]})) 96 } 97 98 return values 99 } 100 101 func notComparableError(lhs, rhs interface{}) error { 102 return xerrors.WithStackTrace(fmt.Errorf("%w: %v and %v", ErrNotComparable, lhs, rhs), xerrors.WithSkipDepth(1)) 103 } 104 105 func comparePrimitives(t Ydb.Type_PrimitiveTypeId, lhs, rhs *Ydb.Value) (int, error) { 106 _, lIsNull := lhs.GetValue().(*Ydb.Value_NullFlagValue) 107 _, rIsNull := rhs.GetValue().(*Ydb.Value_NullFlagValue) 108 if lIsNull { 109 if rIsNull { 110 return 0, nil 111 } 112 113 return -1, nil 114 } 115 if rIsNull { 116 return 1, nil 117 } 118 119 if compare, found := comparators[t]; found { 120 return compare(lhs, rhs), nil 121 } 122 // special cases 123 switch t { 124 case Ydb.Type_DYNUMBER: 125 return compareDyNumber(lhs, rhs) 126 default: 127 return 0, notComparableError(lhs, rhs) 128 } 129 } 130 131 func compareTuplesOrLists(lhs, rhs []*Ydb.TypedValue) (int, error) { 132 for i, lval := range lhs { 133 if i >= len(rhs) { 134 // lhs is longer than rhs, first len(rhs) elements equal 135 return 1, nil 136 } 137 rval := rhs[i] 138 cmp, err := compare(lval, rval) 139 if err != nil { 140 return 0, xerrors.WithStackTrace(err) 141 } 142 if cmp != 0 { 143 return cmp, nil 144 } 145 } 146 // len(lhs) elements equal 147 if len(rhs) > len(lhs) { 148 return -1, nil 149 } 150 151 return 0, nil 152 } 153 154 func compareStructs(lhs, rhs []*Ydb.TypedValue) (int, error) { 155 for i, lval := range lhs { 156 if i >= len(rhs) { 157 // lhs is longer than rhs, first len(rhs) elements equal 158 return 1, nil 159 } 160 rval := rhs[i] 161 cmp, err := compare(lval, rval) 162 if err != nil { 163 return 0, xerrors.WithStackTrace(err) 164 } 165 if cmp != 0 { 166 return cmp, nil 167 } 168 } 169 // len(lhs) elements equal 170 if len(rhs) > len(lhs) { 171 return -1, nil 172 } 173 174 return 0, nil 175 } 176 177 type comparator func(l, r *Ydb.Value) int 178 179 var comparators = map[Ydb.Type_PrimitiveTypeId]comparator{ 180 Ydb.Type_BOOL: compareBool, 181 Ydb.Type_INT8: compareInt32, 182 Ydb.Type_UINT8: compareUint32, 183 Ydb.Type_INT16: compareInt32, 184 Ydb.Type_UINT16: compareUint32, 185 Ydb.Type_INT32: compareInt32, 186 Ydb.Type_UINT32: compareUint32, 187 Ydb.Type_INT64: compareInt64, 188 Ydb.Type_UINT64: compareUint64, 189 Ydb.Type_FLOAT: compareFloat, 190 Ydb.Type_DOUBLE: compareDouble, 191 Ydb.Type_DATE: compareUint32, 192 Ydb.Type_DATETIME: compareUint32, 193 Ydb.Type_TIMESTAMP: compareUint64, 194 Ydb.Type_INTERVAL: compareInt64, 195 Ydb.Type_STRING: compareBytes, 196 Ydb.Type_UTF8: compareText, 197 Ydb.Type_UUID: compareUUID, 198 } 199 200 func compareUint32(l, r *Ydb.Value) int { 201 ll := l.GetUint32Value() 202 rr := r.GetUint32Value() 203 switch { 204 case ll < rr: 205 return -1 206 case ll > rr: 207 return 1 208 default: 209 return 0 210 } 211 } 212 213 func compareInt32(l, r *Ydb.Value) int { 214 ll := l.GetInt32Value() 215 rr := r.GetInt32Value() 216 switch { 217 case ll < rr: 218 return -1 219 case ll > rr: 220 return 1 221 default: 222 return 0 223 } 224 } 225 226 func compareUint64(l, r *Ydb.Value) int { 227 ll := l.GetUint64Value() 228 rr := r.GetUint64Value() 229 switch { 230 case ll < rr: 231 return -1 232 case ll > rr: 233 return 1 234 default: 235 return 0 236 } 237 } 238 239 func compareInt64(l, r *Ydb.Value) int { 240 ll := l.GetInt64Value() 241 rr := r.GetInt64Value() 242 switch { 243 case ll < rr: 244 return -1 245 case ll > rr: 246 return 1 247 default: 248 return 0 249 } 250 } 251 252 func compareFloat(l, r *Ydb.Value) int { 253 ll := l.GetFloatValue() 254 rr := r.GetFloatValue() 255 switch { 256 case ll < rr: 257 return -1 258 case ll > rr: 259 return 1 260 default: 261 return 0 262 } 263 } 264 265 func compareDouble(l, r *Ydb.Value) int { 266 ll := l.GetDoubleValue() 267 rr := r.GetDoubleValue() 268 switch { 269 case ll < rr: 270 return -1 271 case ll > rr: 272 return 1 273 default: 274 return 0 275 } 276 } 277 278 func compareText(l, r *Ydb.Value) int { 279 ll := l.GetTextValue() 280 rr := r.GetTextValue() 281 282 return strings.Compare(ll, rr) 283 } 284 285 func compareBytes(l, r *Ydb.Value) int { 286 ll := l.GetBytesValue() 287 rr := r.GetBytesValue() 288 289 return bytes.Compare(ll, rr) 290 } 291 292 func compareBool(l, r *Ydb.Value) int { 293 ll := l.GetBoolValue() 294 rr := r.GetBoolValue() 295 if ll { 296 if rr { 297 return 0 298 } 299 300 return 1 301 } 302 if rr { 303 return -1 304 } 305 306 return 0 307 } 308 309 func compareDyNumber(l, r *Ydb.Value) (int, error) { 310 ll := l.GetTextValue() 311 rr := r.GetTextValue() 312 lf, _, err := big.ParseFloat(ll, 10, 127, big.ToNearestEven) 313 if err != nil { 314 return 0, xerrors.WithStackTrace(err) 315 } 316 rf, _, err := big.ParseFloat(rr, 10, 127, big.ToNearestEven) 317 if err != nil { 318 return 0, err 319 } 320 321 return lf.Cmp(rf), nil 322 } 323 324 func compareUUID(l, r *Ydb.Value) int { 325 lh := l.GetHigh_128() 326 rh := r.GetHigh_128() 327 switch { 328 case lh > rh: 329 return 1 330 case lh < rh: 331 return -1 332 } 333 ll := l.GetLow_128() 334 rl := r.GetLow_128() 335 switch { 336 case ll < rl: 337 return -1 338 case ll > rl: 339 return 1 340 default: 341 return 0 342 } 343 }