github.com/jpetazzo/etcd@v0.2.1-0.20140113055439-97f1363afac5/mod/lock/v2/acquire_handler.go (about)

     1  package v2
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"net/http"
     7  	"path"
     8  	"strconv"
     9  	"time"
    10  
    11  	"github.com/coreos/go-etcd/etcd"
    12  	"github.com/gorilla/mux"
    13  )
    14  
    15  // acquireHandler attempts to acquire a lock on the given key.
    16  // The "key" parameter specifies the resource to lock.
    17  // The "value" parameter specifies a value to associate with the lock.
    18  // The "ttl" parameter specifies how long the lock will persist for.
    19  // The "timeout" parameter specifies how long the request should wait for the lock.
    20  func (h *handler) acquireHandler(w http.ResponseWriter, req *http.Request) {
    21  	h.client.SyncCluster()
    22  
    23  	// Setup connection watcher.
    24  	closeNotifier, _ := w.(http.CloseNotifier)
    25  	closeChan := closeNotifier.CloseNotify()
    26  	stopChan := make(chan bool)
    27  
    28  	// Parse the lock "key".
    29  	vars := mux.Vars(req)
    30  	keypath := path.Join(prefix, vars["key"])
    31  	value := req.FormValue("value")
    32  
    33  	// Parse "timeout" parameter.
    34  	var timeout int
    35  	var err error
    36  	if req.FormValue("timeout") == "" {
    37  		timeout = -1
    38  	} else if timeout, err = strconv.Atoi(req.FormValue("timeout")); err != nil {
    39  		http.Error(w, "invalid timeout: " + req.FormValue("timeout"), http.StatusInternalServerError)
    40  		return
    41  	}
    42  	timeout = timeout + 1
    43  
    44  	// Parse TTL.
    45  	ttl, err := strconv.Atoi(req.FormValue("ttl"))
    46  	if err != nil {
    47  		http.Error(w, "invalid ttl: " + req.FormValue("ttl"), http.StatusInternalServerError)
    48  		return
    49  	}
    50  
    51  	// If node exists then just watch it. Otherwise create the node and watch it.
    52  	node, index, pos := h.findExistingNode(keypath, value)
    53  	if index > 0 {
    54  		if pos == 0 {
    55  			// If lock is already acquired then update the TTL.
    56  			h.client.Update(node.Key, node.Value, uint64(ttl))
    57  		} else {
    58  			// Otherwise watch until it becomes acquired (or errors).
    59  			err = h.watch(keypath, index, nil)
    60  		}
    61  	} else {
    62  		index, err = h.createNode(keypath, value, ttl, closeChan, stopChan)
    63  	}
    64  
    65  	// Stop all goroutines.
    66  	close(stopChan)
    67  
    68  	// Write response.
    69  	if err != nil {
    70  		http.Error(w, err.Error(), http.StatusInternalServerError)
    71  	} else {
    72  		w.Write([]byte(strconv.Itoa(index)))
    73  	}
    74  }
    75  
    76  // createNode creates a new lock node and watches it until it is acquired or acquisition fails.
    77  func (h *handler) createNode(keypath string, value string, ttl int, closeChan <- chan bool, stopChan chan bool) (int, error) {
    78  	// Default the value to "-" if it is blank.
    79  	if len(value) == 0 {
    80  		value = "-"
    81  	}
    82  
    83  	// Create an incrementing id for the lock.
    84  	resp, err := h.client.AddChild(keypath, value, uint64(ttl))
    85  	if err != nil {
    86  		return 0, errors.New("acquire lock index error: " + err.Error())
    87  	}
    88  	indexpath := resp.Node.Key
    89  	index, _ := strconv.Atoi(path.Base(indexpath))
    90  
    91  	// Keep updating TTL to make sure lock request is not expired before acquisition.
    92  	go h.ttlKeepAlive(indexpath, value, ttl, stopChan)
    93  
    94  	// Watch until we acquire or fail.
    95  	err = h.watch(keypath, index, closeChan)
    96  
    97  	// Check for connection disconnect before we write the lock index.
    98  	if err != nil {
    99  		select {
   100  		case <-closeChan:
   101  			err = errors.New("acquire lock error: user interrupted")
   102  		default:
   103  		}
   104  	}
   105  
   106  	// Update TTL one last time if acquired. Otherwise delete.
   107  	if err == nil {
   108  		h.client.Update(indexpath, value, uint64(ttl))
   109  	} else {
   110  		h.client.Delete(indexpath, false)
   111  	}
   112  
   113  	return index, err
   114  }
   115  
   116  // findExistingNode search for a node on the lock with the given value.
   117  func (h *handler) findExistingNode(keypath string, value string) (*etcd.Node, int, int) {
   118  	if len(value) > 0 {
   119  		resp, err := h.client.Get(keypath, true, true)
   120  		if err == nil {
   121  			nodes := lockNodes{resp.Node.Nodes}
   122  			if node, pos := nodes.FindByValue(value); node != nil {
   123  				index, _ := strconv.Atoi(path.Base(node.Key))
   124  				return node, index, pos
   125  			}
   126  		}
   127  	}
   128  	return nil, 0, 0
   129  }
   130  
   131  // ttlKeepAlive continues to update a key's TTL until the stop channel is closed.
   132  func (h *handler) ttlKeepAlive(k string, value string, ttl int, stopChan chan bool) {
   133  	for {
   134  		select {
   135  		case <-time.After(time.Duration(ttl / 2) * time.Second):
   136  			h.client.Update(k, value, uint64(ttl))
   137  		case <-stopChan:
   138  			return
   139  		}
   140  	}
   141  }
   142  
   143  // watch continuously waits for a given lock index to be acquired or until lock fails.
   144  // Returns a boolean indicating success.
   145  func (h *handler) watch(keypath string, index int, closeChan <- chan bool) error {
   146  	// Wrap close chan so we can pass it to Client.Watch().
   147  	stopWatchChan := make(chan bool)
   148  	go func() {
   149  		select {
   150  		case <- closeChan:
   151  			stopWatchChan <- true
   152  		case <- stopWatchChan:
   153  		}
   154  	}()
   155  	defer close(stopWatchChan)
   156  
   157  	for {
   158  		// Read all nodes for the lock.
   159  		resp, err := h.client.Get(keypath, true, true)
   160  		if err != nil {
   161  			return fmt.Errorf("lock watch lookup error: %s", err.Error())
   162  		}
   163  		waitIndex := resp.Node.ModifiedIndex
   164  		nodes := lockNodes{resp.Node.Nodes}
   165  		prevIndex := nodes.PrevIndex(index)
   166  
   167  		// If there is no previous index then we have the lock.
   168  		if prevIndex == 0 {
   169  			return nil
   170  		}
   171  
   172  		// Watch previous index until it's gone.
   173  		_, err = h.client.Watch(path.Join(keypath, strconv.Itoa(prevIndex)), waitIndex, false, nil, stopWatchChan)
   174  		if err == etcd.ErrWatchStoppedByUser {
   175  			return fmt.Errorf("lock watch closed")
   176  		} else if err != nil {
   177  			return fmt.Errorf("lock watch error:%s", err.Error())
   178  		}
   179  	}
   180  }