github.com/dolthub/go-mysql-server@v0.18.0/sql/expression/function/find_in_set.go (about) 1 // Copyright 2023 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package function 16 17 import ( 18 "fmt" 19 "strings" 20 21 "github.com/dolthub/go-mysql-server/sql" 22 "github.com/dolthub/go-mysql-server/sql/expression" 23 "github.com/dolthub/go-mysql-server/sql/types" 24 ) 25 26 // FindInSet takes out the specified unit(s) from the time expression. 27 type FindInSet struct { 28 expression.BinaryExpressionStub 29 } 30 31 var _ sql.FunctionExpression = (*FindInSet)(nil) 32 var _ sql.CollationCoercible = (*FindInSet)(nil) 33 34 // NewFindInSet creates a new FindInSet expression. 35 func NewFindInSet(e1, e2 sql.Expression) sql.Expression { 36 return &FindInSet{ 37 expression.BinaryExpressionStub{ 38 LeftChild: e1, 39 RightChild: e2, 40 }, 41 } 42 } 43 44 // FunctionName implements sql.FunctionExpression 45 func (f *FindInSet) FunctionName() string { 46 return "find_in_set" 47 } 48 49 // Description implements sql.FunctionExpression 50 func (f *FindInSet) Description() string { 51 return "returns a value in the range of 1 to N if the string str is in the string list strlist consisting of N substrings" 52 } 53 54 // Type implements the Expression interface. 55 func (f *FindInSet) Type() sql.Type { return types.Int64 } 56 57 // CollationCoercibility implements the interface sql.CollationCoercible. 58 func (*FindInSet) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { 59 return ctx.GetCollation(), 5 60 } 61 62 func (f *FindInSet) String() string { 63 return fmt.Sprintf("%s(%s from %s)", f.FunctionName(), f.LeftChild, f.RightChild) 64 } 65 66 // WithChildren implements the Expression interface. 67 func (f *FindInSet) WithChildren(children ...sql.Expression) (sql.Expression, error) { 68 if len(children) != 2 { 69 return nil, sql.ErrInvalidChildrenNumber.New(f, len(children), 2) 70 } 71 return NewFindInSet(children[0], children[1]), nil 72 } 73 74 // Eval implements the Expression interface. 75 func (f *FindInSet) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { 76 if f.LeftChild == nil || f.RightChild == nil { 77 return nil, nil 78 } 79 80 left, err := f.LeftChild.Eval(ctx, row) 81 if err != nil { 82 return nil, err 83 } 84 85 right, err := f.RightChild.Eval(ctx, row) 86 if err != nil { 87 return nil, err 88 } 89 90 if left == nil || right == nil { 91 return nil, nil 92 } 93 94 lVal, _, err := types.LongText.Convert(left) 95 if err != nil { 96 return nil, err 97 } 98 l := lVal.(string) 99 100 // always returns 0 when left contains a comma 101 if strings.Contains(l, ",") { 102 return 0, nil 103 } 104 105 var r string 106 rType := f.RightChild.Type() 107 if setType, ok := rType.(types.SetType); ok { 108 // TODO: set type should take advantage of bit arithmetic 109 r, err = setType.BitsToString(right.(uint64)) 110 if err != nil { 111 return nil, err 112 } 113 } else if enumType, ok := rType.(types.EnumType); ok { 114 r, ok = enumType.At(int(right.(uint16))) 115 if !ok { 116 return nil, fmt.Errorf("enum missing index %v", r) 117 } 118 } else { 119 var rVal interface{} 120 rVal, _, err = types.LongText.Convert(right) 121 if err != nil { 122 return nil, err 123 } 124 r = rVal.(string) 125 } 126 127 leftColl, leftCoer := sql.GetCoercibility(ctx, f.LeftChild) 128 rightColl, rightCoer := sql.GetCoercibility(ctx, f.RightChild) 129 collPref, _ := sql.ResolveCoercibility(leftColl, leftCoer, rightColl, rightCoer) 130 131 strType := types.CreateLongText(collPref) 132 for i, r := range strings.Split(r, ",") { 133 cmp, err := strType.Compare(l, r) 134 if err != nil { 135 return nil, err 136 } 137 if cmp == 0 { 138 return i + 1, nil 139 } 140 } 141 142 return 0, nil 143 }