vitess.io/vitess@v0.16.2/go/vt/vitessdriver/streaming_rows.go (about)

     1  /*
     2  Copyright 2019 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package vitessdriver
    18  
    19  import (
    20  	"database/sql/driver"
    21  	"errors"
    22  
    23  	"vitess.io/vitess/go/sqltypes"
    24  
    25  	querypb "vitess.io/vitess/go/vt/proto/query"
    26  )
    27  
    28  // streamingRows creates a database/sql/driver compliant Row iterator
    29  // for a streaming query.
    30  type streamingRows struct {
    31  	stream  sqltypes.ResultStream
    32  	failed  error
    33  	fields  []*querypb.Field
    34  	qr      *sqltypes.Result
    35  	index   int
    36  	convert *converter
    37  }
    38  
    39  // newStreamingRows creates a new streamingRows from stream.
    40  func newStreamingRows(stream sqltypes.ResultStream, conv *converter) driver.Rows {
    41  	return &streamingRows{
    42  		stream:  stream,
    43  		convert: conv,
    44  	}
    45  }
    46  
    47  func (ri *streamingRows) Columns() []string {
    48  	if ri.failed != nil {
    49  		return nil
    50  	}
    51  	if err := ri.checkFields(); err != nil {
    52  		_ = ri.setErr(err)
    53  		return nil
    54  	}
    55  	cols := make([]string, 0, len(ri.fields))
    56  	for _, field := range ri.fields {
    57  		cols = append(cols, field.Name)
    58  	}
    59  	return cols
    60  }
    61  
    62  func (ri *streamingRows) Close() error {
    63  	return nil
    64  }
    65  
    66  func (ri *streamingRows) Next(dest []driver.Value) error {
    67  	if ri.failed != nil {
    68  		return ri.failed
    69  	}
    70  	if err := ri.checkFields(); err != nil {
    71  		return ri.setErr(err)
    72  	}
    73  	// If no results were fetched or rows exhausted,
    74  	// loop until we get a non-zero number of rows.
    75  	for ri.qr == nil || ri.index >= len(ri.qr.Rows) {
    76  		qr, err := ri.stream.Recv()
    77  		if err != nil {
    78  			return ri.setErr(err)
    79  		}
    80  		ri.qr = qr
    81  		ri.index = 0
    82  	}
    83  	if err := ri.convert.populateRow(dest, ri.qr.Rows[ri.index]); err != nil {
    84  		return err
    85  	}
    86  	ri.index++
    87  	return nil
    88  }
    89  
    90  // checkFields fetches the first packet from the channel, which
    91  // should contain the field info.
    92  func (ri *streamingRows) checkFields() error {
    93  	if ri.fields != nil {
    94  		return nil
    95  	}
    96  	qr, err := ri.stream.Recv()
    97  	if err != nil {
    98  		return err
    99  	}
   100  	ri.fields = qr.Fields
   101  	if ri.fields == nil {
   102  		return errors.New("first packet did not return fields")
   103  	}
   104  	return nil
   105  }
   106  
   107  func (ri *streamingRows) setErr(err error) error {
   108  	ri.failed = err
   109  	return err
   110  }