github.com/mailru/activerecord@v1.12.2/pkg/activerecord/connection_w_test.go (about) 1 package activerecord 2 3 import ( 4 "reflect" 5 "sync" 6 "testing" 7 ) 8 9 type TestOptions struct { 10 hash string 11 mode ServerModeType 12 } 13 14 func (to *TestOptions) InstanceMode() ServerModeType { 15 return to.mode 16 } 17 18 func (to *TestOptions) GetConnectionID() string { 19 return to.hash 20 } 21 22 type TestConnection struct { 23 ch chan struct{} 24 id string 25 } 26 27 func (tc *TestConnection) Close() { 28 tc.ch <- struct{}{} 29 } 30 31 func (tc *TestConnection) Done() <-chan struct{} { 32 return tc.ch 33 } 34 35 var connectionCall = 0 36 37 func connectorFunc(options interface{}) (ConnectionInterface, error) { 38 connectionCall++ 39 to, _ := options.(*TestOptions) 40 return &TestConnection{id: to.hash}, nil 41 } 42 43 func Test_connectionPool_Add(t *testing.T) { 44 to1 := &TestOptions{hash: "testopt1"} 45 46 var clusterInfo = NewClusterInfo( 47 WithShard([]OptionInterface{to1}, []OptionInterface{}), 48 ) 49 50 type args struct { 51 shard ShardInstance 52 connector func(interface{}) (ConnectionInterface, error) 53 } 54 55 tests := []struct { 56 name string 57 args args 58 want ConnectionInterface 59 wantErr bool 60 wantCnt int 61 }{ 62 { 63 name: "first connection", 64 args: args{ 65 shard: clusterInfo.NextMaster(0), 66 connector: connectorFunc, 67 }, 68 wantErr: false, 69 want: &TestConnection{id: "testopt1"}, 70 wantCnt: 1, 71 }, 72 { 73 name: "again first connection", 74 args: args{ 75 shard: clusterInfo.NextMaster(0), 76 connector: connectorFunc, 77 }, 78 wantErr: true, 79 wantCnt: 1, 80 }, 81 } 82 83 cp := connectionPool{ 84 lock: sync.Mutex{}, 85 container: map[string]ConnectionInterface{}, 86 } 87 for _, tt := range tests { 88 t.Run(tt.name, func(t *testing.T) { 89 got, err := cp.Add(tt.args.shard, tt.args.connector) 90 if (err != nil) != tt.wantErr { 91 t.Errorf("connectionPool.Add() error = %v, wantErr %v", err, tt.wantErr) 92 return 93 } 94 if !reflect.DeepEqual(got, tt.want) { 95 t.Errorf("connectionPool.Add() = %+v, want %+v", got, tt.want) 96 } 97 if connectionCall != tt.wantCnt { 98 t.Errorf("connectionPool.Add() connectionCnt = %v, want %v", connectionCall, tt.wantCnt) 99 } 100 }) 101 } 102 }