trpc.group/trpc-go/trpc-go@v1.0.3/transport/tnet/multiplex/shardmap.go (about)

     1  //
     2  //
     3  // Tencent is pleased to support the open source community by making tRPC available.
     4  //
     5  // Copyright (C) 2023 THL A29 Limited, a Tencent company.
     6  // All rights reserved.
     7  //
     8  // If you have downloaded a copy of the tRPC source code from Tencent,
     9  // please note that tRPC source code is licensed under the  Apache 2.0 License,
    10  // A copy of the Apache 2.0 License is included in this file.
    11  //
    12  //
    13  
    14  //go:build linux || freebsd || dragonfly || darwin
    15  // +build linux freebsd dragonfly darwin
    16  
    17  package multiplex
    18  
    19  import (
    20  	"runtime"
    21  	"sync"
    22  	"sync/atomic"
    23  )
    24  
    25  var defaultShardSize = uint32(runtime.GOMAXPROCS(0))
    26  
    27  // shardMap is a concurrent safe <id,*virConn> map.
    28  // To avoid lock bottlenecks this map is dived to several (SHARD_COUNT) map shards.
    29  type shardMap struct {
    30  	size   uint32
    31  	len    uint32
    32  	shards []*shard
    33  }
    34  
    35  // shard is a concurrent safe map.
    36  type shard struct {
    37  	idToVirConn map[uint32]*virtualConnection
    38  	mu          sync.RWMutex
    39  }
    40  
    41  // newShardMap creates a new shardMap.
    42  func newShardMap(size uint32) *shardMap {
    43  	m := &shardMap{
    44  		size:   size,
    45  		shards: make([]*shard, size),
    46  	}
    47  	for i := range m.shards {
    48  		m.shards[i] = &shard{
    49  			idToVirConn: make(map[uint32]*virtualConnection),
    50  		}
    51  	}
    52  	return m
    53  }
    54  
    55  // getShard returns shard of given id.
    56  func (m *shardMap) getShard(id uint32) *shard {
    57  	return m.shards[id%m.size]
    58  }
    59  
    60  // loadOrStore returns the existing virtual connection for the id if present.
    61  // Otherwise, it stores and returns the given vc. The loaded result is true if
    62  // the vc was loaded, false if stored.
    63  func (m *shardMap) loadOrStore(id uint32, vc *virtualConnection) (actual *virtualConnection, loaded bool) {
    64  	shard := m.getShard(id)
    65  	// Generally the ids are always different, here directly add the write lock.
    66  	shard.mu.Lock()
    67  	defer shard.mu.Unlock()
    68  	if actual, ok := shard.idToVirConn[id]; ok {
    69  		return actual, true
    70  	}
    71  	atomic.AddUint32(&m.len, 1)
    72  	shard.idToVirConn[id] = vc
    73  	return vc, false
    74  }
    75  
    76  // store stores virConn.
    77  func (m *shardMap) store(id uint32, vc *virtualConnection) {
    78  	shard := m.getShard(id)
    79  	shard.mu.Lock()
    80  	defer shard.mu.Unlock()
    81  	if _, ok := shard.idToVirConn[id]; !ok {
    82  		atomic.AddUint32(&m.len, 1)
    83  	}
    84  	shard.idToVirConn[id] = vc
    85  }
    86  
    87  // load loads the virConn of the given id.
    88  func (m *shardMap) load(id uint32) (*virtualConnection, bool) {
    89  	shard := m.getShard(id)
    90  	shard.mu.RLock()
    91  	defer shard.mu.RUnlock()
    92  	vc, ok := shard.idToVirConn[id]
    93  	return vc, ok
    94  }
    95  
    96  // delete deletes the virConn of the given id.
    97  func (m *shardMap) delete(id uint32) {
    98  	shard := m.getShard(id)
    99  	shard.mu.Lock()
   100  	defer shard.mu.Unlock()
   101  	if _, ok := shard.idToVirConn[id]; !ok {
   102  		return
   103  	}
   104  	atomic.AddUint32(&m.len, ^uint32(0))
   105  	delete(shard.idToVirConn, id)
   106  }
   107  
   108  // reset deletes all virConns in the shardMap.
   109  func (m *shardMap) reset() {
   110  	if m.length() == 0 {
   111  		return
   112  	}
   113  	atomic.StoreUint32(&m.len, 0)
   114  	for _, shard := range m.shards {
   115  		shard.mu.Lock()
   116  		shard.idToVirConn = make(map[uint32]*virtualConnection)
   117  		shard.mu.Unlock()
   118  	}
   119  }
   120  
   121  // length returns number of all virConns in the shardMap.
   122  func (m *shardMap) length() uint32 {
   123  	return atomic.LoadUint32(&m.len)
   124  }
   125  
   126  // loadAll returns all virConns in the shardMap.
   127  func (m *shardMap) loadAll() []*virtualConnection {
   128  	var conns []*virtualConnection
   129  	for _, shard := range m.shards {
   130  		shard.mu.RLock()
   131  		for _, v := range shard.idToVirConn {
   132  			conns = append(conns, v)
   133  		}
   134  		shard.mu.RUnlock()
   135  	}
   136  	return conns
   137  }