github.com/projecteru2/core@v0.0.0-20240321043226-06bcc1c23f58/cluster/calcium/lock.go (about)

     1  package calcium
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"sort"
     7  	"time"
     8  
     9  	"github.com/projecteru2/core/cluster"
    10  	"github.com/projecteru2/core/lock"
    11  	"github.com/projecteru2/core/log"
    12  	"github.com/projecteru2/core/types"
    13  	"github.com/projecteru2/core/utils"
    14  )
    15  
    16  func (c *Calcium) doLock(ctx context.Context, name string, timeout time.Duration) (lock lock.DistributedLock, rCtx context.Context, err error) {
    17  	if lock, err = c.store.CreateLock(name, timeout); err != nil {
    18  		return lock, rCtx, err
    19  	}
    20  	defer func() {
    21  		if err != nil {
    22  			rollbackCtx, cancel := context.WithTimeout(context.TODO(), timeout)
    23  			defer cancel()
    24  			rollbackCtx = utils.InheritTracingInfo(rollbackCtx, ctx)
    25  			if e := lock.Unlock(rollbackCtx); e != nil {
    26  				log.WithFunc("calcium.doLock").Errorf(rollbackCtx, err, "failed to unlock %s", name)
    27  			}
    28  		}
    29  	}()
    30  	rCtx, err = lock.Lock(ctx)
    31  	return lock, rCtx, err
    32  }
    33  
    34  func (c *Calcium) doUnlock(ctx context.Context, lock lock.DistributedLock, msg string) error {
    35  	log.WithFunc("calcium.doUnlock").Debugf(ctx, "Unlock %s", msg)
    36  	return lock.Unlock(ctx)
    37  }
    38  
    39  func (c *Calcium) doUnlockAll(ctx context.Context, locks map[string]lock.DistributedLock, order ...string) {
    40  	logger := log.WithFunc("calcium.doUnlockAll")
    41  	// unlock in the reverse order
    42  	if len(order) != len(locks) {
    43  		logger.Warn(ctx, "order length not match lock map")
    44  		order = []string{}
    45  		for key := range locks {
    46  			order = append(order, key)
    47  		}
    48  	}
    49  	for _, key := range order {
    50  		if err := c.doUnlock(ctx, locks[key], key); err != nil {
    51  			logger.Errorf(ctx, err, "Unlock %s failed", key)
    52  			continue
    53  		}
    54  	}
    55  }
    56  
    57  func (c *Calcium) withWorkloadLocked(ctx context.Context, ID string, ignoreLock bool, f func(context.Context, *types.Workload) error) error {
    58  	return c.withWorkloadsLocked(ctx, ignoreLock, []string{ID}, func(ctx context.Context, workloads map[string]*types.Workload) error {
    59  		if c, ok := workloads[ID]; ok {
    60  			return f(ctx, c)
    61  		}
    62  		return types.ErrWorkloadNotExists
    63  	})
    64  }
    65  
    66  func (c *Calcium) withWorkloadsLocked(ctx context.Context, ignoreLock bool, IDs []string, f func(context.Context, map[string]*types.Workload) error) error {
    67  	workloads := map[string]*types.Workload{}
    68  	locks := map[string]lock.DistributedLock{}
    69  	logger := log.WithFunc("calcium.withWorkloadsLocked")
    70  
    71  	// sort + unique
    72  	sort.Strings(IDs)
    73  	IDs = IDs[:utils.Unique(IDs, func(i int) string { return IDs[i] })]
    74  
    75  	defer logger.Debugf(ctx, "Workloads %+v unlocked", IDs)
    76  	defer func() {
    77  		utils.Reverse(IDs)
    78  		c.doUnlockAll(utils.NewInheritCtx(ctx), locks, IDs...)
    79  	}()
    80  	cs, err := c.store.GetWorkloads(ctx, IDs)
    81  	if err != nil {
    82  		return err
    83  	}
    84  	var lock lock.DistributedLock
    85  	for _, workload := range cs {
    86  		if !ignoreLock {
    87  			lock, ctx, err = c.doLock(ctx, fmt.Sprintf(cluster.WorkloadLock, workload.ID), c.config.LockTimeout)
    88  			if err != nil {
    89  				return err
    90  			}
    91  			logger.Debugf(ctx, "Workload %s locked", workload.ID)
    92  			locks[workload.ID] = lock
    93  		}
    94  		workloads[workload.ID] = workload
    95  	}
    96  	return f(ctx, workloads)
    97  }
    98  
    99  func (c *Calcium) withNodePodLocked(ctx context.Context, nodename string, f func(context.Context, *types.Node) error) error {
   100  	nodeFilter := &types.NodeFilter{
   101  		Includes: []string{nodename},
   102  		All:      true,
   103  	}
   104  	return c.withNodesPodLocked(ctx, nodeFilter, func(ctx context.Context, nodes map[string]*types.Node) error {
   105  		if n, ok := nodes[nodename]; ok {
   106  			return f(ctx, n)
   107  		}
   108  		return types.ErrNodeNotExists
   109  	})
   110  }
   111  
   112  func (c *Calcium) withNodeOperationLocked(ctx context.Context, nodename string, f func(context.Context, *types.Node) error) error {
   113  	nodeFilter := &types.NodeFilter{
   114  		Includes: []string{nodename},
   115  		All:      true,
   116  	}
   117  	return c.withNodesOperationLocked(ctx, nodeFilter, func(ctx context.Context, nodes map[string]*types.Node) error {
   118  		if n, ok := nodes[nodename]; ok {
   119  			return f(ctx, n)
   120  		}
   121  		return types.ErrNodeNotExists
   122  	})
   123  }
   124  
   125  func (c *Calcium) withNodesOperationLocked(ctx context.Context, nodeFilter *types.NodeFilter, f func(context.Context, map[string]*types.Node) error) error { //nolint
   126  	genKey := func(node *types.Node) string {
   127  		return fmt.Sprintf(cluster.NodeOperationLock, node.Podname, node.Name)
   128  	}
   129  	return c.withNodesLocked(ctx, nodeFilter, genKey, f)
   130  }
   131  
   132  func (c *Calcium) withNodesPodLocked(ctx context.Context, nodeFilter *types.NodeFilter, f func(context.Context, map[string]*types.Node) error) error {
   133  	genKey := func(node *types.Node) string {
   134  		return fmt.Sprintf(cluster.PodLock, node.Podname)
   135  	}
   136  	return c.withNodesLocked(ctx, nodeFilter, genKey, f)
   137  }
   138  
   139  func (c *Calcium) withNodesLocked(ctx context.Context, nodeFilter *types.NodeFilter, genKey func(*types.Node) string, f func(context.Context, map[string]*types.Node) error) error {
   140  	nodes := map[string]*types.Node{}
   141  	locks := map[string]lock.DistributedLock{}
   142  	lockKeys := []string{}
   143  	logger := log.WithFunc("calcium.withNodesLocked")
   144  
   145  	defer func() {
   146  		utils.Reverse(lockKeys)
   147  		c.doUnlockAll(utils.NewInheritCtx(ctx), locks, lockKeys...)
   148  		logger.Debugf(ctx, "keys %+v unlocked", lockKeys)
   149  	}()
   150  
   151  	ns, err := c.filterNodes(ctx, nodeFilter)
   152  	if err != nil {
   153  		return err
   154  	}
   155  
   156  	var lock lock.DistributedLock
   157  	for _, n := range ns {
   158  		key := genKey(n)
   159  		if _, ok := locks[key]; !ok {
   160  			lock, ctx, err = c.doLock(ctx, key, c.config.LockTimeout)
   161  			if err != nil {
   162  				return err
   163  			}
   164  			logger.Debugf(ctx, "key %s locked", key)
   165  			locks[key] = lock
   166  			lockKeys = append(lockKeys, key)
   167  		}
   168  		nodes[n.Name] = n
   169  	}
   170  	return f(ctx, nodes)
   171  }