github.com/ydb-platform/ydb-go-sdk/v3@v3.89.2/testutil/driver.go (about)

     1  package testutil
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"reflect"
     7  	"strings"
     8  
     9  	"github.com/ydb-platform/ydb-go-genproto/protos/Ydb_Operations"
    10  	"google.golang.org/grpc"
    11  	"google.golang.org/grpc/metadata"
    12  	"google.golang.org/protobuf/proto"
    13  	"google.golang.org/protobuf/types/known/anypb"
    14  
    15  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/endpoint"
    16  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors"
    17  )
    18  
    19  var ErrNotImplemented = xerrors.Wrap(fmt.Errorf("testutil: not implemented"))
    20  
    21  type MethodCode uint
    22  
    23  func (m MethodCode) String() string {
    24  	if method, ok := codeToString[m]; ok {
    25  		return method
    26  	}
    27  
    28  	return ""
    29  }
    30  
    31  type Method string
    32  
    33  func (m Method) Code() MethodCode {
    34  	if code, ok := grpcMethodToCode[m]; ok {
    35  		return code
    36  	}
    37  
    38  	return UnknownMethod
    39  }
    40  
    41  const (
    42  	UnknownMethod MethodCode = iota
    43  	TableCreateSession
    44  	TableDeleteSession
    45  	TableKeepAlive
    46  	TableCreateTable
    47  	TableDropTable
    48  	TableAlterTable
    49  	TableCopyTable
    50  	TableDescribeTable
    51  	TableExplainDataQuery
    52  	TablePrepareDataQuery
    53  	TableExecuteDataQuery
    54  	TableExecuteSchemeQuery
    55  	TableBeginTransaction
    56  	TableCommitTransaction
    57  	TableRollbackTransaction
    58  	TableDescribeTableOptions
    59  	TableStreamReadTable
    60  	TableStreamExecuteScanQuery
    61  )
    62  
    63  var grpcMethodToCode = map[Method]MethodCode{
    64  	"/Ydb.Table.V1.TableService/CreateSession":          TableCreateSession,
    65  	"/Ydb.Table.V1.TableService/DeleteSession":          TableDeleteSession,
    66  	"/Ydb.Table.V1.TableService/KeepAlive":              TableKeepAlive,
    67  	"/Ydb.Table.V1.TableService/CreateTable":            TableCreateTable,
    68  	"/Ydb.Table.V1.TableService/DropTable":              TableDropTable,
    69  	"/Ydb.Table.V1.TableService/AlterTable":             TableAlterTable,
    70  	"/Ydb.Table.V1.TableService/CopyTable":              TableCopyTable,
    71  	"/Ydb.Table.V1.TableService/DescribeTable":          TableDescribeTable,
    72  	"/Ydb.Table.V1.TableService/ExplainDataQuery":       TableExplainDataQuery,
    73  	"/Ydb.Table.V1.TableService/PrepareDataQuery":       TablePrepareDataQuery,
    74  	"/Ydb.Table.V1.TableService/ExecuteDataQuery":       TableExecuteDataQuery,
    75  	"/Ydb.Table.V1.TableService/ExecuteSchemeQuery":     TableExecuteSchemeQuery,
    76  	"/Ydb.Table.V1.TableService/BeginTransaction":       TableBeginTransaction,
    77  	"/Ydb.Table.V1.TableService/CommitTransaction":      TableCommitTransaction,
    78  	"/Ydb.Table.V1.TableService/RollbackTransaction":    TableRollbackTransaction,
    79  	"/Ydb.Table.V1.TableService/DescribeTableOptions":   TableDescribeTableOptions,
    80  	"/Ydb.Table.V1.TableService/StreamReadTable":        TableStreamReadTable,
    81  	"/Ydb.Table.V1.TableService/StreamExecuteScanQuery": TableStreamExecuteScanQuery,
    82  }
    83  
    84  var codeToString = map[MethodCode]string{
    85  	TableCreateSession:          lastSegment("/Ydb.Table.V1.TableService/CreateSession"),
    86  	TableDeleteSession:          lastSegment("/Ydb.Table.V1.TableService/DeleteSession"),
    87  	TableKeepAlive:              lastSegment("/Ydb.Table.V1.TableService/KeepAlive"),
    88  	TableCreateTable:            lastSegment("/Ydb.Table.V1.TableService/CreateTable"),
    89  	TableDropTable:              lastSegment("/Ydb.Table.V1.TableService/DropTable"),
    90  	TableAlterTable:             lastSegment("/Ydb.Table.V1.TableService/AlterTable"),
    91  	TableCopyTable:              lastSegment("/Ydb.Table.V1.TableService/CopyTable"),
    92  	TableDescribeTable:          lastSegment("/Ydb.Table.V1.TableService/DescribeTable"),
    93  	TableExplainDataQuery:       lastSegment("/Ydb.Table.V1.TableService/ExplainDataQuery"),
    94  	TablePrepareDataQuery:       lastSegment("/Ydb.Table.V1.TableService/PrepareDataQuery"),
    95  	TableExecuteDataQuery:       lastSegment("/Ydb.Table.V1.TableService/ExecuteDataQuery"),
    96  	TableExecuteSchemeQuery:     lastSegment("/Ydb.Table.V1.TableService/ExecuteSchemeQuery"),
    97  	TableBeginTransaction:       lastSegment("/Ydb.Table.V1.TableService/BeginTransaction"),
    98  	TableCommitTransaction:      lastSegment("/Ydb.Table.V1.TableService/CommitTransaction"),
    99  	TableRollbackTransaction:    lastSegment("/Ydb.Table.V1.TableService/RollbackTransaction"),
   100  	TableDescribeTableOptions:   lastSegment("/Ydb.Table.V1.TableService/DescribeTableOptions"),
   101  	TableStreamReadTable:        lastSegment("/Ydb.Table.V1.TableService/StreamReadTable"),
   102  	TableStreamExecuteScanQuery: lastSegment("/Ydb.Table.V1.TableService/StreamExecuteScanQuery"),
   103  }
   104  
   105  func setField(name string, dst, value interface{}) {
   106  	x := reflect.ValueOf(dst).Elem()
   107  	t := x.Type()
   108  	f, ok := t.FieldByName(name)
   109  	if !ok {
   110  		panic(fmt.Sprintf(
   111  			"struct %s has no field %q",
   112  			t, name,
   113  		))
   114  	}
   115  	v := reflect.ValueOf(value)
   116  	if f.Type.Kind() != v.Type().Kind() {
   117  		panic(fmt.Sprintf(
   118  			"struct %s field %q is types of %s, not %s",
   119  			t, name, f.Type, v.Type(),
   120  		))
   121  	}
   122  	x.FieldByName(f.Name).Set(v)
   123  }
   124  
   125  type balancerStub struct {
   126  	onInvoke func(
   127  		ctx context.Context,
   128  		method string,
   129  		args interface{},
   130  		reply interface{},
   131  		opts ...grpc.CallOption,
   132  	) error
   133  	onNewStream func(
   134  		ctx context.Context,
   135  		desc *grpc.StreamDesc,
   136  		method string,
   137  		opts ...grpc.CallOption,
   138  	) (grpc.ClientStream, error)
   139  }
   140  
   141  func (b *balancerStub) Invoke(
   142  	ctx context.Context,
   143  	method string,
   144  	args interface{},
   145  	reply interface{},
   146  	opts ...grpc.CallOption,
   147  ) (err error) {
   148  	if b.onInvoke == nil {
   149  		return fmt.Errorf("database.onInvoke() not defined")
   150  	}
   151  
   152  	return b.onInvoke(ctx, method, args, reply, opts...)
   153  }
   154  
   155  func (b *balancerStub) NewStream(
   156  	ctx context.Context,
   157  	desc *grpc.StreamDesc,
   158  	method string,
   159  	opts ...grpc.CallOption,
   160  ) (_ grpc.ClientStream, err error) {
   161  	if b.onNewStream == nil {
   162  		return nil, fmt.Errorf("database.onNewStream() not defined")
   163  	}
   164  
   165  	return b.onNewStream(ctx, desc, method, opts...)
   166  }
   167  
   168  func (b *balancerStub) Get(context.Context) (conn grpc.ClientConnInterface, err error) {
   169  	cc := &clientConn{
   170  		onInvoke:    b.onInvoke,
   171  		onNewStream: b.onNewStream,
   172  	}
   173  
   174  	return cc, nil
   175  }
   176  
   177  func (b *balancerStub) Name() string {
   178  	return "testutil.database"
   179  }
   180  
   181  func (b *balancerStub) Close(ctx context.Context) error {
   182  	return nil
   183  }
   184  
   185  type (
   186  	InvokeHandlers    map[MethodCode]func(request interface{}) (result proto.Message, err error)
   187  	NewStreamHandlers map[MethodCode]func(desc *grpc.StreamDesc) (grpc.ClientStream, error)
   188  )
   189  
   190  type balancerOption func(c *balancerStub)
   191  
   192  func WithInvokeHandlers(invokeHandlers InvokeHandlers) balancerOption {
   193  	return func(r *balancerStub) {
   194  		r.onInvoke = func(
   195  			ctx context.Context,
   196  			method string,
   197  			args interface{},
   198  			reply interface{},
   199  			opts ...grpc.CallOption,
   200  		) (err error) {
   201  			if handler, ok := invokeHandlers[Method(method).Code()]; ok {
   202  				var result proto.Message
   203  				result, err = handler(args)
   204  				if err != nil {
   205  					return xerrors.WithStackTrace(err)
   206  				}
   207  				var anyResult *anypb.Any
   208  				anyResult, err = anypb.New(result)
   209  				if err != nil {
   210  					return xerrors.WithStackTrace(err)
   211  				}
   212  				setField(
   213  					"Operation",
   214  					reply,
   215  					&Ydb_Operations.Operation{
   216  						Result: anyResult,
   217  					},
   218  				)
   219  
   220  				return nil
   221  			}
   222  
   223  			return fmt.Errorf("method '%s' not implemented", method)
   224  		}
   225  	}
   226  }
   227  
   228  func WithNewStreamHandlers(newStreamHandlers NewStreamHandlers) balancerOption {
   229  	return func(r *balancerStub) {
   230  		r.onNewStream = func(
   231  			ctx context.Context,
   232  			desc *grpc.StreamDesc,
   233  			method string,
   234  			opts ...grpc.CallOption,
   235  		) (_ grpc.ClientStream, err error) {
   236  			if handler, ok := newStreamHandlers[Method(method).Code()]; ok {
   237  				return handler(desc)
   238  			}
   239  
   240  			return nil, fmt.Errorf("method '%s' not implemented", method)
   241  		}
   242  	}
   243  }
   244  
   245  func NewBalancer(opts ...balancerOption) *balancerStub {
   246  	c := &balancerStub{}
   247  	for _, opt := range opts {
   248  		if opt != nil {
   249  			opt(c)
   250  		}
   251  	}
   252  
   253  	return c
   254  }
   255  
   256  func (b *balancerStub) OnUpdate(func(context.Context, []endpoint.Info)) {
   257  }
   258  
   259  type clientConn struct {
   260  	onInvoke func(
   261  		ctx context.Context,
   262  		method string,
   263  		args interface{},
   264  		reply interface{},
   265  		opts ...grpc.CallOption,
   266  	) error
   267  	onNewStream func(
   268  		ctx context.Context,
   269  		desc *grpc.StreamDesc,
   270  		method string,
   271  		opts ...grpc.CallOption,
   272  	) (grpc.ClientStream, error)
   273  	onAddress func() string
   274  }
   275  
   276  func (c *clientConn) Address() string {
   277  	if c.onAddress != nil {
   278  		return c.onAddress()
   279  	}
   280  
   281  	return ""
   282  }
   283  
   284  func (c *clientConn) Invoke(
   285  	ctx context.Context,
   286  	method string,
   287  	args interface{},
   288  	reply interface{},
   289  	opts ...grpc.CallOption,
   290  ) error {
   291  	if c.onInvoke == nil {
   292  		return fmt.Errorf("onInvoke not implemented (method: %s, request: %v, response: %v)", method, args, reply)
   293  	}
   294  
   295  	return c.onInvoke(ctx, method, args, reply, opts...)
   296  }
   297  
   298  func (c *clientConn) NewStream(
   299  	ctx context.Context,
   300  	desc *grpc.StreamDesc,
   301  	method string,
   302  	opts ...grpc.CallOption,
   303  ) (grpc.ClientStream, error) {
   304  	if c.onNewStream == nil {
   305  		return nil, fmt.Errorf("onNewStream not implemented (method: %s, desc: %v)", method, desc)
   306  	}
   307  
   308  	return c.onNewStream(ctx, desc, method, opts...)
   309  }
   310  
   311  type ClientStream struct {
   312  	OnHeader    func() (metadata.MD, error)
   313  	OnTrailer   func() metadata.MD
   314  	OnCloseSend func() error
   315  	OnContext   func() context.Context
   316  	OnSendMsg   func(m interface{}) error
   317  	OnRecvMsg   func(m interface{}) error
   318  }
   319  
   320  func (s *ClientStream) Header() (metadata.MD, error) {
   321  	if s.OnHeader == nil {
   322  		return nil, xerrors.WithStackTrace(ErrNotImplemented)
   323  	}
   324  
   325  	return s.OnHeader()
   326  }
   327  
   328  func (s *ClientStream) Trailer() metadata.MD {
   329  	if s.OnTrailer == nil {
   330  		return nil
   331  	}
   332  
   333  	return s.OnTrailer()
   334  }
   335  
   336  func (s *ClientStream) CloseSend() error {
   337  	if s.OnCloseSend == nil {
   338  		return xerrors.WithStackTrace(ErrNotImplemented)
   339  	}
   340  
   341  	return s.OnCloseSend()
   342  }
   343  
   344  func (s *ClientStream) Context() context.Context {
   345  	if s.OnContext == nil {
   346  		return nil
   347  	}
   348  
   349  	return s.OnContext()
   350  }
   351  
   352  func (s *ClientStream) SendMsg(m interface{}) error {
   353  	if s.OnSendMsg == nil {
   354  		return xerrors.WithStackTrace(ErrNotImplemented)
   355  	}
   356  
   357  	return s.OnSendMsg(m)
   358  }
   359  
   360  func (s *ClientStream) RecvMsg(m interface{}) error {
   361  	if s.OnRecvMsg == nil {
   362  		return xerrors.WithStackTrace(ErrNotImplemented)
   363  	}
   364  
   365  	return s.OnRecvMsg(m)
   366  }
   367  
   368  func lastSegment(m string) string {
   369  	s := strings.Split(m, "/")
   370  
   371  	return s[len(s)-1]
   372  }