github.com/ydb-platform/ydb-go-sdk/v3@v3.89.2/internal/query/result.go (about)

     1  package query
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"io"
     8  	"sync"
     9  
    10  	"github.com/ydb-platform/ydb-go-genproto/Ydb_Query_V1"
    11  	"github.com/ydb-platform/ydb-go-genproto/protos/Ydb_Query"
    12  
    13  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/query/result"
    14  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/stack"
    15  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/stats"
    16  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors"
    17  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/xiter"
    18  	"github.com/ydb-platform/ydb-go-sdk/v3/query"
    19  	"github.com/ydb-platform/ydb-go-sdk/v3/trace"
    20  )
    21  
    22  var (
    23  	_ result.Result = (*streamResult)(nil)
    24  	_ result.Result = (*materializedResult)(nil)
    25  )
    26  
    27  type (
    28  	materializedResult struct {
    29  		resultSets []result.Set
    30  		idx        int
    31  	}
    32  	streamResult struct {
    33  		stream         Ydb_Query_V1.QueryService_ExecuteQueryClient
    34  		closeOnce      func()
    35  		lastPart       *Ydb_Query.ExecuteQueryResponsePart
    36  		resultSetIndex int64
    37  		closed         chan struct{}
    38  		trace          *trace.Query
    39  		statsCallback  func(queryStats stats.QueryStats)
    40  		onNextPartErr  []func(err error)
    41  		onTxMeta       []func(txMeta *Ydb_Query.TransactionMeta)
    42  	}
    43  	resultOption func(s *streamResult)
    44  )
    45  
    46  func rangeResultSets(ctx context.Context, r result.Result) xiter.Seq2[result.Set, error] {
    47  	return func(yield func(result.Set, error) bool) {
    48  		for {
    49  			rs, err := r.NextResultSet(ctx)
    50  			if err != nil {
    51  				if xerrors.Is(err, io.EOF) {
    52  					return
    53  				}
    54  			}
    55  			cont := yield(rs, err)
    56  			if !cont || err != nil {
    57  				return
    58  			}
    59  		}
    60  	}
    61  }
    62  
    63  func (r *materializedResult) ResultSets(ctx context.Context) xiter.Seq2[result.Set, error] {
    64  	return rangeResultSets(ctx, r)
    65  }
    66  
    67  func (r *streamResult) ResultSets(ctx context.Context) xiter.Seq2[result.Set, error] {
    68  	return rangeResultSets(ctx, r)
    69  }
    70  
    71  func (r *materializedResult) Close(ctx context.Context) error {
    72  	return nil
    73  }
    74  
    75  func (r *materializedResult) NextResultSet(ctx context.Context) (result.Set, error) {
    76  	if r.idx == len(r.resultSets) {
    77  		return nil, xerrors.WithStackTrace(io.EOF)
    78  	}
    79  
    80  	defer func() {
    81  		r.idx++
    82  	}()
    83  
    84  	return r.resultSets[r.idx], nil
    85  }
    86  
    87  func withTrace(t *trace.Query) resultOption {
    88  	return func(s *streamResult) {
    89  		s.trace = t
    90  	}
    91  }
    92  
    93  func withStatsCallback(callback func(queryStats stats.QueryStats)) resultOption {
    94  	return func(s *streamResult) {
    95  		s.statsCallback = callback
    96  	}
    97  }
    98  
    99  func onNextPartErr(callback func(err error)) resultOption {
   100  	return func(s *streamResult) {
   101  		s.onNextPartErr = append(s.onNextPartErr, callback)
   102  	}
   103  }
   104  
   105  func onTxMeta(callback func(txMeta *Ydb_Query.TransactionMeta)) resultOption {
   106  	return func(s *streamResult) {
   107  		s.onTxMeta = append(s.onTxMeta, callback)
   108  	}
   109  }
   110  
   111  func newResult(
   112  	ctx context.Context,
   113  	stream Ydb_Query_V1.QueryService_ExecuteQueryClient,
   114  	opts ...resultOption,
   115  ) (_ *streamResult, finalErr error) {
   116  	r := streamResult{
   117  		stream:         stream,
   118  		closed:         make(chan struct{}),
   119  		resultSetIndex: -1,
   120  	}
   121  	r.closeOnce = sync.OnceFunc(func() {
   122  		close(r.closed)
   123  		r.stream = nil
   124  	})
   125  
   126  	for _, opt := range opts {
   127  		if opt != nil {
   128  			opt(&r)
   129  		}
   130  	}
   131  
   132  	if r.trace != nil {
   133  		onDone := trace.QueryOnResultNew(r.trace, &ctx,
   134  			stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/v3/internal/query.newResult"),
   135  		)
   136  		defer func() {
   137  			onDone(finalErr)
   138  		}()
   139  	}
   140  
   141  	select {
   142  	case <-ctx.Done():
   143  		return nil, xerrors.WithStackTrace(ctx.Err())
   144  	default:
   145  		part, err := r.nextPart(ctx)
   146  		if err != nil {
   147  			return nil, xerrors.WithStackTrace(err)
   148  		}
   149  
   150  		r.lastPart = part
   151  
   152  		if r.statsCallback != nil {
   153  			r.statsCallback(stats.FromQueryStats(part.GetExecStats()))
   154  		}
   155  
   156  		return &r, nil
   157  	}
   158  }
   159  
   160  func (r *streamResult) nextPart(ctx context.Context) (
   161  	part *Ydb_Query.ExecuteQueryResponsePart, err error,
   162  ) {
   163  	if r.trace != nil {
   164  		onDone := trace.QueryOnResultNextPart(r.trace, &ctx,
   165  			stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/v3/internal/query.(*streamResult).nextPart"),
   166  		)
   167  		defer func() {
   168  			onDone(part.GetExecStats(), err)
   169  		}()
   170  	}
   171  
   172  	select {
   173  	case <-r.closed:
   174  		return nil, xerrors.WithStackTrace(io.EOF)
   175  	default:
   176  		part, err = nextPart(r.stream)
   177  		if err != nil {
   178  			r.closeOnce()
   179  
   180  			for _, callback := range r.onNextPartErr {
   181  				callback(err)
   182  			}
   183  
   184  			return nil, xerrors.WithStackTrace(err)
   185  		}
   186  
   187  		if txMeta := part.GetTxMeta(); txMeta != nil {
   188  			for _, f := range r.onTxMeta {
   189  				f(txMeta)
   190  			}
   191  		}
   192  
   193  		return part, nil
   194  	}
   195  }
   196  
   197  func nextPart(stream Ydb_Query_V1.QueryService_ExecuteQueryClient) (
   198  	part *Ydb_Query.ExecuteQueryResponsePart, err error,
   199  ) {
   200  	part, err = stream.Recv()
   201  	if err != nil {
   202  		return nil, xerrors.WithStackTrace(err)
   203  	}
   204  
   205  	return part, nil
   206  }
   207  
   208  func (r *streamResult) Close(ctx context.Context) (finalErr error) {
   209  	defer r.closeOnce()
   210  
   211  	if r.trace != nil {
   212  		onDone := trace.QueryOnResultClose(r.trace, &ctx,
   213  			stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/v3/internal/query.(*streamResult).Close"),
   214  		)
   215  		defer func() {
   216  			onDone(finalErr)
   217  		}()
   218  	}
   219  
   220  	for {
   221  		select {
   222  		case <-r.closed:
   223  			return nil
   224  		default:
   225  			_, err := r.nextPart(ctx)
   226  			if err != nil {
   227  				if xerrors.Is(err, io.EOF) {
   228  					return nil
   229  				}
   230  
   231  				return xerrors.WithStackTrace(err)
   232  			}
   233  		}
   234  	}
   235  }
   236  
   237  func (r *streamResult) nextResultSet(ctx context.Context) (_ *resultSet, err error) {
   238  	nextResultSetIndex := r.resultSetIndex + 1
   239  	for {
   240  		select {
   241  		case <-r.closed:
   242  			return nil, xerrors.WithStackTrace(io.EOF)
   243  		case <-ctx.Done():
   244  			return nil, xerrors.WithStackTrace(ctx.Err())
   245  		default:
   246  			if resultSetIndex := r.lastPart.GetResultSetIndex(); resultSetIndex >= nextResultSetIndex {
   247  				r.resultSetIndex = resultSetIndex
   248  
   249  				return newResultSet(r.nextPartFunc(ctx, nextResultSetIndex), r.lastPart), nil
   250  			}
   251  			if r.stream == nil {
   252  				return nil, xerrors.WithStackTrace(io.EOF)
   253  			}
   254  			part, err := r.nextPart(ctx)
   255  			if err != nil {
   256  				return nil, xerrors.WithStackTrace(err)
   257  			}
   258  			if part.GetExecStats() != nil && r.statsCallback != nil {
   259  				r.statsCallback(stats.FromQueryStats(part.GetExecStats()))
   260  			}
   261  			if part.GetResultSetIndex() < r.resultSetIndex {
   262  				r.closeOnce()
   263  
   264  				return nil, xerrors.WithStackTrace(fmt.Errorf(
   265  					"next result set rowIndex %d less than last result set index %d: %w",
   266  					part.GetResultSetIndex(), r.resultSetIndex, errWrongNextResultSetIndex,
   267  				))
   268  			}
   269  			r.lastPart = part
   270  			r.resultSetIndex = part.GetResultSetIndex()
   271  		}
   272  	}
   273  }
   274  
   275  func (r *streamResult) nextPartFunc(
   276  	ctx context.Context,
   277  	nextResultSetIndex int64,
   278  ) func() (_ *Ydb_Query.ExecuteQueryResponsePart, err error) {
   279  	return func() (_ *Ydb_Query.ExecuteQueryResponsePart, err error) {
   280  		select {
   281  		case <-r.closed:
   282  			return nil, xerrors.WithStackTrace(io.EOF)
   283  		default:
   284  			if r.stream == nil {
   285  				return nil, xerrors.WithStackTrace(io.EOF)
   286  			}
   287  			part, err := r.nextPart(ctx)
   288  			if err != nil {
   289  				return nil, xerrors.WithStackTrace(err)
   290  			}
   291  			r.lastPart = part
   292  			if part.GetExecStats() != nil && r.statsCallback != nil {
   293  				r.statsCallback(stats.FromQueryStats(part.GetExecStats()))
   294  			}
   295  			if part.GetResultSetIndex() > nextResultSetIndex {
   296  				return nil, xerrors.WithStackTrace(fmt.Errorf(
   297  					"result set (index=%d) receive part (index=%d) for next result set: %w",
   298  					nextResultSetIndex, part.GetResultSetIndex(), io.EOF,
   299  				))
   300  			}
   301  
   302  			return part, nil
   303  		}
   304  	}
   305  }
   306  
   307  func (r *streamResult) NextResultSet(ctx context.Context) (_ result.Set, err error) {
   308  	if r.trace != nil {
   309  		onDone := trace.QueryOnResultNextResultSet(r.trace, &ctx,
   310  			stack.FunctionID("github.com/ydb-platform/ydb-go-sdk/v3/internal/query.(*streamResult).NextResultSet"),
   311  		)
   312  		defer func() {
   313  			onDone(err)
   314  		}()
   315  	}
   316  
   317  	return r.nextResultSet(ctx)
   318  }
   319  
   320  func exactlyOneRowFromResult(ctx context.Context, r result.Result) (row result.Row, err error) {
   321  	rs, err := r.NextResultSet(ctx)
   322  	if err != nil {
   323  		return nil, xerrors.WithStackTrace(err)
   324  	}
   325  	row, err = rs.NextRow(ctx)
   326  	if err != nil {
   327  		return nil, xerrors.WithStackTrace(err)
   328  	}
   329  
   330  	_, err = rs.NextRow(ctx)
   331  	switch {
   332  	case err == nil:
   333  		return nil, xerrors.WithStackTrace(errMoreThanOneRow)
   334  	case errors.Is(err, io.EOF):
   335  		// pass
   336  	default:
   337  		return nil, xerrors.WithStackTrace(err)
   338  	}
   339  
   340  	_, err = r.NextResultSet(ctx)
   341  	switch {
   342  	case err == nil:
   343  		return nil, xerrors.WithStackTrace(errMoreThanOneRow)
   344  	case errors.Is(err, io.EOF):
   345  		// pass
   346  	default:
   347  		return nil, xerrors.WithStackTrace(err)
   348  	}
   349  
   350  	return row, nil
   351  }
   352  
   353  func exactlyOneResultSetFromResult(ctx context.Context, r result.Result) (rs result.Set, err error) {
   354  	var rows []query.Row
   355  	rs, err = r.NextResultSet(ctx)
   356  	if err != nil {
   357  		if xerrors.Is(err, io.EOF) {
   358  			return nil, xerrors.WithStackTrace(errNoResultSets)
   359  		}
   360  
   361  		return nil, xerrors.WithStackTrace(err)
   362  	}
   363  
   364  	var row query.Row
   365  	for {
   366  		row, err = rs.NextRow(ctx)
   367  		if err != nil {
   368  			if xerrors.Is(err, io.EOF) {
   369  				break
   370  			}
   371  
   372  			return nil, xerrors.WithStackTrace(err)
   373  		}
   374  
   375  		rows = append(rows, row)
   376  	}
   377  
   378  	_, err = r.NextResultSet(ctx)
   379  	switch {
   380  	case err == nil:
   381  		return nil, xerrors.WithStackTrace(errMoreThanOneResultSet)
   382  	case errors.Is(err, io.EOF):
   383  		// pass
   384  	default:
   385  		return nil, xerrors.WithStackTrace(err)
   386  	}
   387  
   388  	return MaterializedResultSet(rs.Index(), rs.Columns(), rs.ColumnTypes(), rows), nil
   389  }
   390  
   391  func resultToMaterializedResult(ctx context.Context, r result.Result) (result.Result, error) {
   392  	var resultSets []result.Set
   393  
   394  	for {
   395  		rs, err := r.NextResultSet(ctx)
   396  		if err != nil {
   397  			if xerrors.Is(err, io.EOF) {
   398  				break
   399  			}
   400  
   401  			return nil, xerrors.WithStackTrace(err)
   402  		}
   403  
   404  		var rows []query.Row
   405  		for {
   406  			row, err := rs.NextRow(ctx)
   407  			if err != nil {
   408  				if xerrors.Is(err, io.EOF) {
   409  					break
   410  				}
   411  
   412  				return nil, xerrors.WithStackTrace(err)
   413  			}
   414  
   415  			rows = append(rows, row)
   416  		}
   417  
   418  		resultSets = append(resultSets, MaterializedResultSet(rs.Index(), rs.Columns(), rs.ColumnTypes(), rows))
   419  	}
   420  
   421  	return &materializedResult{
   422  		resultSets: resultSets,
   423  	}, nil
   424  }