trpc.group/trpc-go/trpc-go@v1.0.3/transport/tnet/multiplex/shardmap.go (about) 1 // 2 // 3 // Tencent is pleased to support the open source community by making tRPC available. 4 // 5 // Copyright (C) 2023 THL A29 Limited, a Tencent company. 6 // All rights reserved. 7 // 8 // If you have downloaded a copy of the tRPC source code from Tencent, 9 // please note that tRPC source code is licensed under the Apache 2.0 License, 10 // A copy of the Apache 2.0 License is included in this file. 11 // 12 // 13 14 //go:build linux || freebsd || dragonfly || darwin 15 // +build linux freebsd dragonfly darwin 16 17 package multiplex 18 19 import ( 20 "runtime" 21 "sync" 22 "sync/atomic" 23 ) 24 25 var defaultShardSize = uint32(runtime.GOMAXPROCS(0)) 26 27 // shardMap is a concurrent safe <id,*virConn> map. 28 // To avoid lock bottlenecks this map is dived to several (SHARD_COUNT) map shards. 29 type shardMap struct { 30 size uint32 31 len uint32 32 shards []*shard 33 } 34 35 // shard is a concurrent safe map. 36 type shard struct { 37 idToVirConn map[uint32]*virtualConnection 38 mu sync.RWMutex 39 } 40 41 // newShardMap creates a new shardMap. 42 func newShardMap(size uint32) *shardMap { 43 m := &shardMap{ 44 size: size, 45 shards: make([]*shard, size), 46 } 47 for i := range m.shards { 48 m.shards[i] = &shard{ 49 idToVirConn: make(map[uint32]*virtualConnection), 50 } 51 } 52 return m 53 } 54 55 // getShard returns shard of given id. 56 func (m *shardMap) getShard(id uint32) *shard { 57 return m.shards[id%m.size] 58 } 59 60 // loadOrStore returns the existing virtual connection for the id if present. 61 // Otherwise, it stores and returns the given vc. The loaded result is true if 62 // the vc was loaded, false if stored. 63 func (m *shardMap) loadOrStore(id uint32, vc *virtualConnection) (actual *virtualConnection, loaded bool) { 64 shard := m.getShard(id) 65 // Generally the ids are always different, here directly add the write lock. 66 shard.mu.Lock() 67 defer shard.mu.Unlock() 68 if actual, ok := shard.idToVirConn[id]; ok { 69 return actual, true 70 } 71 atomic.AddUint32(&m.len, 1) 72 shard.idToVirConn[id] = vc 73 return vc, false 74 } 75 76 // store stores virConn. 77 func (m *shardMap) store(id uint32, vc *virtualConnection) { 78 shard := m.getShard(id) 79 shard.mu.Lock() 80 defer shard.mu.Unlock() 81 if _, ok := shard.idToVirConn[id]; !ok { 82 atomic.AddUint32(&m.len, 1) 83 } 84 shard.idToVirConn[id] = vc 85 } 86 87 // load loads the virConn of the given id. 88 func (m *shardMap) load(id uint32) (*virtualConnection, bool) { 89 shard := m.getShard(id) 90 shard.mu.RLock() 91 defer shard.mu.RUnlock() 92 vc, ok := shard.idToVirConn[id] 93 return vc, ok 94 } 95 96 // delete deletes the virConn of the given id. 97 func (m *shardMap) delete(id uint32) { 98 shard := m.getShard(id) 99 shard.mu.Lock() 100 defer shard.mu.Unlock() 101 if _, ok := shard.idToVirConn[id]; !ok { 102 return 103 } 104 atomic.AddUint32(&m.len, ^uint32(0)) 105 delete(shard.idToVirConn, id) 106 } 107 108 // reset deletes all virConns in the shardMap. 109 func (m *shardMap) reset() { 110 if m.length() == 0 { 111 return 112 } 113 atomic.StoreUint32(&m.len, 0) 114 for _, shard := range m.shards { 115 shard.mu.Lock() 116 shard.idToVirConn = make(map[uint32]*virtualConnection) 117 shard.mu.Unlock() 118 } 119 } 120 121 // length returns number of all virConns in the shardMap. 122 func (m *shardMap) length() uint32 { 123 return atomic.LoadUint32(&m.len) 124 } 125 126 // loadAll returns all virConns in the shardMap. 127 func (m *shardMap) loadAll() []*virtualConnection { 128 var conns []*virtualConnection 129 for _, shard := range m.shards { 130 shard.mu.RLock() 131 for _, v := range shard.idToVirConn { 132 conns = append(conns, v) 133 } 134 shard.mu.RUnlock() 135 } 136 return conns 137 }