github.com/kaleido-io/firefly@v0.0.0-20210622132723-8b4b6aacb971/internal/database/sqlcommon/filter_sql.go (about) 1 // Copyright © 2021 Kaleido, Inc. 2 // 3 // SPDX-License-Identifier: Apache-2.0 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 17 package sqlcommon 18 19 import ( 20 "context" 21 "fmt" 22 "strings" 23 24 sq "github.com/Masterminds/squirrel" 25 "github.com/kaleido-io/firefly/internal/i18n" 26 "github.com/kaleido-io/firefly/pkg/database" 27 ) 28 29 func (s *SQLCommon) filterSelect(ctx context.Context, tableName string, sel sq.SelectBuilder, filter database.Filter, typeMap map[string]string, preconditions ...sq.Sqlizer) (sq.SelectBuilder, error) { 30 fi, err := filter.Finalize() 31 if err != nil { 32 return sel, err 33 } 34 if len(fi.Sort) == 0 { 35 fi.Sort = []string{"sequence"} 36 fi.Descending = true 37 } 38 sel, err = s.filterSelectFinalized(ctx, tableName, sel, fi, typeMap, preconditions...) 39 direction := "" 40 if fi.Descending { 41 direction = " DESC" 42 } 43 sort := make([]string, len(fi.Sort)) 44 for i, field := range fi.Sort { 45 sort[i] = s.mapField(tableName, field, typeMap) 46 } 47 sel = sel.OrderBy(fmt.Sprintf("%s%s", strings.Join(sort, ","), direction)) 48 if err == nil { 49 if fi.Skip > 0 { 50 sel = sel.Offset(fi.Skip) 51 } 52 if fi.Limit > 0 { 53 sel = sel.Limit(fi.Limit) 54 } 55 } 56 return sel, err 57 } 58 59 func (s *SQLCommon) filterSelectFinalized(ctx context.Context, tableName string, sel sq.SelectBuilder, fi *database.FilterInfo, tm map[string]string, preconditions ...sq.Sqlizer) (sq.SelectBuilder, error) { 60 fop, err := s.filterOp(ctx, tableName, fi, tm) 61 if err != nil { 62 return sel, err 63 } 64 if len(preconditions) > 0 { 65 and := make(sq.And, len(preconditions)+1) 66 for i, p := range preconditions { 67 and[i] = p 68 } 69 and[len(preconditions)] = fop 70 fop = and 71 } 72 return sel.Where(fop), nil 73 } 74 75 func (s *SQLCommon) buildUpdate(sel sq.UpdateBuilder, update database.Update, typeMap map[string]string) (sq.UpdateBuilder, error) { 76 ui, err := update.Finalize() 77 if err != nil { 78 return sel, err 79 } 80 for _, so := range ui.SetOperations { 81 82 sel = sel.Set(s.mapField("", so.Field, typeMap), so.Value) 83 } 84 return sel, nil 85 } 86 87 func (s *SQLCommon) filterUpdate(ctx context.Context, tableName string, update sq.UpdateBuilder, filter database.Filter, typeMap map[string]string) (sq.UpdateBuilder, error) { 88 fi, err := filter.Finalize() 89 var fop sq.Sqlizer 90 if err == nil { 91 fop, err = s.filterOp(ctx, tableName, fi, typeMap) 92 } 93 if err != nil { 94 return update, err 95 } 96 return update.Where(fop), nil 97 } 98 99 func (s *SQLCommon) escapeLike(value database.FieldSerialization) string { 100 v, _ := value.Value() 101 vs, _ := v.(string) 102 vs = strings.ReplaceAll(vs, "[", "[[]") 103 vs = strings.ReplaceAll(vs, "%", "[%]") 104 vs = strings.ReplaceAll(vs, "_", "[_]") 105 return vs 106 } 107 108 func (s *SQLCommon) mapField(tableName, fieldName string, tm map[string]string) string { 109 if fieldName == "sequence" { 110 return s.provider.SequenceField(tableName) 111 } 112 var field = fieldName 113 if tm != nil { 114 if mf, ok := tm[fieldName]; ok { 115 field = mf 116 } 117 } 118 if tableName != "" { 119 field = fmt.Sprintf("%s.%s", tableName, field) 120 } 121 return field 122 } 123 124 func (s *SQLCommon) filterOp(ctx context.Context, tableName string, op *database.FilterInfo, tm map[string]string) (sq.Sqlizer, error) { 125 switch op.Op { 126 case database.FilterOpOr: 127 return s.filterOr(ctx, tableName, op, tm) 128 case database.FilterOpAnd: 129 return s.filterAnd(ctx, tableName, op, tm) 130 case database.FilterOpEq: 131 return sq.Eq{s.mapField(tableName, op.Field, tm): op.Value}, nil 132 case database.FilterOpIn: 133 return sq.Eq{s.mapField(tableName, op.Field, tm): op.Values}, nil 134 case database.FilterOpNe: 135 return sq.NotEq{s.mapField(tableName, op.Field, tm): op.Value}, nil 136 case database.FilterOpNotIn: 137 return sq.NotEq{s.mapField(tableName, op.Field, tm): op.Values}, nil 138 case database.FilterOpCont: 139 return sq.Like{s.mapField(tableName, op.Field, tm): fmt.Sprintf("%%%s%%", s.escapeLike(op.Value))}, nil 140 case database.FilterOpNotCont: 141 return sq.NotLike{s.mapField(tableName, op.Field, tm): fmt.Sprintf("%%%s%%", s.escapeLike(op.Value))}, nil 142 case database.FilterOpICont: 143 return sq.ILike{s.mapField(tableName, op.Field, tm): fmt.Sprintf("%%%s%%", s.escapeLike(op.Value))}, nil 144 case database.FilterOpNotICont: 145 return sq.NotILike{s.mapField(tableName, op.Field, tm): fmt.Sprintf("%%%s%%", s.escapeLike(op.Value))}, nil 146 case database.FilterOpGt: 147 return sq.Gt{s.mapField(tableName, op.Field, tm): op.Value}, nil 148 case database.FilterOpGte: 149 return sq.GtOrEq{s.mapField(tableName, op.Field, tm): op.Value}, nil 150 case database.FilterOpLt: 151 return sq.Lt{s.mapField(tableName, op.Field, tm): op.Value}, nil 152 case database.FilterOpLte: 153 return sq.LtOrEq{s.mapField(tableName, op.Field, tm): op.Value}, nil 154 default: 155 return nil, i18n.NewError(ctx, i18n.MsgUnsupportedSQLOpInFilter, op.Op) 156 } 157 } 158 159 func (s *SQLCommon) filterOr(ctx context.Context, tableName string, op *database.FilterInfo, tm map[string]string) (sq.Sqlizer, error) { 160 var err error 161 or := make(sq.Or, len(op.Children)) 162 for i, c := range op.Children { 163 if or[i], err = s.filterOp(ctx, tableName, c, tm); err != nil { 164 return nil, err 165 } 166 } 167 return or, nil 168 } 169 170 func (s *SQLCommon) filterAnd(ctx context.Context, tableName string, op *database.FilterInfo, tm map[string]string) (sq.Sqlizer, error) { 171 var err error 172 and := make(sq.And, len(op.Children)) 173 for i, c := range op.Children { 174 if and[i], err = s.filterOp(ctx, tableName, c, tm); err != nil { 175 return nil, err 176 } 177 } 178 return and, nil 179 }