github.com/willyham/dosa@v2.3.1-0.20171024181418-1e446d37ee71+incompatible/range_conditions.go (about)

     1  // Copyright (c) 2017 Uber Technologies, Inc.
     2  //
     3  // Permission is hereby granted, free of charge, to any person obtaining a copy
     4  // of this software and associated documentation files (the "Software"), to deal
     5  // in the Software without restriction, including without limitation the rights
     6  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
     7  // copies of the Software, and to permit persons to whom the Software is
     8  // furnished to do so, subject to the following conditions:
     9  //
    10  // The above copyright notice and this permission notice shall be included in
    11  // all copies or substantial portions of the Software.
    12  //
    13  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    14  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    15  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    16  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    17  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    18  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    19  // THE SOFTWARE.
    20  
    21  package dosa
    22  
    23  import (
    24  	"time"
    25  
    26  	"bytes"
    27  	"strings"
    28  
    29  	"github.com/pkg/errors"
    30  )
    31  
    32  // Condition holds an operator and a value for a condition on a field.
    33  type Condition struct {
    34  	Op    Operator
    35  	Value FieldValue
    36  }
    37  
    38  // EnsureValidRangeConditions checks if the conditions for a range query is valid.
    39  // The transform arg is a function to transform the column name to a better representation for error message under
    40  // different circumstances. For example, on client side it can transform the column name to actual go struct field name;
    41  // and on the server side, an identity transformer func can be used.
    42  func EnsureValidRangeConditions(ed *EntityDefinition, pk *PrimaryKey, columnConditions map[string][]*Condition, transform func(string) string) error {
    43  	unconstrainedPartitionKeySet := pk.PartitionKeySet()
    44  	columnTypes := ed.ColumnTypes()
    45  
    46  	clusteringKeyConditions := make([][]*Condition, len(pk.ClusteringKeys))
    47  
    48  COND:
    49  	for column, conds := range columnConditions {
    50  		if _, ok := unconstrainedPartitionKeySet[column]; ok {
    51  			delete(unconstrainedPartitionKeySet, column)
    52  			if err := ensureExactOneEqCondition(columnTypes[column], conds); err != nil {
    53  				return errors.Wrapf(err, "invalid conditions for partition key: %s", transform(column))
    54  			}
    55  			continue
    56  		}
    57  
    58  		for i, c := range pk.ClusteringKeys {
    59  			if column == c.Name {
    60  				clusteringKeyConditions[i] = conds
    61  				continue COND
    62  			}
    63  		}
    64  
    65  		return errors.Errorf("cannot enforce condition on non-key column: %s", transform(column))
    66  	}
    67  
    68  	if len(unconstrainedPartitionKeySet) > 0 {
    69  		names := []string{}
    70  		for k := range unconstrainedPartitionKeySet {
    71  			names = append(names, transform(k))
    72  		}
    73  		return errors.Errorf("missing Eq condition on partition keys: %v", names)
    74  	}
    75  
    76  	if err := ensureClusteringKeyConditions(pk.ClusteringKeys, columnTypes, clusteringKeyConditions, transform); err != nil {
    77  		return errors.Wrap(err, "conditions for clustering keys are invalid")
    78  	}
    79  
    80  	return nil
    81  }
    82  
    83  func ensureExactOneEqCondition(t Type, conditions []*Condition) error {
    84  	if len(conditions) != 1 {
    85  		return errors.Errorf("expect exact one Eq condition, found: %v", conditions)
    86  	}
    87  
    88  	r := conditions[0]
    89  	if r.Op != Eq {
    90  		return errors.Errorf("only Eq condition is allowed on this column for this query, found: %s", r.Op.String())
    91  	}
    92  
    93  	if err := ensureTypeMatch(t, r.Value); err != nil {
    94  		return errors.Wrap(err, "the value in condition does not have expected type")
    95  	}
    96  	return nil
    97  }
    98  
    99  func ensureClusteringKeyConditions(cks []*ClusteringKey, columnTypes map[string]Type,
   100  	clusteringKeyConditions [][]*Condition, transform func(string) string) error {
   101  	// ensure conditions are applied to consecutive clustering keys
   102  	lastConstrainedIndex := -1
   103  	for i, conditions := range clusteringKeyConditions {
   104  		if len(conditions) > 0 {
   105  			if lastConstrainedIndex != i-1 {
   106  				return errors.Errorf("conditions must be applied consecutively on clustering keys, "+
   107  					"but at least one clustering key is unconstrained before: %s", transform(cks[i].Name))
   108  			}
   109  			lastConstrainedIndex = i
   110  		}
   111  	}
   112  
   113  	// ensure only Eq is applied to clustering keys except for the last constrained one
   114  	for i := 0; i < lastConstrainedIndex; i++ {
   115  		name := cks[i].Name
   116  		if err := ensureExactOneEqCondition(columnTypes[name], clusteringKeyConditions[i]); err != nil {
   117  			return errors.Wrapf(err, "exact one Eq condition can be applied except for the last "+
   118  				"constrained clustering key, found invalid condition for key: %s", transform(name))
   119  		}
   120  	}
   121  
   122  	// ensure the last constrained clustering key has valid conditions
   123  	if lastConstrainedIndex >= 0 {
   124  		name := cks[lastConstrainedIndex].Name
   125  		if err := ensureValidConditions(columnTypes[name], clusteringKeyConditions[lastConstrainedIndex]); err != nil {
   126  			return errors.Wrapf(err, "invalid or unsupported conditions for clustering key: %s", transform(name))
   127  		}
   128  	}
   129  
   130  	return nil
   131  }
   132  
   133  const conditionsRule = `
   134  If you have a Lt or LtOrEq operator on a column, you can also have a Gt or GtOrEq on the same column.
   135  No other combinations of operators are permitted.
   136  `
   137  
   138  // Start with simple rules as specified in `conditionsRule` above.
   139  // Hence, the length of valid conditions slice is either one or two (won't be called if zero length).
   140  func ensureValidConditions(t Type, conditions []*Condition) error {
   141  	// check type sanity
   142  	for _, r := range conditions {
   143  		if err := ensureTypeMatch(t, r.Value); err != nil {
   144  			return errors.Wrap(err, "invalid condition")
   145  		}
   146  	}
   147  
   148  	switch {
   149  	case len(conditions) == 1:
   150  		return nil // single condition is always valid
   151  	case len(conditions) > 2:
   152  		return errors.Errorf("conditions: %v, rules: %s", conditions, conditionsRule)
   153  	}
   154  
   155  	r0 := conditions[0]
   156  	r1 := conditions[1]
   157  	// sort conditions according to operators so we have few cases to handle
   158  	if r0.Op >= r1.Op {
   159  		r0, r1 = r1, r0
   160  	}
   161  
   162  	op0 := r0.Op
   163  	v0 := r0.Value
   164  	op1 := r1.Op
   165  	v1 := r1.Value
   166  
   167  	switch {
   168  	//  v1 < fv < v0, v1 <= fv < v0, v1 < fv <= v0    ===> v0 > v1
   169  	case op0 == Lt && op1 == Gt, op0 == Lt && op1 == GtOrEq, op0 == LtOrEq && op1 == Gt:
   170  		if compare(t, v0, v1) <= 0 {
   171  			return errors.Errorf("invalid range: %v", conditions)
   172  		}
   173  		// v1 <= fv <= v0   ===> v0 >= v1
   174  	case op0 == LtOrEq && op1 == GtOrEq:
   175  		if compare(t, v0, v1) < 0 {
   176  			return errors.Errorf("invalid range: %v", conditions)
   177  		}
   178  	default: // invalid combination of operators
   179  		return errors.Errorf("unsupported conditions: %v, rules: %s", conditions, conditionsRule)
   180  
   181  	}
   182  
   183  	return nil
   184  }
   185  
   186  // compare compares two values; return 0 if equal, -1 if <, 1 if >.
   187  // Assumes args are valid.
   188  func compare(t Type, a, b interface{}) int {
   189  	switch t {
   190  	case TUUID:
   191  		// TODO: make sure if comparison for UUID like below makes sense.
   192  		return strings.Compare(string(a.(UUID)), string(b.(UUID)))
   193  	case Int64:
   194  		return int(a.(int64) - b.(int64))
   195  	case Int32:
   196  		return int(a.(int32) - b.(int32))
   197  	case String:
   198  		return strings.Compare(a.(string), b.(string))
   199  	case Blob:
   200  		return bytes.Compare(a.([]byte), b.([]byte))
   201  	case Bool:
   202  		// TODO: we don't need to order bools for range query and should report error if people do dumb things
   203  		var ia, ib int
   204  		if a.(bool) {
   205  			ia = 1
   206  		}
   207  		if b.(bool) {
   208  			ib = 1
   209  		}
   210  		return ia - ib
   211  	case Double:
   212  		fa := a.(float64)
   213  		fb := b.(float64)
   214  		if fa < fb {
   215  			return -1
   216  		}
   217  		if fa > fb {
   218  			return 1
   219  		}
   220  		return 0
   221  	case Timestamp:
   222  		ta := a.(time.Time)
   223  		tb := b.(time.Time)
   224  		if ta.Before(tb) {
   225  			return -1
   226  		}
   227  		if ta.After(tb) {
   228  			return 1
   229  		}
   230  		return 0
   231  	}
   232  	panic("invalid type") // shouldn't reach here
   233  }
   234  
   235  func ensureTypeMatch(t Type, v FieldValue) error {
   236  	switch t {
   237  	case TUUID:
   238  		if _, ok := v.(UUID); !ok {
   239  			return errors.Errorf("invalid value for UUID type: %v", v)
   240  		}
   241  	case Int64:
   242  		if _, ok := v.(int64); !ok {
   243  			return errors.Errorf("invalid value for int64 type: %v", v)
   244  		}
   245  	case Int32:
   246  		if _, ok := v.(int32); !ok {
   247  			return errors.Errorf("invalid value for int32 type: %v", v)
   248  		}
   249  	case String:
   250  		if _, ok := v.(string); !ok {
   251  			return errors.Errorf("invalid value for string type: %v", v)
   252  		}
   253  	case Blob:
   254  		if _, ok := v.([]byte); !ok {
   255  			return errors.Errorf("invalid value for blob type: %v", v)
   256  		}
   257  	case Bool:
   258  		if _, ok := v.(bool); !ok {
   259  			return errors.Errorf("invalid value for bool type: %v", v)
   260  		}
   261  	case Double:
   262  		if _, ok := v.(float64); !ok {
   263  			return errors.Errorf("invalid value for double/float64 type: %v", v)
   264  		}
   265  	case Timestamp:
   266  		if _, ok := v.(time.Time); !ok {
   267  			return errors.Errorf("invalid value for timestamp type: %v", v)
   268  		}
   269  	default:
   270  		// will not happen unless we have a bug
   271  		panic("invalid type")
   272  	}
   273  	return nil
   274  }