github.com/dolthub/go-mysql-server@v0.18.0/sql/expression/function/regexp_like.go (about)

     1  // Copyright 2021 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package function
    16  
    17  import (
    18  	"fmt"
    19  	"strings"
    20  	"sync"
    21  	"sync/atomic"
    22  
    23  	regex "github.com/dolthub/go-icu-regex"
    24  	"gopkg.in/src-d/go-errors.v1"
    25  
    26  	"github.com/dolthub/go-mysql-server/sql"
    27  	"github.com/dolthub/go-mysql-server/sql/expression"
    28  	"github.com/dolthub/go-mysql-server/sql/types"
    29  )
    30  
    31  // RegexpLike implements the REGEXP_LIKE function.
    32  // https://dev.mysql.com/doc/refman/8.0/en/regexp.html#function_regexp-like
    33  type RegexpLike struct {
    34  	Text    sql.Expression
    35  	Pattern sql.Expression
    36  	Flags   sql.Expression
    37  
    38  	cachedVal   atomic.Value
    39  	re          regex.Regex
    40  	compileOnce sync.Once
    41  	compileErr  error
    42  }
    43  
    44  var _ sql.FunctionExpression = (*RegexpLike)(nil)
    45  var _ sql.CollationCoercible = (*RegexpLike)(nil)
    46  var _ sql.Closer = (*RegexpLike)(nil)
    47  
    48  // NewRegexpLike creates a new RegexpLike expression.
    49  func NewRegexpLike(args ...sql.Expression) (sql.Expression, error) {
    50  	var r *RegexpLike
    51  	switch len(args) {
    52  	case 3:
    53  		r = &RegexpLike{
    54  			Text:    args[0],
    55  			Pattern: args[1],
    56  			Flags:   args[2],
    57  		}
    58  	case 2:
    59  		r = &RegexpLike{
    60  			Text:    args[0],
    61  			Pattern: args[1],
    62  		}
    63  	default:
    64  		return nil, sql.ErrInvalidArgumentNumber.New("regexp_like", "2 or 3", len(args))
    65  	}
    66  	return r, nil
    67  }
    68  
    69  // FunctionName implements sql.FunctionExpression
    70  func (r *RegexpLike) FunctionName() string {
    71  	return "regexp_like"
    72  }
    73  
    74  // Description implements sql.FunctionExpression
    75  func (r *RegexpLike) Description() string {
    76  	return "returns whether string matches regular expression."
    77  }
    78  
    79  // Type implements the sql.Expression interface.
    80  func (r *RegexpLike) Type() sql.Type { return types.Int8 }
    81  
    82  // CollationCoercibility implements the interface sql.CollationCoercible.
    83  func (r *RegexpLike) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
    84  	leftCollation, leftCoercibility := sql.GetCoercibility(ctx, r.Text)
    85  	rightCollation, rightCoercibility := sql.GetCoercibility(ctx, r.Pattern)
    86  	return sql.ResolveCoercibility(leftCollation, leftCoercibility, rightCollation, rightCoercibility)
    87  }
    88  
    89  // IsNullable implements the sql.Expression interface.
    90  func (r *RegexpLike) IsNullable() bool { return true }
    91  
    92  // Children implements the sql.Expression interface.
    93  func (r *RegexpLike) Children() []sql.Expression {
    94  	var result = []sql.Expression{r.Text, r.Pattern}
    95  	if r.Flags != nil {
    96  		result = append(result, r.Flags)
    97  	}
    98  	return result
    99  }
   100  
   101  // Resolved implements the sql.Expression interface.
   102  func (r *RegexpLike) Resolved() bool {
   103  	return r.Text.Resolved() && r.Pattern.Resolved() && (r.Flags == nil || r.Flags.Resolved())
   104  }
   105  
   106  // WithChildren implements the sql.Expression interface.
   107  func (r *RegexpLike) WithChildren(children ...sql.Expression) (sql.Expression, error) {
   108  	required := 2
   109  	if r.Flags != nil {
   110  		required = 3
   111  	}
   112  	if len(children) != required {
   113  		return nil, sql.ErrInvalidChildrenNumber.New(r, len(children), required)
   114  	}
   115  	return NewRegexpLike(children...)
   116  }
   117  
   118  func (r *RegexpLike) String() string {
   119  	var args []string
   120  	for _, e := range r.Children() {
   121  		args = append(args, e.String())
   122  	}
   123  	return fmt.Sprintf("%s(%s)", r.FunctionName(), strings.Join(args, ","))
   124  }
   125  
   126  func (r *RegexpLike) compile(ctx *sql.Context) {
   127  	r.compileOnce.Do(func() {
   128  		r.re, r.compileErr = compileRegex(ctx, r.Pattern, r.Text, r.Flags, r.FunctionName(), nil)
   129  	})
   130  }
   131  
   132  // Eval implements the sql.Expression interface.
   133  func (r *RegexpLike) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
   134  	span, ctx := ctx.Span("function.RegexpLike")
   135  	defer span.End()
   136  
   137  	cached := r.cachedVal.Load()
   138  	if cached != nil {
   139  		return cached, nil
   140  	}
   141  
   142  	r.compile(ctx)
   143  	if r.compileErr != nil {
   144  		return nil, r.compileErr
   145  	}
   146  	if r.re == nil {
   147  		return nil, nil
   148  	}
   149  
   150  	text, err := r.Text.Eval(ctx, row)
   151  	if err != nil {
   152  		return nil, err
   153  	}
   154  	if text == nil {
   155  		return nil, nil
   156  	}
   157  	text, _, err = types.LongText.Convert(text)
   158  	if err != nil {
   159  		return nil, err
   160  	}
   161  
   162  	err = r.re.SetMatchString(ctx, text.(string))
   163  	if err != nil {
   164  		return nil, err
   165  	}
   166  	ok, err := r.re.Matches(ctx, 0, 0)
   167  	if err != nil {
   168  		return nil, err
   169  	}
   170  	var outVal int8
   171  	if ok {
   172  		outVal = int8(1)
   173  	} else {
   174  		outVal = int8(0)
   175  	}
   176  
   177  	if canBeCached(r.Text) {
   178  		r.cachedVal.Store(outVal)
   179  	}
   180  	return outVal, nil
   181  }
   182  
   183  // Close implements the sql.Closer interface.
   184  func (r *RegexpLike) Close(ctx *sql.Context) error {
   185  	if r.re != nil {
   186  		return r.re.Close()
   187  	}
   188  	return nil
   189  }
   190  
   191  func compileRegex(ctx *sql.Context, pattern, text, flags sql.Expression, funcName string, row sql.Row) (regex.Regex, error) {
   192  	patternVal, err := pattern.Eval(ctx, row)
   193  	if err != nil {
   194  		return nil, err
   195  	}
   196  	if patternVal == nil {
   197  		return nil, nil
   198  	}
   199  	patternVal, _, err = types.LongText.Convert(patternVal)
   200  	if err != nil {
   201  		return nil, err
   202  	}
   203  
   204  	// Empty regex, throw illegal argument
   205  	if len(patternVal.(string)) == 0 {
   206  		return nil, errors.NewKind("Illegal argument to regular expression.").New()
   207  	}
   208  
   209  	// It appears that MySQL ONLY uses the collation to determine case-sensitivity and character set. We don't need to
   210  	// worry about the character set since we convert all strings to UTF-8 for internal consistency. At the time of
   211  	// writing this comment, all case-insensitive collations end with "_ci", so we can just check for that.
   212  	leftCollation, leftCoercibility := sql.GetCoercibility(ctx, text)
   213  	rightCollation, rightCoercibility := sql.GetCoercibility(ctx, pattern)
   214  	resolvedCollation, _ := sql.ResolveCoercibility(leftCollation, leftCoercibility, rightCollation, rightCoercibility)
   215  	flagsStr := ""
   216  	if strings.HasSuffix(resolvedCollation.String(), "_ci") {
   217  		flagsStr = "i"
   218  	}
   219  
   220  	if flags != nil {
   221  		f, err := flags.Eval(ctx, row)
   222  		if err != nil {
   223  			return nil, err
   224  		}
   225  		if f == nil {
   226  			return nil, nil
   227  		}
   228  		f, _, err = types.LongText.Convert(f)
   229  		if err != nil {
   230  			return nil, err
   231  		}
   232  
   233  		flagsStr = f.(string)
   234  		flagsStr, err = consolidateRegexpFlags(flagsStr, funcName)
   235  		if err != nil {
   236  			return nil, err
   237  		}
   238  	}
   239  	regexFlags := regex.RegexFlags_None
   240  	for _, flag := range flagsStr {
   241  		// The 'c' flag is the default behavior, so we don't need to set anything in that case.
   242  		// Any illegal flags will have been caught by consolidateRegexpFlags.
   243  		switch flag {
   244  		case 'i':
   245  			regexFlags |= regex.RegexFlags_Case_Insensitive
   246  		case 'm':
   247  			regexFlags |= regex.RegexFlags_Multiline
   248  		case 'n':
   249  			regexFlags |= regex.RegexFlags_Dot_All
   250  		case 'u':
   251  			regexFlags |= regex.RegexFlags_Unix_Lines
   252  		}
   253  	}
   254  
   255  	bufferSize := uint32(524288)
   256  	if _, val, ok := sql.SystemVariables.GetGlobal("regexp_buffer_size"); ok {
   257  		bufferSize = uint32(val.(uint64))
   258  	} else {
   259  		ctx.Warn(1193, `System variable for regular expressions "regexp_buffer_size" is missing`)
   260  	}
   261  	re := regex.CreateRegex(bufferSize)
   262  	if err = re.SetRegexString(ctx, patternVal.(string), regexFlags); err != nil {
   263  		_ = re.Close()
   264  		return nil, err
   265  	}
   266  	return re, nil
   267  }
   268  
   269  // consolidateRegexpFlags consolidates regexp flags by removing duplicates, resolving order of conflicting flags, and
   270  // verifying that all flags are valid.
   271  func consolidateRegexpFlags(flags, funcName string) (string, error) {
   272  	flagSet := make(map[string]struct{})
   273  	for _, flag := range flags {
   274  		switch flag {
   275  		case 'c':
   276  			delete(flagSet, "i")
   277  		case 'i':
   278  			flagSet["i"] = struct{}{}
   279  		case 'm':
   280  			flagSet["m"] = struct{}{}
   281  		case 'n':
   282  			flagSet["n"] = struct{}{}
   283  		case 'u':
   284  			flagSet["u"] = struct{}{}
   285  		default:
   286  			return "", sql.ErrInvalidArgument.New(funcName)
   287  		}
   288  	}
   289  	flags = ""
   290  	for flag := range flagSet {
   291  		flags += flag
   292  	}
   293  	return flags, nil
   294  }
   295  
   296  func canBeCached(e sql.Expression) bool {
   297  	hasCols := false
   298  	sql.Inspect(e, func(e sql.Expression) bool {
   299  		switch e.(type) {
   300  		case *expression.GetField, *expression.UserVar, *expression.SystemVar, *expression.ProcedureParam:
   301  			hasCols = true
   302  		}
   303  		return true
   304  	})
   305  	return !hasCols
   306  }