github.com/mailru/activerecord@v1.12.2/pkg/activerecord/connection_w_test.go (about)

     1  package activerecord
     2  
     3  import (
     4  	"reflect"
     5  	"sync"
     6  	"testing"
     7  )
     8  
     9  type TestOptions struct {
    10  	hash string
    11  	mode ServerModeType
    12  }
    13  
    14  func (to *TestOptions) InstanceMode() ServerModeType {
    15  	return to.mode
    16  }
    17  
    18  func (to *TestOptions) GetConnectionID() string {
    19  	return to.hash
    20  }
    21  
    22  type TestConnection struct {
    23  	ch chan struct{}
    24  	id string
    25  }
    26  
    27  func (tc *TestConnection) Close() {
    28  	tc.ch <- struct{}{}
    29  }
    30  
    31  func (tc *TestConnection) Done() <-chan struct{} {
    32  	return tc.ch
    33  }
    34  
    35  var connectionCall = 0
    36  
    37  func connectorFunc(options interface{}) (ConnectionInterface, error) {
    38  	connectionCall++
    39  	to, _ := options.(*TestOptions)
    40  	return &TestConnection{id: to.hash}, nil
    41  }
    42  
    43  func Test_connectionPool_Add(t *testing.T) {
    44  	to1 := &TestOptions{hash: "testopt1"}
    45  
    46  	var clusterInfo = NewClusterInfo(
    47  		WithShard([]OptionInterface{to1}, []OptionInterface{}),
    48  	)
    49  
    50  	type args struct {
    51  		shard     ShardInstance
    52  		connector func(interface{}) (ConnectionInterface, error)
    53  	}
    54  
    55  	tests := []struct {
    56  		name    string
    57  		args    args
    58  		want    ConnectionInterface
    59  		wantErr bool
    60  		wantCnt int
    61  	}{
    62  		{
    63  			name: "first connection",
    64  			args: args{
    65  				shard:     clusterInfo.NextMaster(0),
    66  				connector: connectorFunc,
    67  			},
    68  			wantErr: false,
    69  			want:    &TestConnection{id: "testopt1"},
    70  			wantCnt: 1,
    71  		},
    72  		{
    73  			name: "again first connection",
    74  			args: args{
    75  				shard:     clusterInfo.NextMaster(0),
    76  				connector: connectorFunc,
    77  			},
    78  			wantErr: true,
    79  			wantCnt: 1,
    80  		},
    81  	}
    82  
    83  	cp := connectionPool{
    84  		lock:      sync.Mutex{},
    85  		container: map[string]ConnectionInterface{},
    86  	}
    87  	for _, tt := range tests {
    88  		t.Run(tt.name, func(t *testing.T) {
    89  			got, err := cp.Add(tt.args.shard, tt.args.connector)
    90  			if (err != nil) != tt.wantErr {
    91  				t.Errorf("connectionPool.Add() error = %v, wantErr %v", err, tt.wantErr)
    92  				return
    93  			}
    94  			if !reflect.DeepEqual(got, tt.want) {
    95  				t.Errorf("connectionPool.Add() = %+v, want %+v", got, tt.want)
    96  			}
    97  			if connectionCall != tt.wantCnt {
    98  				t.Errorf("connectionPool.Add() connectionCnt = %v, want %v", connectionCall, tt.wantCnt)
    99  			}
   100  		})
   101  	}
   102  }