github.com/ydb-platform/ydb-go-sdk/v3@v3.57.0/internal/table/scanner/result_test.go (about)

     1  package scanner
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"io"
     7  	"reflect"
     8  	"testing"
     9  	"time"
    10  
    11  	"github.com/stretchr/testify/require"
    12  	"github.com/ydb-platform/ydb-go-genproto/protos/Ydb"
    13  	"github.com/ydb-platform/ydb-go-genproto/protos/Ydb_TableStats"
    14  
    15  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/allocator"
    16  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/types"
    17  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/value"
    18  	"github.com/ydb-platform/ydb-go-sdk/v3/table/options"
    19  )
    20  
    21  func TestResultAny(t *testing.T) {
    22  	for _, test := range []struct {
    23  		name    string
    24  		columns []options.Column
    25  		values  []value.Value
    26  		exp     []interface{}
    27  	}{
    28  		{
    29  			columns: []options.Column{
    30  				{
    31  					Name:   "column0",
    32  					Type:   types.NewOptional(types.Uint32),
    33  					Family: "family0",
    34  				},
    35  			},
    36  			values: []value.Value{
    37  				value.OptionalValue(value.Uint32Value(43)),
    38  				value.NullValue(types.Uint32),
    39  			},
    40  			exp: []interface{}{
    41  				uint32(43),
    42  				nil,
    43  			},
    44  		},
    45  	} {
    46  		t.Run(test.name, func(t *testing.T) {
    47  			a := allocator.New()
    48  			defer a.Free()
    49  			res := NewUnary(
    50  				[]*Ydb.ResultSet{
    51  					NewResultSet(a,
    52  						WithColumns(test.columns...),
    53  						WithValues(test.values...),
    54  					),
    55  				},
    56  				nil,
    57  			)
    58  			var i int
    59  			var act interface{}
    60  			for res.NextResultSet(context.Background()) {
    61  				for res.NextRow() {
    62  					err := res.ScanWithDefaults(&act)
    63  					if err != nil {
    64  						t.Fatal(err)
    65  					}
    66  					if exp := test.exp[i]; !reflect.DeepEqual(act, exp) {
    67  						t.Errorf(
    68  							"unexpected Any() result: %[1]v (%[1]T); want %[2]v (%[2]T)",
    69  							act, exp,
    70  						)
    71  					}
    72  					i++
    73  				}
    74  			}
    75  			if err := res.Err(); err != nil {
    76  				t.Fatal(err)
    77  			}
    78  		})
    79  	}
    80  }
    81  
    82  func TestResultOUint32(t *testing.T) {
    83  	for _, test := range []struct {
    84  		name    string
    85  		columns []options.Column
    86  		values  []value.Value
    87  		exp     []uint32
    88  	}{
    89  		{
    90  			columns: []options.Column{
    91  				{
    92  					Name:   "column0",
    93  					Type:   types.NewOptional(types.Uint32),
    94  					Family: "family0",
    95  				},
    96  				{
    97  					Name:   "column1",
    98  					Type:   types.Uint32,
    99  					Family: "family0",
   100  				},
   101  			},
   102  			values: []value.Value{
   103  				value.OptionalValue(value.Uint32Value(43)),
   104  				value.Uint32Value(43),
   105  			},
   106  			exp: []uint32{
   107  				43,
   108  				43,
   109  			},
   110  		},
   111  	} {
   112  		t.Run(test.name, func(t *testing.T) {
   113  			a := allocator.New()
   114  			defer a.Free()
   115  			res := NewUnary(
   116  				[]*Ydb.ResultSet{
   117  					NewResultSet(a,
   118  						WithColumns(test.columns...),
   119  						WithValues(test.values...),
   120  					),
   121  				},
   122  				nil,
   123  			)
   124  			var i int
   125  			var act uint32
   126  			for res.NextResultSet(context.Background()) {
   127  				for res.NextRow() {
   128  					_ = res.ScanWithDefaults(&act)
   129  					if exp := test.exp[i]; !reflect.DeepEqual(act, exp) {
   130  						t.Errorf(
   131  							"unexpected OUint32() result: %[1]v (%[1]T); want %[2]v (%[2]T)",
   132  							act, exp,
   133  						)
   134  					}
   135  					i++
   136  				}
   137  			}
   138  			if err := res.Err(); err != nil {
   139  				t.Fatal(err)
   140  			}
   141  		})
   142  	}
   143  }
   144  
   145  type resultSetDesc Ydb.ResultSet
   146  
   147  type ResultSetOption func(*resultSetDesc, *allocator.Allocator)
   148  
   149  func WithColumns(cs ...options.Column) ResultSetOption {
   150  	return func(r *resultSetDesc, a *allocator.Allocator) {
   151  		for _, c := range cs {
   152  			r.Columns = append(r.Columns, &Ydb.Column{
   153  				Name: c.Name,
   154  				Type: types.TypeToYDB(c.Type, a),
   155  			})
   156  		}
   157  	}
   158  }
   159  
   160  func WithValues(vs ...value.Value) ResultSetOption {
   161  	return func(r *resultSetDesc, a *allocator.Allocator) {
   162  		n := len(r.Columns)
   163  		if n == 0 {
   164  			panic("empty columns")
   165  		}
   166  		if len(vs)%n != 0 {
   167  			panic("malformed values set")
   168  		}
   169  		var row *Ydb.Value
   170  		for i, v := range vs {
   171  			j := i % n
   172  			if j == 0 && i > 0 {
   173  				r.Rows = append(r.Rows, row)
   174  			}
   175  			if j == 0 {
   176  				row = &Ydb.Value{
   177  					Items: make([]*Ydb.Value, n),
   178  				}
   179  			}
   180  			tv := value.ToYDB(v, a)
   181  			act := types.TypeFromYDB(tv.GetType())
   182  			exp := types.TypeFromYDB(r.Columns[j].GetType())
   183  			if !types.Equal(act, exp) {
   184  				panic(fmt.Sprintf(
   185  					"unexpected types for #%d column: %s; want %s",
   186  					j, act, exp,
   187  				))
   188  			}
   189  			row.Items[j] = tv.GetValue()
   190  		}
   191  		if row != nil {
   192  			r.Rows = append(r.Rows, row)
   193  		}
   194  	}
   195  }
   196  
   197  func NewResultSet(a *allocator.Allocator, opts ...ResultSetOption) *Ydb.ResultSet {
   198  	var d resultSetDesc
   199  	for _, opt := range opts {
   200  		if opt != nil {
   201  			opt(&d, a)
   202  		}
   203  	}
   204  
   205  	return (*Ydb.ResultSet)(&d)
   206  }
   207  
   208  func TestNewStreamWithRecvFirstResultSet(t *testing.T) {
   209  	for _, tt := range []struct {
   210  		ctx         context.Context
   211  		recvCounter int
   212  		err         error
   213  	}{
   214  		{
   215  			ctx: context.Background(),
   216  			err: nil,
   217  		},
   218  		{
   219  			ctx: func() context.Context {
   220  				ctx, cancel := context.WithCancel(context.Background())
   221  				cancel()
   222  
   223  				return ctx
   224  			}(),
   225  			err: context.Canceled,
   226  		},
   227  		{
   228  			ctx: func() context.Context {
   229  				ctx, cancel := context.WithTimeout(context.Background(), 0)
   230  				cancel()
   231  
   232  				return ctx
   233  			}(),
   234  			err: context.DeadlineExceeded,
   235  		},
   236  		{
   237  			ctx: func() context.Context {
   238  				ctx, cancel := context.WithTimeout(context.Background(), time.Hour)
   239  				cancel()
   240  
   241  				return ctx
   242  			}(),
   243  			err: context.Canceled,
   244  		},
   245  	} {
   246  		t.Run("", func(t *testing.T) {
   247  			result, err := NewStream(tt.ctx,
   248  				func(ctx context.Context) (*Ydb.ResultSet, *Ydb_TableStats.QueryStats, error) {
   249  					tt.recvCounter++
   250  					if tt.recvCounter > 1000 {
   251  						return nil, nil, io.EOF
   252  					}
   253  
   254  					return &Ydb.ResultSet{}, nil, ctx.Err()
   255  				},
   256  				func(err error) error {
   257  					return err
   258  				},
   259  			)
   260  			if tt.err != nil {
   261  				require.ErrorIs(t, err, tt.err)
   262  				require.Nil(t, result)
   263  			} else {
   264  				require.NoError(t, err)
   265  				require.NotNil(t, result)
   266  				require.EqualValues(t, 1, tt.recvCounter)
   267  				require.EqualValues(t, 1, result.(*streamResult).nextResultSetCounter.Load())
   268  				for i := range make([]struct{}, 1000) {
   269  					err = result.NextResultSetErr(tt.ctx)
   270  					require.NoError(t, err)
   271  					require.Equal(t, i+1, tt.recvCounter)
   272  					require.Equal(t, i+2, int(result.(*streamResult).nextResultSetCounter.Load()))
   273  				}
   274  				err = result.NextResultSetErr(tt.ctx)
   275  				require.ErrorIs(t, err, io.EOF)
   276  				require.True(t, err == io.EOF) //nolint:errorlint,testifylint
   277  				require.Equal(t, 1001, tt.recvCounter)
   278  				require.Equal(t, 1002, int(result.(*streamResult).nextResultSetCounter.Load()))
   279  			}
   280  		})
   281  	}
   282  }