storj.io/minio@v0.0.0-20230509071714-0cbc90f649b1/pkg/s3select/sql/statement.go (about)

     1  /*
     2   * MinIO Cloud Storage, (C) 2019 MinIO, Inc.
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package sql
    18  
    19  import (
    20  	"errors"
    21  	"fmt"
    22  	"strings"
    23  
    24  	"github.com/bcicen/jstream"
    25  	"github.com/minio/simdjson-go"
    26  )
    27  
    28  var (
    29  	errBadLimitSpecified = errors.New("Limit value must be a positive integer")
    30  )
    31  
    32  const (
    33  	baseTableName = "s3object"
    34  )
    35  
    36  // SelectStatement is the top level parsed and analyzed structure
    37  type SelectStatement struct {
    38  	selectAST *Select
    39  
    40  	// Analysis result of the statement
    41  	selectQProp qProp
    42  
    43  	// Result of parsing the limit clause if one is present
    44  	// (otherwise -1)
    45  	limitValue int64
    46  
    47  	// Count of rows that have been output.
    48  	outputCount int64
    49  
    50  	// Table alias
    51  	tableAlias string
    52  }
    53  
    54  // ParseSelectStatement - parses a select query from the given string
    55  // and analyzes it.
    56  func ParseSelectStatement(s string) (stmt SelectStatement, err error) {
    57  	var selectAST Select
    58  	err = SQLParser.ParseString(s, &selectAST)
    59  	if err != nil {
    60  		err = errQueryParseFailure(err)
    61  		return
    62  	}
    63  
    64  	// Check if select is "SELECT s.* from S3Object s"
    65  	if !selectAST.Expression.All &&
    66  		len(selectAST.Expression.Expressions) == 1 &&
    67  		len(selectAST.Expression.Expressions[0].Expression.And) == 1 &&
    68  		len(selectAST.Expression.Expressions[0].Expression.And[0].Condition) == 1 &&
    69  		selectAST.Expression.Expressions[0].Expression.And[0].Condition[0].Operand != nil &&
    70  		selectAST.Expression.Expressions[0].Expression.And[0].Condition[0].Operand.Operand.Left != nil &&
    71  		selectAST.Expression.Expressions[0].Expression.And[0].Condition[0].Operand.Operand.Left.Left != nil &&
    72  		selectAST.Expression.Expressions[0].Expression.And[0].Condition[0].Operand.Operand.Left.Left.Primary != nil &&
    73  		selectAST.Expression.Expressions[0].Expression.And[0].Condition[0].Operand.Operand.Left.Left.Primary.JPathExpr != nil {
    74  		if selectAST.Expression.Expressions[0].Expression.And[0].Condition[0].Operand.Operand.Left.Left.Primary.JPathExpr.String() == selectAST.From.As+".*" {
    75  			selectAST.Expression.All = true
    76  		}
    77  	}
    78  	stmt.selectAST = &selectAST
    79  
    80  	// Check the parsed limit value
    81  	stmt.limitValue, err = parseLimit(selectAST.Limit)
    82  	if err != nil {
    83  		err = errQueryAnalysisFailure(err)
    84  		return
    85  	}
    86  
    87  	// Analyze where clause
    88  	if selectAST.Where != nil {
    89  		whereQProp := selectAST.Where.analyze(&selectAST)
    90  		if whereQProp.err != nil {
    91  			err = errQueryAnalysisFailure(fmt.Errorf("Where clause error: %w", whereQProp.err))
    92  			return
    93  		}
    94  
    95  		if whereQProp.isAggregation {
    96  			err = errQueryAnalysisFailure(errors.New("WHERE clause cannot have an aggregation"))
    97  			return
    98  		}
    99  	}
   100  
   101  	// Validate table name
   102  	err = validateTableName(selectAST.From)
   103  	if err != nil {
   104  		return
   105  	}
   106  
   107  	// Analyze main select expression
   108  	stmt.selectQProp = selectAST.Expression.analyze(&selectAST)
   109  	err = stmt.selectQProp.err
   110  	if err != nil {
   111  		err = errQueryAnalysisFailure(err)
   112  	}
   113  
   114  	// Set table alias
   115  	stmt.tableAlias = selectAST.From.As
   116  	return
   117  }
   118  
   119  func validateTableName(from *TableExpression) error {
   120  	if strings.ToLower(from.Table.BaseKey.String()) != baseTableName {
   121  		return errBadTableName(errors.New("table name must be `s3object`"))
   122  	}
   123  
   124  	if len(from.Table.PathExpr) > 0 {
   125  		if !from.Table.PathExpr[0].ArrayWildcard {
   126  			return errBadTableName(errors.New("keypath table name is invalid - please check the service documentation"))
   127  		}
   128  	}
   129  	return nil
   130  }
   131  
   132  func parseLimit(v *LitValue) (int64, error) {
   133  	switch {
   134  	case v == nil:
   135  		return -1, nil
   136  	case v.Int == nil:
   137  		return -1, errBadLimitSpecified
   138  	default:
   139  		r := int64(*v.Int)
   140  		if r < 0 {
   141  			return -1, errBadLimitSpecified
   142  		}
   143  		return r, nil
   144  	}
   145  }
   146  
   147  // EvalFrom evaluates the From clause on the input record. It only
   148  // applies to JSON input data format (currently).
   149  func (e *SelectStatement) EvalFrom(format string, input Record) ([]*Record, error) {
   150  	if !e.selectAST.From.HasKeypath() {
   151  		return []*Record{&input}, nil
   152  	}
   153  	_, rawVal := input.Raw()
   154  
   155  	if format != "json" {
   156  		return nil, errDataSource(errors.New("path not supported"))
   157  	}
   158  	switch rec := rawVal.(type) {
   159  	case jstream.KVS:
   160  		txedRec, _, err := jsonpathEval(e.selectAST.From.Table.PathExpr[1:], rec)
   161  		if err != nil {
   162  			return nil, err
   163  		}
   164  
   165  		var kvs jstream.KVS
   166  		switch v := txedRec.(type) {
   167  		case jstream.KVS:
   168  			kvs = v
   169  
   170  		case []interface{}:
   171  			recs := make([]*Record, len(v))
   172  			for i, val := range v {
   173  				tmpRec := input.Clone(nil)
   174  				if err = tmpRec.Replace(val); err != nil {
   175  					return nil, err
   176  				}
   177  				recs[i] = &tmpRec
   178  			}
   179  			return recs, nil
   180  
   181  		default:
   182  			kvs = jstream.KVS{jstream.KV{Key: "_1", Value: v}}
   183  		}
   184  
   185  		if err = input.Replace(kvs); err != nil {
   186  			return nil, err
   187  		}
   188  
   189  		return []*Record{&input}, nil
   190  	case simdjson.Object:
   191  		txedRec, _, err := jsonpathEval(e.selectAST.From.Table.PathExpr[1:], rec)
   192  		if err != nil {
   193  			return nil, err
   194  		}
   195  
   196  		switch v := txedRec.(type) {
   197  		case simdjson.Object:
   198  			err := input.Replace(v)
   199  			if err != nil {
   200  				return nil, err
   201  			}
   202  
   203  		case []interface{}:
   204  			recs := make([]*Record, len(v))
   205  			for i, val := range v {
   206  				tmpRec := input.Clone(nil)
   207  				if err = tmpRec.Replace(val); err != nil {
   208  					return nil, err
   209  				}
   210  				recs[i] = &tmpRec
   211  			}
   212  			return recs, nil
   213  
   214  		default:
   215  			input.Reset()
   216  			input, err = input.Set("_1", &Value{value: v})
   217  			if err != nil {
   218  				return nil, err
   219  			}
   220  		}
   221  		return []*Record{&input}, nil
   222  	}
   223  	return nil, errDataSource(errors.New("unexpected non JSON input"))
   224  }
   225  
   226  // IsAggregated returns if the statement involves SQL aggregation
   227  func (e *SelectStatement) IsAggregated() bool {
   228  	return e.selectQProp.isAggregation
   229  }
   230  
   231  // AggregateResult - returns the aggregated result after all input
   232  // records have been processed. Applies only to aggregation queries.
   233  func (e *SelectStatement) AggregateResult(output Record) error {
   234  	for i, expr := range e.selectAST.Expression.Expressions {
   235  		v, err := expr.evalNode(nil, e.tableAlias)
   236  		if err != nil {
   237  			return err
   238  		}
   239  		if expr.As != "" {
   240  			output, err = output.Set(expr.As, v)
   241  		} else {
   242  			output, err = output.Set(fmt.Sprintf("_%d", i+1), v)
   243  		}
   244  		if err != nil {
   245  			return err
   246  		}
   247  	}
   248  	return nil
   249  }
   250  
   251  func (e *SelectStatement) isPassingWhereClause(input Record) (bool, error) {
   252  	if e.selectAST.Where == nil {
   253  		return true, nil
   254  	}
   255  	value, err := e.selectAST.Where.evalNode(input, e.tableAlias)
   256  	if err != nil {
   257  		return false, err
   258  	}
   259  
   260  	b, ok := value.ToBool()
   261  	if !ok {
   262  		err = fmt.Errorf("WHERE expression did not return bool")
   263  		return false, err
   264  	}
   265  
   266  	return b, nil
   267  }
   268  
   269  // AggregateRow - aggregates the input record. Applies only to
   270  // aggregation queries.
   271  func (e *SelectStatement) AggregateRow(input Record) error {
   272  	ok, err := e.isPassingWhereClause(input)
   273  	if err != nil {
   274  		return err
   275  	}
   276  	if !ok {
   277  		return nil
   278  	}
   279  
   280  	for _, expr := range e.selectAST.Expression.Expressions {
   281  		err := expr.aggregateRow(input, e.tableAlias)
   282  		if err != nil {
   283  			return err
   284  		}
   285  	}
   286  	return nil
   287  }
   288  
   289  // Eval - evaluates the Select statement for the given record. It
   290  // applies only to non-aggregation queries.
   291  // The function returns whether the statement passed the WHERE clause and should be outputted.
   292  func (e *SelectStatement) Eval(input, output Record) (Record, error) {
   293  	ok, err := e.isPassingWhereClause(input)
   294  	if err != nil || !ok {
   295  		// Either error or row did not pass where clause
   296  		return nil, err
   297  	}
   298  
   299  	if e.selectAST.Expression.All {
   300  		// Return the input record for `SELECT * FROM
   301  		// .. WHERE ..`
   302  
   303  		// Update count of records output.
   304  		if e.limitValue > -1 {
   305  			e.outputCount++
   306  		}
   307  		return input.Clone(output), nil
   308  	}
   309  
   310  	for i, expr := range e.selectAST.Expression.Expressions {
   311  		v, err := expr.evalNode(input, e.tableAlias)
   312  		if err != nil {
   313  			return nil, err
   314  		}
   315  
   316  		// Pick output column names
   317  		if expr.As != "" {
   318  			output, err = output.Set(expr.As, v)
   319  		} else if comp, ok := getLastKeypathComponent(expr.Expression); ok {
   320  			output, err = output.Set(comp, v)
   321  		} else {
   322  			output, err = output.Set(fmt.Sprintf("_%d", i+1), v)
   323  		}
   324  		if err != nil {
   325  			return nil, err
   326  		}
   327  	}
   328  
   329  	// Update count of records output.
   330  	if e.limitValue > -1 {
   331  		e.outputCount++
   332  	}
   333  
   334  	return output, nil
   335  }
   336  
   337  // LimitReached - returns true if the number of records output has
   338  // reached the value of the `LIMIT` clause.
   339  func (e *SelectStatement) LimitReached() bool {
   340  	if e.limitValue == -1 {
   341  		return false
   342  	}
   343  	return e.outputCount >= e.limitValue
   344  }