github.com/tursom/GoCollections@v0.3.10/concurrent/ReentrantRWLock.go (about)

     1  /*
     2   * Copyright (c) 2022 tursom. All rights reserved.
     3   * Use of this source code is governed by a GPL-3
     4   * license that can be found in the LICENSE file.
     5   */
     6  
     7  package concurrent
     8  
     9  import (
    10  	"fmt"
    11  	"sync"
    12  
    13  	"github.com/tursom/GoCollections/exceptions"
    14  )
    15  
    16  type ReentrantRWLock struct {
    17  	lock      sync.Mutex
    18  	rlock     sync.RWMutex
    19  	cond      sync.Cond
    20  	recursion int32
    21  	host      int64
    22  }
    23  
    24  func NewReentrantRWLock() *ReentrantRWLock {
    25  	res := &ReentrantRWLock{
    26  		recursion: 0,
    27  		host:      0,
    28  	}
    29  	res.cond = *sync.NewCond(&res.lock)
    30  	return res
    31  }
    32  
    33  func (rt *ReentrantRWLock) Lock() {
    34  	id := GetGoroutineID()
    35  	rt.lock.Lock()
    36  	defer rt.lock.Unlock()
    37  
    38  	if rt.host == id {
    39  		rt.recursion++
    40  		return
    41  	}
    42  
    43  	for rt.recursion != 0 {
    44  		rt.cond.Wait()
    45  	}
    46  	rt.host = id
    47  	rt.recursion = 1
    48  	rt.rlock.Lock()
    49  }
    50  
    51  func (rt *ReentrantRWLock) TryLock() bool {
    52  	id := GetGoroutineID()
    53  	rt.lock.Lock()
    54  	defer rt.lock.Unlock()
    55  
    56  	if rt.host == id {
    57  		rt.recursion++
    58  		return true
    59  	}
    60  
    61  	if rt.recursion == 0 {
    62  		rt.host = id
    63  		rt.recursion = 1
    64  		rt.rlock.Lock()
    65  		return true
    66  	}
    67  
    68  	return false
    69  }
    70  
    71  func (rt *ReentrantRWLock) Unlock() {
    72  	rt.lock.Lock()
    73  	defer rt.lock.Unlock()
    74  
    75  	if rt.recursion == 0 || rt.host != GetGoroutineID() {
    76  		panic(exceptions.NewWrongCallHostException(fmt.Sprintf("the wrong call host: (%d); current_id: %d; recursion: %d", rt.host, GetGoroutineID(), rt.recursion)))
    77  	}
    78  
    79  	rt.recursion--
    80  	if rt.recursion == 0 {
    81  		rt.rlock.Unlock()
    82  		rt.cond.Signal()
    83  	}
    84  }
    85  
    86  func (rt *ReentrantRWLock) RLock() {
    87  	if rt.host == GetGoroutineID() {
    88  		return
    89  	}
    90  	rt.rlock.RLock()
    91  }
    92  
    93  func (rt *ReentrantRWLock) TryRLock() bool {
    94  	if rt.host == GetGoroutineID() {
    95  		return true
    96  	}
    97  	return rt.rlock.TryRLock()
    98  }
    99  
   100  func (rt *ReentrantRWLock) RUnlock() {
   101  	if rt.host == GetGoroutineID() {
   102  		return
   103  	}
   104  	rt.rlock.RUnlock()
   105  }