vitess.io/vitess@v0.16.2/go/vt/vtgate/engine/sql_calc_found_rows.go (about)

     1  /*
     2  Copyright 2020 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package engine
    18  
    19  import (
    20  	"context"
    21  
    22  	"vitess.io/vitess/go/sqltypes"
    23  	querypb "vitess.io/vitess/go/vt/proto/query"
    24  	vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc"
    25  	"vitess.io/vitess/go/vt/vterrors"
    26  	"vitess.io/vitess/go/vt/vtgate/evalengine"
    27  )
    28  
    29  var _ Primitive = (*SQLCalcFoundRows)(nil)
    30  
    31  // SQLCalcFoundRows is a primitive to execute limit and count query as per their individual plan.
    32  type SQLCalcFoundRows struct {
    33  	LimitPrimitive Primitive
    34  	CountPrimitive Primitive
    35  }
    36  
    37  // RouteType implements the Primitive interface
    38  func (s SQLCalcFoundRows) RouteType() string {
    39  	return "SQLCalcFoundRows"
    40  }
    41  
    42  // GetKeyspaceName implements the Primitive interface
    43  func (s SQLCalcFoundRows) GetKeyspaceName() string {
    44  	return s.LimitPrimitive.GetKeyspaceName()
    45  }
    46  
    47  // GetTableName implements the Primitive interface
    48  func (s SQLCalcFoundRows) GetTableName() string {
    49  	return s.LimitPrimitive.GetTableName()
    50  }
    51  
    52  // TryExecute implements the Primitive interface
    53  func (s SQLCalcFoundRows) TryExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool) (*sqltypes.Result, error) {
    54  	limitQr, err := vcursor.ExecutePrimitive(ctx, s.LimitPrimitive, bindVars, wantfields)
    55  	if err != nil {
    56  		return nil, err
    57  	}
    58  	countQr, err := vcursor.ExecutePrimitive(ctx, s.CountPrimitive, bindVars, false)
    59  	if err != nil {
    60  		return nil, err
    61  	}
    62  	if len(countQr.Rows) != 1 || len(countQr.Rows[0]) != 1 {
    63  		return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "count query is not a scalar")
    64  	}
    65  	fr, err := evalengine.ToUint64(countQr.Rows[0][0])
    66  	if err != nil {
    67  		return nil, err
    68  	}
    69  	vcursor.Session().SetFoundRows(fr)
    70  	return limitQr, nil
    71  }
    72  
    73  // TryStreamExecute implements the Primitive interface
    74  func (s SQLCalcFoundRows) TryStreamExecute(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable, wantfields bool, callback func(*sqltypes.Result) error) error {
    75  	err := vcursor.StreamExecutePrimitive(ctx, s.LimitPrimitive, bindVars, wantfields, callback)
    76  	if err != nil {
    77  		return err
    78  	}
    79  
    80  	var fr *uint64
    81  
    82  	err = vcursor.StreamExecutePrimitive(ctx, s.CountPrimitive, bindVars, wantfields, func(countQr *sqltypes.Result) error {
    83  		if len(countQr.Rows) == 0 && countQr.Fields != nil {
    84  			// this is the fields, which we can ignore
    85  			return nil
    86  		}
    87  		if len(countQr.Rows) != 1 || len(countQr.Rows[0]) != 1 {
    88  			return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "count query is not a scalar")
    89  		}
    90  		toUint64, err := evalengine.ToUint64(countQr.Rows[0][0])
    91  		if err != nil {
    92  			return err
    93  		}
    94  		fr = &toUint64
    95  		return nil
    96  	})
    97  	if err != nil {
    98  		return err
    99  	}
   100  	if fr == nil {
   101  		return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "count query for SQL_CALC_FOUND_ROWS never returned a value")
   102  	}
   103  	vcursor.Session().SetFoundRows(*fr)
   104  	return nil
   105  }
   106  
   107  // GetFields implements the Primitive interface
   108  func (s SQLCalcFoundRows) GetFields(ctx context.Context, vcursor VCursor, bindVars map[string]*querypb.BindVariable) (*sqltypes.Result, error) {
   109  	return s.LimitPrimitive.GetFields(ctx, vcursor, bindVars)
   110  }
   111  
   112  // NeedsTransaction implements the Primitive interface
   113  func (s SQLCalcFoundRows) NeedsTransaction() bool {
   114  	return s.LimitPrimitive.NeedsTransaction()
   115  }
   116  
   117  // Inputs implements the Primitive interface
   118  func (s SQLCalcFoundRows) Inputs() []Primitive {
   119  	return []Primitive{s.LimitPrimitive, s.CountPrimitive}
   120  }
   121  
   122  func (s SQLCalcFoundRows) description() PrimitiveDescription {
   123  	return PrimitiveDescription{
   124  		OperatorType: "SQL_CALC_FOUND_ROWS",
   125  	}
   126  }