github.com/ydb-platform/ydb-go-sdk/v3@v3.89.2/testutil/driver.go (about) 1 package testutil 2 3 import ( 4 "context" 5 "fmt" 6 "reflect" 7 "strings" 8 9 "github.com/ydb-platform/ydb-go-genproto/protos/Ydb_Operations" 10 "google.golang.org/grpc" 11 "google.golang.org/grpc/metadata" 12 "google.golang.org/protobuf/proto" 13 "google.golang.org/protobuf/types/known/anypb" 14 15 "github.com/ydb-platform/ydb-go-sdk/v3/internal/endpoint" 16 "github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors" 17 ) 18 19 var ErrNotImplemented = xerrors.Wrap(fmt.Errorf("testutil: not implemented")) 20 21 type MethodCode uint 22 23 func (m MethodCode) String() string { 24 if method, ok := codeToString[m]; ok { 25 return method 26 } 27 28 return "" 29 } 30 31 type Method string 32 33 func (m Method) Code() MethodCode { 34 if code, ok := grpcMethodToCode[m]; ok { 35 return code 36 } 37 38 return UnknownMethod 39 } 40 41 const ( 42 UnknownMethod MethodCode = iota 43 TableCreateSession 44 TableDeleteSession 45 TableKeepAlive 46 TableCreateTable 47 TableDropTable 48 TableAlterTable 49 TableCopyTable 50 TableDescribeTable 51 TableExplainDataQuery 52 TablePrepareDataQuery 53 TableExecuteDataQuery 54 TableExecuteSchemeQuery 55 TableBeginTransaction 56 TableCommitTransaction 57 TableRollbackTransaction 58 TableDescribeTableOptions 59 TableStreamReadTable 60 TableStreamExecuteScanQuery 61 ) 62 63 var grpcMethodToCode = map[Method]MethodCode{ 64 "/Ydb.Table.V1.TableService/CreateSession": TableCreateSession, 65 "/Ydb.Table.V1.TableService/DeleteSession": TableDeleteSession, 66 "/Ydb.Table.V1.TableService/KeepAlive": TableKeepAlive, 67 "/Ydb.Table.V1.TableService/CreateTable": TableCreateTable, 68 "/Ydb.Table.V1.TableService/DropTable": TableDropTable, 69 "/Ydb.Table.V1.TableService/AlterTable": TableAlterTable, 70 "/Ydb.Table.V1.TableService/CopyTable": TableCopyTable, 71 "/Ydb.Table.V1.TableService/DescribeTable": TableDescribeTable, 72 "/Ydb.Table.V1.TableService/ExplainDataQuery": TableExplainDataQuery, 73 "/Ydb.Table.V1.TableService/PrepareDataQuery": TablePrepareDataQuery, 74 "/Ydb.Table.V1.TableService/ExecuteDataQuery": TableExecuteDataQuery, 75 "/Ydb.Table.V1.TableService/ExecuteSchemeQuery": TableExecuteSchemeQuery, 76 "/Ydb.Table.V1.TableService/BeginTransaction": TableBeginTransaction, 77 "/Ydb.Table.V1.TableService/CommitTransaction": TableCommitTransaction, 78 "/Ydb.Table.V1.TableService/RollbackTransaction": TableRollbackTransaction, 79 "/Ydb.Table.V1.TableService/DescribeTableOptions": TableDescribeTableOptions, 80 "/Ydb.Table.V1.TableService/StreamReadTable": TableStreamReadTable, 81 "/Ydb.Table.V1.TableService/StreamExecuteScanQuery": TableStreamExecuteScanQuery, 82 } 83 84 var codeToString = map[MethodCode]string{ 85 TableCreateSession: lastSegment("/Ydb.Table.V1.TableService/CreateSession"), 86 TableDeleteSession: lastSegment("/Ydb.Table.V1.TableService/DeleteSession"), 87 TableKeepAlive: lastSegment("/Ydb.Table.V1.TableService/KeepAlive"), 88 TableCreateTable: lastSegment("/Ydb.Table.V1.TableService/CreateTable"), 89 TableDropTable: lastSegment("/Ydb.Table.V1.TableService/DropTable"), 90 TableAlterTable: lastSegment("/Ydb.Table.V1.TableService/AlterTable"), 91 TableCopyTable: lastSegment("/Ydb.Table.V1.TableService/CopyTable"), 92 TableDescribeTable: lastSegment("/Ydb.Table.V1.TableService/DescribeTable"), 93 TableExplainDataQuery: lastSegment("/Ydb.Table.V1.TableService/ExplainDataQuery"), 94 TablePrepareDataQuery: lastSegment("/Ydb.Table.V1.TableService/PrepareDataQuery"), 95 TableExecuteDataQuery: lastSegment("/Ydb.Table.V1.TableService/ExecuteDataQuery"), 96 TableExecuteSchemeQuery: lastSegment("/Ydb.Table.V1.TableService/ExecuteSchemeQuery"), 97 TableBeginTransaction: lastSegment("/Ydb.Table.V1.TableService/BeginTransaction"), 98 TableCommitTransaction: lastSegment("/Ydb.Table.V1.TableService/CommitTransaction"), 99 TableRollbackTransaction: lastSegment("/Ydb.Table.V1.TableService/RollbackTransaction"), 100 TableDescribeTableOptions: lastSegment("/Ydb.Table.V1.TableService/DescribeTableOptions"), 101 TableStreamReadTable: lastSegment("/Ydb.Table.V1.TableService/StreamReadTable"), 102 TableStreamExecuteScanQuery: lastSegment("/Ydb.Table.V1.TableService/StreamExecuteScanQuery"), 103 } 104 105 func setField(name string, dst, value interface{}) { 106 x := reflect.ValueOf(dst).Elem() 107 t := x.Type() 108 f, ok := t.FieldByName(name) 109 if !ok { 110 panic(fmt.Sprintf( 111 "struct %s has no field %q", 112 t, name, 113 )) 114 } 115 v := reflect.ValueOf(value) 116 if f.Type.Kind() != v.Type().Kind() { 117 panic(fmt.Sprintf( 118 "struct %s field %q is types of %s, not %s", 119 t, name, f.Type, v.Type(), 120 )) 121 } 122 x.FieldByName(f.Name).Set(v) 123 } 124 125 type balancerStub struct { 126 onInvoke func( 127 ctx context.Context, 128 method string, 129 args interface{}, 130 reply interface{}, 131 opts ...grpc.CallOption, 132 ) error 133 onNewStream func( 134 ctx context.Context, 135 desc *grpc.StreamDesc, 136 method string, 137 opts ...grpc.CallOption, 138 ) (grpc.ClientStream, error) 139 } 140 141 func (b *balancerStub) Invoke( 142 ctx context.Context, 143 method string, 144 args interface{}, 145 reply interface{}, 146 opts ...grpc.CallOption, 147 ) (err error) { 148 if b.onInvoke == nil { 149 return fmt.Errorf("database.onInvoke() not defined") 150 } 151 152 return b.onInvoke(ctx, method, args, reply, opts...) 153 } 154 155 func (b *balancerStub) NewStream( 156 ctx context.Context, 157 desc *grpc.StreamDesc, 158 method string, 159 opts ...grpc.CallOption, 160 ) (_ grpc.ClientStream, err error) { 161 if b.onNewStream == nil { 162 return nil, fmt.Errorf("database.onNewStream() not defined") 163 } 164 165 return b.onNewStream(ctx, desc, method, opts...) 166 } 167 168 func (b *balancerStub) Get(context.Context) (conn grpc.ClientConnInterface, err error) { 169 cc := &clientConn{ 170 onInvoke: b.onInvoke, 171 onNewStream: b.onNewStream, 172 } 173 174 return cc, nil 175 } 176 177 func (b *balancerStub) Name() string { 178 return "testutil.database" 179 } 180 181 func (b *balancerStub) Close(ctx context.Context) error { 182 return nil 183 } 184 185 type ( 186 InvokeHandlers map[MethodCode]func(request interface{}) (result proto.Message, err error) 187 NewStreamHandlers map[MethodCode]func(desc *grpc.StreamDesc) (grpc.ClientStream, error) 188 ) 189 190 type balancerOption func(c *balancerStub) 191 192 func WithInvokeHandlers(invokeHandlers InvokeHandlers) balancerOption { 193 return func(r *balancerStub) { 194 r.onInvoke = func( 195 ctx context.Context, 196 method string, 197 args interface{}, 198 reply interface{}, 199 opts ...grpc.CallOption, 200 ) (err error) { 201 if handler, ok := invokeHandlers[Method(method).Code()]; ok { 202 var result proto.Message 203 result, err = handler(args) 204 if err != nil { 205 return xerrors.WithStackTrace(err) 206 } 207 var anyResult *anypb.Any 208 anyResult, err = anypb.New(result) 209 if err != nil { 210 return xerrors.WithStackTrace(err) 211 } 212 setField( 213 "Operation", 214 reply, 215 &Ydb_Operations.Operation{ 216 Result: anyResult, 217 }, 218 ) 219 220 return nil 221 } 222 223 return fmt.Errorf("method '%s' not implemented", method) 224 } 225 } 226 } 227 228 func WithNewStreamHandlers(newStreamHandlers NewStreamHandlers) balancerOption { 229 return func(r *balancerStub) { 230 r.onNewStream = func( 231 ctx context.Context, 232 desc *grpc.StreamDesc, 233 method string, 234 opts ...grpc.CallOption, 235 ) (_ grpc.ClientStream, err error) { 236 if handler, ok := newStreamHandlers[Method(method).Code()]; ok { 237 return handler(desc) 238 } 239 240 return nil, fmt.Errorf("method '%s' not implemented", method) 241 } 242 } 243 } 244 245 func NewBalancer(opts ...balancerOption) *balancerStub { 246 c := &balancerStub{} 247 for _, opt := range opts { 248 if opt != nil { 249 opt(c) 250 } 251 } 252 253 return c 254 } 255 256 func (b *balancerStub) OnUpdate(func(context.Context, []endpoint.Info)) { 257 } 258 259 type clientConn struct { 260 onInvoke func( 261 ctx context.Context, 262 method string, 263 args interface{}, 264 reply interface{}, 265 opts ...grpc.CallOption, 266 ) error 267 onNewStream func( 268 ctx context.Context, 269 desc *grpc.StreamDesc, 270 method string, 271 opts ...grpc.CallOption, 272 ) (grpc.ClientStream, error) 273 onAddress func() string 274 } 275 276 func (c *clientConn) Address() string { 277 if c.onAddress != nil { 278 return c.onAddress() 279 } 280 281 return "" 282 } 283 284 func (c *clientConn) Invoke( 285 ctx context.Context, 286 method string, 287 args interface{}, 288 reply interface{}, 289 opts ...grpc.CallOption, 290 ) error { 291 if c.onInvoke == nil { 292 return fmt.Errorf("onInvoke not implemented (method: %s, request: %v, response: %v)", method, args, reply) 293 } 294 295 return c.onInvoke(ctx, method, args, reply, opts...) 296 } 297 298 func (c *clientConn) NewStream( 299 ctx context.Context, 300 desc *grpc.StreamDesc, 301 method string, 302 opts ...grpc.CallOption, 303 ) (grpc.ClientStream, error) { 304 if c.onNewStream == nil { 305 return nil, fmt.Errorf("onNewStream not implemented (method: %s, desc: %v)", method, desc) 306 } 307 308 return c.onNewStream(ctx, desc, method, opts...) 309 } 310 311 type ClientStream struct { 312 OnHeader func() (metadata.MD, error) 313 OnTrailer func() metadata.MD 314 OnCloseSend func() error 315 OnContext func() context.Context 316 OnSendMsg func(m interface{}) error 317 OnRecvMsg func(m interface{}) error 318 } 319 320 func (s *ClientStream) Header() (metadata.MD, error) { 321 if s.OnHeader == nil { 322 return nil, xerrors.WithStackTrace(ErrNotImplemented) 323 } 324 325 return s.OnHeader() 326 } 327 328 func (s *ClientStream) Trailer() metadata.MD { 329 if s.OnTrailer == nil { 330 return nil 331 } 332 333 return s.OnTrailer() 334 } 335 336 func (s *ClientStream) CloseSend() error { 337 if s.OnCloseSend == nil { 338 return xerrors.WithStackTrace(ErrNotImplemented) 339 } 340 341 return s.OnCloseSend() 342 } 343 344 func (s *ClientStream) Context() context.Context { 345 if s.OnContext == nil { 346 return nil 347 } 348 349 return s.OnContext() 350 } 351 352 func (s *ClientStream) SendMsg(m interface{}) error { 353 if s.OnSendMsg == nil { 354 return xerrors.WithStackTrace(ErrNotImplemented) 355 } 356 357 return s.OnSendMsg(m) 358 } 359 360 func (s *ClientStream) RecvMsg(m interface{}) error { 361 if s.OnRecvMsg == nil { 362 return xerrors.WithStackTrace(ErrNotImplemented) 363 } 364 365 return s.OnRecvMsg(m) 366 } 367 368 func lastSegment(m string) string { 369 s := strings.Split(m, "/") 370 371 return s[len(s)-1] 372 }