github.com/ydb-platform/ydb-go-sdk/v3@v3.89.2/internal/query/result_set.go (about) 1 package query 2 3 import ( 4 "context" 5 "fmt" 6 "io" 7 8 "github.com/ydb-platform/ydb-go-genproto/protos/Ydb" 9 "github.com/ydb-platform/ydb-go-genproto/protos/Ydb_Query" 10 11 "github.com/ydb-platform/ydb-go-sdk/v3/internal/query/result" 12 "github.com/ydb-platform/ydb-go-sdk/v3/internal/types" 13 "github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors" 14 "github.com/ydb-platform/ydb-go-sdk/v3/internal/xiter" 15 "github.com/ydb-platform/ydb-go-sdk/v3/query" 16 ) 17 18 var ( 19 _ query.ResultSet = (*resultSet)(nil) 20 _ query.ResultSet = (*materializedResultSet)(nil) 21 ) 22 23 type ( 24 materializedResultSet struct { 25 index int 26 columnNames []string 27 columnTypes []types.Type 28 rows []query.Row 29 rowIndex int 30 } 31 resultSet struct { 32 index int64 33 recv func() (*Ydb_Query.ExecuteQueryResponsePart, error) 34 columns []*Ydb.Column 35 currentPart *Ydb_Query.ExecuteQueryResponsePart 36 rowIndex int 37 done chan struct{} 38 } 39 resultSetWithClose struct { 40 *resultSet 41 close func(ctx context.Context) error 42 } 43 ) 44 45 func rangeRows(ctx context.Context, rs result.Set) xiter.Seq2[result.Row, error] { 46 return func(yield func(result.Row, error) bool) { 47 for { 48 rs, err := rs.NextRow(ctx) 49 if err != nil { 50 if xerrors.Is(err, io.EOF) { 51 return 52 } 53 } 54 cont := yield(rs, err) 55 if !cont || err != nil { 56 return 57 } 58 } 59 } 60 } 61 62 func (*materializedResultSet) Close(context.Context) error { 63 return nil 64 } 65 66 func (rs *resultSetWithClose) Close(ctx context.Context) error { 67 return rs.close(ctx) 68 } 69 70 func (rs *materializedResultSet) Rows(ctx context.Context) xiter.Seq2[result.Row, error] { 71 return rangeRows(ctx, rs) 72 } 73 74 func (rs *resultSet) Rows(ctx context.Context) xiter.Seq2[result.Row, error] { 75 return rangeRows(ctx, rs) 76 } 77 78 func (rs *materializedResultSet) Columns() (columnNames []string) { 79 return rs.columnNames 80 } 81 82 func (rs *materializedResultSet) ColumnTypes() []types.Type { 83 return rs.columnTypes 84 } 85 86 func (rs *resultSet) ColumnTypes() (columnTypes []types.Type) { 87 columnTypes = make([]types.Type, len(rs.columns)) 88 for i := range rs.columns { 89 columnTypes[i] = types.TypeFromYDB(rs.columns[i].GetType()) 90 } 91 92 return columnTypes 93 } 94 95 func (rs *resultSet) Columns() (columnNames []string) { 96 columnNames = make([]string, len(rs.columns)) 97 for i := range rs.columns { 98 columnNames[i] = rs.columns[i].GetName() 99 } 100 101 return columnNames 102 } 103 104 func (rs *materializedResultSet) NextRow(ctx context.Context) (query.Row, error) { 105 if rs.rowIndex == len(rs.rows) { 106 return nil, xerrors.WithStackTrace(io.EOF) 107 } 108 109 defer func() { 110 rs.rowIndex++ 111 }() 112 113 return rs.rows[rs.rowIndex], nil 114 } 115 116 func (rs *materializedResultSet) Index() int { 117 if rs == nil { 118 return -1 119 } 120 121 return rs.index 122 } 123 124 func MaterializedResultSet( 125 index int, 126 columnNames []string, 127 columnTypes []types.Type, 128 rows []query.Row, 129 ) *materializedResultSet { 130 return &materializedResultSet{ 131 index: index, 132 columnNames: columnNames, 133 columnTypes: columnTypes, 134 rows: rows, 135 } 136 } 137 138 func newResultSet( 139 recv func() (*Ydb_Query.ExecuteQueryResponsePart, error), 140 part *Ydb_Query.ExecuteQueryResponsePart, 141 ) *resultSet { 142 return &resultSet{ 143 index: part.GetResultSetIndex(), 144 recv: recv, 145 currentPart: part, 146 rowIndex: -1, 147 columns: part.GetResultSet().GetColumns(), 148 done: make(chan struct{}), 149 } 150 } 151 152 func (rs *resultSet) nextRow(ctx context.Context) (*Row, error) { 153 rs.rowIndex++ 154 for { 155 select { 156 case <-rs.done: 157 return nil, io.EOF 158 case <-ctx.Done(): 159 return nil, xerrors.WithStackTrace(ctx.Err()) 160 default: 161 if rs.rowIndex == len(rs.currentPart.GetResultSet().GetRows()) { 162 part, err := rs.recv() 163 if err != nil { 164 if xerrors.Is(err, io.EOF) { 165 close(rs.done) 166 } 167 168 return nil, xerrors.WithStackTrace(err) 169 } 170 rs.rowIndex = 0 171 rs.currentPart = part 172 if part == nil { 173 close(rs.done) 174 175 return nil, xerrors.WithStackTrace(io.EOF) 176 } 177 } 178 if rs.currentPart.GetResultSet() != nil && rs.index != rs.currentPart.GetResultSetIndex() { 179 close(rs.done) 180 181 return nil, xerrors.WithStackTrace(fmt.Errorf( 182 "received part with result set index = %d, current result set index = %d: %w", 183 rs.index, rs.currentPart.GetResultSetIndex(), errWrongResultSetIndex, 184 )) 185 } 186 187 if rs.rowIndex < len(rs.currentPart.GetResultSet().GetRows()) { 188 return NewRow(rs.columns, rs.currentPart.GetResultSet().GetRows()[rs.rowIndex]), nil 189 } 190 } 191 } 192 } 193 194 func (rs *resultSet) NextRow(ctx context.Context) (_ query.Row, err error) { 195 return rs.nextRow(ctx) 196 } 197 198 func (rs *resultSet) Index() int { 199 if rs == nil { 200 return -1 201 } 202 203 return int(rs.index) 204 }