github.com/prysmaticlabs/prysm@v1.4.4/shared/mputil/multilock.go (about)

     1  /*
     2  Copyright 2017 Albert Tedja
     3  Licensed under the Apache License, Version 2.0 (the "License");
     4  you may not use this file except in compliance with the License.
     5  You may obtain a copy of the License at
     6      http://www.apache.org/licenses/LICENSE-2.0
     7  Unless required by applicable law or agreed to in writing, software
     8  distributed under the License is distributed on an "AS IS" BASIS,
     9  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    10  See the License for the specific language governing permissions and
    11  limitations under the License.
    12  */
    13  package mputil
    14  
    15  import (
    16  	"runtime"
    17  	"sort"
    18  )
    19  
    20  var locks = struct {
    21  	lock chan byte
    22  	list map[string]chan byte
    23  }{
    24  	lock: make(chan byte, 1),
    25  	list: make(map[string]chan byte),
    26  }
    27  
    28  type Lock struct {
    29  	keys   []string
    30  	chans  []chan byte
    31  	lock   chan byte
    32  	unlock chan byte
    33  }
    34  
    35  func (lk *Lock) Lock() {
    36  	lk.lock <- 1
    37  
    38  	// get the channels and attempt to acquire them
    39  	lk.chans = make([]chan byte, 0, len(lk.keys))
    40  	for i := 0; i < len(lk.keys); {
    41  		ch := getChan(lk.keys[i])
    42  		_, ok := <-ch
    43  		if ok {
    44  			lk.chans = append(lk.chans, ch)
    45  			i++
    46  		}
    47  	}
    48  
    49  	lk.unlock <- 1
    50  }
    51  
    52  // Unlocks this lock. Must be called after Lock.
    53  // Can only be invoked if there is a previous call to Lock.
    54  func (lk *Lock) Unlock() {
    55  	<-lk.unlock
    56  
    57  	if lk.chans != nil {
    58  		for _, ch := range lk.chans {
    59  			ch <- 1
    60  		}
    61  		lk.chans = nil
    62  	}
    63  	// Clean unused channels after the unlock.
    64  	Clean()
    65  	<-lk.lock
    66  }
    67  
    68  // Temporarily unlocks, gives up the cpu time to other goroutine, and attempts to lock again.
    69  func (lk *Lock) Yield() {
    70  	lk.Unlock()
    71  	runtime.Gosched()
    72  	lk.Lock()
    73  }
    74  
    75  // Creates a new multilock for the specified keys
    76  func NewMultilock(locks ...string) *Lock {
    77  	if len(locks) == 0 {
    78  		return nil
    79  	}
    80  
    81  	locks = unique(locks)
    82  	sort.Strings(locks)
    83  	return &Lock{
    84  		keys:   locks,
    85  		lock:   make(chan byte, 1),
    86  		unlock: make(chan byte, 1),
    87  	}
    88  }
    89  
    90  // Cleans old unused locks. Returns removed keys.
    91  func Clean() []string {
    92  	locks.lock <- 1
    93  	defer func() { <-locks.lock }()
    94  
    95  	toDelete := make([]string, 0, len(locks.list))
    96  	for key, ch := range locks.list {
    97  		select {
    98  		case <-ch:
    99  			close(ch)
   100  			toDelete = append(toDelete, key)
   101  		default:
   102  		}
   103  	}
   104  
   105  	for _, del := range toDelete {
   106  		delete(locks.list, del)
   107  	}
   108  
   109  	return toDelete
   110  }
   111  
   112  // Create and get the channel for the specified key.
   113  func getChan(key string) chan byte {
   114  	locks.lock <- 1
   115  	defer func() { <-locks.lock }()
   116  
   117  	if locks.list[key] == nil {
   118  		locks.list[key] = make(chan byte, 1)
   119  		locks.list[key] <- 1
   120  	}
   121  	return locks.list[key]
   122  }
   123  
   124  // Return a new string with unique elements.
   125  func unique(arr []string) []string {
   126  	if arr == nil || len(arr) <= 1 {
   127  		return arr
   128  	}
   129  
   130  	found := map[string]bool{}
   131  	result := make([]string, 0, len(arr))
   132  	for _, v := range arr {
   133  		if !found[v] {
   134  			found[v] = true
   135  			result = append(result, v)
   136  		}
   137  	}
   138  	return result
   139  }