github.com/ydb-platform/ydb-go-sdk/v3@v3.57.0/testutil/driver.go (about)

     1  package testutil
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"reflect"
     7  	"strings"
     8  
     9  	"github.com/ydb-platform/ydb-go-genproto/protos/Ydb_Operations"
    10  	"google.golang.org/grpc"
    11  	"google.golang.org/grpc/metadata"
    12  	"google.golang.org/protobuf/proto"
    13  	"google.golang.org/protobuf/types/known/anypb"
    14  
    15  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/endpoint"
    16  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors"
    17  )
    18  
    19  var ErrNotImplemented = xerrors.Wrap(fmt.Errorf("testutil: not implemented"))
    20  
    21  type MethodCode uint
    22  
    23  func (m MethodCode) String() string {
    24  	if method, ok := codeToString[m]; ok {
    25  		return method
    26  	}
    27  
    28  	return ""
    29  }
    30  
    31  type Method string
    32  
    33  func (m Method) Code() MethodCode {
    34  	if code, ok := grpcMethodToCode[m]; ok {
    35  		return code
    36  	}
    37  
    38  	return UnknownMethod
    39  }
    40  
    41  const (
    42  	UnknownMethod MethodCode = iota
    43  	TableCreateSession
    44  	TableDeleteSession
    45  	TableKeepAlive
    46  	TableCreateTable
    47  	TableDropTable
    48  	TableAlterTable
    49  	TableCopyTable
    50  	TableDescribeTable
    51  	TableExplainDataQuery
    52  	TablePrepareDataQuery
    53  	TableExecuteDataQuery
    54  	TableExecuteSchemeQuery
    55  	TableBeginTransaction
    56  	TableCommitTransaction
    57  	TableRollbackTransaction
    58  	TableDescribeTableOptions
    59  	TableStreamReadTable
    60  	TableStreamExecuteScanQuery
    61  )
    62  
    63  var grpcMethodToCode = map[Method]MethodCode{
    64  	"/Ydb.Table.V1.TableService/CreateSession":          TableCreateSession,
    65  	"/Ydb.Table.V1.TableService/DeleteSession":          TableDeleteSession,
    66  	"/Ydb.Table.V1.TableService/KeepAlive":              TableKeepAlive,
    67  	"/Ydb.Table.V1.TableService/CreateTable":            TableCreateTable,
    68  	"/Ydb.Table.V1.TableService/DropTable":              TableDropTable,
    69  	"/Ydb.Table.V1.TableService/AlterTable":             TableAlterTable,
    70  	"/Ydb.Table.V1.TableService/CopyTable":              TableCopyTable,
    71  	"/Ydb.Table.V1.TableService/DescribeTable":          TableDescribeTable,
    72  	"/Ydb.Table.V1.TableService/ExplainDataQuery":       TableExplainDataQuery,
    73  	"/Ydb.Table.V1.TableService/PrepareDataQuery":       TablePrepareDataQuery,
    74  	"/Ydb.Table.V1.TableService/ExecuteDataQuery":       TableExecuteDataQuery,
    75  	"/Ydb.Table.V1.TableService/ExecuteSchemeQuery":     TableExecuteSchemeQuery,
    76  	"/Ydb.Table.V1.TableService/BeginTransaction":       TableBeginTransaction,
    77  	"/Ydb.Table.V1.TableService/CommitTransaction":      TableCommitTransaction,
    78  	"/Ydb.Table.V1.TableService/RollbackTransaction":    TableRollbackTransaction,
    79  	"/Ydb.Table.V1.TableService/DescribeTableOptions":   TableDescribeTableOptions,
    80  	"/Ydb.Table.V1.TableService/StreamReadTable":        TableStreamReadTable,
    81  	"/Ydb.Table.V1.TableService/StreamExecuteScanQuery": TableStreamExecuteScanQuery,
    82  }
    83  
    84  var codeToString = map[MethodCode]string{
    85  	TableCreateSession:          lastSegment("/Ydb.Table.V1.TableService/CreateSession"),
    86  	TableDeleteSession:          lastSegment("/Ydb.Table.V1.TableService/DeleteSession"),
    87  	TableKeepAlive:              lastSegment("/Ydb.Table.V1.TableService/KeepAlive"),
    88  	TableCreateTable:            lastSegment("/Ydb.Table.V1.TableService/CreateTable"),
    89  	TableDropTable:              lastSegment("/Ydb.Table.V1.TableService/DropTable"),
    90  	TableAlterTable:             lastSegment("/Ydb.Table.V1.TableService/AlterTable"),
    91  	TableCopyTable:              lastSegment("/Ydb.Table.V1.TableService/CopyTable"),
    92  	TableDescribeTable:          lastSegment("/Ydb.Table.V1.TableService/DescribeTable"),
    93  	TableExplainDataQuery:       lastSegment("/Ydb.Table.V1.TableService/ExplainDataQuery"),
    94  	TablePrepareDataQuery:       lastSegment("/Ydb.Table.V1.TableService/PrepareDataQuery"),
    95  	TableExecuteDataQuery:       lastSegment("/Ydb.Table.V1.TableService/ExecuteDataQuery"),
    96  	TableExecuteSchemeQuery:     lastSegment("/Ydb.Table.V1.TableService/ExecuteSchemeQuery"),
    97  	TableBeginTransaction:       lastSegment("/Ydb.Table.V1.TableService/BeginTransaction"),
    98  	TableCommitTransaction:      lastSegment("/Ydb.Table.V1.TableService/CommitTransaction"),
    99  	TableRollbackTransaction:    lastSegment("/Ydb.Table.V1.TableService/RollbackTransaction"),
   100  	TableDescribeTableOptions:   lastSegment("/Ydb.Table.V1.TableService/DescribeTableOptions"),
   101  	TableStreamReadTable:        lastSegment("/Ydb.Table.V1.TableService/StreamReadTable"),
   102  	TableStreamExecuteScanQuery: lastSegment("/Ydb.Table.V1.TableService/StreamExecuteScanQuery"),
   103  }
   104  
   105  func setField(name string, dst, value interface{}) {
   106  	x := reflect.ValueOf(dst).Elem()
   107  	t := x.Type()
   108  	f, ok := t.FieldByName(name)
   109  	if !ok {
   110  		panic(fmt.Sprintf(
   111  			"struct %s has no field %q",
   112  			t, name,
   113  		))
   114  	}
   115  	v := reflect.ValueOf(value)
   116  	if f.Type.Kind() != v.Type().Kind() {
   117  		panic(fmt.Sprintf(
   118  			"struct %s field %q is types of %s, not %s",
   119  			t, name, f.Type, v.Type(),
   120  		))
   121  	}
   122  	x.FieldByName(f.Name).Set(v)
   123  }
   124  
   125  type balancerStub struct {
   126  	onInvoke func(
   127  		ctx context.Context,
   128  		method string,
   129  		args interface{},
   130  		reply interface{},
   131  		opts ...grpc.CallOption,
   132  	) error
   133  	onNewStream func(
   134  		ctx context.Context,
   135  		desc *grpc.StreamDesc,
   136  		method string,
   137  		opts ...grpc.CallOption,
   138  	) (grpc.ClientStream, error)
   139  }
   140  
   141  func (b *balancerStub) HasNode(id uint32) bool {
   142  	return true
   143  }
   144  
   145  func (b *balancerStub) Invoke(
   146  	ctx context.Context,
   147  	method string,
   148  	args interface{},
   149  	reply interface{},
   150  	opts ...grpc.CallOption,
   151  ) (err error) {
   152  	if b.onInvoke == nil {
   153  		return fmt.Errorf("database.onInvoke() not defined")
   154  	}
   155  
   156  	return b.onInvoke(ctx, method, args, reply, opts...)
   157  }
   158  
   159  func (b *balancerStub) NewStream(
   160  	ctx context.Context,
   161  	desc *grpc.StreamDesc,
   162  	method string,
   163  	opts ...grpc.CallOption,
   164  ) (_ grpc.ClientStream, err error) {
   165  	if b.onNewStream == nil {
   166  		return nil, fmt.Errorf("database.onNewStream() not defined")
   167  	}
   168  
   169  	return b.onNewStream(ctx, desc, method, opts...)
   170  }
   171  
   172  func (b *balancerStub) Get(context.Context) (conn grpc.ClientConnInterface, err error) {
   173  	cc := &clientConn{
   174  		onInvoke:    b.onInvoke,
   175  		onNewStream: b.onNewStream,
   176  	}
   177  
   178  	return cc, nil
   179  }
   180  
   181  func (b *balancerStub) Name() string {
   182  	return "testutil.database"
   183  }
   184  
   185  func (b *balancerStub) Close(ctx context.Context) error {
   186  	return nil
   187  }
   188  
   189  type (
   190  	InvokeHandlers    map[MethodCode]func(request interface{}) (result proto.Message, err error)
   191  	NewStreamHandlers map[MethodCode]func(desc *grpc.StreamDesc) (grpc.ClientStream, error)
   192  )
   193  
   194  type balancerOption func(c *balancerStub)
   195  
   196  func WithInvokeHandlers(invokeHandlers InvokeHandlers) balancerOption {
   197  	return func(r *balancerStub) {
   198  		r.onInvoke = func(
   199  			ctx context.Context,
   200  			method string,
   201  			args interface{},
   202  			reply interface{},
   203  			opts ...grpc.CallOption,
   204  		) (err error) {
   205  			if handler, ok := invokeHandlers[Method(method).Code()]; ok {
   206  				var result proto.Message
   207  				result, err = handler(args)
   208  				if err != nil {
   209  					return xerrors.WithStackTrace(err)
   210  				}
   211  				var anyResult *anypb.Any
   212  				anyResult, err = anypb.New(result)
   213  				if err != nil {
   214  					return xerrors.WithStackTrace(err)
   215  				}
   216  				setField(
   217  					"Operation",
   218  					reply,
   219  					&Ydb_Operations.Operation{
   220  						Result: anyResult,
   221  					},
   222  				)
   223  
   224  				return nil
   225  			}
   226  
   227  			return fmt.Errorf("method '%s' not implemented", method)
   228  		}
   229  	}
   230  }
   231  
   232  func WithNewStreamHandlers(newStreamHandlers NewStreamHandlers) balancerOption {
   233  	return func(r *balancerStub) {
   234  		r.onNewStream = func(
   235  			ctx context.Context,
   236  			desc *grpc.StreamDesc,
   237  			method string,
   238  			opts ...grpc.CallOption,
   239  		) (_ grpc.ClientStream, err error) {
   240  			if handler, ok := newStreamHandlers[Method(method).Code()]; ok {
   241  				return handler(desc)
   242  			}
   243  
   244  			return nil, fmt.Errorf("method '%s' not implemented", method)
   245  		}
   246  	}
   247  }
   248  
   249  func NewBalancer(opts ...balancerOption) *balancerStub {
   250  	c := &balancerStub{}
   251  	for _, opt := range opts {
   252  		if opt != nil {
   253  			opt(c)
   254  		}
   255  	}
   256  
   257  	return c
   258  }
   259  
   260  func (b *balancerStub) OnUpdate(func(context.Context, []endpoint.Info)) {
   261  }
   262  
   263  type clientConn struct {
   264  	onInvoke func(
   265  		ctx context.Context,
   266  		method string,
   267  		args interface{},
   268  		reply interface{},
   269  		opts ...grpc.CallOption,
   270  	) error
   271  	onNewStream func(
   272  		ctx context.Context,
   273  		desc *grpc.StreamDesc,
   274  		method string,
   275  		opts ...grpc.CallOption,
   276  	) (grpc.ClientStream, error)
   277  	onAddress func() string
   278  }
   279  
   280  func (c *clientConn) Address() string {
   281  	if c.onAddress != nil {
   282  		return c.onAddress()
   283  	}
   284  
   285  	return ""
   286  }
   287  
   288  func (c *clientConn) Invoke(
   289  	ctx context.Context,
   290  	method string,
   291  	args interface{},
   292  	reply interface{},
   293  	opts ...grpc.CallOption,
   294  ) error {
   295  	if c.onInvoke == nil {
   296  		return fmt.Errorf("onInvoke not implemented (method: %s, request: %v, response: %v)", method, args, reply)
   297  	}
   298  
   299  	return c.onInvoke(ctx, method, args, reply, opts...)
   300  }
   301  
   302  func (c *clientConn) NewStream(
   303  	ctx context.Context,
   304  	desc *grpc.StreamDesc,
   305  	method string,
   306  	opts ...grpc.CallOption,
   307  ) (grpc.ClientStream, error) {
   308  	if c.onNewStream == nil {
   309  		return nil, fmt.Errorf("onNewStream not implemented (method: %s, desc: %v)", method, desc)
   310  	}
   311  
   312  	return c.onNewStream(ctx, desc, method, opts...)
   313  }
   314  
   315  type ClientStream struct {
   316  	OnHeader    func() (metadata.MD, error)
   317  	OnTrailer   func() metadata.MD
   318  	OnCloseSend func() error
   319  	OnContext   func() context.Context
   320  	OnSendMsg   func(m interface{}) error
   321  	OnRecvMsg   func(m interface{}) error
   322  }
   323  
   324  func (s *ClientStream) Header() (metadata.MD, error) {
   325  	if s.OnHeader == nil {
   326  		return nil, xerrors.WithStackTrace(ErrNotImplemented)
   327  	}
   328  
   329  	return s.OnHeader()
   330  }
   331  
   332  func (s *ClientStream) Trailer() metadata.MD {
   333  	if s.OnTrailer == nil {
   334  		return nil
   335  	}
   336  
   337  	return s.OnTrailer()
   338  }
   339  
   340  func (s *ClientStream) CloseSend() error {
   341  	if s.OnCloseSend == nil {
   342  		return xerrors.WithStackTrace(ErrNotImplemented)
   343  	}
   344  
   345  	return s.OnCloseSend()
   346  }
   347  
   348  func (s *ClientStream) Context() context.Context {
   349  	if s.OnContext == nil {
   350  		return nil
   351  	}
   352  
   353  	return s.OnContext()
   354  }
   355  
   356  func (s *ClientStream) SendMsg(m interface{}) error {
   357  	if s.OnSendMsg == nil {
   358  		return xerrors.WithStackTrace(ErrNotImplemented)
   359  	}
   360  
   361  	return s.OnSendMsg(m)
   362  }
   363  
   364  func (s *ClientStream) RecvMsg(m interface{}) error {
   365  	if s.OnRecvMsg == nil {
   366  		return xerrors.WithStackTrace(ErrNotImplemented)
   367  	}
   368  
   369  	return s.OnRecvMsg(m)
   370  }
   371  
   372  func lastSegment(m string) string {
   373  	s := strings.Split(m, "/")
   374  
   375  	return s[len(s)-1]
   376  }