github.com/kaleido-io/firefly@v0.0.0-20210622132723-8b4b6aacb971/internal/database/sqlcommon/filter_sql.go (about)

     1  // Copyright © 2021 Kaleido, Inc.
     2  //
     3  // SPDX-License-Identifier: Apache-2.0
     4  //
     5  // Licensed under the Apache License, Version 2.0 (the "License");
     6  // you may not use this file except in compliance with the License.
     7  // You may obtain a copy of the License at
     8  //
     9  //     http://www.apache.org/licenses/LICENSE-2.0
    10  //
    11  // Unless required by applicable law or agreed to in writing, software
    12  // distributed under the License is distributed on an "AS IS" BASIS,
    13  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  // See the License for the specific language governing permissions and
    15  // limitations under the License.
    16  
    17  package sqlcommon
    18  
    19  import (
    20  	"context"
    21  	"fmt"
    22  	"strings"
    23  
    24  	sq "github.com/Masterminds/squirrel"
    25  	"github.com/kaleido-io/firefly/internal/i18n"
    26  	"github.com/kaleido-io/firefly/pkg/database"
    27  )
    28  
    29  func (s *SQLCommon) filterSelect(ctx context.Context, tableName string, sel sq.SelectBuilder, filter database.Filter, typeMap map[string]string, preconditions ...sq.Sqlizer) (sq.SelectBuilder, error) {
    30  	fi, err := filter.Finalize()
    31  	if err != nil {
    32  		return sel, err
    33  	}
    34  	if len(fi.Sort) == 0 {
    35  		fi.Sort = []string{"sequence"}
    36  		fi.Descending = true
    37  	}
    38  	sel, err = s.filterSelectFinalized(ctx, tableName, sel, fi, typeMap, preconditions...)
    39  	direction := ""
    40  	if fi.Descending {
    41  		direction = " DESC"
    42  	}
    43  	sort := make([]string, len(fi.Sort))
    44  	for i, field := range fi.Sort {
    45  		sort[i] = s.mapField(tableName, field, typeMap)
    46  	}
    47  	sel = sel.OrderBy(fmt.Sprintf("%s%s", strings.Join(sort, ","), direction))
    48  	if err == nil {
    49  		if fi.Skip > 0 {
    50  			sel = sel.Offset(fi.Skip)
    51  		}
    52  		if fi.Limit > 0 {
    53  			sel = sel.Limit(fi.Limit)
    54  		}
    55  	}
    56  	return sel, err
    57  }
    58  
    59  func (s *SQLCommon) filterSelectFinalized(ctx context.Context, tableName string, sel sq.SelectBuilder, fi *database.FilterInfo, tm map[string]string, preconditions ...sq.Sqlizer) (sq.SelectBuilder, error) {
    60  	fop, err := s.filterOp(ctx, tableName, fi, tm)
    61  	if err != nil {
    62  		return sel, err
    63  	}
    64  	if len(preconditions) > 0 {
    65  		and := make(sq.And, len(preconditions)+1)
    66  		for i, p := range preconditions {
    67  			and[i] = p
    68  		}
    69  		and[len(preconditions)] = fop
    70  		fop = and
    71  	}
    72  	return sel.Where(fop), nil
    73  }
    74  
    75  func (s *SQLCommon) buildUpdate(sel sq.UpdateBuilder, update database.Update, typeMap map[string]string) (sq.UpdateBuilder, error) {
    76  	ui, err := update.Finalize()
    77  	if err != nil {
    78  		return sel, err
    79  	}
    80  	for _, so := range ui.SetOperations {
    81  
    82  		sel = sel.Set(s.mapField("", so.Field, typeMap), so.Value)
    83  	}
    84  	return sel, nil
    85  }
    86  
    87  func (s *SQLCommon) filterUpdate(ctx context.Context, tableName string, update sq.UpdateBuilder, filter database.Filter, typeMap map[string]string) (sq.UpdateBuilder, error) {
    88  	fi, err := filter.Finalize()
    89  	var fop sq.Sqlizer
    90  	if err == nil {
    91  		fop, err = s.filterOp(ctx, tableName, fi, typeMap)
    92  	}
    93  	if err != nil {
    94  		return update, err
    95  	}
    96  	return update.Where(fop), nil
    97  }
    98  
    99  func (s *SQLCommon) escapeLike(value database.FieldSerialization) string {
   100  	v, _ := value.Value()
   101  	vs, _ := v.(string)
   102  	vs = strings.ReplaceAll(vs, "[", "[[]")
   103  	vs = strings.ReplaceAll(vs, "%", "[%]")
   104  	vs = strings.ReplaceAll(vs, "_", "[_]")
   105  	return vs
   106  }
   107  
   108  func (s *SQLCommon) mapField(tableName, fieldName string, tm map[string]string) string {
   109  	if fieldName == "sequence" {
   110  		return s.provider.SequenceField(tableName)
   111  	}
   112  	var field = fieldName
   113  	if tm != nil {
   114  		if mf, ok := tm[fieldName]; ok {
   115  			field = mf
   116  		}
   117  	}
   118  	if tableName != "" {
   119  		field = fmt.Sprintf("%s.%s", tableName, field)
   120  	}
   121  	return field
   122  }
   123  
   124  func (s *SQLCommon) filterOp(ctx context.Context, tableName string, op *database.FilterInfo, tm map[string]string) (sq.Sqlizer, error) {
   125  	switch op.Op {
   126  	case database.FilterOpOr:
   127  		return s.filterOr(ctx, tableName, op, tm)
   128  	case database.FilterOpAnd:
   129  		return s.filterAnd(ctx, tableName, op, tm)
   130  	case database.FilterOpEq:
   131  		return sq.Eq{s.mapField(tableName, op.Field, tm): op.Value}, nil
   132  	case database.FilterOpIn:
   133  		return sq.Eq{s.mapField(tableName, op.Field, tm): op.Values}, nil
   134  	case database.FilterOpNe:
   135  		return sq.NotEq{s.mapField(tableName, op.Field, tm): op.Value}, nil
   136  	case database.FilterOpNotIn:
   137  		return sq.NotEq{s.mapField(tableName, op.Field, tm): op.Values}, nil
   138  	case database.FilterOpCont:
   139  		return sq.Like{s.mapField(tableName, op.Field, tm): fmt.Sprintf("%%%s%%", s.escapeLike(op.Value))}, nil
   140  	case database.FilterOpNotCont:
   141  		return sq.NotLike{s.mapField(tableName, op.Field, tm): fmt.Sprintf("%%%s%%", s.escapeLike(op.Value))}, nil
   142  	case database.FilterOpICont:
   143  		return sq.ILike{s.mapField(tableName, op.Field, tm): fmt.Sprintf("%%%s%%", s.escapeLike(op.Value))}, nil
   144  	case database.FilterOpNotICont:
   145  		return sq.NotILike{s.mapField(tableName, op.Field, tm): fmt.Sprintf("%%%s%%", s.escapeLike(op.Value))}, nil
   146  	case database.FilterOpGt:
   147  		return sq.Gt{s.mapField(tableName, op.Field, tm): op.Value}, nil
   148  	case database.FilterOpGte:
   149  		return sq.GtOrEq{s.mapField(tableName, op.Field, tm): op.Value}, nil
   150  	case database.FilterOpLt:
   151  		return sq.Lt{s.mapField(tableName, op.Field, tm): op.Value}, nil
   152  	case database.FilterOpLte:
   153  		return sq.LtOrEq{s.mapField(tableName, op.Field, tm): op.Value}, nil
   154  	default:
   155  		return nil, i18n.NewError(ctx, i18n.MsgUnsupportedSQLOpInFilter, op.Op)
   156  	}
   157  }
   158  
   159  func (s *SQLCommon) filterOr(ctx context.Context, tableName string, op *database.FilterInfo, tm map[string]string) (sq.Sqlizer, error) {
   160  	var err error
   161  	or := make(sq.Or, len(op.Children))
   162  	for i, c := range op.Children {
   163  		if or[i], err = s.filterOp(ctx, tableName, c, tm); err != nil {
   164  			return nil, err
   165  		}
   166  	}
   167  	return or, nil
   168  }
   169  
   170  func (s *SQLCommon) filterAnd(ctx context.Context, tableName string, op *database.FilterInfo, tm map[string]string) (sq.Sqlizer, error) {
   171  	var err error
   172  	and := make(sq.And, len(op.Children))
   173  	for i, c := range op.Children {
   174  		if and[i], err = s.filterOp(ctx, tableName, c, tm); err != nil {
   175  			return nil, err
   176  		}
   177  	}
   178  	return and, nil
   179  }