github.phpd.cn/thought-machine/please@v12.2.0+incompatible/tools/cache/server/rpc_server.go (about)

     1  package server
     2  
     3  import (
     4  	"crypto/tls"
     5  	"crypto/x509"
     6  	"encoding/base64"
     7  	"encoding/pem"
     8  	"fmt"
     9  	"io/ioutil"
    10  	"net"
    11  	"os"
    12  	"os/signal"
    13  	"path"
    14  	"sync"
    15  	"syscall"
    16  
    17  	"github.com/grpc-ecosystem/go-grpc-prometheus"
    18  	"github.com/prometheus/client_golang/prometheus"
    19  	"golang.org/x/net/context"
    20  	"google.golang.org/grpc"
    21  	"google.golang.org/grpc/codes"
    22  	"google.golang.org/grpc/credentials"
    23  	_ "google.golang.org/grpc/encoding/gzip" // Registers the gzip compressor at init
    24  	"google.golang.org/grpc/health"
    25  	healthpb "google.golang.org/grpc/health/grpc_health_v1"
    26  	"google.golang.org/grpc/peer"
    27  	"google.golang.org/grpc/status"
    28  
    29  	pb "cache/proto/rpc_cache"
    30  	"fs"
    31  	"tools/cache/cluster"
    32  )
    33  
    34  // maxMsgSize is the maximum message size our gRPC server accepts.
    35  // We deliberately set this to something high since we don't want to limit artifact size here.
    36  const maxMsgSize = 200 * 1024 * 1024
    37  
    38  // metricsOnce is used to track whether or not we've registered the server metrics.
    39  // In normal operation this only happens once but in tests it can happen multiple times.
    40  var metricsOnce sync.Once
    41  
    42  func init() {
    43  	// When tracing is enabled, it appears to keep references to messages alive, possibly indefinitely (?).
    44  	// This is very bad for us since our messages are large, it can result in leaking memory very quickly
    45  	// and ultimately OOM errors. Disabling tracing appears to alleviate the problem.
    46  	grpc.EnableTracing = false
    47  }
    48  
    49  // A RPCCacheServer implements our RPC cache, including communication in a cluster.
    50  type RPCCacheServer struct {
    51  	cache                                                                          *Cache
    52  	readonlyKeys                                                                   map[string]*x509.Certificate
    53  	writableKeys                                                                   map[string]*x509.Certificate
    54  	cluster                                                                        *cluster.Cluster
    55  	retrievedCounter, storedCounter, retrievedBytes, storedBytes, retrieveFailures *prometheus.CounterVec
    56  }
    57  
    58  // Store implements the Store RPC to store an artifact in the cache.
    59  func (r *RPCCacheServer) Store(ctx context.Context, req *pb.StoreRequest) (*pb.StoreResponse, error) {
    60  	if err := r.authenticateClient(ctx, r.writableKeys); err != nil {
    61  		return nil, err
    62  	}
    63  	success := storeArtifact(r.cache, req.Os, req.Arch, req.Hash, req.Artifacts, req.Hostname, extractAddress(ctx), "")
    64  	if success && r.cluster != nil {
    65  		// Replicate this artifact to another node. Doesn't have to be done synchronously.
    66  		go r.cluster.ReplicateArtifacts(req)
    67  	}
    68  	if success {
    69  		r.storedCounter.WithLabelValues(req.Arch).Inc()
    70  		total := 0
    71  		for _, artifact := range req.Artifacts {
    72  			total += len(artifact.Body)
    73  		}
    74  		r.storedBytes.WithLabelValues(req.Arch).Add(float64(total))
    75  	}
    76  	return &pb.StoreResponse{Success: success}, nil
    77  }
    78  
    79  // storeArtifact stores a series of artifacts in the cache.
    80  // Broken out of above to share with Replicate below.
    81  func storeArtifact(cache *Cache, os, arch string, hash []byte, artifacts []*pb.Artifact, hostname, address, peer string) bool {
    82  	arch = os + "_" + arch
    83  	hashStr := base64.RawURLEncoding.EncodeToString(hash)
    84  	for _, artifact := range artifacts {
    85  		dir := path.Join(arch, artifact.Package, artifact.Target, hashStr)
    86  		file := path.Join(dir, artifact.File)
    87  		if err := cache.StoreArtifact(file, artifact.Body, artifact.Symlink); err != nil {
    88  			return false
    89  		}
    90  		go cache.StoreMetadata(dir, hostname, address, peer)
    91  	}
    92  	return true
    93  }
    94  
    95  // Retrieve implements the Retrieve RPC to retrieve artifacts from the cache.
    96  func (r *RPCCacheServer) Retrieve(ctx context.Context, req *pb.RetrieveRequest) (*pb.RetrieveResponse, error) {
    97  	if err := r.authenticateClient(ctx, r.readonlyKeys); err != nil {
    98  		return nil, err
    99  	}
   100  	response := pb.RetrieveResponse{Success: true}
   101  	arch := req.Os + "_" + req.Arch
   102  	hash := base64.RawURLEncoding.EncodeToString(req.Hash)
   103  	total := 0
   104  	for _, artifact := range req.Artifacts {
   105  		root := path.Join(arch, artifact.Package, artifact.Target, hash)
   106  		fileRoot := path.Join(root, artifact.File)
   107  		arts, err := r.cache.RetrieveArtifact(fileRoot)
   108  		if err != nil {
   109  			log.Debug("Failed to retrieve artifact %s: %s", fileRoot, err)
   110  			r.retrieveFailures.WithLabelValues(req.Arch).Inc()
   111  			return &pb.RetrieveResponse{Success: false}, nil
   112  		}
   113  		for _, art := range arts {
   114  			response.Artifacts = append(response.Artifacts, &pb.Artifact{
   115  				Package: artifact.Package,
   116  				Target:  artifact.Target,
   117  				File:    art.File[len(root)+1:],
   118  				Body:    art.Body,
   119  				Symlink: art.Symlink,
   120  			})
   121  			total += len(art.Body)
   122  		}
   123  	}
   124  	r.retrievedCounter.WithLabelValues(req.Arch).Inc()
   125  	r.retrievedBytes.WithLabelValues(req.Arch).Add(float64(total))
   126  	return &response, nil
   127  }
   128  
   129  // Delete implements the Delete RPC to delete an artifact from the cache.
   130  func (r *RPCCacheServer) Delete(ctx context.Context, req *pb.DeleteRequest) (*pb.DeleteResponse, error) {
   131  	if err := r.authenticateClient(ctx, r.writableKeys); err != nil {
   132  		return nil, err
   133  	}
   134  	if req.Everything {
   135  		return &pb.DeleteResponse{Success: r.cache.DeleteAllArtifacts() == nil}, nil
   136  	}
   137  	success := deleteArtifact(r.cache, req.Os, req.Arch, req.Artifacts)
   138  	if success && r.cluster != nil {
   139  		// Delete this artifact from other nodes. Doesn't have to be done synchronously.
   140  		go r.cluster.DeleteArtifacts(req)
   141  	}
   142  	return &pb.DeleteResponse{Success: success}, nil
   143  }
   144  
   145  // deleteArtifact handles the actual removal of artifacts from the cache.
   146  // It's split out from Delete to share with replication RPCs below.
   147  func deleteArtifact(cache *Cache, os, arch string, artifacts []*pb.Artifact) bool {
   148  	success := true
   149  	for _, artifact := range artifacts {
   150  		if cache.DeleteArtifact(path.Join(os+"_"+arch, artifact.Package, artifact.Target)) != nil {
   151  			success = false
   152  		}
   153  	}
   154  	return success
   155  }
   156  
   157  // ListNodes implements the RPC for clustered servers.
   158  func (r *RPCCacheServer) ListNodes(ctx context.Context, req *pb.ListRequest) (*pb.ListResponse, error) {
   159  	if err := r.authenticateClient(ctx, r.readonlyKeys); err != nil {
   160  		return nil, err
   161  	}
   162  	if r.cluster == nil {
   163  		return &pb.ListResponse{}, nil
   164  	}
   165  	return &pb.ListResponse{Nodes: r.cluster.GetMembers()}, nil
   166  }
   167  
   168  func (r *RPCCacheServer) authenticateClient(ctx context.Context, certs map[string]*x509.Certificate) error {
   169  	if len(certs) == 0 {
   170  		return nil // Open to anyone.
   171  	}
   172  	p, ok := peer.FromContext(ctx)
   173  	if !ok {
   174  		return status.Error(codes.Unauthenticated, "Missing client certificate")
   175  	}
   176  	info, ok := p.AuthInfo.(credentials.TLSInfo)
   177  	if !ok {
   178  		return status.Error(codes.Unauthenticated, "Could not extract auth info")
   179  	}
   180  	if len(info.State.PeerCertificates) == 0 {
   181  		return status.Error(codes.Unauthenticated, "No peer certificate available")
   182  	}
   183  	cert := info.State.PeerCertificates[0]
   184  	okCert := certs[string(cert.RawSubject)]
   185  	if okCert == nil || !okCert.Equal(cert) {
   186  		return status.Error(codes.Unauthenticated, "Invalid or unknown certificate")
   187  	}
   188  	return nil
   189  }
   190  
   191  func extractAddress(ctx context.Context) string {
   192  	p, ok := peer.FromContext(ctx)
   193  	if !ok {
   194  		return ""
   195  	}
   196  	return p.Addr.String()
   197  }
   198  
   199  func loadKeys(filename string) map[string]*x509.Certificate {
   200  	ret := map[string]*x509.Certificate{}
   201  	if err := fs.Walk(filename, func(name string, isDir bool) error {
   202  		if !isDir {
   203  			data, err := ioutil.ReadFile(name)
   204  			if err != nil {
   205  				log.Fatalf("Failed to read cert from %s: %s", name, err)
   206  			}
   207  			p, _ := pem.Decode(data)
   208  			if p == nil {
   209  				log.Fatalf("Couldn't decode PEM data from %s: %s", name, err)
   210  			}
   211  			cert, err := x509.ParseCertificate(p.Bytes)
   212  			if err != nil {
   213  				log.Fatalf("Couldn't parse certificate from %s: %s", name, err)
   214  			}
   215  			ret[string(cert.RawSubject)] = cert
   216  		}
   217  		return nil
   218  	}); err != nil {
   219  		log.Fatalf("%s", err)
   220  	}
   221  	return ret
   222  }
   223  
   224  // RPCServer implements the gRPC server for communication between cache nodes.
   225  type RPCServer struct {
   226  	cache   *Cache
   227  	cluster *cluster.Cluster
   228  }
   229  
   230  // Join implements the Join RPC for a new server joining the cluster.
   231  func (r *RPCServer) Join(ctx context.Context, req *pb.JoinRequest) (*pb.JoinResponse, error) {
   232  	// TODO(pebers): Authentication.
   233  	return r.cluster.AddNode(req), nil
   234  }
   235  
   236  // Replicate implements the Replicate RPC for replicating an artifact from another node.
   237  func (r *RPCServer) Replicate(ctx context.Context, req *pb.ReplicateRequest) (*pb.ReplicateResponse, error) {
   238  	// TODO(pebers): Authentication.
   239  	if req.Delete {
   240  		return &pb.ReplicateResponse{
   241  			Success: deleteArtifact(r.cache, req.Os, req.Arch, req.Artifacts),
   242  		}, nil
   243  	}
   244  	return &pb.ReplicateResponse{
   245  		Success: storeArtifact(r.cache, req.Os, req.Arch, req.Hash, req.Artifacts, req.Hostname, extractAddress(ctx), req.Peer),
   246  	}, nil
   247  }
   248  
   249  // BuildGrpcServer creates a new, unstarted grpc.Server and returns it.
   250  // It also returns a net.Listener to start it on.
   251  func BuildGrpcServer(port int, cache *Cache, cluster *cluster.Cluster, keyFile, certFile, caCertFile, readonlyKeys, writableKeys string) (*grpc.Server, net.Listener) {
   252  	lis, err := net.Listen("tcp", fmt.Sprintf(":%d", port))
   253  	if err != nil {
   254  		log.Fatalf("Failed to listen on port %d: %v", port, err)
   255  	}
   256  	s := serverWithAuth(keyFile, certFile, caCertFile)
   257  	r := &RPCCacheServer{
   258  		cache:   cache,
   259  		cluster: cluster,
   260  		retrievedCounter: prometheus.NewCounterVec(prometheus.CounterOpts{
   261  			Name: "retrieved_count",
   262  			Help: "Number of artifacts successfully retrieved",
   263  		}, []string{"arch"}),
   264  		storedCounter: prometheus.NewCounterVec(prometheus.CounterOpts{
   265  			Name: "stored_count",
   266  			Help: "Number of artifacts successfully stored",
   267  		}, []string{"arch"}),
   268  		retrievedBytes: prometheus.NewCounterVec(prometheus.CounterOpts{
   269  			Name: "retrieved_bytes",
   270  			Help: "Number of bytes successfully retrieved",
   271  		}, []string{"arch"}),
   272  		storedBytes: prometheus.NewCounterVec(prometheus.CounterOpts{
   273  			Name: "stored_bytes",
   274  			Help: "Number of bytes successfully stored",
   275  		}, []string{"arch"}),
   276  		retrieveFailures: prometheus.NewCounterVec(prometheus.CounterOpts{
   277  			Name: "retrieve_failures",
   278  			Help: "Number of failed retrieval attempts",
   279  		}, []string{"arch"}),
   280  	}
   281  	if writableKeys != "" {
   282  		r.writableKeys = loadKeys(writableKeys)
   283  	}
   284  	if readonlyKeys != "" {
   285  		r.readonlyKeys = loadKeys(readonlyKeys)
   286  		if len(r.readonlyKeys) > 0 {
   287  			// This saves duplication when checking later; writable keys are implicitly readable too.
   288  			for k, v := range r.writableKeys {
   289  				if _, present := r.readonlyKeys[k]; !present {
   290  					r.readonlyKeys[k] = v
   291  				}
   292  			}
   293  		}
   294  	}
   295  	r2 := &RPCServer{cache: cache, cluster: cluster}
   296  	pb.RegisterRpcCacheServer(s, r)
   297  	pb.RegisterRpcServerServer(s, r2)
   298  	healthserver := health.NewServer()
   299  	healthserver.SetServingStatus("plz-rpc-cache", healthpb.HealthCheckResponse_SERVING)
   300  	healthpb.RegisterHealthServer(s, healthserver)
   301  	metricsOnce.Do(func() {
   302  		prometheus.MustRegister(r.retrievedCounter)
   303  		prometheus.MustRegister(r.storedCounter)
   304  		prometheus.MustRegister(r.retrievedBytes)
   305  		prometheus.MustRegister(r.storedBytes)
   306  		prometheus.MustRegister(r.retrieveFailures)
   307  	})
   308  	return s, lis
   309  }
   310  
   311  // ServeGrpcForever serves gRPC until killed using the given server.
   312  // It's very simple and provided as a convenience so callers don't have to import grpc themselves.
   313  func ServeGrpcForever(server *grpc.Server, lis net.Listener) {
   314  	log.Notice("Serving RPC cache on %s", lis.Addr())
   315  	go handleSignals(server)
   316  	server.Serve(lis)
   317  }
   318  
   319  // serverWithAuth builds a gRPC server, possibly with authentication if key / cert files are given.
   320  func serverWithAuth(keyFile, certFile, caCertFile string) *grpc.Server {
   321  	if keyFile == "" {
   322  		return grpc.NewServer(grpc.MaxRecvMsgSize(maxMsgSize), grpc.MaxSendMsgSize(maxMsgSize)) // No auth.
   323  	}
   324  	log.Debug("Loading x509 key pair from key: %s cert: %s", keyFile, certFile)
   325  	cert, err := tls.LoadX509KeyPair(certFile, keyFile)
   326  	if err != nil {
   327  		log.Fatalf("Failed to load x509 key pair: %s", err)
   328  	}
   329  	config := tls.Config{
   330  		Certificates: []tls.Certificate{cert},
   331  		ClientAuth:   tls.RequestClientCert,
   332  	}
   333  	if caCertFile != "" {
   334  		cert, err := ioutil.ReadFile(caCertFile)
   335  		if err != nil {
   336  			log.Fatalf("Failed to read CA cert file: %s", err)
   337  		}
   338  		config.ClientCAs = x509.NewCertPool()
   339  		if !config.ClientCAs.AppendCertsFromPEM(cert) {
   340  			log.Fatalf("Failed to find any PEM certificates in CA cert")
   341  		}
   342  	}
   343  	return grpc.NewServer(
   344  		grpc.Creds(credentials.NewTLS(&config)),
   345  		grpc.MaxRecvMsgSize(maxMsgSize),
   346  		grpc.MaxSendMsgSize(maxMsgSize),
   347  		grpc.UnaryInterceptor(grpc_prometheus.UnaryServerInterceptor),
   348  		grpc.StreamInterceptor(grpc_prometheus.StreamServerInterceptor),
   349  	)
   350  }
   351  
   352  // handleSignals received SIGTERM / SIGINT etc to gracefully shut down a gRPC server.
   353  // Repeated signals cause the server to terminate at increasing levels of urgency.
   354  func handleSignals(s *grpc.Server) {
   355  	c := make(chan os.Signal, 3) // Channel should be buffered a bit
   356  	signal.Notify(c, syscall.SIGTERM, syscall.SIGINT, syscall.SIGHUP)
   357  	sig := <-c
   358  	log.Warning("Received signal %s, gracefully shutting down gRPC server", sig)
   359  	go s.GracefulStop()
   360  	sig = <-c
   361  	log.Warning("Received signal %s, non-gracefully shutting down gRPC server", sig)
   362  	go s.Stop()
   363  	sig = <-c
   364  	log.Fatalf("Received signal %s, terminating\n", sig)
   365  }