vitess.io/vitess@v0.16.2/go/vt/vttablet/tabletserver/rules/rules.go (about)

     1  /*
     2  Copyright 2019 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package rules
    18  
    19  import (
    20  	"bytes"
    21  	"context"
    22  	"encoding/json"
    23  	"fmt"
    24  	"reflect"
    25  	"regexp"
    26  	"strconv"
    27  
    28  	"vitess.io/vitess/go/vt/vtgate/evalengine"
    29  
    30  	"vitess.io/vitess/go/sqltypes"
    31  	"vitess.io/vitess/go/vt/sqlparser"
    32  	"vitess.io/vitess/go/vt/vterrors"
    33  	"vitess.io/vitess/go/vt/vttablet/tabletserver/planbuilder"
    34  
    35  	querypb "vitess.io/vitess/go/vt/proto/query"
    36  	vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc"
    37  )
    38  
    39  //-----------------------------------------------
    40  
    41  const (
    42  	bufferedTableRuleName = "buffered_table"
    43  )
    44  
    45  // Rules is used to store and execute rules for the tabletserver.
    46  type Rules struct {
    47  	rules []*Rule
    48  }
    49  
    50  // New creates a new Rules.
    51  func New() *Rules {
    52  	return &Rules{}
    53  }
    54  
    55  // Equal returns true if other is equal to this object, otherwise false.
    56  func (qrs *Rules) Equal(other *Rules) bool {
    57  	if len(qrs.rules) != len(other.rules) {
    58  		return false
    59  	}
    60  	for i := 0; i < len(qrs.rules); i++ {
    61  		if !qrs.rules[i].Equal(other.rules[i]) {
    62  			return false
    63  		}
    64  	}
    65  	return true
    66  }
    67  
    68  // Copy performs a deep copy of Rules.
    69  // A nil input produces a nil output.
    70  func (qrs *Rules) Copy() (newqrs *Rules) {
    71  	newqrs = New()
    72  	if qrs.rules != nil {
    73  		newqrs.rules = make([]*Rule, 0, len(qrs.rules))
    74  		for _, qr := range qrs.rules {
    75  			newqrs.rules = append(newqrs.rules, qr.Copy())
    76  		}
    77  	}
    78  	return newqrs
    79  }
    80  
    81  // CopyUnderlying makes a copy of the underlying rule array and returns it to
    82  // the caller.
    83  func (qrs *Rules) CopyUnderlying() []*Rule {
    84  	cpy := make([]*Rule, 0, len(qrs.rules))
    85  	for _, r := range qrs.rules {
    86  		cpy = append(cpy, r.Copy())
    87  	}
    88  	return cpy
    89  }
    90  
    91  // Append merges the rules from another Rules into the receiver
    92  func (qrs *Rules) Append(otherqrs *Rules) {
    93  	qrs.rules = append(qrs.rules, otherqrs.rules...)
    94  }
    95  
    96  // Add adds a Rule to Rules. It does not check
    97  // for duplicates.
    98  func (qrs *Rules) Add(qr *Rule) {
    99  	qrs.rules = append(qrs.rules, qr)
   100  }
   101  
   102  // Find finds the first occurrence of a Rule by matching
   103  // the Name field. It returns nil if the rule was not found.
   104  func (qrs *Rules) Find(name string) (qr *Rule) {
   105  	for _, qr = range qrs.rules {
   106  		if qr.Name == name {
   107  			return qr
   108  		}
   109  	}
   110  	return nil
   111  }
   112  
   113  // Delete deletes a Rule by name and returns the rule
   114  // that was deleted. It returns nil if the rule was not found.
   115  func (qrs *Rules) Delete(name string) (qr *Rule) {
   116  	for i, qr := range qrs.rules {
   117  		if qr.Name == name {
   118  			for j := i; j < len(qrs.rules)-i-1; j++ {
   119  				qrs.rules[j] = qrs.rules[j+1]
   120  			}
   121  			qrs.rules = qrs.rules[:len(qrs.rules)-1]
   122  			return qr
   123  		}
   124  	}
   125  	return nil
   126  }
   127  
   128  // UnmarshalJSON unmarshals Rules.
   129  func (qrs *Rules) UnmarshalJSON(data []byte) (err error) {
   130  	var rulesInfo []map[string]any
   131  	dec := json.NewDecoder(bytes.NewReader(data))
   132  	dec.UseNumber()
   133  	err = dec.Decode(&rulesInfo)
   134  	if err != nil {
   135  		return vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "%v", err)
   136  	}
   137  	for _, ruleInfo := range rulesInfo {
   138  		qr, err := BuildQueryRule(ruleInfo)
   139  		if err != nil {
   140  			return err
   141  		}
   142  		qrs.Add(qr)
   143  	}
   144  	return nil
   145  }
   146  
   147  // MarshalJSON marshals to JSON.
   148  func (qrs *Rules) MarshalJSON() ([]byte, error) {
   149  	b := bytes.NewBuffer(nil)
   150  	_, _ = b.WriteString("[")
   151  	for i, rule := range qrs.rules {
   152  		if i != 0 {
   153  			_, _ = b.WriteString(",")
   154  		}
   155  		safeEncode(b, "", rule)
   156  	}
   157  	_, _ = b.WriteString("]")
   158  	return b.Bytes(), nil
   159  }
   160  
   161  // FilterByPlan creates a new Rules by prefiltering on the query and planId. This allows
   162  // us to create query plan specific Rules out of the original Rules. In the new rules,
   163  // query, plans and tableNames predicates are empty.
   164  func (qrs *Rules) FilterByPlan(query string, planid planbuilder.PlanType, tableNames ...string) (newqrs *Rules) {
   165  	var newrules []*Rule
   166  	for _, qr := range qrs.rules {
   167  		if newrule := qr.FilterByPlan(query, planid, tableNames); newrule != nil {
   168  			newrules = append(newrules, newrule)
   169  		}
   170  	}
   171  	return &Rules{newrules}
   172  }
   173  
   174  // GetAction runs the input against the rules engine and returns the action to be performed.
   175  func (qrs *Rules) GetAction(
   176  	ip,
   177  	user string,
   178  	bindVars map[string]*querypb.BindVariable,
   179  	marginComments sqlparser.MarginComments,
   180  ) (action Action, cancelCtx context.Context, desc string) {
   181  	for _, qr := range qrs.rules {
   182  		if act := qr.GetAction(ip, user, bindVars, marginComments); act != QRContinue {
   183  			return act, qr.cancelCtx, qr.Description
   184  		}
   185  	}
   186  	return QRContinue, nil, ""
   187  }
   188  
   189  //-----------------------------------------------
   190  
   191  // Rule represents one rule (conditions-action).
   192  // Name is meant to uniquely identify a rule.
   193  // Description is a human readable comment that describes the rule.
   194  // For a Rule to fire, all conditions of the Rule
   195  // have to match. For example, an empty Rule will match
   196  // all requests.
   197  // Every Rule has an associated Action. If all the conditions
   198  // of the Rule are met, then the Action is triggerred.
   199  type Rule struct {
   200  	Description string
   201  	Name        string
   202  
   203  	// All defined conditions must match for the rule to fire (AND).
   204  
   205  	// Regexp conditions. nil conditions are ignored (TRUE).
   206  	requestIP, user, query, leadingComment, trailingComment namedRegexp
   207  
   208  	// Any matched plan will make this condition true (OR)
   209  	plans []planbuilder.PlanType
   210  
   211  	// Any matched tableNames will make this condition true (OR)
   212  	tableNames []string
   213  
   214  	// All BindVar conditions have to be fulfilled to make this true (AND)
   215  	bindVarConds []BindVarCond
   216  
   217  	// Action to be performed on trigger
   218  	act Action
   219  
   220  	// a rule can be dynamically cancelled. This function determines whether it is cancelled
   221  	cancelCtx context.Context
   222  }
   223  
   224  type namedRegexp struct {
   225  	name string
   226  	*regexp.Regexp
   227  }
   228  
   229  // MarshalJSON marshals to JSON.
   230  func (nr namedRegexp) MarshalJSON() ([]byte, error) {
   231  	return json.Marshal(nr.name)
   232  }
   233  
   234  // Equal returns true if other is equal to this namedRegexp, otherwise false.
   235  func (nr namedRegexp) Equal(other namedRegexp) bool {
   236  	if nr.Regexp == nil || other.Regexp == nil {
   237  		return nr.Regexp == nil && other.Regexp == nil && nr.name == other.name
   238  	}
   239  	return nr.name == other.name && nr.String() == other.String()
   240  }
   241  
   242  // NewQueryRule creates a new Rule.
   243  func NewQueryRule(description, name string, act Action) (qr *Rule) {
   244  	// We ignore act because there's only one action right now
   245  	return &Rule{Description: description, Name: name, act: act}
   246  }
   247  
   248  // NewBufferedTableQueryRule creates a new buffer Rule.
   249  func NewBufferedTableQueryRule(cancelCtx context.Context, tableName string, description string) (qr *Rule) {
   250  	// We ignore act because there's only one action right now
   251  	return &Rule{cancelCtx: cancelCtx, Description: description, Name: bufferedTableRuleName, tableNames: []string{tableName}, act: QRBuffer}
   252  }
   253  
   254  // Equal returns true if other is equal to this Rule, otherwise false.
   255  func (qr *Rule) Equal(other *Rule) bool {
   256  	if qr == nil || other == nil {
   257  		return qr == nil && other == nil
   258  	}
   259  	return (qr.Description == other.Description &&
   260  		qr.Name == other.Name &&
   261  		qr.requestIP.Equal(other.requestIP) &&
   262  		qr.user.Equal(other.user) &&
   263  		qr.query.Equal(other.query) &&
   264  		qr.leadingComment.Equal(other.leadingComment) &&
   265  		qr.trailingComment.Equal(other.trailingComment) &&
   266  		reflect.DeepEqual(qr.plans, other.plans) &&
   267  		reflect.DeepEqual(qr.tableNames, other.tableNames) &&
   268  		reflect.DeepEqual(qr.bindVarConds, other.bindVarConds) &&
   269  		qr.act == other.act)
   270  }
   271  
   272  // Copy performs a deep copy of a Rule.
   273  func (qr *Rule) Copy() (newqr *Rule) {
   274  	newqr = &Rule{
   275  		Description:     qr.Description,
   276  		Name:            qr.Name,
   277  		requestIP:       qr.requestIP,
   278  		user:            qr.user,
   279  		query:           qr.query,
   280  		leadingComment:  qr.leadingComment,
   281  		trailingComment: qr.trailingComment,
   282  		act:             qr.act,
   283  		cancelCtx:       qr.cancelCtx,
   284  	}
   285  	if qr.plans != nil {
   286  		newqr.plans = make([]planbuilder.PlanType, len(qr.plans))
   287  		copy(newqr.plans, qr.plans)
   288  	}
   289  	if qr.tableNames != nil {
   290  		newqr.tableNames = make([]string, len(qr.tableNames))
   291  		copy(newqr.tableNames, qr.tableNames)
   292  	}
   293  	if qr.bindVarConds != nil {
   294  		newqr.bindVarConds = make([]BindVarCond, len(qr.bindVarConds))
   295  		copy(newqr.bindVarConds, qr.bindVarConds)
   296  	}
   297  	return newqr
   298  }
   299  
   300  // MarshalJSON marshals to JSON.
   301  func (qr *Rule) MarshalJSON() ([]byte, error) {
   302  	b := bytes.NewBuffer(nil)
   303  	safeEncode(b, `{"Description":`, qr.Description)
   304  	safeEncode(b, `,"Name":`, qr.Name)
   305  	if qr.requestIP.Regexp != nil {
   306  		safeEncode(b, `,"RequestIP":`, qr.requestIP)
   307  	}
   308  	if qr.user.Regexp != nil {
   309  		safeEncode(b, `,"User":`, qr.user)
   310  	}
   311  	if qr.query.Regexp != nil {
   312  		safeEncode(b, `,"Query":`, qr.query)
   313  	}
   314  	if qr.leadingComment.Regexp != nil {
   315  		safeEncode(b, `,"LeadingComment":`, qr.leadingComment)
   316  	}
   317  	if qr.trailingComment.Regexp != nil {
   318  		safeEncode(b, `,"TrailingComment":`, qr.trailingComment)
   319  	}
   320  	if qr.plans != nil {
   321  		safeEncode(b, `,"Plans":`, qr.plans)
   322  	}
   323  	if qr.tableNames != nil {
   324  		safeEncode(b, `,"TableNames":`, qr.tableNames)
   325  	}
   326  	if qr.bindVarConds != nil {
   327  		safeEncode(b, `,"BindVarConds":`, qr.bindVarConds)
   328  	}
   329  	if qr.act != QRContinue {
   330  		safeEncode(b, `,"Action":`, qr.act)
   331  	}
   332  	_, _ = b.WriteString("}")
   333  	return b.Bytes(), nil
   334  }
   335  
   336  // SetIPCond adds a regular expression condition for the client IP.
   337  // It has to be a full match (not substring).
   338  func (qr *Rule) SetIPCond(pattern string) (err error) {
   339  	qr.requestIP.name = pattern
   340  	qr.requestIP.Regexp, err = regexp.Compile(makeExact(pattern))
   341  	return err
   342  }
   343  
   344  // SetUserCond adds a regular expression condition for the user name
   345  // used by the client.
   346  func (qr *Rule) SetUserCond(pattern string) (err error) {
   347  	qr.user.name = pattern
   348  	qr.user.Regexp, err = regexp.Compile(makeExact(pattern))
   349  	return
   350  }
   351  
   352  // AddPlanCond adds to the list of plans that can be matched for
   353  // the rule to fire.
   354  // This function acts as an OR: Any plan id match is considered a match.
   355  func (qr *Rule) AddPlanCond(planType planbuilder.PlanType) {
   356  	qr.plans = append(qr.plans, planType)
   357  }
   358  
   359  // AddTableCond adds to the list of tableNames that can be matched for
   360  // the rule to fire.
   361  // This function acts as an OR: Any tableName match is considered a match.
   362  func (qr *Rule) AddTableCond(tableName string) {
   363  	qr.tableNames = append(qr.tableNames, tableName)
   364  }
   365  
   366  // SetQueryCond adds a regular expression condition for the query.
   367  func (qr *Rule) SetQueryCond(pattern string) (err error) {
   368  	qr.query.name = pattern
   369  	qr.query.Regexp, err = regexp.Compile(makeExact(pattern))
   370  	return
   371  }
   372  
   373  // SetLeadingCommentCond adds a regular expression condition for a leading query comment.
   374  func (qr *Rule) SetLeadingCommentCond(pattern string) (err error) {
   375  	qr.leadingComment.name = pattern
   376  	qr.leadingComment.Regexp, err = regexp.Compile(makeExact(pattern))
   377  	return
   378  }
   379  
   380  // SetTrailingCommentCond adds a regular expression condition for a trailing query comment.
   381  func (qr *Rule) SetTrailingCommentCond(pattern string) (err error) {
   382  	qr.trailingComment.name = pattern
   383  	qr.trailingComment.Regexp, err = regexp.Compile(makeExact(pattern))
   384  	return
   385  }
   386  
   387  // makeExact forces a full string match for the regex instead of substring
   388  func makeExact(pattern string) string {
   389  	return fmt.Sprintf("^%s$", pattern)
   390  }
   391  
   392  // AddBindVarCond adds a bind variable restriction to the Rule.
   393  // All bind var conditions have to be satisfied for the Rule
   394  // to be a match.
   395  // name represents the name (not regexp) of the bind variable.
   396  // onAbsent specifies the value of the condition if the
   397  // bind variable is absent.
   398  // onMismatch specifies the value of the condition if there's
   399  // a type mismatch on the condition.
   400  // For inequalities, the bindvar is the left operand and the value
   401  // in the condition is the right operand: bindVar Operator value.
   402  // Value & operator rules
   403  // Type     Operators                              Bindvar
   404  // nil      ""                                     any type
   405  // uint64   ==, !=, <, >=, >, <=                   whole numbers
   406  // int64    ==, !=, <, >=, >, <=                   whole numbers
   407  // string   ==, !=, <, >=, >, <=, MATCH, NOMATCH   []byte, string
   408  // whole numbers can be: int, int8, int16, int32, int64, uint64
   409  func (qr *Rule) AddBindVarCond(name string, onAbsent, onMismatch bool, op Operator, value any) error {
   410  	var converted bvcValue
   411  	if op == QRNoOp {
   412  		qr.bindVarConds = append(qr.bindVarConds, BindVarCond{name, onAbsent, onMismatch, op, nil})
   413  		return nil
   414  	}
   415  	switch v := value.(type) {
   416  	case uint64:
   417  		if op < QREqual || op > QRLessEqual {
   418  			goto Error
   419  		}
   420  		converted = bvcuint64(v)
   421  	case int64:
   422  		if op < QREqual || op > QRLessEqual {
   423  			goto Error
   424  		}
   425  		converted = bvcint64(v)
   426  	case string:
   427  		if op >= QREqual && op <= QRLessEqual {
   428  			converted = bvcstring(v)
   429  		} else if op >= QRMatch && op <= QRNoMatch {
   430  			var err error
   431  			// Change the value to compiled regexp
   432  			re, err := regexp.Compile(makeExact(v))
   433  			if err != nil {
   434  				return vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "processing %s: %v", v, err)
   435  			}
   436  			converted = bvcre{re}
   437  		} else {
   438  			goto Error
   439  		}
   440  	default:
   441  		return vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "type %T not allowed as condition operand (%v)", value, value)
   442  	}
   443  	qr.bindVarConds = append(qr.bindVarConds, BindVarCond{name, onAbsent, onMismatch, op, converted})
   444  	return nil
   445  
   446  Error:
   447  	return vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "invalid operator %v for type %T (%v)", op, value, value)
   448  }
   449  
   450  // FilterByPlan returns a new Rule if the query and planid match.
   451  // The new Rule will contain all the original constraints other
   452  // than the plan and query. If the plan and query don't match the Rule,
   453  // then it returns nil.
   454  func (qr *Rule) FilterByPlan(query string, planid planbuilder.PlanType, tableNames []string) (newqr *Rule) {
   455  	if !reMatch(qr.query.Regexp, query) {
   456  		return nil
   457  	}
   458  	if !planMatch(qr.plans, planid) {
   459  		return nil
   460  	}
   461  	if !tableMatch(qr.tableNames, tableNames) {
   462  		return nil
   463  	}
   464  	newqr = qr.Copy()
   465  	newqr.query = namedRegexp{}
   466  	// Note we explicitly don't remove the leading/trailing comments as they
   467  	// must be evaluated at execution time.
   468  	newqr.plans = nil
   469  	newqr.tableNames = nil
   470  	return newqr
   471  }
   472  
   473  // GetAction returns the action for a single rule.
   474  func (qr *Rule) GetAction(
   475  	ip,
   476  	user string,
   477  	bindVars map[string]*querypb.BindVariable,
   478  	marginComments sqlparser.MarginComments,
   479  ) Action {
   480  	if qr.cancelCtx != nil {
   481  		select {
   482  		case <-qr.cancelCtx.Done():
   483  			// rule was cancelled. Nothing else to check
   484  			return QRContinue
   485  		default:
   486  			// rule will be cancelled in the future. Until then, it applies!
   487  			// proceed to evaluate rules
   488  		}
   489  	}
   490  	if !reMatch(qr.leadingComment.Regexp, marginComments.Leading) {
   491  		return QRContinue
   492  	}
   493  	if !reMatch(qr.trailingComment.Regexp, marginComments.Trailing) {
   494  		return QRContinue
   495  	}
   496  	if !reMatch(qr.requestIP.Regexp, ip) {
   497  		return QRContinue
   498  	}
   499  	if !reMatch(qr.user.Regexp, user) {
   500  		return QRContinue
   501  	}
   502  	for _, bvcond := range qr.bindVarConds {
   503  		if !bvMatch(bvcond, bindVars) {
   504  			return QRContinue
   505  		}
   506  	}
   507  	return qr.act
   508  }
   509  
   510  func reMatch(re *regexp.Regexp, val string) bool {
   511  	return re == nil || re.MatchString(val)
   512  }
   513  
   514  func planMatch(plans []planbuilder.PlanType, plan planbuilder.PlanType) bool {
   515  	if plans == nil {
   516  		return true
   517  	}
   518  	for _, p := range plans {
   519  		if p == plan {
   520  			return true
   521  		}
   522  	}
   523  	return false
   524  }
   525  
   526  func tableMatch(tableNames []string, otherNames []string) bool {
   527  	if tableNames == nil {
   528  		return true
   529  	}
   530  	otherNamesMap := map[string]bool{}
   531  	for _, name := range otherNames {
   532  		otherNamesMap[name] = true
   533  	}
   534  	for _, name := range tableNames {
   535  		if otherNamesMap[name] {
   536  			return true
   537  		}
   538  	}
   539  	return false
   540  }
   541  
   542  func bvMatch(bvcond BindVarCond, bindVars map[string]*querypb.BindVariable) bool {
   543  	bv, ok := bindVars[bvcond.name]
   544  	if !ok {
   545  		return bvcond.onAbsent
   546  	}
   547  	if bvcond.op == QRNoOp {
   548  		return !bvcond.onAbsent
   549  	}
   550  	return bvcond.value.eval(bv, bvcond.op, bvcond.onMismatch)
   551  }
   552  
   553  //-----------------------------------------------
   554  // Support types for Rule
   555  
   556  // Action speficies the list of actions to perform
   557  // when a Rule is triggered.
   558  type Action int
   559  
   560  // These are actions.
   561  const (
   562  	QRContinue = Action(iota)
   563  	QRFail
   564  	QRFailRetry
   565  	QRBuffer
   566  )
   567  
   568  // MarshalJSON marshals to JSON.
   569  func (act Action) MarshalJSON() ([]byte, error) {
   570  	// If we add more actions, we'll need to use a map.
   571  	var str string
   572  	switch act {
   573  	case QRFail:
   574  		str = "FAIL"
   575  	case QRFailRetry:
   576  		str = "FAIL_RETRY"
   577  	case QRBuffer:
   578  		str = "BUFFER"
   579  	default:
   580  		str = "INVALID"
   581  	}
   582  	return json.Marshal(str)
   583  }
   584  
   585  // BindVarCond represents a bind var condition.
   586  type BindVarCond struct {
   587  	name       string
   588  	onAbsent   bool
   589  	onMismatch bool
   590  	op         Operator
   591  	value      bvcValue
   592  }
   593  
   594  // MarshalJSON marshals to JSON.
   595  func (bvc BindVarCond) MarshalJSON() ([]byte, error) {
   596  	b := bytes.NewBuffer(nil)
   597  	safeEncode(b, `{"Name":`, bvc.name)
   598  	safeEncode(b, `,"OnAbsent":`, bvc.onAbsent)
   599  	if bvc.op != QRNoOp {
   600  		safeEncode(b, `,"OnMismatch":`, bvc.onMismatch)
   601  	}
   602  	safeEncode(b, `,"Operator":`, bvc.op)
   603  	if bvc.op != QRNoOp {
   604  		safeEncode(b, `,"Value":`, bvc.value)
   605  	}
   606  	_, _ = b.WriteString("}")
   607  	return b.Bytes(), nil
   608  }
   609  
   610  // Operator represents the list of operators.
   611  type Operator int
   612  
   613  // These are comparison operators.
   614  const (
   615  	QRNoOp = Operator(iota)
   616  	QREqual
   617  	QRNotEqual
   618  	QRLessThan
   619  	QRGreaterEqual
   620  	QRGreaterThan
   621  	QRLessEqual
   622  	QRMatch
   623  	QRNoMatch
   624  	QRNumOp
   625  )
   626  
   627  var opmap = map[string]Operator{
   628  	"":        QRNoOp,
   629  	"==":      QREqual,
   630  	"!=":      QRNotEqual,
   631  	"<":       QRLessThan,
   632  	">=":      QRGreaterEqual,
   633  	">":       QRGreaterThan,
   634  	"<=":      QRLessEqual,
   635  	"MATCH":   QRMatch,
   636  	"NOMATCH": QRNoMatch,
   637  }
   638  
   639  var opnames []string
   640  
   641  func init() {
   642  	opnames = make([]string, QRNumOp)
   643  	for k, v := range opmap {
   644  		opnames[v] = k
   645  	}
   646  }
   647  
   648  // These are return statii.
   649  const (
   650  	QROK = iota
   651  	QRMismatch
   652  	QROutOfRange
   653  )
   654  
   655  // MarshalJSON marshals to JSON.
   656  func (op Operator) MarshalJSON() ([]byte, error) {
   657  	return json.Marshal(opnames[op])
   658  }
   659  
   660  // bvcValue defines the common interface
   661  // for all bind var condition values
   662  type bvcValue interface {
   663  	eval(bv *querypb.BindVariable, op Operator, onMismatch bool) bool
   664  }
   665  
   666  type bvcuint64 uint64
   667  
   668  func (uval bvcuint64) eval(bv *querypb.BindVariable, op Operator, onMismatch bool) bool {
   669  	num, status := getuint64(bv)
   670  	switch op {
   671  	case QREqual:
   672  		switch status {
   673  		case QROK:
   674  			return num == uint64(uval)
   675  		case QROutOfRange:
   676  			return false
   677  		}
   678  	case QRNotEqual:
   679  		switch status {
   680  		case QROK:
   681  			return num != uint64(uval)
   682  		case QROutOfRange:
   683  			return true
   684  		}
   685  	case QRLessThan:
   686  		switch status {
   687  		case QROK:
   688  			return num < uint64(uval)
   689  		case QROutOfRange:
   690  			return true
   691  		}
   692  	case QRGreaterEqual:
   693  		switch status {
   694  		case QROK:
   695  			return num >= uint64(uval)
   696  		case QROutOfRange:
   697  			return false
   698  		}
   699  	case QRGreaterThan:
   700  		switch status {
   701  		case QROK:
   702  			return num > uint64(uval)
   703  		case QROutOfRange:
   704  			return false
   705  		}
   706  	case QRLessEqual:
   707  		switch status {
   708  		case QROK:
   709  			return num <= uint64(uval)
   710  		case QROutOfRange:
   711  			return true
   712  		}
   713  	default:
   714  		panic("unreachable")
   715  	}
   716  
   717  	return onMismatch
   718  }
   719  
   720  type bvcint64 int64
   721  
   722  func (ival bvcint64) eval(bv *querypb.BindVariable, op Operator, onMismatch bool) bool {
   723  	num, status := getint64(bv)
   724  	switch op {
   725  	case QREqual:
   726  		switch status {
   727  		case QROK:
   728  			return num == int64(ival)
   729  		case QROutOfRange:
   730  			return false
   731  		}
   732  	case QRNotEqual:
   733  		switch status {
   734  		case QROK:
   735  			return num != int64(ival)
   736  		case QROutOfRange:
   737  			return true
   738  		}
   739  	case QRLessThan:
   740  		switch status {
   741  		case QROK:
   742  			return num < int64(ival)
   743  		case QROutOfRange:
   744  			return false
   745  		}
   746  	case QRGreaterEqual:
   747  		switch status {
   748  		case QROK:
   749  			return num >= int64(ival)
   750  		case QROutOfRange:
   751  			return true
   752  		}
   753  	case QRGreaterThan:
   754  		switch status {
   755  		case QROK:
   756  			return num > int64(ival)
   757  		case QROutOfRange:
   758  			return true
   759  		}
   760  	case QRLessEqual:
   761  		switch status {
   762  		case QROK:
   763  			return num <= int64(ival)
   764  		case QROutOfRange:
   765  			return false
   766  		}
   767  	default:
   768  		panic("unreachable")
   769  	}
   770  
   771  	return onMismatch
   772  }
   773  
   774  type bvcstring string
   775  
   776  func (sval bvcstring) eval(bv *querypb.BindVariable, op Operator, onMismatch bool) bool {
   777  	str, status := getstring(bv)
   778  	if status != QROK {
   779  		return onMismatch
   780  	}
   781  	switch op {
   782  	case QREqual:
   783  		return str == string(sval)
   784  	case QRNotEqual:
   785  		return str != string(sval)
   786  	case QRLessThan:
   787  		return str < string(sval)
   788  	case QRGreaterEqual:
   789  		return str >= string(sval)
   790  	case QRGreaterThan:
   791  		return str > string(sval)
   792  	case QRLessEqual:
   793  		return str <= string(sval)
   794  	}
   795  	panic("unreachable")
   796  }
   797  
   798  type bvcre struct {
   799  	re *regexp.Regexp
   800  }
   801  
   802  func (reval bvcre) eval(bv *querypb.BindVariable, op Operator, onMismatch bool) bool {
   803  	str, status := getstring(bv)
   804  	if status != QROK {
   805  		return onMismatch
   806  	}
   807  	switch op {
   808  	case QRMatch:
   809  		return reval.re.MatchString(str)
   810  	case QRNoMatch:
   811  		return !reval.re.MatchString(str)
   812  	}
   813  	panic("unreachable")
   814  }
   815  
   816  // getuint64 returns QROutOfRange for negative values
   817  func getuint64(val *querypb.BindVariable) (uv uint64, status int) {
   818  	bv, err := sqltypes.BindVariableToValue(val)
   819  	if err != nil {
   820  		return 0, QROutOfRange
   821  	}
   822  	v, err := evalengine.ToUint64(bv)
   823  	if err != nil {
   824  		return 0, QROutOfRange
   825  	}
   826  	return v, QROK
   827  }
   828  
   829  // getint64 returns QROutOfRange if a uint64 is too large
   830  func getint64(val *querypb.BindVariable) (iv int64, status int) {
   831  	bv, err := sqltypes.BindVariableToValue(val)
   832  	if err != nil {
   833  		return 0, QROutOfRange
   834  	}
   835  	v, err := evalengine.ToInt64(bv)
   836  	if err != nil {
   837  		return 0, QROutOfRange
   838  	}
   839  	return v, QROK
   840  }
   841  
   842  // TODO(sougou): this is inefficient. Optimize to use []byte.
   843  func getstring(val *querypb.BindVariable) (s string, status int) {
   844  	if sqltypes.IsIntegral(val.Type) || sqltypes.IsFloat(val.Type) || sqltypes.IsText(val.Type) || sqltypes.IsBinary(val.Type) {
   845  		return string(val.Value), QROK
   846  	}
   847  	return "", QRMismatch
   848  }
   849  
   850  //-----------------------------------------------
   851  // Support functions for JSON
   852  
   853  // MapStrOperator maps a string representation to an Operator.
   854  func MapStrOperator(strop string) (op Operator, err error) {
   855  	if op, ok := opmap[strop]; ok {
   856  		return op, nil
   857  	}
   858  	return QRNoOp, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "invalid Operator %s", strop)
   859  }
   860  
   861  // BuildQueryRule builds a query rule from a ruleInfo.
   862  func BuildQueryRule(ruleInfo map[string]any) (qr *Rule, err error) {
   863  	qr = NewQueryRule("", "", QRFail)
   864  	for k, v := range ruleInfo {
   865  		var sv string
   866  		var lv []any
   867  		var ok bool
   868  		switch k {
   869  		case "Name", "Description", "RequestIP", "User", "Query", "Action", "LeadingComment", "TrailingComment":
   870  			sv, ok = v.(string)
   871  			if !ok {
   872  				return nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "want string for %s", k)
   873  			}
   874  		case "Plans", "BindVarConds", "TableNames":
   875  			lv, ok = v.([]any)
   876  			if !ok {
   877  				return nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "want list for %s", k)
   878  			}
   879  		default:
   880  			return nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "unrecognized tag %s", k)
   881  		}
   882  		switch k {
   883  		case "Name":
   884  			qr.Name = sv
   885  		case "Description":
   886  			qr.Description = sv
   887  		case "RequestIP":
   888  			err = qr.SetIPCond(sv)
   889  			if err != nil {
   890  				return nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "could not set IP condition: %v", sv)
   891  			}
   892  		case "User":
   893  			err = qr.SetUserCond(sv)
   894  			if err != nil {
   895  				return nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "could not set User condition: %v", sv)
   896  			}
   897  		case "Query":
   898  			err = qr.SetQueryCond(sv)
   899  			if err != nil {
   900  				return nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "could not set Query condition: %v", sv)
   901  			}
   902  		case "LeadingComment":
   903  			err = qr.SetLeadingCommentCond(sv)
   904  			if err != nil {
   905  				return nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "could not set LeadingComment condition: %v", sv)
   906  			}
   907  		case "TrailingComment":
   908  			err = qr.SetTrailingCommentCond(sv)
   909  			if err != nil {
   910  				return nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "could not set TrailingComment condition: %v", sv)
   911  			}
   912  		case "Plans":
   913  			for _, p := range lv {
   914  				pv, ok := p.(string)
   915  				if !ok {
   916  					return nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "want string for Plans")
   917  				}
   918  				pt, ok := planbuilder.PlanByName(pv)
   919  				if !ok {
   920  					return nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "invalid plan name: %s", pv)
   921  				}
   922  				qr.AddPlanCond(pt)
   923  			}
   924  		case "TableNames":
   925  			for _, t := range lv {
   926  				tableName, ok := t.(string)
   927  				if !ok {
   928  					return nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "want string for TableNames")
   929  				}
   930  				qr.AddTableCond(tableName)
   931  			}
   932  		case "BindVarConds":
   933  			for _, bvc := range lv {
   934  				name, onAbsent, onMismatch, op, value, err := buildBindVarCondition(bvc)
   935  				if err != nil {
   936  					return nil, err
   937  				}
   938  				err = qr.AddBindVarCond(name, onAbsent, onMismatch, op, value)
   939  				if err != nil {
   940  					return nil, err
   941  				}
   942  			}
   943  		case "Action":
   944  			switch sv {
   945  			case "FAIL":
   946  				qr.act = QRFail
   947  			case "FAIL_RETRY":
   948  				qr.act = QRFailRetry
   949  			case "BUFFER":
   950  				qr.act = QRBuffer
   951  			default:
   952  				return nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "invalid Action %s", sv)
   953  			}
   954  		}
   955  	}
   956  	return qr, nil
   957  }
   958  
   959  func buildBindVarCondition(bvc any) (name string, onAbsent, onMismatch bool, op Operator, value any, err error) {
   960  	bvcinfo, ok := bvc.(map[string]any)
   961  	if !ok {
   962  		err = vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "want json object for bind var conditions")
   963  		return
   964  	}
   965  
   966  	var v any
   967  	v, ok = bvcinfo["Name"]
   968  	if !ok {
   969  		err = vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "Name missing in BindVarConds")
   970  		return
   971  	}
   972  	name, ok = v.(string)
   973  	if !ok {
   974  		err = vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "want string for Name in BindVarConds")
   975  		return
   976  	}
   977  
   978  	v, ok = bvcinfo["OnAbsent"]
   979  	if !ok {
   980  		err = vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "OnAbsent missing in BindVarConds")
   981  		return
   982  	}
   983  	onAbsent, ok = v.(bool)
   984  	if !ok {
   985  		err = vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "want bool for OnAbsent")
   986  		return
   987  	}
   988  
   989  	v, ok = bvcinfo["Operator"]
   990  	if !ok {
   991  		err = vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "Operator missing in BindVarConds")
   992  		return
   993  	}
   994  	strop, ok := v.(string)
   995  	if !ok {
   996  		err = vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "want string for Operator")
   997  		return
   998  	}
   999  	op, err = MapStrOperator(strop)
  1000  	if err != nil {
  1001  		return
  1002  	}
  1003  	if op == QRNoOp {
  1004  		return
  1005  	}
  1006  	v, ok = bvcinfo["Value"]
  1007  	if !ok {
  1008  		err = vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "Value missing in BindVarConds")
  1009  		return
  1010  	}
  1011  	if op >= QREqual && op <= QRLessEqual {
  1012  		switch v := v.(type) {
  1013  		case json.Number:
  1014  			value, err = v.Int64()
  1015  			if err != nil {
  1016  				// Maybe uint64
  1017  				value, err = strconv.ParseUint(string(v), 10, 64)
  1018  				if err != nil {
  1019  					err = vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "want int64/uint64: %s", string(v))
  1020  					return
  1021  				}
  1022  			}
  1023  		case string:
  1024  			value = v
  1025  		default:
  1026  			err = vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "want string or number: %v", v)
  1027  			return
  1028  		}
  1029  	} else if op == QRMatch || op == QRNoMatch {
  1030  		strvalue, ok := v.(string)
  1031  		if !ok {
  1032  			err = vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "want string: %v", v)
  1033  			return
  1034  		}
  1035  		value = strvalue
  1036  	}
  1037  
  1038  	v, ok = bvcinfo["OnMismatch"]
  1039  	if !ok {
  1040  		err = vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "OnMismatch missing in BindVarConds")
  1041  		return
  1042  	}
  1043  	onMismatch, ok = v.(bool)
  1044  	if !ok {
  1045  		err = vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "want bool for OnMismatch")
  1046  		return
  1047  	}
  1048  	return
  1049  }
  1050  
  1051  func safeEncode(b *bytes.Buffer, prefix string, v any) {
  1052  	enc := json.NewEncoder(b)
  1053  	_, _ = b.WriteString(prefix)
  1054  	if err := enc.Encode(v); err != nil {
  1055  		_ = enc.Encode(err.Error())
  1056  	}
  1057  }