github.phpd.cn/thought-machine/please@v12.2.0+incompatible/tools/cache/server/rpc_server.go (about) 1 package server 2 3 import ( 4 "crypto/tls" 5 "crypto/x509" 6 "encoding/base64" 7 "encoding/pem" 8 "fmt" 9 "io/ioutil" 10 "net" 11 "os" 12 "os/signal" 13 "path" 14 "sync" 15 "syscall" 16 17 "github.com/grpc-ecosystem/go-grpc-prometheus" 18 "github.com/prometheus/client_golang/prometheus" 19 "golang.org/x/net/context" 20 "google.golang.org/grpc" 21 "google.golang.org/grpc/codes" 22 "google.golang.org/grpc/credentials" 23 _ "google.golang.org/grpc/encoding/gzip" // Registers the gzip compressor at init 24 "google.golang.org/grpc/health" 25 healthpb "google.golang.org/grpc/health/grpc_health_v1" 26 "google.golang.org/grpc/peer" 27 "google.golang.org/grpc/status" 28 29 pb "cache/proto/rpc_cache" 30 "fs" 31 "tools/cache/cluster" 32 ) 33 34 // maxMsgSize is the maximum message size our gRPC server accepts. 35 // We deliberately set this to something high since we don't want to limit artifact size here. 36 const maxMsgSize = 200 * 1024 * 1024 37 38 // metricsOnce is used to track whether or not we've registered the server metrics. 39 // In normal operation this only happens once but in tests it can happen multiple times. 40 var metricsOnce sync.Once 41 42 func init() { 43 // When tracing is enabled, it appears to keep references to messages alive, possibly indefinitely (?). 44 // This is very bad for us since our messages are large, it can result in leaking memory very quickly 45 // and ultimately OOM errors. Disabling tracing appears to alleviate the problem. 46 grpc.EnableTracing = false 47 } 48 49 // A RPCCacheServer implements our RPC cache, including communication in a cluster. 50 type RPCCacheServer struct { 51 cache *Cache 52 readonlyKeys map[string]*x509.Certificate 53 writableKeys map[string]*x509.Certificate 54 cluster *cluster.Cluster 55 retrievedCounter, storedCounter, retrievedBytes, storedBytes, retrieveFailures *prometheus.CounterVec 56 } 57 58 // Store implements the Store RPC to store an artifact in the cache. 59 func (r *RPCCacheServer) Store(ctx context.Context, req *pb.StoreRequest) (*pb.StoreResponse, error) { 60 if err := r.authenticateClient(ctx, r.writableKeys); err != nil { 61 return nil, err 62 } 63 success := storeArtifact(r.cache, req.Os, req.Arch, req.Hash, req.Artifacts, req.Hostname, extractAddress(ctx), "") 64 if success && r.cluster != nil { 65 // Replicate this artifact to another node. Doesn't have to be done synchronously. 66 go r.cluster.ReplicateArtifacts(req) 67 } 68 if success { 69 r.storedCounter.WithLabelValues(req.Arch).Inc() 70 total := 0 71 for _, artifact := range req.Artifacts { 72 total += len(artifact.Body) 73 } 74 r.storedBytes.WithLabelValues(req.Arch).Add(float64(total)) 75 } 76 return &pb.StoreResponse{Success: success}, nil 77 } 78 79 // storeArtifact stores a series of artifacts in the cache. 80 // Broken out of above to share with Replicate below. 81 func storeArtifact(cache *Cache, os, arch string, hash []byte, artifacts []*pb.Artifact, hostname, address, peer string) bool { 82 arch = os + "_" + arch 83 hashStr := base64.RawURLEncoding.EncodeToString(hash) 84 for _, artifact := range artifacts { 85 dir := path.Join(arch, artifact.Package, artifact.Target, hashStr) 86 file := path.Join(dir, artifact.File) 87 if err := cache.StoreArtifact(file, artifact.Body, artifact.Symlink); err != nil { 88 return false 89 } 90 go cache.StoreMetadata(dir, hostname, address, peer) 91 } 92 return true 93 } 94 95 // Retrieve implements the Retrieve RPC to retrieve artifacts from the cache. 96 func (r *RPCCacheServer) Retrieve(ctx context.Context, req *pb.RetrieveRequest) (*pb.RetrieveResponse, error) { 97 if err := r.authenticateClient(ctx, r.readonlyKeys); err != nil { 98 return nil, err 99 } 100 response := pb.RetrieveResponse{Success: true} 101 arch := req.Os + "_" + req.Arch 102 hash := base64.RawURLEncoding.EncodeToString(req.Hash) 103 total := 0 104 for _, artifact := range req.Artifacts { 105 root := path.Join(arch, artifact.Package, artifact.Target, hash) 106 fileRoot := path.Join(root, artifact.File) 107 arts, err := r.cache.RetrieveArtifact(fileRoot) 108 if err != nil { 109 log.Debug("Failed to retrieve artifact %s: %s", fileRoot, err) 110 r.retrieveFailures.WithLabelValues(req.Arch).Inc() 111 return &pb.RetrieveResponse{Success: false}, nil 112 } 113 for _, art := range arts { 114 response.Artifacts = append(response.Artifacts, &pb.Artifact{ 115 Package: artifact.Package, 116 Target: artifact.Target, 117 File: art.File[len(root)+1:], 118 Body: art.Body, 119 Symlink: art.Symlink, 120 }) 121 total += len(art.Body) 122 } 123 } 124 r.retrievedCounter.WithLabelValues(req.Arch).Inc() 125 r.retrievedBytes.WithLabelValues(req.Arch).Add(float64(total)) 126 return &response, nil 127 } 128 129 // Delete implements the Delete RPC to delete an artifact from the cache. 130 func (r *RPCCacheServer) Delete(ctx context.Context, req *pb.DeleteRequest) (*pb.DeleteResponse, error) { 131 if err := r.authenticateClient(ctx, r.writableKeys); err != nil { 132 return nil, err 133 } 134 if req.Everything { 135 return &pb.DeleteResponse{Success: r.cache.DeleteAllArtifacts() == nil}, nil 136 } 137 success := deleteArtifact(r.cache, req.Os, req.Arch, req.Artifacts) 138 if success && r.cluster != nil { 139 // Delete this artifact from other nodes. Doesn't have to be done synchronously. 140 go r.cluster.DeleteArtifacts(req) 141 } 142 return &pb.DeleteResponse{Success: success}, nil 143 } 144 145 // deleteArtifact handles the actual removal of artifacts from the cache. 146 // It's split out from Delete to share with replication RPCs below. 147 func deleteArtifact(cache *Cache, os, arch string, artifacts []*pb.Artifact) bool { 148 success := true 149 for _, artifact := range artifacts { 150 if cache.DeleteArtifact(path.Join(os+"_"+arch, artifact.Package, artifact.Target)) != nil { 151 success = false 152 } 153 } 154 return success 155 } 156 157 // ListNodes implements the RPC for clustered servers. 158 func (r *RPCCacheServer) ListNodes(ctx context.Context, req *pb.ListRequest) (*pb.ListResponse, error) { 159 if err := r.authenticateClient(ctx, r.readonlyKeys); err != nil { 160 return nil, err 161 } 162 if r.cluster == nil { 163 return &pb.ListResponse{}, nil 164 } 165 return &pb.ListResponse{Nodes: r.cluster.GetMembers()}, nil 166 } 167 168 func (r *RPCCacheServer) authenticateClient(ctx context.Context, certs map[string]*x509.Certificate) error { 169 if len(certs) == 0 { 170 return nil // Open to anyone. 171 } 172 p, ok := peer.FromContext(ctx) 173 if !ok { 174 return status.Error(codes.Unauthenticated, "Missing client certificate") 175 } 176 info, ok := p.AuthInfo.(credentials.TLSInfo) 177 if !ok { 178 return status.Error(codes.Unauthenticated, "Could not extract auth info") 179 } 180 if len(info.State.PeerCertificates) == 0 { 181 return status.Error(codes.Unauthenticated, "No peer certificate available") 182 } 183 cert := info.State.PeerCertificates[0] 184 okCert := certs[string(cert.RawSubject)] 185 if okCert == nil || !okCert.Equal(cert) { 186 return status.Error(codes.Unauthenticated, "Invalid or unknown certificate") 187 } 188 return nil 189 } 190 191 func extractAddress(ctx context.Context) string { 192 p, ok := peer.FromContext(ctx) 193 if !ok { 194 return "" 195 } 196 return p.Addr.String() 197 } 198 199 func loadKeys(filename string) map[string]*x509.Certificate { 200 ret := map[string]*x509.Certificate{} 201 if err := fs.Walk(filename, func(name string, isDir bool) error { 202 if !isDir { 203 data, err := ioutil.ReadFile(name) 204 if err != nil { 205 log.Fatalf("Failed to read cert from %s: %s", name, err) 206 } 207 p, _ := pem.Decode(data) 208 if p == nil { 209 log.Fatalf("Couldn't decode PEM data from %s: %s", name, err) 210 } 211 cert, err := x509.ParseCertificate(p.Bytes) 212 if err != nil { 213 log.Fatalf("Couldn't parse certificate from %s: %s", name, err) 214 } 215 ret[string(cert.RawSubject)] = cert 216 } 217 return nil 218 }); err != nil { 219 log.Fatalf("%s", err) 220 } 221 return ret 222 } 223 224 // RPCServer implements the gRPC server for communication between cache nodes. 225 type RPCServer struct { 226 cache *Cache 227 cluster *cluster.Cluster 228 } 229 230 // Join implements the Join RPC for a new server joining the cluster. 231 func (r *RPCServer) Join(ctx context.Context, req *pb.JoinRequest) (*pb.JoinResponse, error) { 232 // TODO(pebers): Authentication. 233 return r.cluster.AddNode(req), nil 234 } 235 236 // Replicate implements the Replicate RPC for replicating an artifact from another node. 237 func (r *RPCServer) Replicate(ctx context.Context, req *pb.ReplicateRequest) (*pb.ReplicateResponse, error) { 238 // TODO(pebers): Authentication. 239 if req.Delete { 240 return &pb.ReplicateResponse{ 241 Success: deleteArtifact(r.cache, req.Os, req.Arch, req.Artifacts), 242 }, nil 243 } 244 return &pb.ReplicateResponse{ 245 Success: storeArtifact(r.cache, req.Os, req.Arch, req.Hash, req.Artifacts, req.Hostname, extractAddress(ctx), req.Peer), 246 }, nil 247 } 248 249 // BuildGrpcServer creates a new, unstarted grpc.Server and returns it. 250 // It also returns a net.Listener to start it on. 251 func BuildGrpcServer(port int, cache *Cache, cluster *cluster.Cluster, keyFile, certFile, caCertFile, readonlyKeys, writableKeys string) (*grpc.Server, net.Listener) { 252 lis, err := net.Listen("tcp", fmt.Sprintf(":%d", port)) 253 if err != nil { 254 log.Fatalf("Failed to listen on port %d: %v", port, err) 255 } 256 s := serverWithAuth(keyFile, certFile, caCertFile) 257 r := &RPCCacheServer{ 258 cache: cache, 259 cluster: cluster, 260 retrievedCounter: prometheus.NewCounterVec(prometheus.CounterOpts{ 261 Name: "retrieved_count", 262 Help: "Number of artifacts successfully retrieved", 263 }, []string{"arch"}), 264 storedCounter: prometheus.NewCounterVec(prometheus.CounterOpts{ 265 Name: "stored_count", 266 Help: "Number of artifacts successfully stored", 267 }, []string{"arch"}), 268 retrievedBytes: prometheus.NewCounterVec(prometheus.CounterOpts{ 269 Name: "retrieved_bytes", 270 Help: "Number of bytes successfully retrieved", 271 }, []string{"arch"}), 272 storedBytes: prometheus.NewCounterVec(prometheus.CounterOpts{ 273 Name: "stored_bytes", 274 Help: "Number of bytes successfully stored", 275 }, []string{"arch"}), 276 retrieveFailures: prometheus.NewCounterVec(prometheus.CounterOpts{ 277 Name: "retrieve_failures", 278 Help: "Number of failed retrieval attempts", 279 }, []string{"arch"}), 280 } 281 if writableKeys != "" { 282 r.writableKeys = loadKeys(writableKeys) 283 } 284 if readonlyKeys != "" { 285 r.readonlyKeys = loadKeys(readonlyKeys) 286 if len(r.readonlyKeys) > 0 { 287 // This saves duplication when checking later; writable keys are implicitly readable too. 288 for k, v := range r.writableKeys { 289 if _, present := r.readonlyKeys[k]; !present { 290 r.readonlyKeys[k] = v 291 } 292 } 293 } 294 } 295 r2 := &RPCServer{cache: cache, cluster: cluster} 296 pb.RegisterRpcCacheServer(s, r) 297 pb.RegisterRpcServerServer(s, r2) 298 healthserver := health.NewServer() 299 healthserver.SetServingStatus("plz-rpc-cache", healthpb.HealthCheckResponse_SERVING) 300 healthpb.RegisterHealthServer(s, healthserver) 301 metricsOnce.Do(func() { 302 prometheus.MustRegister(r.retrievedCounter) 303 prometheus.MustRegister(r.storedCounter) 304 prometheus.MustRegister(r.retrievedBytes) 305 prometheus.MustRegister(r.storedBytes) 306 prometheus.MustRegister(r.retrieveFailures) 307 }) 308 return s, lis 309 } 310 311 // ServeGrpcForever serves gRPC until killed using the given server. 312 // It's very simple and provided as a convenience so callers don't have to import grpc themselves. 313 func ServeGrpcForever(server *grpc.Server, lis net.Listener) { 314 log.Notice("Serving RPC cache on %s", lis.Addr()) 315 go handleSignals(server) 316 server.Serve(lis) 317 } 318 319 // serverWithAuth builds a gRPC server, possibly with authentication if key / cert files are given. 320 func serverWithAuth(keyFile, certFile, caCertFile string) *grpc.Server { 321 if keyFile == "" { 322 return grpc.NewServer(grpc.MaxRecvMsgSize(maxMsgSize), grpc.MaxSendMsgSize(maxMsgSize)) // No auth. 323 } 324 log.Debug("Loading x509 key pair from key: %s cert: %s", keyFile, certFile) 325 cert, err := tls.LoadX509KeyPair(certFile, keyFile) 326 if err != nil { 327 log.Fatalf("Failed to load x509 key pair: %s", err) 328 } 329 config := tls.Config{ 330 Certificates: []tls.Certificate{cert}, 331 ClientAuth: tls.RequestClientCert, 332 } 333 if caCertFile != "" { 334 cert, err := ioutil.ReadFile(caCertFile) 335 if err != nil { 336 log.Fatalf("Failed to read CA cert file: %s", err) 337 } 338 config.ClientCAs = x509.NewCertPool() 339 if !config.ClientCAs.AppendCertsFromPEM(cert) { 340 log.Fatalf("Failed to find any PEM certificates in CA cert") 341 } 342 } 343 return grpc.NewServer( 344 grpc.Creds(credentials.NewTLS(&config)), 345 grpc.MaxRecvMsgSize(maxMsgSize), 346 grpc.MaxSendMsgSize(maxMsgSize), 347 grpc.UnaryInterceptor(grpc_prometheus.UnaryServerInterceptor), 348 grpc.StreamInterceptor(grpc_prometheus.StreamServerInterceptor), 349 ) 350 } 351 352 // handleSignals received SIGTERM / SIGINT etc to gracefully shut down a gRPC server. 353 // Repeated signals cause the server to terminate at increasing levels of urgency. 354 func handleSignals(s *grpc.Server) { 355 c := make(chan os.Signal, 3) // Channel should be buffered a bit 356 signal.Notify(c, syscall.SIGTERM, syscall.SIGINT, syscall.SIGHUP) 357 sig := <-c 358 log.Warning("Received signal %s, gracefully shutting down gRPC server", sig) 359 go s.GracefulStop() 360 sig = <-c 361 log.Warning("Received signal %s, non-gracefully shutting down gRPC server", sig) 362 go s.Stop() 363 sig = <-c 364 log.Fatalf("Received signal %s, terminating\n", sig) 365 }