github.com/ydb-platform/ydb-go-sdk/v3@v3.57.0/internal/query/result.go (about) 1 package query 2 3 import ( 4 "context" 5 "fmt" 6 "io" 7 "sync" 8 9 "github.com/ydb-platform/ydb-go-genproto/Ydb_Query_V1" 10 "github.com/ydb-platform/ydb-go-genproto/protos/Ydb" 11 "github.com/ydb-platform/ydb-go-genproto/protos/Ydb_Query" 12 13 "github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors" 14 "github.com/ydb-platform/ydb-go-sdk/v3/query" 15 ) 16 17 var _ query.Result = (*result)(nil) 18 19 type result struct { 20 stream Ydb_Query_V1.QueryService_ExecuteQueryClient 21 interrupt func() 22 close func() 23 lastPart *Ydb_Query.ExecuteQueryResponsePart 24 resultSetIndex int64 25 errs []error 26 interrupted chan struct{} 27 closed chan struct{} 28 } 29 30 func newResult( 31 ctx context.Context, 32 stream Ydb_Query_V1.QueryService_ExecuteQueryClient, 33 streamCancel func(), 34 ) (_ *result, txID string, _ error) { 35 interrupted := make(chan struct{}) 36 r := result{ 37 stream: stream, 38 resultSetIndex: -1, 39 interrupted: interrupted, 40 closed: make(chan struct{}), 41 interrupt: sync.OnceFunc(func() { 42 close(interrupted) 43 streamCancel() 44 }), 45 } 46 select { 47 case <-ctx.Done(): 48 return nil, txID, xerrors.WithStackTrace(ctx.Err()) 49 default: 50 part, err := nextPart(stream) 51 if err != nil { 52 return nil, txID, xerrors.WithStackTrace(err) 53 } 54 r.lastPart = part 55 r.close = sync.OnceFunc(func() { 56 r.interrupt() 57 close(r.closed) 58 }) 59 60 return &r, part.GetTxMeta().GetId(), nil 61 } 62 } 63 64 func nextPart(stream Ydb_Query_V1.QueryService_ExecuteQueryClient) (*Ydb_Query.ExecuteQueryResponsePart, error) { 65 part, err := stream.Recv() 66 if err != nil { 67 if xerrors.Is(err, io.EOF) { 68 return nil, xerrors.WithStackTrace(err) 69 } 70 71 return nil, xerrors.WithStackTrace(xerrors.Transport(err)) 72 } 73 if status := part.GetStatus(); status != Ydb.StatusIds_SUCCESS { 74 return nil, xerrors.WithStackTrace( 75 xerrors.FromOperation(part), 76 ) 77 } 78 79 return part, nil 80 } 81 82 func (r *result) Close(ctx context.Context) error { 83 r.close() 84 85 return nil 86 } 87 88 func (r *result) nextResultSet(ctx context.Context) (_ *resultSet, err error) { 89 defer func() { 90 if err != nil && !xerrors.Is(err, 91 io.EOF, errClosedResult, context.Canceled, 92 ) { 93 r.errs = append(r.errs, err) 94 } 95 }() 96 nextResultSetIndex := r.resultSetIndex + 1 97 for { 98 select { 99 case <-r.closed: 100 return nil, xerrors.WithStackTrace(errClosedResult) 101 case <-ctx.Done(): 102 return nil, xerrors.WithStackTrace(ctx.Err()) 103 default: 104 select { 105 case <-r.interrupted: 106 return nil, xerrors.WithStackTrace(errInterruptedStream) 107 default: 108 if resultSetIndex := r.lastPart.GetResultSetIndex(); resultSetIndex >= nextResultSetIndex { //nolint:nestif 109 r.resultSetIndex = resultSetIndex 110 111 return newResultSet(func() (_ *Ydb_Query.ExecuteQueryResponsePart, err error) { 112 defer func() { 113 if err != nil && !xerrors.Is(err, 114 io.EOF, context.Canceled, 115 ) { 116 r.errs = append(r.errs, err) 117 } 118 }() 119 select { 120 case <-r.closed: 121 return nil, errClosedResult 122 case <-r.interrupted: 123 return nil, errInterruptedStream 124 default: 125 part, err := nextPart(r.stream) 126 if err != nil { 127 if xerrors.Is(err, io.EOF) { 128 r.close() 129 } 130 131 return nil, xerrors.WithStackTrace(err) 132 } 133 r.lastPart = part 134 if part.GetResultSetIndex() > nextResultSetIndex { 135 return nil, xerrors.WithStackTrace(fmt.Errorf( 136 "result set (index=%d) receive part (index=%d) for next result set: %w", 137 nextResultSetIndex, part.GetResultSetIndex(), io.EOF, 138 )) 139 } 140 141 return part, nil 142 } 143 }, r.lastPart), nil 144 } 145 part, err := nextPart(r.stream) 146 if err != nil { 147 return nil, xerrors.WithStackTrace(err) 148 } 149 if part.GetResultSetIndex() < r.resultSetIndex { 150 return nil, xerrors.WithStackTrace(fmt.Errorf( 151 "next result set index %d less than last result set index %d: %w", 152 part.GetResultSetIndex(), r.resultSetIndex, errWrongNextResultSetIndex, 153 )) 154 } 155 r.lastPart = part 156 r.resultSetIndex = part.GetResultSetIndex() 157 } 158 } 159 } 160 } 161 162 func (r *result) NextResultSet(ctx context.Context) (query.ResultSet, error) { 163 return r.nextResultSet(ctx) 164 } 165 166 func (r *result) Err() error { 167 switch { 168 case len(r.errs) == 0: 169 return nil 170 case len(r.errs) == 1: 171 return r.errs[0] 172 default: 173 return xerrors.WithStackTrace(xerrors.Join(r.errs...)) 174 } 175 }