github.com/willyham/dosa@v2.3.1-0.20171024181418-1e446d37ee71+incompatible/range_conditions.go (about) 1 // Copyright (c) 2017 Uber Technologies, Inc. 2 // 3 // Permission is hereby granted, free of charge, to any person obtaining a copy 4 // of this software and associated documentation files (the "Software"), to deal 5 // in the Software without restriction, including without limitation the rights 6 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 // copies of the Software, and to permit persons to whom the Software is 8 // furnished to do so, subject to the following conditions: 9 // 10 // The above copyright notice and this permission notice shall be included in 11 // all copies or substantial portions of the Software. 12 // 13 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 // THE SOFTWARE. 20 21 package dosa 22 23 import ( 24 "time" 25 26 "bytes" 27 "strings" 28 29 "github.com/pkg/errors" 30 ) 31 32 // Condition holds an operator and a value for a condition on a field. 33 type Condition struct { 34 Op Operator 35 Value FieldValue 36 } 37 38 // EnsureValidRangeConditions checks if the conditions for a range query is valid. 39 // The transform arg is a function to transform the column name to a better representation for error message under 40 // different circumstances. For example, on client side it can transform the column name to actual go struct field name; 41 // and on the server side, an identity transformer func can be used. 42 func EnsureValidRangeConditions(ed *EntityDefinition, pk *PrimaryKey, columnConditions map[string][]*Condition, transform func(string) string) error { 43 unconstrainedPartitionKeySet := pk.PartitionKeySet() 44 columnTypes := ed.ColumnTypes() 45 46 clusteringKeyConditions := make([][]*Condition, len(pk.ClusteringKeys)) 47 48 COND: 49 for column, conds := range columnConditions { 50 if _, ok := unconstrainedPartitionKeySet[column]; ok { 51 delete(unconstrainedPartitionKeySet, column) 52 if err := ensureExactOneEqCondition(columnTypes[column], conds); err != nil { 53 return errors.Wrapf(err, "invalid conditions for partition key: %s", transform(column)) 54 } 55 continue 56 } 57 58 for i, c := range pk.ClusteringKeys { 59 if column == c.Name { 60 clusteringKeyConditions[i] = conds 61 continue COND 62 } 63 } 64 65 return errors.Errorf("cannot enforce condition on non-key column: %s", transform(column)) 66 } 67 68 if len(unconstrainedPartitionKeySet) > 0 { 69 names := []string{} 70 for k := range unconstrainedPartitionKeySet { 71 names = append(names, transform(k)) 72 } 73 return errors.Errorf("missing Eq condition on partition keys: %v", names) 74 } 75 76 if err := ensureClusteringKeyConditions(pk.ClusteringKeys, columnTypes, clusteringKeyConditions, transform); err != nil { 77 return errors.Wrap(err, "conditions for clustering keys are invalid") 78 } 79 80 return nil 81 } 82 83 func ensureExactOneEqCondition(t Type, conditions []*Condition) error { 84 if len(conditions) != 1 { 85 return errors.Errorf("expect exact one Eq condition, found: %v", conditions) 86 } 87 88 r := conditions[0] 89 if r.Op != Eq { 90 return errors.Errorf("only Eq condition is allowed on this column for this query, found: %s", r.Op.String()) 91 } 92 93 if err := ensureTypeMatch(t, r.Value); err != nil { 94 return errors.Wrap(err, "the value in condition does not have expected type") 95 } 96 return nil 97 } 98 99 func ensureClusteringKeyConditions(cks []*ClusteringKey, columnTypes map[string]Type, 100 clusteringKeyConditions [][]*Condition, transform func(string) string) error { 101 // ensure conditions are applied to consecutive clustering keys 102 lastConstrainedIndex := -1 103 for i, conditions := range clusteringKeyConditions { 104 if len(conditions) > 0 { 105 if lastConstrainedIndex != i-1 { 106 return errors.Errorf("conditions must be applied consecutively on clustering keys, "+ 107 "but at least one clustering key is unconstrained before: %s", transform(cks[i].Name)) 108 } 109 lastConstrainedIndex = i 110 } 111 } 112 113 // ensure only Eq is applied to clustering keys except for the last constrained one 114 for i := 0; i < lastConstrainedIndex; i++ { 115 name := cks[i].Name 116 if err := ensureExactOneEqCondition(columnTypes[name], clusteringKeyConditions[i]); err != nil { 117 return errors.Wrapf(err, "exact one Eq condition can be applied except for the last "+ 118 "constrained clustering key, found invalid condition for key: %s", transform(name)) 119 } 120 } 121 122 // ensure the last constrained clustering key has valid conditions 123 if lastConstrainedIndex >= 0 { 124 name := cks[lastConstrainedIndex].Name 125 if err := ensureValidConditions(columnTypes[name], clusteringKeyConditions[lastConstrainedIndex]); err != nil { 126 return errors.Wrapf(err, "invalid or unsupported conditions for clustering key: %s", transform(name)) 127 } 128 } 129 130 return nil 131 } 132 133 const conditionsRule = ` 134 If you have a Lt or LtOrEq operator on a column, you can also have a Gt or GtOrEq on the same column. 135 No other combinations of operators are permitted. 136 ` 137 138 // Start with simple rules as specified in `conditionsRule` above. 139 // Hence, the length of valid conditions slice is either one or two (won't be called if zero length). 140 func ensureValidConditions(t Type, conditions []*Condition) error { 141 // check type sanity 142 for _, r := range conditions { 143 if err := ensureTypeMatch(t, r.Value); err != nil { 144 return errors.Wrap(err, "invalid condition") 145 } 146 } 147 148 switch { 149 case len(conditions) == 1: 150 return nil // single condition is always valid 151 case len(conditions) > 2: 152 return errors.Errorf("conditions: %v, rules: %s", conditions, conditionsRule) 153 } 154 155 r0 := conditions[0] 156 r1 := conditions[1] 157 // sort conditions according to operators so we have few cases to handle 158 if r0.Op >= r1.Op { 159 r0, r1 = r1, r0 160 } 161 162 op0 := r0.Op 163 v0 := r0.Value 164 op1 := r1.Op 165 v1 := r1.Value 166 167 switch { 168 // v1 < fv < v0, v1 <= fv < v0, v1 < fv <= v0 ===> v0 > v1 169 case op0 == Lt && op1 == Gt, op0 == Lt && op1 == GtOrEq, op0 == LtOrEq && op1 == Gt: 170 if compare(t, v0, v1) <= 0 { 171 return errors.Errorf("invalid range: %v", conditions) 172 } 173 // v1 <= fv <= v0 ===> v0 >= v1 174 case op0 == LtOrEq && op1 == GtOrEq: 175 if compare(t, v0, v1) < 0 { 176 return errors.Errorf("invalid range: %v", conditions) 177 } 178 default: // invalid combination of operators 179 return errors.Errorf("unsupported conditions: %v, rules: %s", conditions, conditionsRule) 180 181 } 182 183 return nil 184 } 185 186 // compare compares two values; return 0 if equal, -1 if <, 1 if >. 187 // Assumes args are valid. 188 func compare(t Type, a, b interface{}) int { 189 switch t { 190 case TUUID: 191 // TODO: make sure if comparison for UUID like below makes sense. 192 return strings.Compare(string(a.(UUID)), string(b.(UUID))) 193 case Int64: 194 return int(a.(int64) - b.(int64)) 195 case Int32: 196 return int(a.(int32) - b.(int32)) 197 case String: 198 return strings.Compare(a.(string), b.(string)) 199 case Blob: 200 return bytes.Compare(a.([]byte), b.([]byte)) 201 case Bool: 202 // TODO: we don't need to order bools for range query and should report error if people do dumb things 203 var ia, ib int 204 if a.(bool) { 205 ia = 1 206 } 207 if b.(bool) { 208 ib = 1 209 } 210 return ia - ib 211 case Double: 212 fa := a.(float64) 213 fb := b.(float64) 214 if fa < fb { 215 return -1 216 } 217 if fa > fb { 218 return 1 219 } 220 return 0 221 case Timestamp: 222 ta := a.(time.Time) 223 tb := b.(time.Time) 224 if ta.Before(tb) { 225 return -1 226 } 227 if ta.After(tb) { 228 return 1 229 } 230 return 0 231 } 232 panic("invalid type") // shouldn't reach here 233 } 234 235 func ensureTypeMatch(t Type, v FieldValue) error { 236 switch t { 237 case TUUID: 238 if _, ok := v.(UUID); !ok { 239 return errors.Errorf("invalid value for UUID type: %v", v) 240 } 241 case Int64: 242 if _, ok := v.(int64); !ok { 243 return errors.Errorf("invalid value for int64 type: %v", v) 244 } 245 case Int32: 246 if _, ok := v.(int32); !ok { 247 return errors.Errorf("invalid value for int32 type: %v", v) 248 } 249 case String: 250 if _, ok := v.(string); !ok { 251 return errors.Errorf("invalid value for string type: %v", v) 252 } 253 case Blob: 254 if _, ok := v.([]byte); !ok { 255 return errors.Errorf("invalid value for blob type: %v", v) 256 } 257 case Bool: 258 if _, ok := v.(bool); !ok { 259 return errors.Errorf("invalid value for bool type: %v", v) 260 } 261 case Double: 262 if _, ok := v.(float64); !ok { 263 return errors.Errorf("invalid value for double/float64 type: %v", v) 264 } 265 case Timestamp: 266 if _, ok := v.(time.Time); !ok { 267 return errors.Errorf("invalid value for timestamp type: %v", v) 268 } 269 default: 270 // will not happen unless we have a bug 271 panic("invalid type") 272 } 273 return nil 274 }