github.com/dolthub/go-mysql-server@v0.18.0/sql/expression/function/regexp_replace.go (about) 1 // Copyright 2021 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package function 16 17 import ( 18 "fmt" 19 "strings" 20 21 "gopkg.in/src-d/go-errors.v1" 22 23 "github.com/dolthub/go-mysql-server/sql" 24 "github.com/dolthub/go-mysql-server/sql/types" 25 ) 26 27 // RegexpReplace implements the REGEXP_REPLACE function. 28 // https://dev.mysql.com/doc/refman/8.0/en/regexp.html#function_regexp-replace 29 type RegexpReplace struct { 30 args []sql.Expression 31 } 32 33 var _ sql.FunctionExpression = (*RegexpReplace)(nil) 34 var _ sql.CollationCoercible = (*RegexpReplace)(nil) 35 36 // NewRegexpReplace creates a new RegexpReplace expression. 37 func NewRegexpReplace(args ...sql.Expression) (sql.Expression, error) { 38 if len(args) < 3 || len(args) > 6 { 39 return nil, sql.ErrInvalidArgumentNumber.New("regexp_replace", "3,4,5 or 6", len(args)) 40 } 41 42 return &RegexpReplace{args: args}, nil 43 } 44 45 // FunctionName implements sql.FunctionExpression 46 func (r *RegexpReplace) FunctionName() string { 47 return "regexp_replace" 48 } 49 50 // Description implements sql.FunctionExpression 51 func (r *RegexpReplace) Description() string { 52 return "replaces substrings matching regular expression." 53 } 54 55 // Type implements the sql.Expression interface. 56 func (r *RegexpReplace) Type() sql.Type { return types.LongText } 57 58 // CollationCoercibility implements the interface sql.CollationCoercible. 59 func (r *RegexpReplace) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { 60 if len(r.args) == 0 { 61 return sql.Collation_binary, 6 62 } 63 collation, coercibility = sql.GetCoercibility(ctx, r.args[0]) 64 for i := 1; i < len(r.args) && i < 3; i++ { 65 nextCollation, nextCoercibility := sql.GetCoercibility(ctx, r.args[i]) 66 collation, coercibility = sql.ResolveCoercibility(collation, coercibility, nextCollation, nextCoercibility) 67 } 68 return collation, coercibility 69 } 70 71 // IsNullable implements the sql.Expression interface. 72 func (r *RegexpReplace) IsNullable() bool { return true } 73 74 // Children implements the sql.Expression interface. 75 func (r *RegexpReplace) Children() []sql.Expression { 76 return r.args 77 } 78 79 // Resolved implements the sql.Expression interface. 80 func (r *RegexpReplace) Resolved() bool { 81 for _, arg := range r.args { 82 if !arg.Resolved() { 83 return false 84 } 85 } 86 return true 87 } 88 89 // WithChildren implements the sql.Expression interface. 90 func (r *RegexpReplace) WithChildren(children ...sql.Expression) (sql.Expression, error) { 91 if len(children) != len(r.args) { 92 return nil, sql.ErrInvalidChildrenNumber.New(r, len(children), len(r.args)) 93 } 94 return NewRegexpReplace(children...) 95 } 96 97 func (r *RegexpReplace) String() string { 98 var args []string 99 for _, e := range r.args { 100 args = append(args, e.String()) 101 } 102 return fmt.Sprintf("%s(%s)", r.FunctionName(), strings.Join(args, ",")) 103 } 104 105 // Eval implements the sql.Expression interface. 106 func (r *RegexpReplace) Eval(ctx *sql.Context, row sql.Row) (val interface{}, err error) { 107 // Evaluate string value 108 str, err := r.args[0].Eval(ctx, row) 109 if err != nil { 110 return nil, err 111 } 112 if str == nil { 113 return nil, nil 114 } 115 str, _, err = types.LongText.Convert(str) 116 if err != nil { 117 return nil, err 118 } 119 120 // Convert to string 121 _str := str.(string) 122 123 // Handle flags 124 var flags sql.Expression = nil 125 if len(r.args) == 6 { 126 flags = r.args[5] 127 } 128 129 // Create regex, should handle null pattern and null flags 130 re, compileErr := compileRegex(ctx, r.args[1], r.args[0], flags, r.FunctionName(), row) 131 if compileErr != nil { 132 return nil, compileErr 133 } 134 if re == nil { 135 return nil, nil 136 } 137 defer func() { 138 if nErr := re.Close(); err == nil { 139 err = nErr 140 } 141 }() 142 if err = re.SetMatchString(ctx, _str); err != nil { 143 return nil, err 144 } 145 146 // Evaluate ReplaceStr 147 replaceStr, err := r.args[2].Eval(ctx, row) 148 if err != nil { 149 return nil, err 150 } 151 if replaceStr == nil { 152 return nil, nil 153 } 154 replaceStr, _, err = types.LongText.Convert(replaceStr) 155 if err != nil { 156 return nil, err 157 } 158 159 // Convert to string 160 _replaceStr := replaceStr.(string) 161 162 // Do nothing if str is empty 163 if len(_str) == 0 { 164 return _str, nil 165 } 166 167 // Default position is 1 168 _pos := 1 169 170 // Check if position argument was provided 171 if len(r.args) >= 4 { 172 // Evaluate position argument 173 pos, err := r.args[3].Eval(ctx, row) 174 if err != nil { 175 return nil, err 176 } 177 if pos == nil { 178 return nil, nil 179 } 180 181 // Convert to int32 182 pos, _, err = types.Int32.Convert(pos) 183 if err != nil { 184 return nil, err 185 } 186 // Convert to int 187 _pos = int(pos.(int32)) 188 } 189 190 // Non-positive position throws incorrect parameter 191 if _pos <= 0 { 192 return nil, sql.ErrInvalidArgumentDetails.New(r.FunctionName(), fmt.Sprintf("%d", _pos)) 193 } 194 195 // Handle out of bounds 196 if _pos > len(_str) { 197 return nil, errors.NewKind("Index out of bounds for regular expression search.").New() 198 } 199 200 // Default occurrence is 0 (replace all occurrences) 201 _occ := 0 202 203 // Check if Occurrence argument was provided 204 if len(r.args) >= 5 { 205 occ, err := r.args[4].Eval(ctx, row) 206 if err != nil { 207 return nil, err 208 } 209 if occ == nil { 210 return nil, nil 211 } 212 213 // Convert occurrence to int32 214 occ, _, err = types.Int32.Convert(occ) 215 if err != nil { 216 return nil, err 217 } 218 219 // Convert to int 220 _occ = int(occ.(int32)) 221 } 222 223 return re.Replace(ctx, _replaceStr, _pos, _occ) 224 }