github.com/dolthub/go-mysql-server@v0.18.0/sql/expression/function/regexp_like.go (about) 1 // Copyright 2021 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package function 16 17 import ( 18 "fmt" 19 "strings" 20 "sync" 21 "sync/atomic" 22 23 regex "github.com/dolthub/go-icu-regex" 24 "gopkg.in/src-d/go-errors.v1" 25 26 "github.com/dolthub/go-mysql-server/sql" 27 "github.com/dolthub/go-mysql-server/sql/expression" 28 "github.com/dolthub/go-mysql-server/sql/types" 29 ) 30 31 // RegexpLike implements the REGEXP_LIKE function. 32 // https://dev.mysql.com/doc/refman/8.0/en/regexp.html#function_regexp-like 33 type RegexpLike struct { 34 Text sql.Expression 35 Pattern sql.Expression 36 Flags sql.Expression 37 38 cachedVal atomic.Value 39 re regex.Regex 40 compileOnce sync.Once 41 compileErr error 42 } 43 44 var _ sql.FunctionExpression = (*RegexpLike)(nil) 45 var _ sql.CollationCoercible = (*RegexpLike)(nil) 46 var _ sql.Closer = (*RegexpLike)(nil) 47 48 // NewRegexpLike creates a new RegexpLike expression. 49 func NewRegexpLike(args ...sql.Expression) (sql.Expression, error) { 50 var r *RegexpLike 51 switch len(args) { 52 case 3: 53 r = &RegexpLike{ 54 Text: args[0], 55 Pattern: args[1], 56 Flags: args[2], 57 } 58 case 2: 59 r = &RegexpLike{ 60 Text: args[0], 61 Pattern: args[1], 62 } 63 default: 64 return nil, sql.ErrInvalidArgumentNumber.New("regexp_like", "2 or 3", len(args)) 65 } 66 return r, nil 67 } 68 69 // FunctionName implements sql.FunctionExpression 70 func (r *RegexpLike) FunctionName() string { 71 return "regexp_like" 72 } 73 74 // Description implements sql.FunctionExpression 75 func (r *RegexpLike) Description() string { 76 return "returns whether string matches regular expression." 77 } 78 79 // Type implements the sql.Expression interface. 80 func (r *RegexpLike) Type() sql.Type { return types.Int8 } 81 82 // CollationCoercibility implements the interface sql.CollationCoercible. 83 func (r *RegexpLike) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { 84 leftCollation, leftCoercibility := sql.GetCoercibility(ctx, r.Text) 85 rightCollation, rightCoercibility := sql.GetCoercibility(ctx, r.Pattern) 86 return sql.ResolveCoercibility(leftCollation, leftCoercibility, rightCollation, rightCoercibility) 87 } 88 89 // IsNullable implements the sql.Expression interface. 90 func (r *RegexpLike) IsNullable() bool { return true } 91 92 // Children implements the sql.Expression interface. 93 func (r *RegexpLike) Children() []sql.Expression { 94 var result = []sql.Expression{r.Text, r.Pattern} 95 if r.Flags != nil { 96 result = append(result, r.Flags) 97 } 98 return result 99 } 100 101 // Resolved implements the sql.Expression interface. 102 func (r *RegexpLike) Resolved() bool { 103 return r.Text.Resolved() && r.Pattern.Resolved() && (r.Flags == nil || r.Flags.Resolved()) 104 } 105 106 // WithChildren implements the sql.Expression interface. 107 func (r *RegexpLike) WithChildren(children ...sql.Expression) (sql.Expression, error) { 108 required := 2 109 if r.Flags != nil { 110 required = 3 111 } 112 if len(children) != required { 113 return nil, sql.ErrInvalidChildrenNumber.New(r, len(children), required) 114 } 115 return NewRegexpLike(children...) 116 } 117 118 func (r *RegexpLike) String() string { 119 var args []string 120 for _, e := range r.Children() { 121 args = append(args, e.String()) 122 } 123 return fmt.Sprintf("%s(%s)", r.FunctionName(), strings.Join(args, ",")) 124 } 125 126 func (r *RegexpLike) compile(ctx *sql.Context) { 127 r.compileOnce.Do(func() { 128 r.re, r.compileErr = compileRegex(ctx, r.Pattern, r.Text, r.Flags, r.FunctionName(), nil) 129 }) 130 } 131 132 // Eval implements the sql.Expression interface. 133 func (r *RegexpLike) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { 134 span, ctx := ctx.Span("function.RegexpLike") 135 defer span.End() 136 137 cached := r.cachedVal.Load() 138 if cached != nil { 139 return cached, nil 140 } 141 142 r.compile(ctx) 143 if r.compileErr != nil { 144 return nil, r.compileErr 145 } 146 if r.re == nil { 147 return nil, nil 148 } 149 150 text, err := r.Text.Eval(ctx, row) 151 if err != nil { 152 return nil, err 153 } 154 if text == nil { 155 return nil, nil 156 } 157 text, _, err = types.LongText.Convert(text) 158 if err != nil { 159 return nil, err 160 } 161 162 err = r.re.SetMatchString(ctx, text.(string)) 163 if err != nil { 164 return nil, err 165 } 166 ok, err := r.re.Matches(ctx, 0, 0) 167 if err != nil { 168 return nil, err 169 } 170 var outVal int8 171 if ok { 172 outVal = int8(1) 173 } else { 174 outVal = int8(0) 175 } 176 177 if canBeCached(r.Text) { 178 r.cachedVal.Store(outVal) 179 } 180 return outVal, nil 181 } 182 183 // Close implements the sql.Closer interface. 184 func (r *RegexpLike) Close(ctx *sql.Context) error { 185 if r.re != nil { 186 return r.re.Close() 187 } 188 return nil 189 } 190 191 func compileRegex(ctx *sql.Context, pattern, text, flags sql.Expression, funcName string, row sql.Row) (regex.Regex, error) { 192 patternVal, err := pattern.Eval(ctx, row) 193 if err != nil { 194 return nil, err 195 } 196 if patternVal == nil { 197 return nil, nil 198 } 199 patternVal, _, err = types.LongText.Convert(patternVal) 200 if err != nil { 201 return nil, err 202 } 203 204 // Empty regex, throw illegal argument 205 if len(patternVal.(string)) == 0 { 206 return nil, errors.NewKind("Illegal argument to regular expression.").New() 207 } 208 209 // It appears that MySQL ONLY uses the collation to determine case-sensitivity and character set. We don't need to 210 // worry about the character set since we convert all strings to UTF-8 for internal consistency. At the time of 211 // writing this comment, all case-insensitive collations end with "_ci", so we can just check for that. 212 leftCollation, leftCoercibility := sql.GetCoercibility(ctx, text) 213 rightCollation, rightCoercibility := sql.GetCoercibility(ctx, pattern) 214 resolvedCollation, _ := sql.ResolveCoercibility(leftCollation, leftCoercibility, rightCollation, rightCoercibility) 215 flagsStr := "" 216 if strings.HasSuffix(resolvedCollation.String(), "_ci") { 217 flagsStr = "i" 218 } 219 220 if flags != nil { 221 f, err := flags.Eval(ctx, row) 222 if err != nil { 223 return nil, err 224 } 225 if f == nil { 226 return nil, nil 227 } 228 f, _, err = types.LongText.Convert(f) 229 if err != nil { 230 return nil, err 231 } 232 233 flagsStr = f.(string) 234 flagsStr, err = consolidateRegexpFlags(flagsStr, funcName) 235 if err != nil { 236 return nil, err 237 } 238 } 239 regexFlags := regex.RegexFlags_None 240 for _, flag := range flagsStr { 241 // The 'c' flag is the default behavior, so we don't need to set anything in that case. 242 // Any illegal flags will have been caught by consolidateRegexpFlags. 243 switch flag { 244 case 'i': 245 regexFlags |= regex.RegexFlags_Case_Insensitive 246 case 'm': 247 regexFlags |= regex.RegexFlags_Multiline 248 case 'n': 249 regexFlags |= regex.RegexFlags_Dot_All 250 case 'u': 251 regexFlags |= regex.RegexFlags_Unix_Lines 252 } 253 } 254 255 bufferSize := uint32(524288) 256 if _, val, ok := sql.SystemVariables.GetGlobal("regexp_buffer_size"); ok { 257 bufferSize = uint32(val.(uint64)) 258 } else { 259 ctx.Warn(1193, `System variable for regular expressions "regexp_buffer_size" is missing`) 260 } 261 re := regex.CreateRegex(bufferSize) 262 if err = re.SetRegexString(ctx, patternVal.(string), regexFlags); err != nil { 263 _ = re.Close() 264 return nil, err 265 } 266 return re, nil 267 } 268 269 // consolidateRegexpFlags consolidates regexp flags by removing duplicates, resolving order of conflicting flags, and 270 // verifying that all flags are valid. 271 func consolidateRegexpFlags(flags, funcName string) (string, error) { 272 flagSet := make(map[string]struct{}) 273 for _, flag := range flags { 274 switch flag { 275 case 'c': 276 delete(flagSet, "i") 277 case 'i': 278 flagSet["i"] = struct{}{} 279 case 'm': 280 flagSet["m"] = struct{}{} 281 case 'n': 282 flagSet["n"] = struct{}{} 283 case 'u': 284 flagSet["u"] = struct{}{} 285 default: 286 return "", sql.ErrInvalidArgument.New(funcName) 287 } 288 } 289 flags = "" 290 for flag := range flagSet { 291 flags += flag 292 } 293 return flags, nil 294 } 295 296 func canBeCached(e sql.Expression) bool { 297 hasCols := false 298 sql.Inspect(e, func(e sql.Expression) bool { 299 switch e.(type) { 300 case *expression.GetField, *expression.UserVar, *expression.SystemVar, *expression.ProcedureParam: 301 hasCols = true 302 } 303 return true 304 }) 305 return !hasCols 306 }