github.com/square/finch@v0.0.0-20240412205204-6530c03e2b96/compute/api.go (about)

     1  package compute
     2  
     3  import (
     4  	"encoding/json"
     5  	"fmt"
     6  	"io"
     7  	"io/ioutil"
     8  	"log"
     9  	"net"
    10  	"net/http"
    11  	"strconv"
    12  	"strings"
    13  	"sync"
    14  	"time"
    15  
    16  	"github.com/square/finch"
    17  	"github.com/square/finch/config"
    18  	"github.com/square/finch/stats"
    19  )
    20  
    21  type API struct {
    22  	*sync.Mutex
    23  	httpServer *http.Server
    24  	stage      *stageMeta // current stage
    25  	prev       map[string]string
    26  }
    27  
    28  const (
    29  	ready byte = iota
    30  	booting
    31  	runnable
    32  	running
    33  )
    34  
    35  type stageMeta struct {
    36  	*sync.Mutex
    37  	cfg      config.Stage
    38  	nRemotes uint
    39  	bootChan chan ack         // 1. <-client after booting stage
    40  	runChan  chan struct{}    // 2. server closes to signal clients to run
    41  	doneChan chan ack         // 3. <-client after running stage
    42  	stats    *stats.Collector // receives stats from clients while running
    43  	booted   bool
    44  	done     bool
    45  	clients  map[string]*client
    46  }
    47  
    48  type client struct {
    49  	name  string
    50  	stage *stageMeta
    51  	state byte
    52  }
    53  
    54  func NewAPI(addr string) *API {
    55  	a := &API{
    56  		Mutex: &sync.Mutex{},
    57  	}
    58  
    59  	// HTTP server that client instances calls
    60  	mux := http.NewServeMux()
    61  	mux.HandleFunc("/boot", a.boot)
    62  	mux.HandleFunc("/file", a.file)
    63  	mux.HandleFunc("/run", a.run)
    64  	mux.HandleFunc("/stats", a.stats)
    65  	mux.HandleFunc("/ping", a.ping)
    66  	a.httpServer = &http.Server{
    67  		Addr:    addr,
    68  		Handler: mux,
    69  	}
    70  
    71  	// Make sure we can bind to addr:port. ListenAndServe will return an error
    72  	// but it's run in a goroutine so that error will occur async to the boot,
    73  	// which is a poor experience: failure a millisecond after boot. This makes
    74  	// it sync, so nothing boots if it fails. ListenAndServe might still fail
    75  	// for other reasons, but that's unlikely, so this check is good enough.
    76  	ln, err := net.Listen("tcp", addr)
    77  	if err != nil {
    78  		log.Fatal(err)
    79  	}
    80  	ln.Close()
    81  	go func() {
    82  		if err := a.httpServer.ListenAndServe(); err != nil {
    83  			log.Fatal(err)
    84  		}
    85  		log.Println("Listening on", addr)
    86  	}()
    87  	return a
    88  }
    89  
    90  // ServeHTTP implements the http.HandlerFunc interface.
    91  func (a *API) ServeHTTP(w http.ResponseWriter, r *http.Request) {
    92  	a.httpServer.Handler.ServeHTTP(w, r)
    93  }
    94  
    95  func (a *API) Stage(newStage *stageMeta) error {
    96  	if newStage != nil {
    97  		finch.Debug("new stage %s (%s)", newStage.cfg.Name, newStage.cfg.Id)
    98  	}
    99  
   100  	// If there's no current stage, set new one and done
   101  	a.Lock()
   102  	if a.stage == nil {
   103  		a.stage = newStage
   104  		a.Unlock()
   105  		return nil
   106  	}
   107  
   108  	// Stop current (old) stage before setting new stage.
   109  	oldStage := a.stage
   110  	a.Unlock()
   111  
   112  	// Signal clients that stage has stopped early
   113  	finch.Debug("stop old stage %s (%s)", oldStage.cfg.Name, oldStage.cfg.Id)
   114  	oldStage.Lock()
   115  	oldStage.done = true
   116  	if oldStage.cfg.Test {
   117  		close(oldStage.runChan)
   118  	}
   119  	oldStage.Unlock()
   120  
   121  	// Wait for clients to check in (GET /run), be signaled that stage.done=true,
   122  	// send final stats, then call POST /run to terminate
   123  	timeout := time.After(3 * time.Second)
   124  	for {
   125  		time.Sleep(100 * time.Millisecond)
   126  		select {
   127  		case <-timeout:
   128  			finch.Debug("timeout waiting for clients to reset")
   129  			break
   130  		default:
   131  		}
   132  		oldStage.Lock()
   133  		n := len(oldStage.clients)
   134  		oldStage.Unlock()
   135  		if n == 0 {
   136  			break
   137  		}
   138  	}
   139  
   140  	oldStage.Lock()
   141  	if len(oldStage.clients) > 0 {
   142  		log.Printf("%d clients did not stop, ignoring (stats will be lost): %v", len(oldStage.clients), oldStage.clients)
   143  	}
   144  	oldStage.Unlock()
   145  
   146  	// Set new stage now that old stage has stopped
   147  	a.Lock()
   148  	a.stage = newStage
   149  	a.Unlock()
   150  
   151  	return nil
   152  }
   153  
   154  func (a *API) boot(w http.ResponseWriter, r *http.Request) {
   155  	rc, get, ok := a.client(w, r, true) // true == allow new clients on GET /boot
   156  	if !ok {
   157  		return // client() wrote error response
   158  	}
   159  
   160  	if get {
   161  		// GET /boot: client is booting, waiting to receive config.File
   162  		if rc.state != ready {
   163  			w.WriteHeader(http.StatusPreconditionFailed)
   164  			return
   165  		}
   166  
   167  		// Wait until there's a stage that's not done booting (needs more instances)
   168  		log.Printf("Remote %s ready to boot\n", rc.name)
   169  		for {
   170  			// Has server set a stage?
   171  			a.Lock()
   172  			stage := a.stage // copy ptr
   173  			if stage == nil || stage.done {
   174  				a.Unlock()
   175  				goto RETRY // no stage
   176  			}
   177  
   178  			// Is the stage still booting (waiting for instances)?
   179  			stage.Lock()
   180  			if stage.booted || len(stage.clients) == int(stage.nRemotes) {
   181  				stage.Unlock()
   182  				a.Unlock()
   183  				goto RETRY // stage is full
   184  			}
   185  
   186  			// Stage is ready and there's a space for this client
   187  			stage.clients[rc.name] = rc
   188  			rc.stage = stage
   189  			rc.state = booting // advance client state
   190  
   191  			// Unwind locks before sending stage config via HTTP in case net is slow
   192  			stage.Unlock()
   193  			a.Unlock()
   194  
   195  			finch.Debug("assigned %s to stage %s (%s): %d of %d clients", rc.name, stage.cfg.Name, stage.cfg.Id,
   196  				len(stage.clients), stage.nRemotes)
   197  			json.NewEncoder(w).Encode(stage.cfg) // send stage config
   198  			return
   199  
   200  		RETRY:
   201  			time.Sleep(200 * time.Millisecond)
   202  		}
   203  	} else {
   204  		// POST /boot: client is ack'ing previous GET /boot; body is error message, if any
   205  		if rc.state != booting {
   206  			w.WriteHeader(http.StatusPreconditionFailed)
   207  			return
   208  		}
   209  
   210  		body, err := io.ReadAll(r.Body)
   211  		if err != nil {
   212  			log.Printf("error reading error from client: %s", err)
   213  			return
   214  		}
   215  		r.Body.Close()
   216  		w.WriteHeader(http.StatusOK)
   217  
   218  		// Remote might fail to boot. If that's the case, do not advance its state;
   219  		// it should call GET /boot again to reset itself and try again.
   220  		var clientErr error
   221  		if string(body) != "" {
   222  			// Don't advance state: client failed to boot, so it's not ready to run
   223  			clientErr = fmt.Errorf("%s", string(body))
   224  		} else {
   225  			rc.state = runnable // advance client state (successful boot)
   226  		}
   227  		rc.stage.bootChan <- ack{name: rc.name, err: clientErr}
   228  
   229  	}
   230  }
   231  
   232  func (a *API) file(w http.ResponseWriter, r *http.Request) {
   233  	rc, _, ok := a.client(w, r, false)
   234  	if !ok {
   235  		return // client() wrote error response
   236  	}
   237  
   238  	if rc.state != booting {
   239  		w.WriteHeader(http.StatusPreconditionFailed)
   240  		return
   241  	}
   242  
   243  	// Parse file ref 'stage=...&i=...' from URL
   244  	q := r.URL.Query()
   245  	finch.Debug("file params %+v", q)
   246  	vals, ok := q["stage"]
   247  	if !ok {
   248  		http.Error(w, "missing stage param in URL query: ?stage=...", http.StatusBadRequest)
   249  		return
   250  	}
   251  	if len(vals) == 0 {
   252  		http.Error(w, "stage param has no value, expected stage name", http.StatusBadRequest)
   253  		return
   254  	}
   255  
   256  	vals, ok = q["i"]
   257  	if !ok {
   258  		http.Error(w, "missing i param in URL query: i=N", http.StatusBadRequest)
   259  		return
   260  	}
   261  	if len(vals) == 0 {
   262  		http.Error(w, "i param has no value, expected file number", http.StatusBadRequest)
   263  		return
   264  	}
   265  	i, err := strconv.Atoi(clean(vals[0]))
   266  	if err != nil {
   267  		http.Error(w, "i param is not an integer", http.StatusBadRequest)
   268  		return
   269  	}
   270  	if i < 0 {
   271  		http.Error(w, "i param is negative", http.StatusBadRequest)
   272  		return
   273  	}
   274  	s := rc.stage.cfg // shortcut
   275  	if i > len(s.Trx)-1 {
   276  		http.Error(w, "i param out of range for stage "+s.Name, http.StatusBadRequest)
   277  		return
   278  	}
   279  
   280  	log.Printf("Sending file %s to %s...", s.Trx[i].File, rc.name)
   281  
   282  	// Read file and send it to the client instance
   283  	bytes, err := ioutil.ReadFile(s.Trx[i].File)
   284  	if err != nil {
   285  		http.Error(w, err.Error(), http.StatusInternalServerError)
   286  		return
   287  	}
   288  
   289  	w.Write(bytes)
   290  	log.Printf("Sent file %s to %s", s.Trx[i].File, rc.name)
   291  }
   292  
   293  func (a *API) run(w http.ResponseWriter, r *http.Request) {
   294  	rc, get, ok := a.client(w, r, false)
   295  	if !ok {
   296  		return // client() wrote error response
   297  	}
   298  
   299  	if get {
   300  		// GET /run: client is waiting for signal to run previously booted stage
   301  		if rc.state != runnable {
   302  			w.WriteHeader(http.StatusPreconditionFailed)
   303  			return
   304  		}
   305  
   306  		// Remote is waiting for next stage to run
   307  		log.Printf("Remote %s waiting to start...", rc.name)
   308  		<-rc.stage.runChan // closed in Server.Run, or api.Stage if --test
   309  
   310  		// If boot --test and there's a new stage, Server.Boot calls api.Stage
   311  		// which will stop the old stage and trigger this block.
   312  		rc.stage.Lock()
   313  		if rc.stage.done {
   314  			delete(rc.stage.clients, rc.name)
   315  			rc.stage.Unlock()
   316  			w.WriteHeader(http.StatusResetContent) // reset
   317  			return
   318  		}
   319  		rc.stage.Unlock()
   320  
   321  		w.WriteHeader(http.StatusOK)
   322  		if _, err := w.Write([]byte{0}); err != nil {
   323  			log.Printf("Lost client %s on stage %s, but it will return\n", rc.name, rc.stage.cfg.Name)
   324  			return
   325  		}
   326  
   327  		log.Printf("Started client %s on stage %s\n", rc.name, rc.stage.cfg.Name)
   328  		rc.state = running // advance client state
   329  	} else {
   330  		// POST /run: client is done running stage
   331  		if rc.state != running {
   332  			w.WriteHeader(http.StatusPreconditionFailed)
   333  			return
   334  		}
   335  
   336  		body, err := io.ReadAll(r.Body)
   337  		if err != nil {
   338  			// Ignore error; it doesn't change fact that client is done
   339  			log.Printf("Error reading error from client on POST /run, ignoring: %s", err)
   340  		}
   341  		r.Body.Close()
   342  		w.WriteHeader(http.StatusOK)
   343  
   344  		rc.stage.Lock()
   345  		delete(rc.stage.clients, rc.name)
   346  		rc.stage.Unlock()
   347  
   348  		// Tell server client completed stage
   349  		var clientErr error
   350  		if string(body) != "" {
   351  			clientErr = fmt.Errorf("%s", string(body))
   352  		}
   353  		rc.stage.doneChan <- ack{name: rc.name, err: clientErr}
   354  		rc.state = ready // advance client state (ready to run another stage)
   355  	}
   356  }
   357  
   358  func (a *API) stats(w http.ResponseWriter, r *http.Request) {
   359  	rc, _, ok := a.client(w, r, false)
   360  	if !ok {
   361  		return // client() wrote error response
   362  	}
   363  
   364  	// Stats are sent only while running. If this error occurs, it might just be
   365  	// a network issue that delayed stats sent earlier (before client stopped running).
   366  	// If it happens frequently, then it's probably a bug in Finch.
   367  	if rc.state != running {
   368  		w.WriteHeader(http.StatusPreconditionFailed)
   369  		return
   370  	}
   371  
   372  	body, err := io.ReadAll(r.Body)
   373  	if err != nil {
   374  		log.Printf("error reading error from client: %s", err)
   375  		return
   376  	}
   377  	r.Body.Close()
   378  	w.WriteHeader(http.StatusOK)
   379  
   380  	var s stats.Instance
   381  	if err := json.Unmarshal(body, &s); err != nil {
   382  		log.Printf("Invalid stats from %s: %s", rc.name, err)
   383  		return
   384  	}
   385  
   386  	if rc.stage.stats != nil {
   387  		rc.stage.stats.Recv(s)
   388  	}
   389  }
   390  
   391  func (a *API) ping(w http.ResponseWriter, r *http.Request) {
   392  	rc, _, ok := a.client(w, r, false)
   393  	if !ok {
   394  		return // client() wrote error response
   395  	}
   396  	rc.stage.Lock()
   397  	done := rc.stage.done
   398  	rc.stage.Unlock()
   399  	if done {
   400  		log.Printf("Stage done, resetting %s", rc.name)
   401  		w.WriteHeader(http.StatusResetContent) // reset
   402  		return
   403  	}
   404  	w.WriteHeader(http.StatusOK) // keep running
   405  }
   406  
   407  // --------------------------------------------------------------------------
   408  
   409  func (a *API) client(w http.ResponseWriter, r *http.Request, boot bool) (*client, bool, bool) {
   410  	finch.Debug("%v", r)
   411  
   412  	// GET or POST
   413  	get := true
   414  	switch r.Method {
   415  	case http.MethodGet: // allowed
   416  	case http.MethodPost: // allowed
   417  		get = false
   418  	default:
   419  		w.WriteHeader(http.StatusMethodNotAllowed)
   420  		return nil, false, false
   421  	}
   422  
   423  	// ?name=...
   424  	q := r.URL.Query()
   425  	if len(q) == 0 {
   426  		http.Error(w, "missing URL query: ?name=...", http.StatusBadRequest)
   427  		return nil, false, false
   428  	}
   429  	vals, ok := q["name"]
   430  	if !ok {
   431  		http.Error(w, "missing name param in URL query: ?name=...", http.StatusBadRequest)
   432  		return nil, false, false
   433  	}
   434  	if len(vals) == 0 {
   435  		http.Error(w, "name param has no value, expected instance name", http.StatusBadRequest)
   436  		return nil, false, false
   437  	}
   438  	name := clean(vals[0])
   439  
   440  	vals, ok = q["stage-id"]
   441  	if !ok {
   442  		http.Error(w, "missing stage-id param in URL query: ?stage-id=...", http.StatusBadRequest)
   443  		return nil, false, false
   444  	}
   445  	if len(vals) == 0 {
   446  		http.Error(w, "stage-id param has no value", http.StatusBadRequest)
   447  		return nil, false, false
   448  	}
   449  	sid := clean(vals[0])
   450  
   451  	a.Lock()
   452  	defer a.Unlock()
   453  
   454  	// Has server set a stage? Instances can connect before server is ready.
   455  	if a.stage == nil {
   456  		w.WriteHeader(http.StatusGone)
   457  		return nil, false, false
   458  	}
   459  
   460  	a.stage.Lock()
   461  	defer a.stage.Unlock()
   462  
   463  	// Is instance assigned to the current stage?
   464  	rc, ok := a.stage.clients[name]
   465  	if !ok {
   466  
   467  		// Instance not assigned to the stage, but that's ok if it's trying
   468  		// to boot and join the stage.
   469  		if get && boot {
   470  			finch.Debug("new client")
   471  			rc = &client{
   472  				name:  name,
   473  				state: ready,
   474  			}
   475  			// Do not add to stage.clients; that's done in boot() if this client
   476  			// is assigned to the stage
   477  			return rc, get, true // success (new client)
   478  		}
   479  
   480  		// Instance not assigned to stage and not booting, so it's out of sync
   481  		log.Printf("Unknown client: %s", name)
   482  		w.WriteHeader(http.StatusGone) // reset
   483  		return nil, false, false
   484  	}
   485  
   486  	// Instance is assigned to the stage, but check stage ID to make sure a bad
   487  	// network partition (or some other net delay/weirdness) hasn't caused a
   488  	// _past_ query from the instance to finally reach us now after the stage
   489  	// has changed.
   490  	if !a.stage.done && a.stage.cfg.Id != sid {
   491  		log.Printf("Wrong stage ID: %s: client %s != current %s", name, sid, a.stage.cfg.Id)
   492  		w.WriteHeader(http.StatusGone) // reset
   493  		return nil, false, false
   494  	}
   495  
   496  	return rc, get, true // success
   497  }
   498  
   499  // clean removes \n\r to avoid code scanning alert "Log entries created from user input".
   500  func clean(s string) string {
   501  	c := strings.Replace(s, "\n", "", -1)
   502  	return strings.Replace(c, "\r", "", -1)
   503  }