github.com/m-lab/locate@v0.17.6/cmd/heartbeat/main.go (about)

     1  package main
     2  
     3  import (
     4  	"context"
     5  	"flag"
     6  	"fmt"
     7  	"log"
     8  	"net/http"
     9  	"os"
    10  	"os/signal"
    11  	"syscall"
    12  	"time"
    13  
    14  	compute "cloud.google.com/go/compute/apiv1"
    15  	md "cloud.google.com/go/compute/metadata"
    16  	"github.com/gorilla/websocket"
    17  	"github.com/m-lab/go/flagx"
    18  	"github.com/m-lab/go/memoryless"
    19  	"github.com/m-lab/go/prometheusx"
    20  	"github.com/m-lab/go/rtx"
    21  	v2 "github.com/m-lab/locate/api/v2"
    22  	"github.com/m-lab/locate/cmd/heartbeat/health"
    23  	"github.com/m-lab/locate/cmd/heartbeat/metadata"
    24  	"github.com/m-lab/locate/cmd/heartbeat/registration"
    25  	"github.com/m-lab/locate/connection"
    26  	"github.com/m-lab/locate/metrics"
    27  	"github.com/m-lab/locate/static"
    28  )
    29  
    30  var (
    31  	heartbeatURL        string
    32  	hostname			flagx.StringFile
    33  	experiment          string
    34  	pod                 string
    35  	node                string
    36  	namespace           string
    37  	kubernetesAuth      = "/var/run/secrets/kubernetes.io/serviceaccount/"
    38  	kubernetesURL       = flagx.URL{}
    39  	registrationURL     = flagx.URL{}
    40  	services            = flagx.KeyValueArray{}
    41  	heartbeatPeriod     = static.HeartbeatPeriod
    42  	mainCtx, mainCancel = context.WithCancel(context.Background())
    43  	lbPath              = "/metadata/loadbalanced"
    44  )
    45  
    46  // Checker generates a health score for the heartbeat instance (0, 1).
    47  type Checker interface {
    48  	GetHealth(ctx context.Context) float64 // Health score.
    49  }
    50  
    51  func init() {
    52  	flag.StringVar(&heartbeatURL, "heartbeat-url", "ws://localhost:8080/v2/platform/heartbeat",
    53  		"URL for locate service")
    54  	flag.Var(&hostname, "hostname", "The service hostname (may be read from @/path/file)")
    55  	flag.StringVar(&experiment, "experiment", "", "Experiment name")
    56  	flag.StringVar(&pod, "pod", "", "Kubernetes pod name")
    57  	flag.StringVar(&node, "node", "", "Kubernetes node name")
    58  	flag.StringVar(&namespace, "namespace", "", "Kubernetes namespace")
    59  	flag.Var(&kubernetesURL, "kubernetes-url", "URL for Kubernetes API")
    60  	flag.Var(&registrationURL, "registration-url", "URL for site registration")
    61  	flag.Var(&services, "services", "Maps experiment target names to their set of services")
    62  }
    63  
    64  func main() {
    65  	flag.Parse()
    66  	rtx.Must(flagx.ArgsFromEnvWithLog(flag.CommandLine, false), "failed to read args from env")
    67  
    68  	// Start metrics server.
    69  	prom := prometheusx.MustServeMetrics()
    70  	defer prom.Close()
    71  
    72  	// Load registration data.
    73  	ldrConfig := memoryless.Config{
    74  		Min:      static.RegistrationLoadMin,
    75  		Expected: static.RegistrationLoadExpected,
    76  		Max:      static.RegistrationLoadMax,
    77  	}
    78  	svcs := services.Get()
    79  	ldr, err := registration.NewLoader(mainCtx, registrationURL.URL, hostname.Value, experiment, svcs, ldrConfig)
    80  	rtx.Must(err, "could not initialize registration loader")
    81  	r, err := ldr.GetRegistration(mainCtx)
    82  	rtx.Must(err, "could not load registration data")
    83  	hbm := v2.HeartbeatMessage{Registration: r}
    84  
    85  	// Establish a connection.
    86  	conn := connection.NewConn()
    87  	err = conn.Dial(heartbeatURL, http.Header{}, hbm)
    88  	rtx.Must(err, "failed to establish a websocket connection with %s", heartbeatURL)
    89  
    90  	probe := health.NewPortProbe(svcs)
    91  	ec := health.NewEndpointClient(static.HealthEndpointTimeout)
    92  	var hc Checker
    93  
    94  	// TODO(kinkade): cause a fatal error if lberr is not nil. Not fatally
    95  	// exiting on lberr is just a workaround to get this rolled out while we
    96  	// wait for every physical machine on the platform to actually have that
    97  	// file, which won't be the case until the rolling reboot in production
    98  	// completes in 4 or 5 days, as of this comment 2024-08-06.
    99  	lbbytes, lberr := os.ReadFile(lbPath)
   100  
   101  	// If the "loadbalanced" file exists, then make sure that the content of the
   102  	// file is "true". If the file doesn't exist, then, for now, just consider
   103  	// the machine as not loadbalanced.
   104  	if lberr == nil && string(lbbytes) == "true" {
   105  		gcpmd, err := metadata.NewGCPMetadata(md.NewClient(http.DefaultClient), hostname.Value)
   106  		rtx.Must(err, "failed to get VM metadata")
   107  		gceClient, err := compute.NewRegionBackendServicesRESTClient(mainCtx)
   108  		rtx.Must(err, "failed to create GCE client")
   109  		hc = health.NewGCPChecker(gceClient, gcpmd)
   110  	} else if kubernetesURL.URL == nil {
   111  		hc = health.NewChecker(probe, ec)
   112  	} else {
   113  		k8s := health.MustNewKubernetesClient(kubernetesURL.URL, pod, node, namespace, kubernetesAuth)
   114  		hc = health.NewCheckerK8S(probe, k8s, ec)
   115  	}
   116  
   117  	write(conn, hc, ldr)
   118  }
   119  
   120  // write starts a write loop to send health messages every
   121  // HeartbeatPeriod.
   122  func write(ws *connection.Conn, hc Checker, ldr *registration.Loader) {
   123  	defer ws.Close()
   124  	hbTicker := *time.NewTicker(heartbeatPeriod)
   125  	defer hbTicker.Stop()
   126  
   127  	// Register the channel to receive SIGTERM events.
   128  	sigterm := make(chan os.Signal, 1)
   129  	defer close(sigterm)
   130  	signal.Notify(sigterm, syscall.SIGTERM)
   131  
   132  	defer ldr.Ticker.Stop()
   133  
   134  	for {
   135  		select {
   136  		case <-mainCtx.Done():
   137  			log.Println("context cancelled")
   138  			sendExitMessage(ws)
   139  			return
   140  		case <-sigterm:
   141  			log.Println("received SIGTERM")
   142  			sendExitMessage(ws)
   143  			mainCancel()
   144  			return
   145  		case <-ldr.Ticker.C:
   146  			reg, err := ldr.GetRegistration(mainCtx)
   147  			if err != nil {
   148  				log.Printf("could not load registration data, err: %v", err)
   149  			}
   150  			if reg != nil {
   151  				sendMessage(ws, v2.HeartbeatMessage{Registration: reg}, "registration")
   152  				log.Printf("updated registration to %v", reg)
   153  			}
   154  		case <-hbTicker.C:
   155  			t := time.Now()
   156  			score := getHealth(hc)
   157  			healthMsg := v2.Health{Score: score}
   158  			hbm := v2.HeartbeatMessage{Health: &healthMsg}
   159  			sendMessage(ws, hbm, "health")
   160  
   161  			// Record duration metric.
   162  			fmtScore := fmt.Sprintf("%.1f", score)
   163  			metrics.HealthTransmissionDuration.WithLabelValues(fmtScore).Observe(time.Since(t).Seconds())
   164  		}
   165  	}
   166  }
   167  
   168  func getHealth(hc Checker) float64 {
   169  	ctx, cancel := context.WithTimeout(mainCtx, heartbeatPeriod)
   170  	defer cancel()
   171  	return hc.GetHealth(ctx)
   172  }
   173  
   174  func sendMessage(ws *connection.Conn, hbm v2.HeartbeatMessage, msgType string) {
   175  	// If a new registration message was found, update the websocket's dial message.
   176  	// The message is sent whenever the connection is restarted (i.e., once per hour in App Engine).
   177  	if msgType == "registration" {
   178  		ws.DialMessage = hbm
   179  	}
   180  
   181  	err := ws.WriteMessage(websocket.TextMessage, hbm)
   182  	if err != nil {
   183  		log.Printf("failed to write %s message, err: %v", msgType, err)
   184  	}
   185  }
   186  
   187  func sendExitMessage(ws *connection.Conn) {
   188  	// Notify the receiver that the health score should now be 0.
   189  	hbm := v2.HeartbeatMessage{
   190  		Health: &v2.Health{
   191  			Score: 0,
   192  		},
   193  	}
   194  	sendMessage(ws, hbm, "final health")
   195  }