vitess.io/vitess@v0.16.2/go/vt/vtgate/engine/limit.go (about)

     1  /*
     2  Copyright 2019 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package engine
    18  
    19  import (
    20  	"context"
    21  	"fmt"
    22  	"io"
    23  
    24  	"vitess.io/vitess/go/vt/vtgate/evalengine"
    25  
    26  	"vitess.io/vitess/go/sqltypes"
    27  
    28  	querypb "vitess.io/vitess/go/vt/proto/query"
    29  )
    30  
    31  var _ Primitive = (*Limit)(nil)
    32  
    33  // Limit is a primitive that performs the LIMIT operation.
    34  type Limit struct {
    35  	Count  evalengine.Expr
    36  	Offset evalengine.Expr
    37  	Input  Primitive
    38  }
    39  
    40  // RouteType returns a description of the query routing type used by the primitive
    41  func (l *Limit) RouteType() string {
    42  	return l.Input.RouteType()
    43  }
    44  
    45  // GetKeyspaceName specifies the Keyspace that this primitive routes to.
    46  func (l *Limit) GetKeyspaceName() string {
    47  	return l.Input.GetKeyspaceName()
    48  }
    49  
    50  // GetTableName specifies the table that this primitive routes to.
    51  func (l *Limit) GetTableName() string {
    52  	return l.Input.GetTableName()
    53  }
    54  
    55  // TryExecute satisfies the Primitive interface.
    56  func (l *Limit) TryExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) {
    57  	count, offset, err := l.getCountAndOffset(vcursor, bindVars)
    58  	if err != nil {
    59  		return nil, err
    60  	}
    61  	// When offset is present, we hijack the limit value so we can calculate
    62  	// the offset in memory from the result of the scatter query with count + offset.
    63  	bindVars["__upper_limit"] = sqltypes.Int64BindVariable(int64(count + offset))
    64  
    65  	result, err := vcursor.ExecutePrimitive(ctx, l.Input, bindVars, wantfields)
    66  	if err != nil {
    67  		return nil, err
    68  	}
    69  
    70  	// There are more rows in the response than limit + offset
    71  	if count+offset <= len(result.Rows) {
    72  		result.Rows = result.Rows[offset : count+offset]
    73  		return result, nil
    74  	}
    75  	// Remove extra rows from response
    76  	if offset <= len(result.Rows) {
    77  		result.Rows = result.Rows[offset:]
    78  		return result, nil
    79  	}
    80  	// offset is beyond the result set
    81  	result.Rows = nil
    82  	return result, nil
    83  }
    84  
    85  // TryStreamExecute satisfies the Primitive interface.
    86  func (l *Limit) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error {
    87  	count, offset, err := l.getCountAndOffset(vcursor, bindVars)
    88  	if err != nil {
    89  		return err
    90  	}
    91  
    92  	bindVars = copyBindVars(bindVars)
    93  
    94  	// When offset is present, we hijack the limit value so we can calculate
    95  	// the offset in memory from the result of the scatter query with count + offset.
    96  	bindVars["__upper_limit"] = sqltypes.Int64BindVariable(int64(count + offset))
    97  
    98  	err = vcursor.StreamExecutePrimitive(ctx, l.Input, bindVars, wantfields, func(qr *sqltypes.Result) error {
    99  		if len(qr.Fields) != 0 {
   100  			if err := callback(&sqltypes.Result{Fields: qr.Fields}); err != nil {
   101  				return err
   102  			}
   103  		}
   104  		inputSize := len(qr.Rows)
   105  		if inputSize == 0 {
   106  			return nil
   107  		}
   108  
   109  		// we've still not seen all rows we need to see before we can return anything to the client
   110  		if offset > 0 {
   111  			if inputSize <= offset {
   112  				// not enough to return anything yet
   113  				offset -= inputSize
   114  				return nil
   115  			}
   116  			qr.Rows = qr.Rows[offset:]
   117  			offset = 0
   118  		}
   119  
   120  		if count == 0 {
   121  			return io.EOF
   122  		}
   123  
   124  		// reduce count till 0.
   125  		result := &sqltypes.Result{Rows: qr.Rows}
   126  		resultSize := len(result.Rows)
   127  		if count > resultSize {
   128  			count -= resultSize
   129  			return callback(result)
   130  		}
   131  		result.Rows = result.Rows[:count]
   132  		count = 0
   133  		if err := callback(result); err != nil {
   134  			return err
   135  		}
   136  		return io.EOF
   137  	})
   138  
   139  	if err == io.EOF {
   140  		// We may get back the EOF we returned in the callback.
   141  		// If so, suppress it.
   142  		return nil
   143  	}
   144  	if err != nil {
   145  		return err
   146  	}
   147  	return nil
   148  }
   149  
   150  // GetFields implements the Primitive interface.
   151  func (l *Limit) GetFields(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable) (*sqltypes.Result, error) {
   152  	return l.Input.GetFields(ctx, vcursor, bindVars)
   153  }
   154  
   155  // Inputs returns the input to limit
   156  func (l *Limit) Inputs() []Primitive {
   157  	return []Primitive{l.Input}
   158  }
   159  
   160  // NeedsTransaction implements the Primitive interface.
   161  func (l *Limit) NeedsTransaction() bool {
   162  	return l.Input.NeedsTransaction()
   163  }
   164  
   165  func (l *Limit) getCountAndOffset(vcursor VCursor, bindVars map[string]*querypb.BindVariable) (count int, offset int, err error) {
   166  	env := evalengine.EnvWithBindVars(bindVars, vcursor.ConnCollation())
   167  	count, err = getIntFrom(env, l.Count)
   168  	if err != nil {
   169  		return
   170  	}
   171  	offset, err = getIntFrom(env, l.Offset)
   172  	if err != nil {
   173  		return
   174  	}
   175  	return
   176  }
   177  
   178  func getIntFrom(env *evalengine.ExpressionEnv, expr evalengine.Expr) (int, error) {
   179  	if expr == nil {
   180  		return 0, nil
   181  	}
   182  	evalResult, err := env.Evaluate(expr)
   183  	if err != nil {
   184  		return 0, err
   185  	}
   186  	value := evalResult.Value()
   187  	if value.IsNull() {
   188  		return 0, nil
   189  	}
   190  
   191  	num, err := value.ToUint64()
   192  	if err != nil {
   193  		return 0, err
   194  	}
   195  	count := int(num)
   196  	if count < 0 {
   197  		return 0, fmt.Errorf("requested limit is out of range: %v", num)
   198  	}
   199  	return count, nil
   200  }
   201  
   202  func (l *Limit) description() PrimitiveDescription {
   203  	other := map[string]any{}
   204  
   205  	if l.Count != nil {
   206  		other["Count"] = evalengine.FormatExpr(l.Count)
   207  	}
   208  	if l.Offset != nil {
   209  		other["Offset"] = evalengine.FormatExpr(l.Offset)
   210  	}
   211  
   212  	return PrimitiveDescription{
   213  		OperatorType: "Limit",
   214  		Other:        other,
   215  	}
   216  }