github.com/mailru/activerecord@v1.12.2/pkg/activerecord/connection.go (about)

     1  package activerecord
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"sync"
     7  )
     8  
     9  type ConnectionInterface interface {
    10  	Close()
    11  	Done() <-chan struct{}
    12  }
    13  
    14  type connectionPool struct {
    15  	lock      sync.Mutex
    16  	container map[string]ConnectionInterface
    17  }
    18  
    19  func newConnectionPool() *connectionPool {
    20  	return &connectionPool{
    21  		lock:      sync.Mutex{},
    22  		container: make(map[string]ConnectionInterface),
    23  	}
    24  }
    25  
    26  // TODO при долгом неиспользовании какого то пула надо закрывать его. Это для случаев когда в конфиге поменялась конфигурация
    27  // надо зачищать старые пулы, что бы освободить конекты.
    28  // если будут колбеки о том, что сменилась конфигурация то можно подчищать по этим колбекам.
    29  func (cp *connectionPool) add(shard ShardInstance, connector func(interface{}) (ConnectionInterface, error)) (ConnectionInterface, error) {
    30  	if _, ex := cp.container[shard.ParamsID]; ex {
    31  		return nil, fmt.Errorf("attempt to add duplicate connID: %s", shard.ParamsID)
    32  	}
    33  
    34  	pool, err := connector(shard.Options)
    35  	if err != nil {
    36  		return nil, fmt.Errorf("error add connection to shard: %w", err)
    37  	}
    38  
    39  	cp.container[shard.ParamsID] = pool
    40  
    41  	return pool, nil
    42  }
    43  
    44  func (cp *connectionPool) Add(shard ShardInstance, connector func(interface{}) (ConnectionInterface, error)) (ConnectionInterface, error) {
    45  	cp.lock.Lock()
    46  	defer cp.lock.Unlock()
    47  
    48  	return cp.add(shard, connector)
    49  }
    50  
    51  func (cp *connectionPool) GetOrAdd(shard ShardInstance, connector func(interface{}) (ConnectionInterface, error)) (ConnectionInterface, error) {
    52  	cp.lock.Lock()
    53  	defer cp.lock.Unlock()
    54  
    55  	var err error
    56  
    57  	conn := cp.Get(shard)
    58  	if conn == nil {
    59  		conn, err = cp.add(shard, connector)
    60  	}
    61  
    62  	return conn, err
    63  }
    64  
    65  func (cp *connectionPool) Get(shard ShardInstance) ConnectionInterface {
    66  	if conn, ex := cp.container[shard.ParamsID]; ex {
    67  		return conn
    68  	}
    69  
    70  	return nil
    71  }
    72  
    73  func (cp *connectionPool) CloseConnection(ctx context.Context) {
    74  	cp.lock.Lock()
    75  
    76  	for name, pool := range cp.container {
    77  		pool.Close()
    78  		Logger().Debug(ctx, "connection close: %s", name)
    79  	}
    80  
    81  	for _, pool := range cp.container {
    82  		<-pool.Done()
    83  		Logger().Debug(ctx, "pool closed done")
    84  	}
    85  
    86  	cp.lock.Unlock()
    87  }