github.com/ydb-platform/ydb-go-sdk/v3@v3.57.0/internal/query/result_set.go (about)

     1  package query
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"io"
     7  
     8  	"github.com/ydb-platform/ydb-go-genproto/protos/Ydb"
     9  	"github.com/ydb-platform/ydb-go-genproto/protos/Ydb_Query"
    10  
    11  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors"
    12  	"github.com/ydb-platform/ydb-go-sdk/v3/query"
    13  )
    14  
    15  var _ query.ResultSet = (*resultSet)(nil)
    16  
    17  type resultSet struct {
    18  	index       int64
    19  	recv        func() (*Ydb_Query.ExecuteQueryResponsePart, error)
    20  	columns     []*Ydb.Column
    21  	currentPart *Ydb_Query.ExecuteQueryResponsePart
    22  	rowIndex    int
    23  	done        chan struct{}
    24  }
    25  
    26  func newResultSet(
    27  	recv func() (
    28  		*Ydb_Query.ExecuteQueryResponsePart, error,
    29  	),
    30  	part *Ydb_Query.ExecuteQueryResponsePart,
    31  ) *resultSet {
    32  	return &resultSet{
    33  		index:       part.GetResultSetIndex(),
    34  		recv:        recv,
    35  		currentPart: part,
    36  		rowIndex:    -1,
    37  		columns:     part.GetResultSet().GetColumns(),
    38  		done:        make(chan struct{}),
    39  	}
    40  }
    41  
    42  func (rs *resultSet) next(ctx context.Context) (*row, error) {
    43  	rs.rowIndex++
    44  	select {
    45  	case <-rs.done:
    46  		return nil, io.EOF
    47  	case <-ctx.Done():
    48  		return nil, xerrors.WithStackTrace(ctx.Err())
    49  	default:
    50  		if rs.rowIndex == len(rs.currentPart.GetResultSet().GetRows()) {
    51  			part, err := rs.recv()
    52  			if err != nil {
    53  				if xerrors.Is(err, io.EOF) {
    54  					close(rs.done)
    55  				}
    56  
    57  				return nil, xerrors.WithStackTrace(err)
    58  			}
    59  			rs.rowIndex = 0
    60  			rs.currentPart = part
    61  		}
    62  		if rs.index != rs.currentPart.GetResultSetIndex() {
    63  			close(rs.done)
    64  
    65  			return nil, xerrors.WithStackTrace(fmt.Errorf(
    66  				"received part with result set index = %d, current result set index = %d: %w",
    67  				rs.index, rs.currentPart.GetResultSetIndex(), errWrongResultSetIndex,
    68  			))
    69  		}
    70  
    71  		return newRow(rs.columns, rs.currentPart.GetResultSet().GetRows()[rs.rowIndex])
    72  	}
    73  }
    74  
    75  func (rs *resultSet) NextRow(ctx context.Context) (query.Row, error) {
    76  	return rs.next(ctx)
    77  }