github.com/minio/minio@v0.0.0-20240328213742-3f72439b8a27/internal/s3select/sql/statement.go (about)

     1  // Copyright (c) 2015-2021 MinIO, Inc.
     2  //
     3  // This file is part of MinIO Object Storage stack
     4  //
     5  // This program is free software: you can redistribute it and/or modify
     6  // it under the terms of the GNU Affero General Public License as published by
     7  // the Free Software Foundation, either version 3 of the License, or
     8  // (at your option) any later version.
     9  //
    10  // This program is distributed in the hope that it will be useful
    11  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    12  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    13  // GNU Affero General Public License for more details.
    14  //
    15  // You should have received a copy of the GNU Affero General Public License
    16  // along with this program.  If not, see <http://www.gnu.org/licenses/>.
    17  
    18  package sql
    19  
    20  import (
    21  	"errors"
    22  	"fmt"
    23  	"strings"
    24  
    25  	"github.com/bcicen/jstream"
    26  	"github.com/minio/simdjson-go"
    27  )
    28  
    29  var errBadLimitSpecified = errors.New("Limit value must be a positive integer")
    30  
    31  const (
    32  	baseTableName = "s3object"
    33  )
    34  
    35  // SelectStatement is the top level parsed and analyzed structure
    36  type SelectStatement struct {
    37  	selectAST *Select
    38  
    39  	// Analysis result of the statement
    40  	selectQProp qProp
    41  
    42  	// Result of parsing the limit clause if one is present
    43  	// (otherwise -1)
    44  	limitValue int64
    45  
    46  	// Count of rows that have been output.
    47  	outputCount int64
    48  
    49  	// Table alias
    50  	tableAlias string
    51  }
    52  
    53  // ParseSelectStatement - parses a select query from the given string
    54  // and analyzes it.
    55  func ParseSelectStatement(s string) (stmt SelectStatement, err error) {
    56  	var selectAST Select
    57  	err = SQLParser.ParseString(s, &selectAST)
    58  	if err != nil {
    59  		err = errQueryParseFailure(err)
    60  		return
    61  	}
    62  
    63  	// Check if select is "SELECT s.* from S3Object s"
    64  	if !selectAST.Expression.All &&
    65  		len(selectAST.Expression.Expressions) == 1 &&
    66  		len(selectAST.Expression.Expressions[0].Expression.And) == 1 &&
    67  		len(selectAST.Expression.Expressions[0].Expression.And[0].Condition) == 1 &&
    68  		selectAST.Expression.Expressions[0].Expression.And[0].Condition[0].Operand != nil &&
    69  		selectAST.Expression.Expressions[0].Expression.And[0].Condition[0].Operand.Operand.Left != nil &&
    70  		selectAST.Expression.Expressions[0].Expression.And[0].Condition[0].Operand.Operand.Left.Left != nil &&
    71  		selectAST.Expression.Expressions[0].Expression.And[0].Condition[0].Operand.Operand.Left.Left.Primary != nil &&
    72  		selectAST.Expression.Expressions[0].Expression.And[0].Condition[0].Operand.Operand.Left.Left.Primary.JPathExpr != nil {
    73  		if selectAST.Expression.Expressions[0].Expression.And[0].Condition[0].Operand.Operand.Left.Left.Primary.JPathExpr.String() == selectAST.From.As+".*" {
    74  			selectAST.Expression.All = true
    75  		}
    76  	}
    77  	stmt.selectAST = &selectAST
    78  
    79  	// Check the parsed limit value
    80  	stmt.limitValue, err = parseLimit(selectAST.Limit)
    81  	if err != nil {
    82  		err = errQueryAnalysisFailure(err)
    83  		return
    84  	}
    85  
    86  	// Analyze where clause
    87  	if selectAST.Where != nil {
    88  		whereQProp := selectAST.Where.analyze(&selectAST)
    89  		if whereQProp.err != nil {
    90  			err = errQueryAnalysisFailure(fmt.Errorf("Where clause error: %w", whereQProp.err))
    91  			return
    92  		}
    93  
    94  		if whereQProp.isAggregation {
    95  			err = errQueryAnalysisFailure(errors.New("WHERE clause cannot have an aggregation"))
    96  			return
    97  		}
    98  	}
    99  
   100  	// Validate table name
   101  	err = validateTableName(selectAST.From)
   102  	if err != nil {
   103  		return
   104  	}
   105  
   106  	// Analyze main select expression
   107  	stmt.selectQProp = selectAST.Expression.analyze(&selectAST)
   108  	err = stmt.selectQProp.err
   109  	if err != nil {
   110  		err = errQueryAnalysisFailure(err)
   111  	}
   112  
   113  	// Set table alias
   114  	stmt.tableAlias = selectAST.From.As
   115  	// Remove quotes from column aliases
   116  	if selectAST.Expression != nil {
   117  		for _, exp := range selectAST.Expression.Expressions {
   118  			if strings.HasSuffix(exp.As, "'") && strings.HasPrefix(exp.As, "'") && len(exp.As) >= 2 {
   119  				exp.As = exp.As[1 : len(exp.As)-1]
   120  			}
   121  		}
   122  	}
   123  	return
   124  }
   125  
   126  func validateTableName(from *TableExpression) error {
   127  	if !strings.EqualFold(from.Table.BaseKey.String(), baseTableName) {
   128  		return errBadTableName(errors.New("table name must be `s3object`"))
   129  	}
   130  
   131  	if len(from.Table.PathExpr) > 0 {
   132  		if !from.Table.PathExpr[0].ArrayWildcard {
   133  			return errBadTableName(errors.New("keypath table name is invalid - please check the service documentation"))
   134  		}
   135  	}
   136  	return nil
   137  }
   138  
   139  func parseLimit(v *LitValue) (int64, error) {
   140  	switch {
   141  	case v == nil:
   142  		return -1, nil
   143  	case v.Int == nil:
   144  		return -1, errBadLimitSpecified
   145  	default:
   146  		r := int64(*v.Int)
   147  		if r < 0 {
   148  			return -1, errBadLimitSpecified
   149  		}
   150  		return r, nil
   151  	}
   152  }
   153  
   154  // EvalFrom evaluates the From clause on the input record. It only
   155  // applies to JSON input data format (currently).
   156  func (e *SelectStatement) EvalFrom(format string, input Record) ([]*Record, error) {
   157  	if !e.selectAST.From.HasKeypath() {
   158  		return []*Record{&input}, nil
   159  	}
   160  	_, rawVal := input.Raw()
   161  
   162  	if format != "json" {
   163  		return nil, errDataSource(errors.New("path not supported"))
   164  	}
   165  	switch rec := rawVal.(type) {
   166  	case jstream.KVS:
   167  		txedRec, _, err := jsonpathEval(e.selectAST.From.Table.PathExpr[1:], rec)
   168  		if err != nil {
   169  			return nil, err
   170  		}
   171  
   172  		var kvs jstream.KVS
   173  		switch v := txedRec.(type) {
   174  		case jstream.KVS:
   175  			kvs = v
   176  
   177  		case []interface{}:
   178  			recs := make([]*Record, len(v))
   179  			for i, val := range v {
   180  				tmpRec := input.Clone(nil)
   181  				if err = tmpRec.Replace(val); err != nil {
   182  					return nil, err
   183  				}
   184  				recs[i] = &tmpRec
   185  			}
   186  			return recs, nil
   187  
   188  		default:
   189  			kvs = jstream.KVS{jstream.KV{Key: "_1", Value: v}}
   190  		}
   191  
   192  		if err = input.Replace(kvs); err != nil {
   193  			return nil, err
   194  		}
   195  
   196  		return []*Record{&input}, nil
   197  	case simdjson.Object:
   198  		txedRec, _, err := jsonpathEval(e.selectAST.From.Table.PathExpr[1:], rec)
   199  		if err != nil {
   200  			return nil, err
   201  		}
   202  
   203  		switch v := txedRec.(type) {
   204  		case simdjson.Object:
   205  			err := input.Replace(v)
   206  			if err != nil {
   207  				return nil, err
   208  			}
   209  
   210  		case []interface{}:
   211  			recs := make([]*Record, len(v))
   212  			for i, val := range v {
   213  				tmpRec := input.Clone(nil)
   214  				if err = tmpRec.Replace(val); err != nil {
   215  					return nil, err
   216  				}
   217  				recs[i] = &tmpRec
   218  			}
   219  			return recs, nil
   220  
   221  		default:
   222  			input.Reset()
   223  			input, err = input.Set("_1", &Value{value: v})
   224  			if err != nil {
   225  				return nil, err
   226  			}
   227  		}
   228  		return []*Record{&input}, nil
   229  	}
   230  	return nil, errDataSource(errors.New("unexpected non JSON input"))
   231  }
   232  
   233  // IsAggregated returns if the statement involves SQL aggregation
   234  func (e *SelectStatement) IsAggregated() bool {
   235  	return e.selectQProp.isAggregation
   236  }
   237  
   238  // AggregateResult - returns the aggregated result after all input
   239  // records have been processed. Applies only to aggregation queries.
   240  func (e *SelectStatement) AggregateResult(output Record) error {
   241  	for i, expr := range e.selectAST.Expression.Expressions {
   242  		v, err := expr.evalNode(nil, e.tableAlias)
   243  		if err != nil {
   244  			return err
   245  		}
   246  		if expr.As != "" {
   247  			output, err = output.Set(expr.As, v)
   248  		} else {
   249  			output, err = output.Set(fmt.Sprintf("_%d", i+1), v)
   250  		}
   251  		if err != nil {
   252  			return err
   253  		}
   254  	}
   255  	return nil
   256  }
   257  
   258  func (e *SelectStatement) isPassingWhereClause(input Record) (bool, error) {
   259  	if e.selectAST.Where == nil {
   260  		return true, nil
   261  	}
   262  	value, err := e.selectAST.Where.evalNode(input, e.tableAlias)
   263  	if err != nil {
   264  		return false, err
   265  	}
   266  
   267  	b, ok := value.ToBool()
   268  	if !ok {
   269  		err = fmt.Errorf("WHERE expression did not return bool")
   270  		return false, err
   271  	}
   272  
   273  	return b, nil
   274  }
   275  
   276  // AggregateRow - aggregates the input record. Applies only to
   277  // aggregation queries.
   278  func (e *SelectStatement) AggregateRow(input Record) error {
   279  	ok, err := e.isPassingWhereClause(input)
   280  	if err != nil {
   281  		return err
   282  	}
   283  	if !ok {
   284  		return nil
   285  	}
   286  
   287  	for _, expr := range e.selectAST.Expression.Expressions {
   288  		err := expr.aggregateRow(input, e.tableAlias)
   289  		if err != nil {
   290  			return err
   291  		}
   292  	}
   293  	return nil
   294  }
   295  
   296  // Eval - evaluates the Select statement for the given record. It
   297  // applies only to non-aggregation queries.
   298  // The function returns whether the statement passed the WHERE clause and should be outputted.
   299  func (e *SelectStatement) Eval(input, output Record) (Record, error) {
   300  	ok, err := e.isPassingWhereClause(input)
   301  	if err != nil || !ok {
   302  		// Either error or row did not pass where clause
   303  		return nil, err
   304  	}
   305  
   306  	if e.selectAST.Expression.All {
   307  		// Return the input record for `SELECT * FROM
   308  		// .. WHERE ..`
   309  
   310  		// Update count of records output.
   311  		e.outputCount++
   312  
   313  		return input.Clone(output), nil
   314  	}
   315  
   316  	for i, expr := range e.selectAST.Expression.Expressions {
   317  		v, err := expr.evalNode(input, e.tableAlias)
   318  		if err != nil {
   319  			return nil, err
   320  		}
   321  
   322  		// Pick output column names
   323  		if expr.As != "" {
   324  			output, err = output.Set(expr.As, v)
   325  		} else if comp, ok := getLastKeypathComponent(expr.Expression); ok {
   326  			output, err = output.Set(comp, v)
   327  		} else {
   328  			output, err = output.Set(fmt.Sprintf("_%d", i+1), v)
   329  		}
   330  		if err != nil {
   331  			return nil, err
   332  		}
   333  	}
   334  
   335  	// Update count of records output.
   336  	e.outputCount++
   337  
   338  	return output, nil
   339  }
   340  
   341  // LimitReached - returns true if the number of records output has
   342  // reached the value of the `LIMIT` clause.
   343  func (e *SelectStatement) LimitReached() bool {
   344  	if e.limitValue == -1 {
   345  		return false
   346  	}
   347  	return e.outputCount >= e.limitValue
   348  }