github.com/ydb-platform/ydb-go-sdk/v3@v3.57.0/testutil/driver.go (about) 1 package testutil 2 3 import ( 4 "context" 5 "fmt" 6 "reflect" 7 "strings" 8 9 "github.com/ydb-platform/ydb-go-genproto/protos/Ydb_Operations" 10 "google.golang.org/grpc" 11 "google.golang.org/grpc/metadata" 12 "google.golang.org/protobuf/proto" 13 "google.golang.org/protobuf/types/known/anypb" 14 15 "github.com/ydb-platform/ydb-go-sdk/v3/internal/endpoint" 16 "github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors" 17 ) 18 19 var ErrNotImplemented = xerrors.Wrap(fmt.Errorf("testutil: not implemented")) 20 21 type MethodCode uint 22 23 func (m MethodCode) String() string { 24 if method, ok := codeToString[m]; ok { 25 return method 26 } 27 28 return "" 29 } 30 31 type Method string 32 33 func (m Method) Code() MethodCode { 34 if code, ok := grpcMethodToCode[m]; ok { 35 return code 36 } 37 38 return UnknownMethod 39 } 40 41 const ( 42 UnknownMethod MethodCode = iota 43 TableCreateSession 44 TableDeleteSession 45 TableKeepAlive 46 TableCreateTable 47 TableDropTable 48 TableAlterTable 49 TableCopyTable 50 TableDescribeTable 51 TableExplainDataQuery 52 TablePrepareDataQuery 53 TableExecuteDataQuery 54 TableExecuteSchemeQuery 55 TableBeginTransaction 56 TableCommitTransaction 57 TableRollbackTransaction 58 TableDescribeTableOptions 59 TableStreamReadTable 60 TableStreamExecuteScanQuery 61 ) 62 63 var grpcMethodToCode = map[Method]MethodCode{ 64 "/Ydb.Table.V1.TableService/CreateSession": TableCreateSession, 65 "/Ydb.Table.V1.TableService/DeleteSession": TableDeleteSession, 66 "/Ydb.Table.V1.TableService/KeepAlive": TableKeepAlive, 67 "/Ydb.Table.V1.TableService/CreateTable": TableCreateTable, 68 "/Ydb.Table.V1.TableService/DropTable": TableDropTable, 69 "/Ydb.Table.V1.TableService/AlterTable": TableAlterTable, 70 "/Ydb.Table.V1.TableService/CopyTable": TableCopyTable, 71 "/Ydb.Table.V1.TableService/DescribeTable": TableDescribeTable, 72 "/Ydb.Table.V1.TableService/ExplainDataQuery": TableExplainDataQuery, 73 "/Ydb.Table.V1.TableService/PrepareDataQuery": TablePrepareDataQuery, 74 "/Ydb.Table.V1.TableService/ExecuteDataQuery": TableExecuteDataQuery, 75 "/Ydb.Table.V1.TableService/ExecuteSchemeQuery": TableExecuteSchemeQuery, 76 "/Ydb.Table.V1.TableService/BeginTransaction": TableBeginTransaction, 77 "/Ydb.Table.V1.TableService/CommitTransaction": TableCommitTransaction, 78 "/Ydb.Table.V1.TableService/RollbackTransaction": TableRollbackTransaction, 79 "/Ydb.Table.V1.TableService/DescribeTableOptions": TableDescribeTableOptions, 80 "/Ydb.Table.V1.TableService/StreamReadTable": TableStreamReadTable, 81 "/Ydb.Table.V1.TableService/StreamExecuteScanQuery": TableStreamExecuteScanQuery, 82 } 83 84 var codeToString = map[MethodCode]string{ 85 TableCreateSession: lastSegment("/Ydb.Table.V1.TableService/CreateSession"), 86 TableDeleteSession: lastSegment("/Ydb.Table.V1.TableService/DeleteSession"), 87 TableKeepAlive: lastSegment("/Ydb.Table.V1.TableService/KeepAlive"), 88 TableCreateTable: lastSegment("/Ydb.Table.V1.TableService/CreateTable"), 89 TableDropTable: lastSegment("/Ydb.Table.V1.TableService/DropTable"), 90 TableAlterTable: lastSegment("/Ydb.Table.V1.TableService/AlterTable"), 91 TableCopyTable: lastSegment("/Ydb.Table.V1.TableService/CopyTable"), 92 TableDescribeTable: lastSegment("/Ydb.Table.V1.TableService/DescribeTable"), 93 TableExplainDataQuery: lastSegment("/Ydb.Table.V1.TableService/ExplainDataQuery"), 94 TablePrepareDataQuery: lastSegment("/Ydb.Table.V1.TableService/PrepareDataQuery"), 95 TableExecuteDataQuery: lastSegment("/Ydb.Table.V1.TableService/ExecuteDataQuery"), 96 TableExecuteSchemeQuery: lastSegment("/Ydb.Table.V1.TableService/ExecuteSchemeQuery"), 97 TableBeginTransaction: lastSegment("/Ydb.Table.V1.TableService/BeginTransaction"), 98 TableCommitTransaction: lastSegment("/Ydb.Table.V1.TableService/CommitTransaction"), 99 TableRollbackTransaction: lastSegment("/Ydb.Table.V1.TableService/RollbackTransaction"), 100 TableDescribeTableOptions: lastSegment("/Ydb.Table.V1.TableService/DescribeTableOptions"), 101 TableStreamReadTable: lastSegment("/Ydb.Table.V1.TableService/StreamReadTable"), 102 TableStreamExecuteScanQuery: lastSegment("/Ydb.Table.V1.TableService/StreamExecuteScanQuery"), 103 } 104 105 func setField(name string, dst, value interface{}) { 106 x := reflect.ValueOf(dst).Elem() 107 t := x.Type() 108 f, ok := t.FieldByName(name) 109 if !ok { 110 panic(fmt.Sprintf( 111 "struct %s has no field %q", 112 t, name, 113 )) 114 } 115 v := reflect.ValueOf(value) 116 if f.Type.Kind() != v.Type().Kind() { 117 panic(fmt.Sprintf( 118 "struct %s field %q is types of %s, not %s", 119 t, name, f.Type, v.Type(), 120 )) 121 } 122 x.FieldByName(f.Name).Set(v) 123 } 124 125 type balancerStub struct { 126 onInvoke func( 127 ctx context.Context, 128 method string, 129 args interface{}, 130 reply interface{}, 131 opts ...grpc.CallOption, 132 ) error 133 onNewStream func( 134 ctx context.Context, 135 desc *grpc.StreamDesc, 136 method string, 137 opts ...grpc.CallOption, 138 ) (grpc.ClientStream, error) 139 } 140 141 func (b *balancerStub) HasNode(id uint32) bool { 142 return true 143 } 144 145 func (b *balancerStub) Invoke( 146 ctx context.Context, 147 method string, 148 args interface{}, 149 reply interface{}, 150 opts ...grpc.CallOption, 151 ) (err error) { 152 if b.onInvoke == nil { 153 return fmt.Errorf("database.onInvoke() not defined") 154 } 155 156 return b.onInvoke(ctx, method, args, reply, opts...) 157 } 158 159 func (b *balancerStub) NewStream( 160 ctx context.Context, 161 desc *grpc.StreamDesc, 162 method string, 163 opts ...grpc.CallOption, 164 ) (_ grpc.ClientStream, err error) { 165 if b.onNewStream == nil { 166 return nil, fmt.Errorf("database.onNewStream() not defined") 167 } 168 169 return b.onNewStream(ctx, desc, method, opts...) 170 } 171 172 func (b *balancerStub) Get(context.Context) (conn grpc.ClientConnInterface, err error) { 173 cc := &clientConn{ 174 onInvoke: b.onInvoke, 175 onNewStream: b.onNewStream, 176 } 177 178 return cc, nil 179 } 180 181 func (b *balancerStub) Name() string { 182 return "testutil.database" 183 } 184 185 func (b *balancerStub) Close(ctx context.Context) error { 186 return nil 187 } 188 189 type ( 190 InvokeHandlers map[MethodCode]func(request interface{}) (result proto.Message, err error) 191 NewStreamHandlers map[MethodCode]func(desc *grpc.StreamDesc) (grpc.ClientStream, error) 192 ) 193 194 type balancerOption func(c *balancerStub) 195 196 func WithInvokeHandlers(invokeHandlers InvokeHandlers) balancerOption { 197 return func(r *balancerStub) { 198 r.onInvoke = func( 199 ctx context.Context, 200 method string, 201 args interface{}, 202 reply interface{}, 203 opts ...grpc.CallOption, 204 ) (err error) { 205 if handler, ok := invokeHandlers[Method(method).Code()]; ok { 206 var result proto.Message 207 result, err = handler(args) 208 if err != nil { 209 return xerrors.WithStackTrace(err) 210 } 211 var anyResult *anypb.Any 212 anyResult, err = anypb.New(result) 213 if err != nil { 214 return xerrors.WithStackTrace(err) 215 } 216 setField( 217 "Operation", 218 reply, 219 &Ydb_Operations.Operation{ 220 Result: anyResult, 221 }, 222 ) 223 224 return nil 225 } 226 227 return fmt.Errorf("method '%s' not implemented", method) 228 } 229 } 230 } 231 232 func WithNewStreamHandlers(newStreamHandlers NewStreamHandlers) balancerOption { 233 return func(r *balancerStub) { 234 r.onNewStream = func( 235 ctx context.Context, 236 desc *grpc.StreamDesc, 237 method string, 238 opts ...grpc.CallOption, 239 ) (_ grpc.ClientStream, err error) { 240 if handler, ok := newStreamHandlers[Method(method).Code()]; ok { 241 return handler(desc) 242 } 243 244 return nil, fmt.Errorf("method '%s' not implemented", method) 245 } 246 } 247 } 248 249 func NewBalancer(opts ...balancerOption) *balancerStub { 250 c := &balancerStub{} 251 for _, opt := range opts { 252 if opt != nil { 253 opt(c) 254 } 255 } 256 257 return c 258 } 259 260 func (b *balancerStub) OnUpdate(func(context.Context, []endpoint.Info)) { 261 } 262 263 type clientConn struct { 264 onInvoke func( 265 ctx context.Context, 266 method string, 267 args interface{}, 268 reply interface{}, 269 opts ...grpc.CallOption, 270 ) error 271 onNewStream func( 272 ctx context.Context, 273 desc *grpc.StreamDesc, 274 method string, 275 opts ...grpc.CallOption, 276 ) (grpc.ClientStream, error) 277 onAddress func() string 278 } 279 280 func (c *clientConn) Address() string { 281 if c.onAddress != nil { 282 return c.onAddress() 283 } 284 285 return "" 286 } 287 288 func (c *clientConn) Invoke( 289 ctx context.Context, 290 method string, 291 args interface{}, 292 reply interface{}, 293 opts ...grpc.CallOption, 294 ) error { 295 if c.onInvoke == nil { 296 return fmt.Errorf("onInvoke not implemented (method: %s, request: %v, response: %v)", method, args, reply) 297 } 298 299 return c.onInvoke(ctx, method, args, reply, opts...) 300 } 301 302 func (c *clientConn) NewStream( 303 ctx context.Context, 304 desc *grpc.StreamDesc, 305 method string, 306 opts ...grpc.CallOption, 307 ) (grpc.ClientStream, error) { 308 if c.onNewStream == nil { 309 return nil, fmt.Errorf("onNewStream not implemented (method: %s, desc: %v)", method, desc) 310 } 311 312 return c.onNewStream(ctx, desc, method, opts...) 313 } 314 315 type ClientStream struct { 316 OnHeader func() (metadata.MD, error) 317 OnTrailer func() metadata.MD 318 OnCloseSend func() error 319 OnContext func() context.Context 320 OnSendMsg func(m interface{}) error 321 OnRecvMsg func(m interface{}) error 322 } 323 324 func (s *ClientStream) Header() (metadata.MD, error) { 325 if s.OnHeader == nil { 326 return nil, xerrors.WithStackTrace(ErrNotImplemented) 327 } 328 329 return s.OnHeader() 330 } 331 332 func (s *ClientStream) Trailer() metadata.MD { 333 if s.OnTrailer == nil { 334 return nil 335 } 336 337 return s.OnTrailer() 338 } 339 340 func (s *ClientStream) CloseSend() error { 341 if s.OnCloseSend == nil { 342 return xerrors.WithStackTrace(ErrNotImplemented) 343 } 344 345 return s.OnCloseSend() 346 } 347 348 func (s *ClientStream) Context() context.Context { 349 if s.OnContext == nil { 350 return nil 351 } 352 353 return s.OnContext() 354 } 355 356 func (s *ClientStream) SendMsg(m interface{}) error { 357 if s.OnSendMsg == nil { 358 return xerrors.WithStackTrace(ErrNotImplemented) 359 } 360 361 return s.OnSendMsg(m) 362 } 363 364 func (s *ClientStream) RecvMsg(m interface{}) error { 365 if s.OnRecvMsg == nil { 366 return xerrors.WithStackTrace(ErrNotImplemented) 367 } 368 369 return s.OnRecvMsg(m) 370 } 371 372 func lastSegment(m string) string { 373 s := strings.Split(m, "/") 374 375 return s[len(s)-1] 376 }