github.com/movsb/taorm@v0.0.0-20201209183410-91bafb0b22a6/filter/filter.go (about)

     1  package filter
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"reflect"
     7  	"strconv"
     8  	"strings"
     9  )
    10  
    11  // Mapper maps
    12  type Mapper map[string]interface{}
    13  
    14  type Fielder func(field string) reflect.Type
    15  
    16  type _Filter struct {
    17  	mapper    Mapper
    18  	fielder   Fielder
    19  	tableName string
    20  }
    21  
    22  // Filter news an filter and calls its Filter method.
    23  func Filter(fielder Fielder, filter string, mapper Mapper, tableName string) (query string, args []interface{}, err error) {
    24  	f := &_Filter{
    25  		mapper:    mapper,
    26  		fielder:   fielder,
    27  		tableName: tableName,
    28  	}
    29  	return f.Filter(filter)
    30  }
    31  
    32  // Filter the result of query according to some conditions.
    33  //
    34  // Mapper prototypes are:
    35  //
    36  //   (strval)   => (strval)
    37  //   (strval)   => (key, strval)
    38  //   (strval)   => (intval)
    39  //   (strval)   => (key, intval)
    40  //   (strval)   => (boolval)
    41  //   (strval)   => (key, boolval)
    42  //   (intval)   => (intval)
    43  //   (intval)   => (key, intval)
    44  //   ()         => (enum)
    45  //   ()         => (key, enum)
    46  //   (enumItem) => (enumItem, enum)
    47  //   (enumItem) => (key, enumItem, enum)
    48  // Override:
    49  //   (operator, strval)   => (query, args)
    50  //   (operator, strval)   => ()
    51  //   (strval)             => ()
    52  //   (intval)             => ()
    53  //
    54  func (i *_Filter) Filter(filter string) (string, []interface{}, error) {
    55  	tokenizer := NewTokenizer(filter)
    56  	parser := NewParser(tokenizer)
    57  
    58  	ast, err := parser.Parse()
    59  	if err != nil {
    60  		return "", nil, err
    61  	}
    62  
    63  	var query string
    64  	var args []interface{}
    65  
    66  	for index, and := range ast.AndExprs {
    67  		orQuery, orArgs, err := i.filterAndExpression(and)
    68  		if err != nil {
    69  			return "", nil, err
    70  		}
    71  		if orQuery == "" {
    72  			continue
    73  		}
    74  		if index != 0 && query != "" {
    75  			query += " AND "
    76  		}
    77  		query += "(" + orQuery + ")"
    78  		args = append(args, orArgs...)
    79  	}
    80  
    81  	return query, args, nil
    82  }
    83  
    84  func (i *_Filter) callMapper(expr *Expression) error {
    85  	mapper, ok := i.mapper[expr.Name]
    86  	if !ok {
    87  		return nil
    88  	}
    89  
    90  	// just in case. Already converted? no matter who did this.
    91  	if expr.Type != ValueTypeRaw {
    92  		return nil
    93  	}
    94  	raw, ok := expr.Value.(string)
    95  	// shouldn't happen
    96  	if !ok {
    97  		return nil
    98  	}
    99  
   100  	switch typed := mapper.(type) {
   101  	// (strval) => (strval)
   102  	case func(string) string:
   103  		val := typed(raw)
   104  		expr.SetString(val)
   105  	// (strval) => (key, strval)
   106  	case func(string) (string, string):
   107  		key, val := typed(raw)
   108  		expr.SetName(key)
   109  		expr.SetString(val)
   110  	// (strval) => (intval)
   111  	case func(string) int64:
   112  		val := typed(raw)
   113  		expr.SetNumber(val)
   114  	// (strval) => (key, intval)
   115  	case func(string) (string, int64):
   116  		key, val := typed(raw)
   117  		expr.SetName(key)
   118  		expr.SetNumber(val)
   119  	// (strval) => (boolval)
   120  	case func(string) bool:
   121  		val := typed(raw)
   122  		expr.SetBoolean(val)
   123  	// (strval) => (key, boolval)
   124  	case func(string) (string, bool):
   125  		key, val := typed(raw)
   126  		expr.SetName(key)
   127  		expr.SetBoolean(val)
   128  	// (intval) => (intval)
   129  	case func(int64) int64:
   130  		num, err := strconv.ParseInt(raw, 10, 64)
   131  		if err != nil {
   132  			return err
   133  		}
   134  		val := typed(num)
   135  		expr.SetNumber(val)
   136  	// (intval) => (key, intval)
   137  	case func(int64) (string, int64):
   138  		num, err := strconv.ParseInt(raw, 10, 64)
   139  		if err != nil {
   140  			return err
   141  		}
   142  		key, val := typed(num)
   143  		expr.SetName(key)
   144  		expr.SetNumber(val)
   145  	// () => enum
   146  	case map[string]int32:
   147  		val, ok := typed[raw]
   148  		if !ok {
   149  			return newEnumNotFoundError(expr.Name, raw)
   150  		}
   151  		expr.SetNumber(int64(val))
   152  	// () => enum
   153  	case func() map[string]int32:
   154  		val, ok := typed()[raw]
   155  		if !ok {
   156  			return newEnumNotFoundError(expr.Name, raw)
   157  		}
   158  		expr.SetNumber(int64(val))
   159  	// () => (key, enum)
   160  	case func() (string, map[string]int32):
   161  		key, values := typed()
   162  		val, ok := values[raw]
   163  		if !ok {
   164  			return newEnumNotFoundError(key, raw)
   165  		}
   166  		expr.SetName(key)
   167  		expr.SetNumber(int64(val))
   168  	// (enumItem) => (enumItem, enum)
   169  	case func(string) (string, map[string]int32):
   170  		enumItem, values := typed(raw)
   171  		val, ok := values[enumItem]
   172  		if !ok {
   173  			return newEnumNotFoundError(expr.Name, enumItem)
   174  		}
   175  		expr.SetNumber(int64(val))
   176  	// (enumItem) => (key, enumItem, enum)
   177  	case func(string) (string, string, map[string]int32):
   178  		key, enumItem, values := typed(raw)
   179  		val, ok := values[enumItem]
   180  		if !ok {
   181  			return newEnumNotFoundError(key, enumItem)
   182  		}
   183  		expr.SetName(key)
   184  		expr.SetNumber(int64(val))
   185  	// (operator, strval) => (query, args)
   186  	case func(TokenType, string) (string, []interface{}):
   187  		query, args := typed(expr.Operator.TokenType, raw)
   188  		expr.overrider = &_ExpressionOverrider{
   189  			Query: query,
   190  			Args:  args,
   191  		}
   192  	// (operator, strval) => ()
   193  	case func(TokenType, string):
   194  		typed(expr.Operator.TokenType, raw)
   195  		return errSkipFilter
   196  	case func(string):
   197  		if expr.Operator.TokenType != TokenTypeEqual {
   198  			return &InvalidOperatorError{
   199  				Name:     expr.Name,
   200  				Operator: expr.Operator.TokenType,
   201  			}
   202  		}
   203  		typed(raw)
   204  		return errSkipFilter
   205  	case func(int64):
   206  		if expr.Operator.TokenType != TokenTypeEqual {
   207  			return &InvalidOperatorError{
   208  				Name:     expr.Name,
   209  				Operator: expr.Operator.TokenType,
   210  			}
   211  		}
   212  		num, err := strconv.ParseInt(raw, 10, 64)
   213  		if err != nil {
   214  			return err
   215  		}
   216  		typed(num)
   217  		return errSkipFilter
   218  	default:
   219  		return &UnknownMapperError{}
   220  	}
   221  	return nil
   222  }
   223  
   224  func (i *_Filter) filterAndExpression(andExpr *AndExpression) (query string, args []interface{}, err error) {
   225  	where := bytes.NewBuffer(nil)
   226  
   227  	for index, expr := range andExpr.OrExprs {
   228  		condition := ""
   229  
   230  		// calls to mapper to see if user want to do some customizing
   231  		if err := i.callMapper(expr); err != nil {
   232  			switch err {
   233  			default:
   234  				return "", nil, err
   235  			case errSkipFilter:
   236  				continue
   237  			}
   238  		}
   239  
   240  		if expr.overrider == nil {
   241  			fType := i.fielder(expr.Name)
   242  			if fType == nil {
   243  				return "", nil, fmt.Errorf("filter: unknown field: %s", expr.Name)
   244  			}
   245  			vType := ValueTypeRaw
   246  			switch fType.Kind() {
   247  			case reflect.Int, reflect.Uint, reflect.Int8, reflect.Uint8, reflect.Int16, reflect.Uint16,
   248  				reflect.Int32, reflect.Uint32, reflect.Int64, reflect.Uint64, reflect.Float32, reflect.Float64:
   249  				vType = ValueTypeNumber
   250  			case reflect.String:
   251  				vType = ValueTypeString
   252  			case reflect.Bool:
   253  				vType = ValueTypeBoolean
   254  			default:
   255  				return "", nil, fmt.Errorf("filter: invalid field type to filter")
   256  			}
   257  
   258  			columnName := expr.Name
   259  
   260  			if err := expr.convertTo(vType); err != nil {
   261  				return "", nil, err
   262  			}
   263  
   264  			switch expr.Operator.TokenType {
   265  			case TokenTypeEqual:
   266  				condition = fmt.Sprintf("%s.%s = ?", i.tableName, columnName)
   267  			case TokenTypeNotEqual:
   268  				condition = fmt.Sprintf("%s.%s <> ?", i.tableName, columnName)
   269  			case TokenTypeInclude:
   270  				condition = fmt.Sprintf("%s.%s LIKE ?", i.tableName, columnName)
   271  			case TokenTypeNotInclude:
   272  				condition = fmt.Sprintf("%s.%s NOT LIKE ?", i.tableName, columnName)
   273  			case TokenTypeStartsWith:
   274  				condition = fmt.Sprintf("%s.%s LIKE ?", i.tableName, columnName)
   275  			case TokenTypeEndsWith:
   276  				condition = fmt.Sprintf("%s.%s LIKE ?", i.tableName, columnName)
   277  			case TokenTypeMatch, TokenTypeNotMatch:
   278  				return "", nil, fmt.Errorf("not supported operator: %s", expr.Operator.TokenValue)
   279  			case TokenTypeGreaterThan:
   280  				condition = fmt.Sprintf("%s.%s > ?", i.tableName, columnName)
   281  			case TokenTypeLessThan:
   282  				condition = fmt.Sprintf("%s.%s < ?", i.tableName, columnName)
   283  			case TokenTypeGreaterThanOrEqual:
   284  				condition = fmt.Sprintf("%s.%s >= ?", i.tableName, columnName)
   285  			case TokenTypeLessThanOrEqual:
   286  				condition = fmt.Sprintf("%s.%s <= ?", i.tableName, columnName)
   287  			default:
   288  				return "", nil, fmt.Errorf("unknown operator: %s", expr.Operator.TokenValue)
   289  			}
   290  
   291  			where.WriteString(condition)
   292  
   293  			switch expr.Operator.TokenType {
   294  			// must be string
   295  			case TokenTypeInclude, TokenTypeNotInclude, TokenTypeStartsWith, TokenTypeEndsWith:
   296  				search := strings.Replace(fmt.Sprintf("%v", expr.Value), "%", "%%", -1)
   297  				format := "%s"
   298  				switch expr.Operator.TokenType {
   299  				case TokenTypeInclude, TokenTypeNotInclude:
   300  					format = "%%%s%%"
   301  				case TokenTypeStartsWith:
   302  					format = "%s%%"
   303  				case TokenTypeEndsWith:
   304  					format = "%%%s"
   305  				}
   306  				args = append(args, fmt.Sprintf(format, search))
   307  			default:
   308  				switch expr.Type {
   309  				case ValueTypeBoolean:
   310  					// by default, we assume that boolean is stored as tinyint(1).
   311  					if expr.Value.(bool) {
   312  						args = append(args, 1)
   313  					} else {
   314  						args = append(args, 0)
   315  					}
   316  				default:
   317  					args = append(args, expr.Value)
   318  				}
   319  			}
   320  		} else {
   321  			where.WriteString(expr.overrider.Query)
   322  			args = append(args, expr.overrider.Args...)
   323  		}
   324  
   325  		if index < len(andExpr.OrExprs)-1 {
   326  			where.WriteString(" OR ")
   327  		}
   328  	}
   329  
   330  	return where.String(), args, nil
   331  }