github.com/ydb-platform/ydb-go-sdk/v3@v3.89.2/internal/query/result.go (about) 1 package query 2 3 import ( 4 "context" 5 "errors" 6 "fmt" 7 "io" 8 "sync" 9 10 "github.com/ydb-platform/ydb-go-genproto/Ydb_Query_V1" 11 "github.com/ydb-platform/ydb-go-genproto/protos/Ydb_Query" 12 13 "github.com/ydb-platform/ydb-go-sdk/v3/internal/query/result" 14 "github.com/ydb-platform/ydb-go-sdk/v3/internal/stack" 15 "github.com/ydb-platform/ydb-go-sdk/v3/internal/stats" 16 "github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors" 17 "github.com/ydb-platform/ydb-go-sdk/v3/internal/xiter" 18 "github.com/ydb-platform/ydb-go-sdk/v3/query" 19 "github.com/ydb-platform/ydb-go-sdk/v3/trace" 20 ) 21 22 var ( 23 _ result.Result = (*streamResult)(nil) 24 _ result.Result = (*materializedResult)(nil) 25 ) 26 27 type ( 28 materializedResult struct { 29 resultSets []result.Set 30 idx int 31 } 32 streamResult struct { 33 stream Ydb_Query_V1.QueryService_ExecuteQueryClient 34 closeOnce func() 35 lastPart *Ydb_Query.ExecuteQueryResponsePart 36 resultSetIndex int64 37 closed chan struct{} 38 trace *trace.Query 39 statsCallback func(queryStats stats.QueryStats) 40 onNextPartErr []func(err error) 41 onTxMeta []func(txMeta *Ydb_Query.TransactionMeta) 42 } 43 resultOption func(s *streamResult) 44 ) 45 46 func rangeResultSets(ctx context.Context, r result.Result) xiter.Seq2[result.Set, error] { 47 return func(yield func(result.Set, error) bool) { 48 for { 49 rs, err := r.NextResultSet(ctx) 50 if err != nil { 51 if xerrors.Is(err, io.EOF) { 52 return 53 } 54 } 55 cont := yield(rs, err) 56 if !cont || err != nil { 57 return 58 } 59 } 60 } 61 } 62 63 func (r *materializedResult) ResultSets(ctx context.Context) xiter.Seq2[result.Set, error] { 64 return rangeResultSets(ctx, r) 65 } 66 67 func (r *streamResult) ResultSets(ctx context.Context) xiter.Seq2[result.Set, error] { 68 return rangeResultSets(ctx, r) 69 } 70 71 func (r *materializedResult) Close(ctx context.Context) error { 72 return nil 73 } 74 75 func (r *materializedResult) NextResultSet(ctx context.Context) (result.Set, error) { 76 if r.idx == len(r.resultSets) { 77 return nil, xerrors.WithStackTrace(io.EOF) 78 } 79 80 defer func() { 81 r.idx++ 82 }() 83 84 return r.resultSets[r.idx], nil 85 } 86 87 func withTrace(t *trace.Query) resultOption { 88 return func(s *streamResult) { 89 s.trace = t 90 } 91 } 92 93 func withStatsCallback(callback func(queryStats stats.QueryStats)) resultOption { 94 return func(s *streamResult) { 95 s.statsCallback = callback 96 } 97 } 98 99 func onNextPartErr(callback func(err error)) resultOption { 100 return func(s *streamResult) { 101 s.onNextPartErr = append(s.onNextPartErr, callback) 102 } 103 } 104 105 func onTxMeta(callback func(txMeta *Ydb_Query.TransactionMeta)) resultOption { 106 return func(s *streamResult) { 107 s.onTxMeta = append(s.onTxMeta, callback) 108 } 109 } 110 111 func newResult( 112 ctx context.Context, 113 stream Ydb_Query_V1.QueryService_ExecuteQueryClient, 114 opts ...resultOption, 115 ) (_ *streamResult, finalErr error) { 116 r := streamResult{ 117 stream: stream, 118 closed: make(chan struct{}), 119 resultSetIndex: -1, 120 } 121 r.closeOnce = sync.OnceFunc(func() { 122 close(r.closed) 123 r.stream = nil 124 }) 125 126 for _, opt := range opts { 127 if opt != nil { 128 opt(&r) 129 } 130 } 131 132 if r.trace != nil { 133 onDone := trace.QueryOnResultNew(r.trace, &ctx, 134 stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/v3/internal/query.newResult"), 135 ) 136 defer func() { 137 onDone(finalErr) 138 }() 139 } 140 141 select { 142 case <-ctx.Done(): 143 return nil, xerrors.WithStackTrace(ctx.Err()) 144 default: 145 part, err := r.nextPart(ctx) 146 if err != nil { 147 return nil, xerrors.WithStackTrace(err) 148 } 149 150 r.lastPart = part 151 152 if r.statsCallback != nil { 153 r.statsCallback(stats.FromQueryStats(part.GetExecStats())) 154 } 155 156 return &r, nil 157 } 158 } 159 160 func (r *streamResult) nextPart(ctx context.Context) ( 161 part *Ydb_Query.ExecuteQueryResponsePart, err error, 162 ) { 163 if r.trace != nil { 164 onDone := trace.QueryOnResultNextPart(r.trace, &ctx, 165 stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/v3/internal/query.(*streamResult).nextPart"), 166 ) 167 defer func() { 168 onDone(part.GetExecStats(), err) 169 }() 170 } 171 172 select { 173 case <-r.closed: 174 return nil, xerrors.WithStackTrace(io.EOF) 175 default: 176 part, err = nextPart(r.stream) 177 if err != nil { 178 r.closeOnce() 179 180 for _, callback := range r.onNextPartErr { 181 callback(err) 182 } 183 184 return nil, xerrors.WithStackTrace(err) 185 } 186 187 if txMeta := part.GetTxMeta(); txMeta != nil { 188 for _, f := range r.onTxMeta { 189 f(txMeta) 190 } 191 } 192 193 return part, nil 194 } 195 } 196 197 func nextPart(stream Ydb_Query_V1.QueryService_ExecuteQueryClient) ( 198 part *Ydb_Query.ExecuteQueryResponsePart, err error, 199 ) { 200 part, err = stream.Recv() 201 if err != nil { 202 return nil, xerrors.WithStackTrace(err) 203 } 204 205 return part, nil 206 } 207 208 func (r *streamResult) Close(ctx context.Context) (finalErr error) { 209 defer r.closeOnce() 210 211 if r.trace != nil { 212 onDone := trace.QueryOnResultClose(r.trace, &ctx, 213 stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/v3/internal/query.(*streamResult).Close"), 214 ) 215 defer func() { 216 onDone(finalErr) 217 }() 218 } 219 220 for { 221 select { 222 case <-r.closed: 223 return nil 224 default: 225 _, err := r.nextPart(ctx) 226 if err != nil { 227 if xerrors.Is(err, io.EOF) { 228 return nil 229 } 230 231 return xerrors.WithStackTrace(err) 232 } 233 } 234 } 235 } 236 237 func (r *streamResult) nextResultSet(ctx context.Context) (_ *resultSet, err error) { 238 nextResultSetIndex := r.resultSetIndex + 1 239 for { 240 select { 241 case <-r.closed: 242 return nil, xerrors.WithStackTrace(io.EOF) 243 case <-ctx.Done(): 244 return nil, xerrors.WithStackTrace(ctx.Err()) 245 default: 246 if resultSetIndex := r.lastPart.GetResultSetIndex(); resultSetIndex >= nextResultSetIndex { 247 r.resultSetIndex = resultSetIndex 248 249 return newResultSet(r.nextPartFunc(ctx, nextResultSetIndex), r.lastPart), nil 250 } 251 if r.stream == nil { 252 return nil, xerrors.WithStackTrace(io.EOF) 253 } 254 part, err := r.nextPart(ctx) 255 if err != nil { 256 return nil, xerrors.WithStackTrace(err) 257 } 258 if part.GetExecStats() != nil && r.statsCallback != nil { 259 r.statsCallback(stats.FromQueryStats(part.GetExecStats())) 260 } 261 if part.GetResultSetIndex() < r.resultSetIndex { 262 r.closeOnce() 263 264 return nil, xerrors.WithStackTrace(fmt.Errorf( 265 "next result set rowIndex %d less than last result set index %d: %w", 266 part.GetResultSetIndex(), r.resultSetIndex, errWrongNextResultSetIndex, 267 )) 268 } 269 r.lastPart = part 270 r.resultSetIndex = part.GetResultSetIndex() 271 } 272 } 273 } 274 275 func (r *streamResult) nextPartFunc( 276 ctx context.Context, 277 nextResultSetIndex int64, 278 ) func() (_ *Ydb_Query.ExecuteQueryResponsePart, err error) { 279 return func() (_ *Ydb_Query.ExecuteQueryResponsePart, err error) { 280 select { 281 case <-r.closed: 282 return nil, xerrors.WithStackTrace(io.EOF) 283 default: 284 if r.stream == nil { 285 return nil, xerrors.WithStackTrace(io.EOF) 286 } 287 part, err := r.nextPart(ctx) 288 if err != nil { 289 return nil, xerrors.WithStackTrace(err) 290 } 291 r.lastPart = part 292 if part.GetExecStats() != nil && r.statsCallback != nil { 293 r.statsCallback(stats.FromQueryStats(part.GetExecStats())) 294 } 295 if part.GetResultSetIndex() > nextResultSetIndex { 296 return nil, xerrors.WithStackTrace(fmt.Errorf( 297 "result set (index=%d) receive part (index=%d) for next result set: %w", 298 nextResultSetIndex, part.GetResultSetIndex(), io.EOF, 299 )) 300 } 301 302 return part, nil 303 } 304 } 305 } 306 307 func (r *streamResult) NextResultSet(ctx context.Context) (_ result.Set, err error) { 308 if r.trace != nil { 309 onDone := trace.QueryOnResultNextResultSet(r.trace, &ctx, 310 stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/v3/internal/query.(*streamResult).NextResultSet"), 311 ) 312 defer func() { 313 onDone(err) 314 }() 315 } 316 317 return r.nextResultSet(ctx) 318 } 319 320 func exactlyOneRowFromResult(ctx context.Context, r result.Result) (row result.Row, err error) { 321 rs, err := r.NextResultSet(ctx) 322 if err != nil { 323 return nil, xerrors.WithStackTrace(err) 324 } 325 row, err = rs.NextRow(ctx) 326 if err != nil { 327 return nil, xerrors.WithStackTrace(err) 328 } 329 330 _, err = rs.NextRow(ctx) 331 switch { 332 case err == nil: 333 return nil, xerrors.WithStackTrace(errMoreThanOneRow) 334 case errors.Is(err, io.EOF): 335 // pass 336 default: 337 return nil, xerrors.WithStackTrace(err) 338 } 339 340 _, err = r.NextResultSet(ctx) 341 switch { 342 case err == nil: 343 return nil, xerrors.WithStackTrace(errMoreThanOneRow) 344 case errors.Is(err, io.EOF): 345 // pass 346 default: 347 return nil, xerrors.WithStackTrace(err) 348 } 349 350 return row, nil 351 } 352 353 func exactlyOneResultSetFromResult(ctx context.Context, r result.Result) (rs result.Set, err error) { 354 var rows []query.Row 355 rs, err = r.NextResultSet(ctx) 356 if err != nil { 357 if xerrors.Is(err, io.EOF) { 358 return nil, xerrors.WithStackTrace(errNoResultSets) 359 } 360 361 return nil, xerrors.WithStackTrace(err) 362 } 363 364 var row query.Row 365 for { 366 row, err = rs.NextRow(ctx) 367 if err != nil { 368 if xerrors.Is(err, io.EOF) { 369 break 370 } 371 372 return nil, xerrors.WithStackTrace(err) 373 } 374 375 rows = append(rows, row) 376 } 377 378 _, err = r.NextResultSet(ctx) 379 switch { 380 case err == nil: 381 return nil, xerrors.WithStackTrace(errMoreThanOneResultSet) 382 case errors.Is(err, io.EOF): 383 // pass 384 default: 385 return nil, xerrors.WithStackTrace(err) 386 } 387 388 return MaterializedResultSet(rs.Index(), rs.Columns(), rs.ColumnTypes(), rows), nil 389 } 390 391 func resultToMaterializedResult(ctx context.Context, r result.Result) (result.Result, error) { 392 var resultSets []result.Set 393 394 for { 395 rs, err := r.NextResultSet(ctx) 396 if err != nil { 397 if xerrors.Is(err, io.EOF) { 398 break 399 } 400 401 return nil, xerrors.WithStackTrace(err) 402 } 403 404 var rows []query.Row 405 for { 406 row, err := rs.NextRow(ctx) 407 if err != nil { 408 if xerrors.Is(err, io.EOF) { 409 break 410 } 411 412 return nil, xerrors.WithStackTrace(err) 413 } 414 415 rows = append(rows, row) 416 } 417 418 resultSets = append(resultSets, MaterializedResultSet(rs.Index(), rs.Columns(), rs.ColumnTypes(), rows)) 419 } 420 421 return &materializedResult{ 422 resultSets: resultSets, 423 }, nil 424 }