github.com/ydb-platform/ydb-go-sdk/v3@v3.89.2/internal/query/result_set.go (about)

     1  package query
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"io"
     7  
     8  	"github.com/ydb-platform/ydb-go-genproto/protos/Ydb"
     9  	"github.com/ydb-platform/ydb-go-genproto/protos/Ydb_Query"
    10  
    11  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/query/result"
    12  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/types"
    13  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors"
    14  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/xiter"
    15  	"github.com/ydb-platform/ydb-go-sdk/v3/query"
    16  )
    17  
    18  var (
    19  	_ query.ResultSet = (*resultSet)(nil)
    20  	_ query.ResultSet = (*materializedResultSet)(nil)
    21  )
    22  
    23  type (
    24  	materializedResultSet struct {
    25  		index       int
    26  		columnNames []string
    27  		columnTypes []types.Type
    28  		rows        []query.Row
    29  		rowIndex    int
    30  	}
    31  	resultSet struct {
    32  		index       int64
    33  		recv        func() (*Ydb_Query.ExecuteQueryResponsePart, error)
    34  		columns     []*Ydb.Column
    35  		currentPart *Ydb_Query.ExecuteQueryResponsePart
    36  		rowIndex    int
    37  		done        chan struct{}
    38  	}
    39  	resultSetWithClose struct {
    40  		*resultSet
    41  		close func(ctx context.Context) error
    42  	}
    43  )
    44  
    45  func rangeRows(ctx context.Context, rs result.Set) xiter.Seq2[result.Row, error] {
    46  	return func(yield func(result.Row, error) bool) {
    47  		for {
    48  			rs, err := rs.NextRow(ctx)
    49  			if err != nil {
    50  				if xerrors.Is(err, io.EOF) {
    51  					return
    52  				}
    53  			}
    54  			cont := yield(rs, err)
    55  			if !cont || err != nil {
    56  				return
    57  			}
    58  		}
    59  	}
    60  }
    61  
    62  func (*materializedResultSet) Close(context.Context) error {
    63  	return nil
    64  }
    65  
    66  func (rs *resultSetWithClose) Close(ctx context.Context) error {
    67  	return rs.close(ctx)
    68  }
    69  
    70  func (rs *materializedResultSet) Rows(ctx context.Context) xiter.Seq2[result.Row, error] {
    71  	return rangeRows(ctx, rs)
    72  }
    73  
    74  func (rs *resultSet) Rows(ctx context.Context) xiter.Seq2[result.Row, error] {
    75  	return rangeRows(ctx, rs)
    76  }
    77  
    78  func (rs *materializedResultSet) Columns() (columnNames []string) {
    79  	return rs.columnNames
    80  }
    81  
    82  func (rs *materializedResultSet) ColumnTypes() []types.Type {
    83  	return rs.columnTypes
    84  }
    85  
    86  func (rs *resultSet) ColumnTypes() (columnTypes []types.Type) {
    87  	columnTypes = make([]types.Type, len(rs.columns))
    88  	for i := range rs.columns {
    89  		columnTypes[i] = types.TypeFromYDB(rs.columns[i].GetType())
    90  	}
    91  
    92  	return columnTypes
    93  }
    94  
    95  func (rs *resultSet) Columns() (columnNames []string) {
    96  	columnNames = make([]string, len(rs.columns))
    97  	for i := range rs.columns {
    98  		columnNames[i] = rs.columns[i].GetName()
    99  	}
   100  
   101  	return columnNames
   102  }
   103  
   104  func (rs *materializedResultSet) NextRow(ctx context.Context) (query.Row, error) {
   105  	if rs.rowIndex == len(rs.rows) {
   106  		return nil, xerrors.WithStackTrace(io.EOF)
   107  	}
   108  
   109  	defer func() {
   110  		rs.rowIndex++
   111  	}()
   112  
   113  	return rs.rows[rs.rowIndex], nil
   114  }
   115  
   116  func (rs *materializedResultSet) Index() int {
   117  	if rs == nil {
   118  		return -1
   119  	}
   120  
   121  	return rs.index
   122  }
   123  
   124  func MaterializedResultSet(
   125  	index int,
   126  	columnNames []string,
   127  	columnTypes []types.Type,
   128  	rows []query.Row,
   129  ) *materializedResultSet {
   130  	return &materializedResultSet{
   131  		index:       index,
   132  		columnNames: columnNames,
   133  		columnTypes: columnTypes,
   134  		rows:        rows,
   135  	}
   136  }
   137  
   138  func newResultSet(
   139  	recv func() (*Ydb_Query.ExecuteQueryResponsePart, error),
   140  	part *Ydb_Query.ExecuteQueryResponsePart,
   141  ) *resultSet {
   142  	return &resultSet{
   143  		index:       part.GetResultSetIndex(),
   144  		recv:        recv,
   145  		currentPart: part,
   146  		rowIndex:    -1,
   147  		columns:     part.GetResultSet().GetColumns(),
   148  		done:        make(chan struct{}),
   149  	}
   150  }
   151  
   152  func (rs *resultSet) nextRow(ctx context.Context) (*Row, error) {
   153  	rs.rowIndex++
   154  	for {
   155  		select {
   156  		case <-rs.done:
   157  			return nil, io.EOF
   158  		case <-ctx.Done():
   159  			return nil, xerrors.WithStackTrace(ctx.Err())
   160  		default:
   161  			if rs.rowIndex == len(rs.currentPart.GetResultSet().GetRows()) {
   162  				part, err := rs.recv()
   163  				if err != nil {
   164  					if xerrors.Is(err, io.EOF) {
   165  						close(rs.done)
   166  					}
   167  
   168  					return nil, xerrors.WithStackTrace(err)
   169  				}
   170  				rs.rowIndex = 0
   171  				rs.currentPart = part
   172  				if part == nil {
   173  					close(rs.done)
   174  
   175  					return nil, xerrors.WithStackTrace(io.EOF)
   176  				}
   177  			}
   178  			if rs.currentPart.GetResultSet() != nil && rs.index != rs.currentPart.GetResultSetIndex() {
   179  				close(rs.done)
   180  
   181  				return nil, xerrors.WithStackTrace(fmt.Errorf(
   182  					"received part with result set index = %d, current result set index = %d: %w",
   183  					rs.index, rs.currentPart.GetResultSetIndex(), errWrongResultSetIndex,
   184  				))
   185  			}
   186  
   187  			if rs.rowIndex < len(rs.currentPart.GetResultSet().GetRows()) {
   188  				return NewRow(rs.columns, rs.currentPart.GetResultSet().GetRows()[rs.rowIndex]), nil
   189  			}
   190  		}
   191  	}
   192  }
   193  
   194  func (rs *resultSet) NextRow(ctx context.Context) (_ query.Row, err error) {
   195  	return rs.nextRow(ctx)
   196  }
   197  
   198  func (rs *resultSet) Index() int {
   199  	if rs == nil {
   200  		return -1
   201  	}
   202  
   203  	return int(rs.index)
   204  }