github.com/ydb-platform/ydb-go-sdk/v3@v3.89.2/tests/integration/connection_test.go (about)

     1  //go:build integration
     2  // +build integration
     3  
     4  package integration
     5  
     6  import (
     7  	"context"
     8  	"crypto/tls"
     9  	"fmt"
    10  	"os"
    11  	"testing"
    12  	"time"
    13  
    14  	"github.com/stretchr/testify/require"
    15  	"github.com/ydb-platform/ydb-go-genproto/Ydb_Discovery_V1"
    16  	"github.com/ydb-platform/ydb-go-genproto/Ydb_Export_V1"
    17  	"github.com/ydb-platform/ydb-go-genproto/Ydb_Scripting_V1"
    18  	"github.com/ydb-platform/ydb-go-genproto/protos/Ydb"
    19  	"github.com/ydb-platform/ydb-go-genproto/protos/Ydb_Discovery"
    20  	"github.com/ydb-platform/ydb-go-genproto/protos/Ydb_Export"
    21  	"github.com/ydb-platform/ydb-go-genproto/protos/Ydb_Operations"
    22  	"github.com/ydb-platform/ydb-go-genproto/protos/Ydb_Scripting"
    23  	"google.golang.org/grpc"
    24  	grpcCodes "google.golang.org/grpc/codes"
    25  	"google.golang.org/grpc/metadata"
    26  	"google.golang.org/protobuf/proto"
    27  	"google.golang.org/protobuf/types/known/durationpb"
    28  
    29  	"github.com/ydb-platform/ydb-go-sdk/v3"
    30  	"github.com/ydb-platform/ydb-go-sdk/v3/config"
    31  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/meta"
    32  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/xtest"
    33  	"github.com/ydb-platform/ydb-go-sdk/v3/log"
    34  	"github.com/ydb-platform/ydb-go-sdk/v3/retry"
    35  	"github.com/ydb-platform/ydb-go-sdk/v3/trace"
    36  )
    37  
    38  //nolint:gocyclo
    39  func TestConnection(sourceTest *testing.T) {
    40  	t := xtest.MakeSyncedTest(sourceTest)
    41  	const sumColumn = "sum"
    42  	var (
    43  		userAgent     = "connection user agent"
    44  		requestType   = "connection request type"
    45  		traceParentID = "test-traceparent-id"
    46  		checkMetadata = func(ctx context.Context) {
    47  			md, has := metadata.FromOutgoingContext(ctx)
    48  			if !has {
    49  				t.Fatalf("no medatada")
    50  			}
    51  			userAgents := md.Get(meta.HeaderApplicationName)
    52  			if len(userAgents) == 0 {
    53  				t.Fatalf("no user agent")
    54  			}
    55  			if userAgents[0] != userAgent {
    56  				t.Fatalf("unknown user agent: %s", userAgents[0])
    57  			}
    58  			requestTypes := md.Get(meta.HeaderRequestType)
    59  			if len(requestTypes) == 0 {
    60  				t.Fatalf("no request type")
    61  			}
    62  			if requestTypes[0] != requestType {
    63  				t.Fatalf("unknown request type: %s", requestTypes[0])
    64  			}
    65  			traceIDs := md.Get(meta.HeaderTraceID)
    66  			if len(traceIDs) == 0 {
    67  				t.Fatalf("no traceIDs")
    68  			}
    69  			if len(traceIDs[0]) == 0 {
    70  				t.Fatalf("empty traceID header")
    71  			}
    72  			traceParent := md.Get(meta.HeaderTraceParent)
    73  			if len(traceParent) == 0 {
    74  				t.Fatalf("no traceparent header")
    75  			}
    76  			if len(traceParent[0]) == 0 {
    77  				t.Fatalf("empty traceparent header")
    78  			}
    79  			if traceParent[0] != traceParentID {
    80  				t.Fatalf("unexpected traceparent header")
    81  			}
    82  		}
    83  		ctx = meta.WithTraceParent(xtest.Context(t), traceParentID)
    84  	)
    85  
    86  	t.RunSynced("ydb.New", func(t *xtest.SyncedTest) {
    87  		db, err := ydb.New(ctx, //nolint:gocritic
    88  			ydb.WithConnectionString(os.Getenv("YDB_CONNECTION_STRING")),
    89  			ydb.WithAccessTokenCredentials(
    90  				os.Getenv("YDB_ACCESS_TOKEN_CREDENTIALS"),
    91  			),
    92  			ydb.With(
    93  				config.WithOperationTimeout(time.Second*2),
    94  				config.WithOperationCancelAfter(time.Second*2),
    95  			),
    96  			ydb.WithConnectionTTL(time.Millisecond*10000),
    97  			ydb.WithMinTLSVersion(tls.VersionTLS10),
    98  			ydb.WithLogger(
    99  				newLogger(t),
   100  				trace.MatchDetails(`ydb\.(driver|discovery|retry|scheme).*`),
   101  			),
   102  		)
   103  		if err != nil {
   104  			t.Fatal(err)
   105  		}
   106  		defer func() {
   107  			// cleanup connection
   108  			if e := db.Close(ctx); e != nil {
   109  				t.Fatalf("close failed: %+v", e)
   110  			}
   111  		}()
   112  	})
   113  	t.RunSynced("ydb.Open", func(t *xtest.SyncedTest) {
   114  		db, err := ydb.Open(ctx,
   115  			os.Getenv("YDB_CONNECTION_STRING"),
   116  			ydb.WithAccessTokenCredentials(
   117  				os.Getenv("YDB_ACCESS_TOKEN_CREDENTIALS"),
   118  			),
   119  			ydb.With(
   120  				config.WithOperationTimeout(time.Second*2),
   121  				config.WithOperationCancelAfter(time.Second*2),
   122  			),
   123  			ydb.WithConnectionTTL(time.Millisecond*10000),
   124  			ydb.WithMinTLSVersion(tls.VersionTLS10),
   125  			ydb.WithLogger(
   126  				newLoggerWithMinLevel(t, log.WARN),
   127  				trace.MatchDetails(`ydb\.(driver|discovery|retry|scheme).*`),
   128  			),
   129  			ydb.WithApplicationName(userAgent),
   130  			ydb.WithRequestsType(requestType),
   131  			ydb.With(
   132  				config.WithGrpcOptions(
   133  					grpc.WithUnaryInterceptor(func(
   134  						ctx context.Context,
   135  						method string,
   136  						req, reply interface{},
   137  						cc *grpc.ClientConn,
   138  						invoker grpc.UnaryInvoker,
   139  						opts ...grpc.CallOption,
   140  					) error {
   141  						checkMetadata(ctx)
   142  						return invoker(ctx, method, req, reply, cc, opts...)
   143  					}),
   144  					grpc.WithStreamInterceptor(func(
   145  						ctx context.Context,
   146  						desc *grpc.StreamDesc,
   147  						cc *grpc.ClientConn,
   148  						method string,
   149  						streamer grpc.Streamer,
   150  						opts ...grpc.CallOption,
   151  					) (grpc.ClientStream, error) {
   152  						checkMetadata(ctx)
   153  						return streamer(ctx, desc, cc, method, opts...)
   154  					}),
   155  				),
   156  			),
   157  		)
   158  		if err != nil {
   159  			t.Fatal(err)
   160  		}
   161  		defer func() {
   162  			// cleanup connection
   163  			if e := db.Close(ctx); e != nil {
   164  				t.Fatalf("close failed: %+v", e)
   165  			}
   166  		}()
   167  		t.Run("discovery.WhoAmI", func(t *testing.T) {
   168  			if err = retry.Retry(ctx, func(ctx context.Context) (err error) {
   169  				discoveryClient := Ydb_Discovery_V1.NewDiscoveryServiceClient(ydb.GRPCConn(db))
   170  				response, err := discoveryClient.WhoAmI(
   171  					ctx,
   172  					&Ydb_Discovery.WhoAmIRequest{IncludeGroups: true},
   173  				)
   174  				if err != nil {
   175  					return err
   176  				}
   177  				var result Ydb_Discovery.WhoAmIResult
   178  				err = proto.Unmarshal(response.GetOperation().GetResult().GetValue(), &result)
   179  				if err != nil {
   180  					return
   181  				}
   182  				return nil
   183  			}, retry.WithIdempotent(true)); err != nil {
   184  				t.Fatalf("Execute failed: %v", err)
   185  			}
   186  		})
   187  		t.Run("scripting.ExecuteYql", func(t *testing.T) {
   188  			if err = retry.Retry(ctx, func(ctx context.Context) (err error) {
   189  				scriptingClient := Ydb_Scripting_V1.NewScriptingServiceClient(ydb.GRPCConn(db))
   190  				response, err := scriptingClient.ExecuteYql(
   191  					ctx,
   192  					&Ydb_Scripting.ExecuteYqlRequest{Script: "SELECT 1+100 AS sum"},
   193  				)
   194  				if err != nil {
   195  					return err
   196  				}
   197  				var result Ydb_Scripting.ExecuteYqlResult
   198  				err = proto.Unmarshal(response.GetOperation().GetResult().GetValue(), &result)
   199  				if err != nil {
   200  					return
   201  				}
   202  				if len(result.GetResultSets()) != 1 {
   203  					return fmt.Errorf(
   204  						"unexpected result sets count: %d",
   205  						len(result.GetResultSets()),
   206  					)
   207  				}
   208  				if len(result.GetResultSets()[0].GetColumns()) != 1 {
   209  					return fmt.Errorf(
   210  						"unexpected colums count: %d",
   211  						len(result.GetResultSets()[0].GetColumns()),
   212  					)
   213  				}
   214  				if result.GetResultSets()[0].GetColumns()[0].GetName() != sumColumn {
   215  					return fmt.Errorf(
   216  						"unexpected colum name: %s",
   217  						result.GetResultSets()[0].GetColumns()[0].GetName(),
   218  					)
   219  				}
   220  				if len(result.GetResultSets()[0].GetRows()) != 1 {
   221  					return fmt.Errorf(
   222  						"unexpected rows count: %d",
   223  						len(result.GetResultSets()[0].GetRows()),
   224  					)
   225  				}
   226  				if result.GetResultSets()[0].GetRows()[0].GetItems()[0].GetInt32Value() != 101 {
   227  					return fmt.Errorf(
   228  						"unexpected result of select: %d",
   229  						result.GetResultSets()[0].GetRows()[0].GetInt64Value(),
   230  					)
   231  				}
   232  				return nil
   233  			}, retry.WithIdempotent(true)); err != nil {
   234  				t.Fatalf("Execute failed: %v", err)
   235  			}
   236  		})
   237  		t.Run("scripting.StreamExecuteYql", func(t *testing.T) {
   238  			if err = retry.Retry(ctx, func(ctx context.Context) (err error) {
   239  				scriptingClient := Ydb_Scripting_V1.NewScriptingServiceClient(ydb.GRPCConn(db))
   240  				client, err := scriptingClient.StreamExecuteYql(
   241  					ctx,
   242  					&Ydb_Scripting.ExecuteYqlRequest{Script: "SELECT 1+100 AS sum"},
   243  				)
   244  				if err != nil {
   245  					return err
   246  				}
   247  				response, err := client.Recv()
   248  				if err != nil {
   249  					return err
   250  				}
   251  				if len(response.GetResult().GetResultSet().GetColumns()) != 1 {
   252  					return fmt.Errorf(
   253  						"unexpected colums count: %d",
   254  						len(response.GetResult().GetResultSet().GetColumns()),
   255  					)
   256  				}
   257  				if response.GetResult().GetResultSet().GetColumns()[0].GetName() != sumColumn {
   258  					return fmt.Errorf(
   259  						"unexpected colum name: %s",
   260  						response.GetResult().GetResultSet().GetColumns()[0].GetName(),
   261  					)
   262  				}
   263  				if len(response.GetResult().GetResultSet().GetRows()) != 1 {
   264  					return fmt.Errorf(
   265  						"unexpected rows count: %d",
   266  						len(response.GetResult().GetResultSet().GetRows()),
   267  					)
   268  				}
   269  				if response.GetResult().GetResultSet().GetRows()[0].GetItems()[0].GetInt32Value() != 101 {
   270  					return fmt.Errorf(
   271  						"unexpected result of select: %d",
   272  						response.GetResult().GetResultSet().GetRows()[0].GetInt64Value(),
   273  					)
   274  				}
   275  				return nil
   276  			}, retry.WithIdempotent(true)); err != nil {
   277  				t.Fatalf("Stream execute failed: %v", err)
   278  			}
   279  		})
   280  		t.Run("with.scripting.StreamExecuteYql", func(t *testing.T) {
   281  			var childDB *ydb.Driver
   282  			childDB, err = db.With(
   283  				ctx,
   284  				ydb.WithDialTimeout(time.Second*5),
   285  			)
   286  			if err != nil {
   287  				t.Fatalf("failed to open sub-connection: %v", err)
   288  			}
   289  			defer func() {
   290  				_ = childDB.Close(ctx)
   291  			}()
   292  			if err = retry.Retry(ctx, func(ctx context.Context) (err error) {
   293  				scriptingClient := Ydb_Scripting_V1.NewScriptingServiceClient(ydb.GRPCConn(childDB))
   294  				client, err := scriptingClient.StreamExecuteYql(
   295  					ctx,
   296  					&Ydb_Scripting.ExecuteYqlRequest{Script: "SELECT 1+100 AS sum"},
   297  				)
   298  				if err != nil {
   299  					return err
   300  				}
   301  				response, err := client.Recv()
   302  				if err != nil {
   303  					return err
   304  				}
   305  				if len(response.GetResult().GetResultSet().GetColumns()) != 1 {
   306  					return fmt.Errorf(
   307  						"unexpected colums count: %d",
   308  						len(response.GetResult().GetResultSet().GetColumns()),
   309  					)
   310  				}
   311  				if response.GetResult().GetResultSet().GetColumns()[0].GetName() != sumColumn {
   312  					return fmt.Errorf(
   313  						"unexpected colum name: %s",
   314  						response.GetResult().GetResultSet().GetColumns()[0].GetName(),
   315  					)
   316  				}
   317  				if len(response.GetResult().GetResultSet().GetRows()) != 1 {
   318  					return fmt.Errorf(
   319  						"unexpected rows count: %d",
   320  						len(response.GetResult().GetResultSet().GetRows()),
   321  					)
   322  				}
   323  				if response.GetResult().GetResultSet().GetRows()[0].GetItems()[0].GetInt32Value() != 101 {
   324  					return fmt.Errorf(
   325  						"unexpected result of select: %d",
   326  						response.GetResult().GetResultSet().GetRows()[0].GetInt64Value(),
   327  					)
   328  				}
   329  				return nil
   330  			}, retry.WithIdempotent(true)); err != nil {
   331  				t.Fatalf("Stream execute failed: %v", err)
   332  			}
   333  		})
   334  		t.Run("export.ExportToS3", func(t *testing.T) {
   335  			if err = retry.Retry(ctx, func(ctx context.Context) (err error) {
   336  				exportClient := Ydb_Export_V1.NewExportServiceClient(ydb.GRPCConn(db))
   337  				response, err := exportClient.ExportToS3(
   338  					ctx,
   339  					&Ydb_Export.ExportToS3Request{
   340  						OperationParams: &Ydb_Operations.OperationParams{
   341  							OperationTimeout: durationpb.New(time.Second),
   342  							CancelAfter:      durationpb.New(time.Second),
   343  						},
   344  						Settings: &Ydb_Export.ExportToS3Settings{},
   345  					},
   346  				)
   347  				if err != nil {
   348  					return err
   349  				}
   350  				if response.GetOperation().GetStatus() != Ydb.StatusIds_BAD_REQUEST {
   351  					return fmt.Errorf(
   352  						"operation must be BAD_REQUEST: %s",
   353  						response.GetOperation().GetStatus().String(),
   354  					)
   355  				}
   356  				return nil
   357  			}, retry.WithIdempotent(true)); err != nil {
   358  				t.Fatalf("check export failed: %v", err)
   359  			}
   360  		})
   361  	})
   362  }
   363  
   364  func TestZeroDialTimeout(t *testing.T) {
   365  	ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
   366  	defer cancel()
   367  
   368  	var traceID string
   369  
   370  	db, err := ydb.Open(
   371  		ctx,
   372  		"grpc://non-existent.com:2135/some",
   373  		ydb.WithDialTimeout(0),
   374  		ydb.With(
   375  			config.WithGrpcOptions(
   376  				grpc.WithUnaryInterceptor(func(
   377  					ctx context.Context,
   378  					method string,
   379  					req, reply interface{},
   380  					cc *grpc.ClientConn,
   381  					invoker grpc.UnaryInvoker,
   382  					opts ...grpc.CallOption,
   383  				) error {
   384  					md, has := metadata.FromOutgoingContext(ctx)
   385  					if !has {
   386  						t.Fatalf("no medatada")
   387  					}
   388  					traceIDs := md.Get(meta.HeaderTraceID)
   389  					if len(traceIDs) == 0 {
   390  						t.Fatalf("no traceIDs")
   391  					}
   392  					traceID = traceIDs[0]
   393  					return invoker(ctx, method, req, reply, cc, opts...)
   394  				}),
   395  			),
   396  		),
   397  	)
   398  
   399  	require.Error(t, err)
   400  	require.ErrorContains(t, err, traceID)
   401  	require.Nil(t, db)
   402  	if !ydb.IsTransportError(err, grpcCodes.DeadlineExceeded) {
   403  		require.ErrorIs(t, err, context.DeadlineExceeded)
   404  	}
   405  }
   406  
   407  func TestClusterDiscoveryRetry(t *testing.T) {
   408  	ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
   409  	defer cancel()
   410  
   411  	counter := 0
   412  
   413  	db, err := ydb.Open(ctx,
   414  		"grpc://non-existent.com:2135/some",
   415  		ydb.WithDialTimeout(time.Second),
   416  		ydb.WithTraceDriver(trace.Driver{
   417  			OnBalancerClusterDiscoveryAttempt: func(info trace.DriverBalancerClusterDiscoveryAttemptStartInfo) func(
   418  				trace.DriverBalancerClusterDiscoveryAttemptDoneInfo,
   419  			) {
   420  				counter++
   421  				return nil
   422  			},
   423  		}),
   424  	)
   425  	t.Logf("attempts: %d", counter)
   426  	t.Logf("err: %v", err)
   427  	require.Error(t, err)
   428  	require.Nil(t, db)
   429  	if !ydb.IsTransportError(err, grpcCodes.DeadlineExceeded) {
   430  		require.ErrorIs(t, err, context.DeadlineExceeded)
   431  	}
   432  	require.Greater(t, counter, 1)
   433  }