github.com/cloudwego/kitex@v0.9.0/pkg/remote/trans/netpollmux/shard_map.go (about)

     1  /*
     2   * Copyright 2021 CloudWeGo Authors
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package netpollmux
    18  
    19  import (
    20  	"sync"
    21  
    22  	"github.com/cloudwego/kitex/pkg/remote"
    23  )
    24  
    25  // EventHandler is used to handle events
    26  type EventHandler interface {
    27  	Recv(bufReader remote.ByteBuffer, err error) error
    28  }
    29  
    30  // A concurrent safe <seqID,EventHandler> map
    31  // To avoid lock bottlenecks this map is dived to several (SHARD_COUNT) map shards.
    32  type shardMap struct {
    33  	size   int32
    34  	shards []*shard
    35  }
    36  
    37  // A "thread" safe string to anything map.
    38  type shard struct {
    39  	msgs map[int32]EventHandler
    40  	sync.RWMutex
    41  }
    42  
    43  // Creates a new concurrent map.
    44  func newShardMap(size int) *shardMap {
    45  	m := &shardMap{
    46  		size:   int32(size),
    47  		shards: make([]*shard, size),
    48  	}
    49  	for i := range m.shards {
    50  		m.shards[i] = &shard{
    51  			msgs: make(map[int32]EventHandler),
    52  		}
    53  	}
    54  	return m
    55  }
    56  
    57  // getShard returns shard under given seq id
    58  func (m *shardMap) getShard(seqID int32) *shard {
    59  	return m.shards[abs(seqID)%m.size]
    60  }
    61  
    62  // store stores msg under given seq id.
    63  func (m *shardMap) store(seqID int32, msg EventHandler) {
    64  	if seqID == 0 {
    65  		return
    66  	}
    67  	// Get map shard.
    68  	shard := m.getShard(seqID)
    69  	shard.Lock()
    70  	shard.msgs[seqID] = msg
    71  	shard.Unlock()
    72  }
    73  
    74  // load loads the msg under the seq id.
    75  func (m *shardMap) load(seqID int32) (msg EventHandler, ok bool) {
    76  	if seqID == 0 {
    77  		return nil, false
    78  	}
    79  	shard := m.getShard(seqID)
    80  	shard.RLock()
    81  	msg, ok = shard.msgs[seqID]
    82  	shard.RUnlock()
    83  	return msg, ok
    84  }
    85  
    86  // delete deletes the msg under the given seq id.
    87  func (m *shardMap) delete(seqID int32) {
    88  	if seqID == 0 {
    89  		return
    90  	}
    91  	shard := m.getShard(seqID)
    92  	shard.Lock()
    93  	delete(shard.msgs, seqID)
    94  	shard.Unlock()
    95  }
    96  
    97  // rangeMap iterates over the map.
    98  func (m *shardMap) rangeMap(fn func(seqID int32, msg EventHandler)) {
    99  	for _, shard := range m.shards {
   100  		shard.Lock()
   101  		for k, v := range shard.msgs {
   102  			fn(k, v)
   103  		}
   104  		shard.Unlock()
   105  	}
   106  }
   107  
   108  func abs(n int32) int32 {
   109  	if n < 0 {
   110  		return -n
   111  	}
   112  	return n
   113  }