github.com/dolthub/go-mysql-server@v0.18.0/sql/expression/function/locate.go (about) 1 // Copyright 2020-2021 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package function 16 17 import ( 18 "fmt" 19 "strings" 20 21 "github.com/dolthub/go-mysql-server/sql" 22 "github.com/dolthub/go-mysql-server/sql/expression" 23 "github.com/dolthub/go-mysql-server/sql/types" 24 ) 25 26 // Locate returns the position of the first occurrence of a substring in a string. 27 // If the substring is not found within the original string, this function returns 0. 28 // This function performs a case-insensitive search. 29 type Locate struct { 30 expression.NaryExpression 31 } 32 33 var _ sql.FunctionExpression = (*Locate)(nil) 34 var _ sql.CollationCoercible = (*Locate)(nil) 35 36 // NewLocate returns a new Locate function. 37 func NewLocate(exprs ...sql.Expression) (sql.Expression, error) { 38 if len(exprs) < 2 || len(exprs) > 3 { 39 return nil, sql.ErrInvalidArgumentNumber.New("LOCATE", "2 or 3", len(exprs)) 40 } 41 42 return &Locate{expression.NaryExpression{ChildExpressions: exprs}}, nil 43 } 44 45 // FunctionName implements sql.FunctionExpression 46 func (l *Locate) FunctionName() string { 47 return "locate" 48 } 49 50 // Description implements sql.FunctionExpression 51 func (l *Locate) Description() string { 52 return "returns the position of the first occurrence of a substring in a string." 53 } 54 55 // WithChildren implements the Expression interface. 56 func (l *Locate) WithChildren(children ...sql.Expression) (sql.Expression, error) { 57 if len(children) < 2 || len(children) > 3 { 58 return nil, sql.ErrInvalidChildrenNumber.New(l, len(children), 2) 59 } 60 61 return &Locate{expression.NaryExpression{ChildExpressions: children}}, nil 62 } 63 64 // Type implements the sql.Expression interface. 65 func (l *Locate) Type() sql.Type { return types.Int32 } 66 67 // CollationCoercibility implements the interface sql.CollationCoercible. 68 func (*Locate) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { 69 return sql.Collation_binary, 5 70 } 71 72 func (l *Locate) String() string { 73 switch len(l.ChildExpressions) { 74 case 2: 75 return fmt.Sprintf("%s(%s,%s)", l.FunctionName(), l.ChildExpressions[0], l.ChildExpressions[1]) 76 case 3: 77 return fmt.Sprintf("%s(%s,%s,%s)", l.FunctionName(), l.ChildExpressions[0], l.ChildExpressions[1], l.ChildExpressions[2]) 78 } 79 return "" 80 } 81 82 func (l *Locate) DebugString() string { 83 switch len(l.ChildExpressions) { 84 case 2: 85 return fmt.Sprintf("%s(%s,%s)", l.FunctionName(), sql.DebugString(l.ChildExpressions[0]), sql.DebugString(l.ChildExpressions[1])) 86 case 3: 87 return fmt.Sprintf("%s(%s,%s,%s)", l.FunctionName(), sql.DebugString(l.ChildExpressions[0]), sql.DebugString(l.ChildExpressions[1]), sql.DebugString(l.ChildExpressions[2])) 88 } 89 return "" 90 } 91 92 // Eval implements the sql.Expression interface. 93 func (l *Locate) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { 94 if len(l.ChildExpressions) < 2 || len(l.ChildExpressions) > 3 { 95 return nil, nil 96 } 97 98 substrVal, err := l.ChildExpressions[0].Eval(ctx, row) 99 if err != nil { 100 return nil, err 101 } 102 103 if substrVal == nil { 104 return nil, nil 105 } 106 107 substr, ok := substrVal.(string) 108 if !ok { 109 return nil, sql.ErrInvalidArgumentDetails.New("locate", "substring must be a string") 110 } 111 112 strVal, err := l.ChildExpressions[1].Eval(ctx, row) 113 if err != nil { 114 return nil, err 115 } 116 117 if strVal == nil { 118 return nil, nil 119 } 120 121 str, ok := strVal.(string) 122 if !ok { 123 return nil, sql.ErrInvalidArgumentDetails.New("locate", "string must be a string") 124 } 125 126 position := 1 127 128 if len(l.ChildExpressions) == 3 { 129 posVal, err := l.ChildExpressions[2].Eval(ctx, row) 130 if err != nil { 131 return nil, err 132 } 133 134 if posVal != nil { 135 posInt, _, err := types.Int32.Convert(posVal) 136 if err != nil { 137 return nil, sql.ErrInvalidArgumentDetails.New("locate", "start must be an integer") 138 } 139 position = int(posInt.(int32)) 140 } 141 } 142 143 // Edge cases that cannot be handled by strings.Index. 144 switch { 145 // Position 0 doesn't exist. 146 case position == 0 || (len(str) > 0 && position > len(str)): 147 return int32(0), nil 148 // Locate("", "") returns 1 if start is 1. 149 case len(substr) == 0 && len(str) == 0: 150 if position == 1 { 151 return int32(1), nil 152 } 153 return int32(0), nil 154 } 155 156 res := strings.Index(strings.ToLower(str[position-1:]), strings.ToLower(substr)) 157 if res == -1 { 158 return int32(0), nil 159 } 160 return int32(res + position), nil 161 }