github.com/ydb-platform/ydb-go-sdk/v3@v3.57.0/internal/query/result.go (about)

     1  package query
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"io"
     7  	"sync"
     8  
     9  	"github.com/ydb-platform/ydb-go-genproto/Ydb_Query_V1"
    10  	"github.com/ydb-platform/ydb-go-genproto/protos/Ydb"
    11  	"github.com/ydb-platform/ydb-go-genproto/protos/Ydb_Query"
    12  
    13  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors"
    14  	"github.com/ydb-platform/ydb-go-sdk/v3/query"
    15  )
    16  
    17  var _ query.Result = (*result)(nil)
    18  
    19  type result struct {
    20  	stream         Ydb_Query_V1.QueryService_ExecuteQueryClient
    21  	interrupt      func()
    22  	close          func()
    23  	lastPart       *Ydb_Query.ExecuteQueryResponsePart
    24  	resultSetIndex int64
    25  	errs           []error
    26  	interrupted    chan struct{}
    27  	closed         chan struct{}
    28  }
    29  
    30  func newResult(
    31  	ctx context.Context,
    32  	stream Ydb_Query_V1.QueryService_ExecuteQueryClient,
    33  	streamCancel func(),
    34  ) (_ *result, txID string, _ error) {
    35  	interrupted := make(chan struct{})
    36  	r := result{
    37  		stream:         stream,
    38  		resultSetIndex: -1,
    39  		interrupted:    interrupted,
    40  		closed:         make(chan struct{}),
    41  		interrupt: sync.OnceFunc(func() {
    42  			close(interrupted)
    43  			streamCancel()
    44  		}),
    45  	}
    46  	select {
    47  	case <-ctx.Done():
    48  		return nil, txID, xerrors.WithStackTrace(ctx.Err())
    49  	default:
    50  		part, err := nextPart(stream)
    51  		if err != nil {
    52  			return nil, txID, xerrors.WithStackTrace(err)
    53  		}
    54  		r.lastPart = part
    55  		r.close = sync.OnceFunc(func() {
    56  			r.interrupt()
    57  			close(r.closed)
    58  		})
    59  
    60  		return &r, part.GetTxMeta().GetId(), nil
    61  	}
    62  }
    63  
    64  func nextPart(stream Ydb_Query_V1.QueryService_ExecuteQueryClient) (*Ydb_Query.ExecuteQueryResponsePart, error) {
    65  	part, err := stream.Recv()
    66  	if err != nil {
    67  		if xerrors.Is(err, io.EOF) {
    68  			return nil, xerrors.WithStackTrace(err)
    69  		}
    70  
    71  		return nil, xerrors.WithStackTrace(xerrors.Transport(err))
    72  	}
    73  	if status := part.GetStatus(); status != Ydb.StatusIds_SUCCESS {
    74  		return nil, xerrors.WithStackTrace(
    75  			xerrors.FromOperation(part),
    76  		)
    77  	}
    78  
    79  	return part, nil
    80  }
    81  
    82  func (r *result) Close(ctx context.Context) error {
    83  	r.close()
    84  
    85  	return nil
    86  }
    87  
    88  func (r *result) nextResultSet(ctx context.Context) (_ *resultSet, err error) {
    89  	defer func() {
    90  		if err != nil && !xerrors.Is(err,
    91  			io.EOF, errClosedResult, context.Canceled,
    92  		) {
    93  			r.errs = append(r.errs, err)
    94  		}
    95  	}()
    96  	nextResultSetIndex := r.resultSetIndex + 1
    97  	for {
    98  		select {
    99  		case <-r.closed:
   100  			return nil, xerrors.WithStackTrace(errClosedResult)
   101  		case <-ctx.Done():
   102  			return nil, xerrors.WithStackTrace(ctx.Err())
   103  		default:
   104  			select {
   105  			case <-r.interrupted:
   106  				return nil, xerrors.WithStackTrace(errInterruptedStream)
   107  			default:
   108  				if resultSetIndex := r.lastPart.GetResultSetIndex(); resultSetIndex >= nextResultSetIndex { //nolint:nestif
   109  					r.resultSetIndex = resultSetIndex
   110  
   111  					return newResultSet(func() (_ *Ydb_Query.ExecuteQueryResponsePart, err error) {
   112  						defer func() {
   113  							if err != nil && !xerrors.Is(err,
   114  								io.EOF, context.Canceled,
   115  							) {
   116  								r.errs = append(r.errs, err)
   117  							}
   118  						}()
   119  						select {
   120  						case <-r.closed:
   121  							return nil, errClosedResult
   122  						case <-r.interrupted:
   123  							return nil, errInterruptedStream
   124  						default:
   125  							part, err := nextPart(r.stream)
   126  							if err != nil {
   127  								if xerrors.Is(err, io.EOF) {
   128  									r.close()
   129  								}
   130  
   131  								return nil, xerrors.WithStackTrace(err)
   132  							}
   133  							r.lastPart = part
   134  							if part.GetResultSetIndex() > nextResultSetIndex {
   135  								return nil, xerrors.WithStackTrace(fmt.Errorf(
   136  									"result set (index=%d) receive part (index=%d) for next result set: %w",
   137  									nextResultSetIndex, part.GetResultSetIndex(), io.EOF,
   138  								))
   139  							}
   140  
   141  							return part, nil
   142  						}
   143  					}, r.lastPart), nil
   144  				}
   145  				part, err := nextPart(r.stream)
   146  				if err != nil {
   147  					return nil, xerrors.WithStackTrace(err)
   148  				}
   149  				if part.GetResultSetIndex() < r.resultSetIndex {
   150  					return nil, xerrors.WithStackTrace(fmt.Errorf(
   151  						"next result set index %d less than last result set index %d: %w",
   152  						part.GetResultSetIndex(), r.resultSetIndex, errWrongNextResultSetIndex,
   153  					))
   154  				}
   155  				r.lastPart = part
   156  				r.resultSetIndex = part.GetResultSetIndex()
   157  			}
   158  		}
   159  	}
   160  }
   161  
   162  func (r *result) NextResultSet(ctx context.Context) (query.ResultSet, error) {
   163  	return r.nextResultSet(ctx)
   164  }
   165  
   166  func (r *result) Err() error {
   167  	switch {
   168  	case len(r.errs) == 0:
   169  		return nil
   170  	case len(r.errs) == 1:
   171  		return r.errs[0]
   172  	default:
   173  		return xerrors.WithStackTrace(xerrors.Join(r.errs...))
   174  	}
   175  }