agones.dev/agones@v1.54.0/cmd/allocator/main.go (about)

     1  // Copyright 2019 Google LLC All Rights Reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //	http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  package main
    15  
    16  import (
    17  	"context"
    18  	"crypto/tls"
    19  	"crypto/x509"
    20  	"fmt"
    21  	"net"
    22  	"net/http"
    23  	"os"
    24  	"path/filepath"
    25  	"strings"
    26  	"sync"
    27  	"time"
    28  
    29  	"agones.dev/agones/pkg"
    30  	"agones.dev/agones/pkg/allocation/converters"
    31  	pb "agones.dev/agones/pkg/allocation/go"
    32  	allocationv1 "agones.dev/agones/pkg/apis/allocation/v1"
    33  	"agones.dev/agones/pkg/client/clientset/versioned"
    34  	"agones.dev/agones/pkg/client/informers/externalversions"
    35  	"agones.dev/agones/pkg/gameserverallocations"
    36  	"agones.dev/agones/pkg/gameservers"
    37  	"agones.dev/agones/pkg/metrics"
    38  	"agones.dev/agones/pkg/processor"
    39  	"agones.dev/agones/pkg/util/fswatch"
    40  	"github.com/heptiolabs/healthcheck"
    41  	"github.com/pkg/errors"
    42  	"github.com/sirupsen/logrus"
    43  	"github.com/spf13/pflag"
    44  	"github.com/spf13/viper"
    45  	"go.opencensus.io/plugin/ocgrpc"
    46  	"google.golang.org/grpc"
    47  	"google.golang.org/grpc/codes"
    48  	"google.golang.org/grpc/credentials"
    49  	grpchealth "google.golang.org/grpc/health"
    50  	"google.golang.org/grpc/health/grpc_health_v1"
    51  	"google.golang.org/grpc/keepalive"
    52  	"google.golang.org/grpc/status"
    53  	metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
    54  	k8sruntime "k8s.io/apimachinery/pkg/runtime"
    55  	"k8s.io/client-go/informers"
    56  	"k8s.io/client-go/kubernetes"
    57  	"k8s.io/client-go/rest"
    58  
    59  	"agones.dev/agones/pkg/util/httpserver"
    60  	"agones.dev/agones/pkg/util/runtime"
    61  	"agones.dev/agones/pkg/util/signals"
    62  )
    63  
    64  var (
    65  	podReady bool
    66  	logger   = runtime.NewLoggerWithSource("main")
    67  )
    68  
    69  const (
    70  	certDir = "/home/allocator/client-ca/"
    71  	tlsDir  = "/home/allocator/tls/"
    72  )
    73  
    74  const (
    75  	httpPortFlag                     = "http-port"
    76  	grpcPortFlag                     = "grpc-port"
    77  	enableStackdriverMetricsFlag     = "stackdriver-exporter"
    78  	enablePrometheusMetricsFlag      = "prometheus-exporter"
    79  	projectIDFlag                    = "gcp-project-id"
    80  	stackdriverLabels                = "stackdriver-labels"
    81  	mTLSDisabledFlag                 = "disable-mtls"
    82  	tlsDisabledFlag                  = "disable-tls"
    83  	remoteAllocationTimeoutFlag      = "remote-allocation-timeout"
    84  	totalRemoteAllocationTimeoutFlag = "total-remote-allocation-timeout"
    85  	apiServerSustainedQPSFlag        = "api-server-qps"
    86  	apiServerBurstQPSFlag            = "api-server-qps-burst"
    87  	logLevelFlag                     = "log-level"
    88  	allocationBatchWaitTime          = "allocation-batch-wait-time"
    89  	readinessShutdownDuration        = "readiness-shutdown-duration"
    90  	httpUnallocatedStatusCode        = "http-unallocated-status-code"
    91  	processorGRPCAddress             = "processor-grpc-address"
    92  	processorGRPCPort                = "processor-grpc-port"
    93  	processorMaxBatchSize            = "processor-max-batch-size"
    94  )
    95  
    96  func parseEnvFlags() config {
    97  	viper.SetDefault(httpPortFlag, -1)
    98  	viper.SetDefault(grpcPortFlag, -1)
    99  	viper.SetDefault(apiServerSustainedQPSFlag, 400)
   100  	viper.SetDefault(apiServerBurstQPSFlag, 500)
   101  	viper.SetDefault(enablePrometheusMetricsFlag, true)
   102  	viper.SetDefault(enableStackdriverMetricsFlag, false)
   103  	viper.SetDefault(projectIDFlag, "")
   104  	viper.SetDefault(stackdriverLabels, "")
   105  	viper.SetDefault(mTLSDisabledFlag, false)
   106  	viper.SetDefault(tlsDisabledFlag, false)
   107  	viper.SetDefault(remoteAllocationTimeoutFlag, 10*time.Second)
   108  	viper.SetDefault(totalRemoteAllocationTimeoutFlag, 30*time.Second)
   109  	viper.SetDefault(logLevelFlag, "Info")
   110  	viper.SetDefault(allocationBatchWaitTime, 500*time.Millisecond)
   111  	viper.SetDefault(httpUnallocatedStatusCode, http.StatusTooManyRequests)
   112  	viper.SetDefault(processorGRPCAddress, "agones-processor.agones-system.svc.cluster.local")
   113  	viper.SetDefault(processorGRPCPort, 9090)
   114  	viper.SetDefault(processorMaxBatchSize, 100)
   115  
   116  	pflag.Int32(httpPortFlag, viper.GetInt32(httpPortFlag), "Port to listen on for REST requests")
   117  	pflag.Int32(grpcPortFlag, viper.GetInt32(grpcPortFlag), "Port to listen on for gRPC requests")
   118  	pflag.Int32(apiServerSustainedQPSFlag, viper.GetInt32(apiServerSustainedQPSFlag), "Maximum sustained queries per second to send to the API server")
   119  	pflag.Int32(apiServerBurstQPSFlag, viper.GetInt32(apiServerBurstQPSFlag), "Maximum burst queries per second to send to the API server")
   120  	pflag.Bool(enablePrometheusMetricsFlag, viper.GetBool(enablePrometheusMetricsFlag), "Flag to activate metrics of Agones. Can also use PROMETHEUS_EXPORTER env variable.")
   121  	pflag.Bool(enableStackdriverMetricsFlag, viper.GetBool(enableStackdriverMetricsFlag), "Flag to activate stackdriver monitoring metrics for Agones. Can also use STACKDRIVER_EXPORTER env variable.")
   122  	pflag.String(projectIDFlag, viper.GetString(projectIDFlag), "GCP ProjectID used for Stackdriver, if not specified ProjectID from Application Default Credentials would be used. Can also use GCP_PROJECT_ID env variable.")
   123  	pflag.String(stackdriverLabels, viper.GetString(stackdriverLabels), "A set of default labels to add to all stackdriver metrics generated. By default metadata are automatically added using Kubernetes API and GCP metadata enpoint.")
   124  	pflag.Bool(mTLSDisabledFlag, viper.GetBool(mTLSDisabledFlag), "Flag to enable/disable mTLS in the allocator.")
   125  	pflag.Bool(tlsDisabledFlag, viper.GetBool(tlsDisabledFlag), "Flag to enable/disable TLS in the allocator.")
   126  	pflag.Duration(remoteAllocationTimeoutFlag, viper.GetDuration(remoteAllocationTimeoutFlag), "Flag to set remote allocation call timeout.")
   127  	pflag.Duration(totalRemoteAllocationTimeoutFlag, viper.GetDuration(totalRemoteAllocationTimeoutFlag), "Flag to set total remote allocation timeout including retries.")
   128  	pflag.String(logLevelFlag, viper.GetString(logLevelFlag), "Agones Log level")
   129  	pflag.Duration(allocationBatchWaitTime, viper.GetDuration(allocationBatchWaitTime), "Flag to configure the waiting period between allocations batches")
   130  	pflag.Duration(readinessShutdownDuration, viper.GetDuration(readinessShutdownDuration), "Time in seconds for SIGTERM/SIGINT handler to sleep for.")
   131  	pflag.Int32(httpUnallocatedStatusCode, viper.GetInt32(httpUnallocatedStatusCode), "HTTP status code to return when no GameServer is available")
   132  	pflag.String(processorGRPCAddress, viper.GetString(processorGRPCAddress), "The gRPC address of the Agones Processor service")
   133  	pflag.Int32(processorGRPCPort, viper.GetInt32(processorGRPCPort), "The gRPC port of the Agones Processor service")
   134  	pflag.Int32(processorMaxBatchSize, viper.GetInt32(processorMaxBatchSize), "The maximum batch size to send to the Agones Processor service")
   135  
   136  	runtime.FeaturesBindFlags()
   137  	pflag.Parse()
   138  
   139  	viper.SetEnvKeyReplacer(strings.NewReplacer("-", "_"))
   140  	runtime.Must(viper.BindEnv(httpPortFlag))
   141  	runtime.Must(viper.BindEnv(grpcPortFlag))
   142  	runtime.Must(viper.BindEnv(apiServerSustainedQPSFlag))
   143  	runtime.Must(viper.BindEnv(apiServerBurstQPSFlag))
   144  	runtime.Must(viper.BindEnv(enablePrometheusMetricsFlag))
   145  	runtime.Must(viper.BindEnv(enableStackdriverMetricsFlag))
   146  	runtime.Must(viper.BindEnv(projectIDFlag))
   147  	runtime.Must(viper.BindEnv(stackdriverLabels))
   148  	runtime.Must(viper.BindEnv(mTLSDisabledFlag))
   149  	runtime.Must(viper.BindEnv(tlsDisabledFlag))
   150  	runtime.Must(viper.BindEnv(remoteAllocationTimeoutFlag))
   151  	runtime.Must(viper.BindEnv(totalRemoteAllocationTimeoutFlag))
   152  	runtime.Must(viper.BindEnv(logLevelFlag))
   153  	runtime.Must(viper.BindEnv(allocationBatchWaitTime))
   154  	runtime.Must(viper.BindEnv(readinessShutdownDuration))
   155  	runtime.Must(viper.BindEnv(httpUnallocatedStatusCode))
   156  	runtime.Must(viper.BindPFlags(pflag.CommandLine))
   157  	runtime.Must(runtime.FeaturesBindEnv())
   158  
   159  	runtime.Must(runtime.ParseFeaturesFromEnv())
   160  
   161  	return config{
   162  		HTTPPort:                     int(viper.GetInt32(httpPortFlag)),
   163  		GRPCPort:                     int(viper.GetInt32(grpcPortFlag)),
   164  		APIServerSustainedQPS:        int(viper.GetInt32(apiServerSustainedQPSFlag)),
   165  		APIServerBurstQPS:            int(viper.GetInt32(apiServerBurstQPSFlag)),
   166  		PrometheusMetrics:            viper.GetBool(enablePrometheusMetricsFlag),
   167  		Stackdriver:                  viper.GetBool(enableStackdriverMetricsFlag),
   168  		GCPProjectID:                 viper.GetString(projectIDFlag),
   169  		StackdriverLabels:            viper.GetString(stackdriverLabels),
   170  		MTLSDisabled:                 viper.GetBool(mTLSDisabledFlag),
   171  		TLSDisabled:                  viper.GetBool(tlsDisabledFlag),
   172  		LogLevel:                     viper.GetString(logLevelFlag),
   173  		remoteAllocationTimeout:      viper.GetDuration(remoteAllocationTimeoutFlag),
   174  		totalRemoteAllocationTimeout: viper.GetDuration(totalRemoteAllocationTimeoutFlag),
   175  		allocationBatchWaitTime:      viper.GetDuration(allocationBatchWaitTime),
   176  		ReadinessShutdownDuration:    viper.GetDuration(readinessShutdownDuration),
   177  		httpUnallocatedStatusCode:    int(viper.GetInt32(httpUnallocatedStatusCode)),
   178  		processorGRPCAddress:         viper.GetString(processorGRPCAddress),
   179  		processorGRPCPort:            int(viper.GetInt32(processorGRPCPort)),
   180  		processorMaxBatchSize:        int(viper.GetInt32(processorMaxBatchSize)),
   181  	}
   182  }
   183  
   184  type config struct {
   185  	GRPCPort                     int
   186  	HTTPPort                     int
   187  	APIServerSustainedQPS        int
   188  	APIServerBurstQPS            int
   189  	TLSDisabled                  bool
   190  	MTLSDisabled                 bool
   191  	PrometheusMetrics            bool
   192  	Stackdriver                  bool
   193  	GCPProjectID                 string
   194  	StackdriverLabels            string
   195  	LogLevel                     string
   196  	totalRemoteAllocationTimeout time.Duration
   197  	remoteAllocationTimeout      time.Duration
   198  	allocationBatchWaitTime      time.Duration
   199  	ReadinessShutdownDuration    time.Duration
   200  	httpUnallocatedStatusCode    int
   201  	processorGRPCAddress         string
   202  	processorGRPCPort            int
   203  	processorMaxBatchSize        int
   204  }
   205  
   206  // grpcHandlerFunc returns an http.Handler that delegates to grpcServer on incoming gRPC
   207  // connections or otherHandler otherwise. Copied from https://github.com/philips/grpc-gateway-example.
   208  func grpcHandlerFunc(grpcServer http.Handler, otherHandler http.Handler) http.Handler {
   209  	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   210  		// This is a partial recreation of gRPC's internal checks https://github.com/grpc/grpc-go/pull/514/files#diff-95e9a25b738459a2d3030e1e6fa2a718R61
   211  		// We switch on HTTP/1.1 or HTTP/2 by checking the ProtoMajor
   212  		if r.ProtoMajor == 2 && strings.Contains(r.Header.Get("Content-Type"), "application/grpc") {
   213  			grpcServer.ServeHTTP(w, r)
   214  		} else {
   215  			otherHandler.ServeHTTP(w, r)
   216  		}
   217  	})
   218  }
   219  func main() {
   220  	conf := parseEnvFlags()
   221  
   222  	logger.WithField("version", pkg.Version).WithField("ctlConf", conf).
   223  		WithField("featureGates", runtime.EncodeFeatures()).
   224  		Info("Starting agones-allocator")
   225  
   226  	logger.WithField("logLevel", conf.LogLevel).Info("Setting LogLevel configuration")
   227  	level, err := logrus.ParseLevel(strings.ToLower(conf.LogLevel))
   228  	if err == nil {
   229  		runtime.SetLevel(level)
   230  	} else {
   231  		logger.WithError(err).Info("Specified wrong Logging.SdkServer. Setting default loglevel - Info")
   232  		runtime.SetLevel(logrus.InfoLevel)
   233  	}
   234  
   235  	if !validPort(conf.GRPCPort) && !validPort(conf.HTTPPort) {
   236  		logger.WithField("grpc-port", conf.GRPCPort).WithField("http-port", conf.HTTPPort).Fatal("Must specify a valid gRPC port or an HTTP port for the allocator service")
   237  	}
   238  	healthserver := &httpserver.Server{Logger: logger}
   239  	var health healthcheck.Handler
   240  
   241  	metricsConf := metrics.Config{
   242  		Stackdriver:       conf.Stackdriver,
   243  		PrometheusMetrics: conf.PrometheusMetrics,
   244  		GCPProjectID:      conf.GCPProjectID,
   245  		StackdriverLabels: conf.StackdriverLabels,
   246  	}
   247  	health, closer := metrics.SetupMetrics(metricsConf, healthserver)
   248  	defer closer()
   249  
   250  	metrics.SetReportingPeriod(conf.PrometheusMetrics, conf.Stackdriver)
   251  
   252  	kubeClient, agonesClient, err := getClients(conf)
   253  	if err != nil {
   254  		logger.WithError(err).Fatal("could not create clients")
   255  	}
   256  
   257  	listenCtx, cancelListenCtx := context.WithCancel(context.Background())
   258  
   259  	// This will test the connection to agones on each readiness probe
   260  	// so if one of the allocator pod can't reach Kubernetes it will be removed
   261  	// from the Kubernetes service.
   262  	podReady = true
   263  	grpcHealth := grpchealth.NewServer() // only used for gRPC, ignored o/w
   264  	health.AddReadinessCheck("allocator-agones-client", func() error {
   265  		if !podReady {
   266  			return errors.New("asked to shut down, failed readiness check")
   267  		}
   268  		_, err := agonesClient.ServerVersion()
   269  		if err != nil {
   270  			return fmt.Errorf("failed to reach Kubernetes: %w", err)
   271  		}
   272  		return nil
   273  	})
   274  
   275  	signals.NewSigTermHandler(func() {
   276  		logger.Info("Pod shutdown has been requested, failing readiness check")
   277  		podReady = false
   278  		grpcHealth.Shutdown()
   279  		time.Sleep(conf.ReadinessShutdownDuration)
   280  		cancelListenCtx()
   281  	})
   282  
   283  	workerCtx, cancelWorkerCtx := context.WithCancel(context.Background())
   284  
   285  	var h *serviceHandler
   286  	if runtime.FeatureEnabled(runtime.FeatureProcessorAllocator) {
   287  		processorConfig := processor.Config{
   288  			ClientID:          os.Getenv("POD_NAME"),
   289  			ProcessorAddress:  fmt.Sprintf("%s:%d", conf.processorGRPCAddress, conf.processorGRPCPort),
   290  			MaxBatchSize:      conf.processorMaxBatchSize,
   291  			AllocationTimeout: 30 * time.Second,
   292  			ReconnectInterval: 5 * time.Second,
   293  		}
   294  
   295  		processorClient := processor.NewClient(processorConfig, logger.WithField("component", "processor-client"))
   296  
   297  		go func() {
   298  			if err := processorClient.Run(workerCtx); err != nil {
   299  				if workerCtx.Err() != nil {
   300  					logger.WithError(err).Error("Processor client stopped due to context error")
   301  					return
   302  				}
   303  				logger.WithError(err).Error("Processor client failed, initiating graceful shutdown")
   304  			}
   305  		}()
   306  
   307  		h = newProcessorServiceHandler(processorClient, conf.MTLSDisabled, conf.TLSDisabled)
   308  	} else {
   309  		grpcUnallocatedStatusCode := grpcCodeFromHTTPStatus(conf.httpUnallocatedStatusCode)
   310  		h = newServiceHandler(workerCtx, kubeClient, agonesClient, health, conf.MTLSDisabled, conf.TLSDisabled, conf.remoteAllocationTimeout, conf.totalRemoteAllocationTimeout, conf.allocationBatchWaitTime, grpcUnallocatedStatusCode)
   311  	}
   312  
   313  	if !h.tlsDisabled {
   314  		cancelTLS, err := fswatch.Watch(logger, tlsDir, time.Second, func() {
   315  			tlsCert, err := readTLSCert()
   316  			if err != nil {
   317  				logger.WithError(err).Error("could not load TLS certs; keeping old one")
   318  				return
   319  			}
   320  			h.tlsMutex.Lock()
   321  			defer h.tlsMutex.Unlock()
   322  			h.tlsCert = tlsCert
   323  			logger.Info("TLS certs updated")
   324  		})
   325  		if err != nil {
   326  			logger.WithError(err).Fatal("could not create watcher for TLS certs")
   327  		}
   328  		defer cancelTLS()
   329  
   330  		if !h.mTLSDisabled {
   331  			cancelCert, err := fswatch.Watch(logger, certDir, time.Second, func() {
   332  				h.certMutex.Lock()
   333  				defer h.certMutex.Unlock()
   334  				caCertPool, err := getCACertPool(certDir)
   335  				if err != nil {
   336  					logger.WithError(err).Error("could not load CA certs; keeping old ones")
   337  					return
   338  				}
   339  				h.caCertPool = caCertPool
   340  				logger.Info("CA certs updated")
   341  			})
   342  			if err != nil {
   343  				logger.WithError(err).Fatal("could not create watcher for CA certs")
   344  			}
   345  			defer cancelCert()
   346  		}
   347  	}
   348  
   349  	// If grpc and http use the same port then use a mux.
   350  	if conf.GRPCPort == conf.HTTPPort {
   351  		runMux(listenCtx, workerCtx, h, grpcHealth, conf.HTTPPort)
   352  	} else {
   353  		// Otherwise, run each on a dedicated port.
   354  		if validPort(conf.HTTPPort) {
   355  			runREST(listenCtx, workerCtx, h, conf.HTTPPort)
   356  		}
   357  		if validPort(conf.GRPCPort) {
   358  			runGRPC(listenCtx, h, grpcHealth, conf.GRPCPort)
   359  		}
   360  	}
   361  
   362  	// Finally listen on 8080 (http), used to serve /live and /ready handlers for Kubernetes probes.
   363  	healthserver.Handle("/", health)
   364  	go func() { _ = healthserver.Run(listenCtx, 0) }()
   365  
   366  	// TODO: This is messy. Contexts are the wrong way to handle this - we should be using shutdown,
   367  	// and a cascading graceful shutdown instead of multiple contexts and sleeps.
   368  	<-listenCtx.Done()
   369  	logger.Infof("Listen context cancelled")
   370  	time.Sleep(5 * time.Second)
   371  	cancelWorkerCtx()
   372  	logger.Infof("Worker context cancelled")
   373  	time.Sleep(1 * time.Second)
   374  	logger.Info("Shut down allocator")
   375  }
   376  
   377  func validPort(port int) bool {
   378  	const maxPort = 65535
   379  	return port >= 0 && port < maxPort
   380  }
   381  
   382  func runMux(listenCtx context.Context, workerCtx context.Context, h *serviceHandler, grpcHealth *grpchealth.Server, httpPort int) {
   383  	logger.Infof("Running the mux handler on port %d", httpPort)
   384  	grpcServer := grpc.NewServer(h.getMuxServerOptions()...)
   385  	pb.RegisterAllocationServiceServer(grpcServer, h)
   386  	grpc_health_v1.RegisterHealthServer(grpcServer, grpcHealth)
   387  
   388  	mux := runtime.NewServerMux()
   389  	if err := pb.RegisterAllocationServiceHandlerServer(context.Background(), mux, h); err != nil {
   390  		panic(err)
   391  	}
   392  
   393  	runHTTP(listenCtx, workerCtx, h, httpPort, grpcHandlerFunc(grpcServer, mux))
   394  }
   395  
   396  func runREST(listenCtx context.Context, workerCtx context.Context, h *serviceHandler, httpPort int) {
   397  	logger.WithField("port", httpPort).Info("Running the rest handler")
   398  	mux := runtime.NewServerMux()
   399  	if err := pb.RegisterAllocationServiceHandlerServer(context.Background(), mux, h); err != nil {
   400  		panic(err)
   401  	}
   402  	runHTTP(listenCtx, workerCtx, h, httpPort, mux)
   403  }
   404  
   405  func runHTTP(listenCtx context.Context, workerCtx context.Context, h *serviceHandler, httpPort int, handler http.Handler) {
   406  	cfg := &tls.Config{}
   407  	if !h.tlsDisabled {
   408  		cfg.GetCertificate = h.getTLSCert
   409  	}
   410  	if !h.mTLSDisabled {
   411  		cfg.ClientAuth = tls.RequireAnyClientCert
   412  		cfg.VerifyPeerCertificate = h.verifyClientCertificate
   413  	}
   414  
   415  	// Create a Server instance to listen on the http port with the TLS config.
   416  	server := &http.Server{
   417  		Addr:      fmt.Sprintf(":%d", httpPort),
   418  		TLSConfig: cfg,
   419  		Handler:   handler,
   420  	}
   421  
   422  	go func() {
   423  		go func() {
   424  			<-listenCtx.Done()
   425  			_ = server.Shutdown(workerCtx)
   426  		}()
   427  
   428  		var err error
   429  		if !h.tlsDisabled {
   430  			err = server.ListenAndServeTLS("", "")
   431  		} else {
   432  			err = server.ListenAndServe()
   433  		}
   434  
   435  		if err == http.ErrServerClosed {
   436  			logger.WithError(err).Info("HTTP/HTTPS server closed")
   437  			os.Exit(0)
   438  		}
   439  		logger.WithError(err).Fatal("Unable to start HTTP/HTTPS listener")
   440  		os.Exit(1)
   441  
   442  	}()
   443  }
   444  
   445  func runGRPC(ctx context.Context, h *serviceHandler, grpcHealth *grpchealth.Server, grpcPort int) {
   446  	logger.WithField("port", grpcPort).Info("Running the grpc handler on port")
   447  	listener, err := net.Listen("tcp", fmt.Sprintf(":%d", grpcPort))
   448  	if err != nil {
   449  		logger.WithError(err).Fatalf("failed to listen on TCP port %d", grpcPort)
   450  		os.Exit(1)
   451  	}
   452  
   453  	grpcServer := grpc.NewServer(h.getGRPCServerOptions()...)
   454  	pb.RegisterAllocationServiceServer(grpcServer, h)
   455  	grpc_health_v1.RegisterHealthServer(grpcServer, grpcHealth)
   456  
   457  	go func() {
   458  		go func() {
   459  			<-ctx.Done()
   460  			grpcServer.GracefulStop()
   461  		}()
   462  
   463  		err := grpcServer.Serve(listener)
   464  		if err != nil {
   465  			logger.WithError(err).Fatal("allocation service crashed")
   466  			os.Exit(1)
   467  		}
   468  		logger.Info("allocation server closed")
   469  		os.Exit(0)
   470  
   471  	}()
   472  }
   473  
   474  func newProcessorServiceHandler(processorClient processor.Client, mTLSDisabled, tlsDisabled bool) *serviceHandler {
   475  	h := serviceHandler{
   476  		mTLSDisabled:    mTLSDisabled,
   477  		tlsDisabled:     tlsDisabled,
   478  		processorClient: processorClient,
   479  	}
   480  
   481  	if !h.tlsDisabled {
   482  		tlsCert, err := readTLSCert()
   483  		if err != nil {
   484  			logger.WithError(err).Fatal("could not load TLS certs.")
   485  		}
   486  		h.tlsMutex.Lock()
   487  		h.tlsCert = tlsCert
   488  		h.tlsMutex.Unlock()
   489  
   490  		if !h.mTLSDisabled {
   491  			caCertPool, err := getCACertPool(certDir)
   492  			if err != nil {
   493  				logger.WithError(err).Fatal("could not load CA certs.")
   494  			}
   495  			h.certMutex.Lock()
   496  			h.caCertPool = caCertPool
   497  			h.certMutex.Unlock()
   498  		}
   499  	}
   500  
   501  	return &h
   502  }
   503  
   504  func newServiceHandler(ctx context.Context, kubeClient kubernetes.Interface, agonesClient versioned.Interface, health healthcheck.Handler, mTLSDisabled bool, tlsDisabled bool, remoteAllocationTimeout time.Duration, totalRemoteAllocationTimeout time.Duration, allocationBatchWaitTime time.Duration, grpcUnallocatedStatusCode codes.Code) *serviceHandler {
   505  	defaultResync := 30 * time.Second
   506  	agonesInformerFactory := externalversions.NewSharedInformerFactory(agonesClient, defaultResync)
   507  	kubeInformerFactory := informers.NewSharedInformerFactory(kubeClient, defaultResync)
   508  	gsCounter := gameservers.NewPerNodeCounter(kubeInformerFactory, agonesInformerFactory)
   509  
   510  	allocator := gameserverallocations.NewAllocator(
   511  		agonesInformerFactory.Multicluster().V1().GameServerAllocationPolicies(),
   512  		kubeInformerFactory.Core().V1().Secrets(),
   513  		agonesClient.AgonesV1(),
   514  		kubeClient,
   515  		gameserverallocations.NewAllocationCache(agonesInformerFactory.Agones().V1().GameServers(), gsCounter, health),
   516  		remoteAllocationTimeout,
   517  		totalRemoteAllocationTimeout,
   518  		allocationBatchWaitTime)
   519  
   520  	h := serviceHandler{
   521  		allocationCallback: func(gsa *allocationv1.GameServerAllocation) (k8sruntime.Object, error) {
   522  			return allocator.Allocate(ctx, gsa)
   523  		},
   524  		mTLSDisabled:              mTLSDisabled,
   525  		tlsDisabled:               tlsDisabled,
   526  		grpcUnallocatedStatusCode: grpcUnallocatedStatusCode,
   527  	}
   528  
   529  	kubeInformerFactory.Start(ctx.Done())
   530  	agonesInformerFactory.Start(ctx.Done())
   531  	if err := allocator.Run(ctx); err != nil {
   532  		logger.WithError(err).Fatal("starting allocator failed.")
   533  	}
   534  
   535  	if !h.tlsDisabled {
   536  		tlsCert, err := readTLSCert()
   537  		if err != nil {
   538  			logger.WithError(err).Fatal("could not load TLS certs.")
   539  		}
   540  		h.tlsMutex.Lock()
   541  		h.tlsCert = tlsCert
   542  		h.tlsMutex.Unlock()
   543  
   544  		if !h.mTLSDisabled {
   545  			caCertPool, err := getCACertPool(certDir)
   546  			if err != nil {
   547  				logger.WithError(err).Fatal("could not load CA certs.")
   548  			}
   549  			h.certMutex.Lock()
   550  			h.caCertPool = caCertPool
   551  			h.certMutex.Unlock()
   552  		}
   553  	}
   554  
   555  	return &h
   556  }
   557  
   558  func readTLSCert() (*tls.Certificate, error) {
   559  	tlsCert, err := tls.LoadX509KeyPair(tlsDir+"tls.crt", tlsDir+"tls.key")
   560  	if err != nil {
   561  		return nil, err
   562  	}
   563  	return &tlsCert, nil
   564  }
   565  
   566  // getMuxServerOptions returns a list of GRPC server option to use when
   567  // serving gRPC and REST over an HTTP multiplexer.
   568  // Current options are opencensus stats handler.
   569  func (h *serviceHandler) getMuxServerOptions() []grpc.ServerOption {
   570  	// Add options for  OpenCensus stats handler to enable stats and tracing.
   571  	// The keepalive options are useful for efficiency purposes (keeping a single connection alive
   572  	// instead of constantly recreating connections), when placing the Agones allocator behind load balancers.
   573  	return []grpc.ServerOption{
   574  		grpc.StatsHandler(&ocgrpc.ServerHandler{}),
   575  		grpc.KeepaliveEnforcementPolicy(keepalive.EnforcementPolicy{
   576  			MinTime:             1 * time.Minute,
   577  			PermitWithoutStream: true,
   578  		}),
   579  		grpc.KeepaliveParams(keepalive.ServerParameters{
   580  			MaxConnectionIdle: 5 * time.Minute,
   581  			Timeout:           10 * time.Minute,
   582  		}),
   583  	}
   584  }
   585  
   586  // getGRPCServerOptions returns a list of GRPC server options to use when
   587  // only serving gRPC requests.
   588  // Current options are TLS certs and opencensus stats handler.
   589  func (h *serviceHandler) getGRPCServerOptions() []grpc.ServerOption {
   590  	// Add options for  OpenCensus stats handler to enable stats and tracing.
   591  	// The keepalive options are useful for efficiency purposes (keeping a single connection alive
   592  	// instead of constantly recreating connections), when placing the Agones allocator behind load balancers.
   593  	opts := []grpc.ServerOption{
   594  		grpc.StatsHandler(&ocgrpc.ServerHandler{}),
   595  		grpc.KeepaliveEnforcementPolicy(keepalive.EnforcementPolicy{
   596  			MinTime:             1 * time.Minute,
   597  			PermitWithoutStream: true,
   598  		}),
   599  		grpc.KeepaliveParams(keepalive.ServerParameters{
   600  			MaxConnectionIdle: 5 * time.Minute,
   601  			Timeout:           10 * time.Minute,
   602  		}),
   603  	}
   604  	if h.tlsDisabled {
   605  		return opts
   606  	}
   607  
   608  	cfg := &tls.Config{
   609  		GetCertificate: h.getTLSCert,
   610  	}
   611  
   612  	if !h.mTLSDisabled {
   613  		cfg.ClientAuth = tls.RequireAnyClientCert
   614  		cfg.VerifyPeerCertificate = h.verifyClientCertificate
   615  	}
   616  
   617  	return append([]grpc.ServerOption{grpc.Creds(credentials.NewTLS(cfg))}, opts...)
   618  }
   619  
   620  func (h *serviceHandler) getTLSCert(_ *tls.ClientHelloInfo) (*tls.Certificate, error) {
   621  	h.tlsMutex.RLock()
   622  	defer h.tlsMutex.RUnlock()
   623  	return h.tlsCert, nil
   624  }
   625  
   626  // verifyClientCertificate verifies that the client certificate is accepted
   627  // This method is used as GetConfigForClient is cross lang incompatible.
   628  func (h *serviceHandler) verifyClientCertificate(rawCerts [][]byte, _ [][]*x509.Certificate) error {
   629  	opts := x509.VerifyOptions{
   630  		Roots:         h.caCertPool,
   631  		CurrentTime:   time.Now(),
   632  		Intermediates: x509.NewCertPool(),
   633  		KeyUsages:     []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
   634  	}
   635  
   636  	for _, rawCert := range rawCerts[1:] {
   637  		cert, err := x509.ParseCertificate(rawCert)
   638  		if err != nil {
   639  			logger.WithError(err).Warning("cannot parse intermediate certificate")
   640  			return errors.New("bad intermediate certificate: " + err.Error())
   641  		}
   642  		opts.Intermediates.AddCert(cert)
   643  	}
   644  
   645  	c, err := x509.ParseCertificate(rawCerts[0])
   646  	if err != nil {
   647  		logger.WithError(err).Warning("cannot parse client certificate")
   648  		return errors.New("bad client certificate: " + err.Error())
   649  	}
   650  
   651  	h.certMutex.RLock()
   652  	defer h.certMutex.RUnlock()
   653  	_, err = c.Verify(opts)
   654  	if err != nil {
   655  		logger.WithError(err).Warning("failed to verify client certificate")
   656  		return errors.New("failed to verify client certificate: " + err.Error())
   657  	}
   658  	return nil
   659  }
   660  
   661  // Set up our client which we will use to call the API
   662  func getClients(ctlConfig config) (*kubernetes.Clientset, *versioned.Clientset, error) {
   663  	// Create the in-cluster config
   664  	config, err := rest.InClusterConfig()
   665  	if err != nil {
   666  		return nil, nil, errors.New("Could not create in cluster config")
   667  	}
   668  
   669  	config.QPS = float32(ctlConfig.APIServerSustainedQPS)
   670  	config.Burst = ctlConfig.APIServerBurstQPS
   671  
   672  	// Access to the Agones resources through the Agones Clientset
   673  	kubeClient, err := kubernetes.NewForConfig(config)
   674  	if err != nil {
   675  		return nil, nil, errors.New("Could not create the kubernetes api clientset")
   676  	}
   677  
   678  	// Access to the Agones resources through the Agones Clientset
   679  	agonesClient, err := versioned.NewForConfig(config)
   680  	if err != nil {
   681  		return nil, nil, errors.New("Could not create the agones api clientset")
   682  	}
   683  	return kubeClient, agonesClient, nil
   684  }
   685  
   686  func getCACertPool(path string) (*x509.CertPool, error) {
   687  	// Add all certificates under client-certs path because there could be multiple clusters
   688  	// and all client certs should be added.
   689  	caCertPool := x509.NewCertPool()
   690  	dirEntries, err := os.ReadDir(path)
   691  	if err != nil {
   692  		return nil, fmt.Errorf("error reading certs from dir %s: %s", path, err.Error())
   693  	}
   694  
   695  	for _, dirEntry := range dirEntries {
   696  		if !strings.HasSuffix(dirEntry.Name(), ".crt") && !strings.HasSuffix(dirEntry.Name(), ".pem") {
   697  			continue
   698  		}
   699  		certFile := filepath.Join(path, dirEntry.Name())
   700  		caCert, err := os.ReadFile(certFile)
   701  		if err != nil {
   702  			logger.Errorf("CA cert is not readable or missing: %s", err.Error())
   703  			continue
   704  		}
   705  		if !caCertPool.AppendCertsFromPEM(caCert) {
   706  			logger.Errorf("client cert %s cannot be installed", certFile)
   707  			continue
   708  		}
   709  		logger.Infof("client cert %s is installed", certFile)
   710  	}
   711  
   712  	return caCertPool, nil
   713  }
   714  
   715  type serviceHandler struct {
   716  	allocationCallback func(*allocationv1.GameServerAllocation) (k8sruntime.Object, error)
   717  
   718  	certMutex  sync.RWMutex
   719  	caCertPool *x509.CertPool
   720  
   721  	tlsMutex sync.RWMutex
   722  	tlsCert  *tls.Certificate
   723  
   724  	mTLSDisabled bool
   725  	tlsDisabled  bool
   726  
   727  	grpcUnallocatedStatusCode codes.Code
   728  
   729  	processorClient processor.Client
   730  }
   731  
   732  // Allocate implements the Allocate gRPC method definition
   733  func (h *serviceHandler) Allocate(ctx context.Context, in *pb.AllocationRequest) (*pb.AllocationResponse, error) {
   734  	logger.WithField("request", in).Infof("allocation request received.")
   735  
   736  	gsa := converters.ConvertAllocationRequestToGSA(in)
   737  	gsa.ApplyDefaults()
   738  
   739  	if runtime.FeatureEnabled(runtime.FeatureProcessorAllocator) {
   740  		req := converters.ConvertGSAToAllocationRequest(gsa)
   741  
   742  		resp, err := h.processorClient.Allocate(ctx, req)
   743  		if err != nil {
   744  			logger.WithField("gsa", gsa).WithError(err).Error("allocation failed")
   745  			return nil, err
   746  		}
   747  
   748  		allocatedGsa := converters.ConvertAllocationResponseToGSA(resp, resp.Source)
   749  		response, err := converters.ConvertGSAToAllocationResponse(allocatedGsa, h.grpcUnallocatedStatusCode)
   750  		logger.WithField("response", response).WithError(err).Info("allocation response is being sent")
   751  
   752  		return response, err
   753  	}
   754  
   755  	resultObj, err := h.allocationCallback(gsa)
   756  	if err != nil {
   757  		logger.WithField("gsa", gsa).WithError(err).Error("allocation failed")
   758  		return nil, err
   759  	}
   760  
   761  	if s, ok := resultObj.(*metav1.Status); ok {
   762  		return nil, status.Errorf(codes.Code(s.Code), s.Message, resultObj)
   763  	}
   764  
   765  	allocatedGsa, ok := resultObj.(*allocationv1.GameServerAllocation)
   766  	if !ok {
   767  		logger.Errorf("internal server error - Bad GSA format %v", resultObj)
   768  		return nil, status.Errorf(codes.Internal, "internal server error- Bad GSA format %v", resultObj)
   769  	}
   770  	response, err := converters.ConvertGSAToAllocationResponse(allocatedGsa, h.grpcUnallocatedStatusCode)
   771  	logger.WithField("response", response).WithError(err).Infof("allocation response is being sent")
   772  
   773  	return response, err
   774  }
   775  
   776  // grpcCodeFromHTTPStatus converts an HTTP status code to the corresponding gRPC status code.
   777  func grpcCodeFromHTTPStatus(httpUnallocatedStatusCode int) codes.Code {
   778  	switch httpUnallocatedStatusCode {
   779  	case http.StatusOK:
   780  		return codes.OK
   781  	case 499:
   782  		return codes.Canceled
   783  	case http.StatusInternalServerError:
   784  		return codes.Internal
   785  	case http.StatusBadRequest:
   786  		return codes.InvalidArgument
   787  	case http.StatusGatewayTimeout:
   788  		return codes.DeadlineExceeded
   789  	case http.StatusNotFound:
   790  		return codes.NotFound
   791  	case http.StatusConflict:
   792  		return codes.AlreadyExists
   793  	case http.StatusForbidden:
   794  		return codes.PermissionDenied
   795  	case http.StatusUnauthorized:
   796  		return codes.Unauthenticated
   797  	case http.StatusTooManyRequests:
   798  		return codes.ResourceExhausted
   799  	case http.StatusNotImplemented:
   800  		return codes.Unimplemented
   801  	case http.StatusServiceUnavailable:
   802  		return codes.Unavailable
   803  	default:
   804  		logger.WithField("httpStatusCode", httpUnallocatedStatusCode).Warnf("received unknown http status code, defaulting to codes.ResourceExhausted / 429")
   805  		return codes.ResourceExhausted
   806  	}
   807  }