github.com/bazelbuild/remote-apis-sdks@v0.0.0-20240425170053-8a36686a6350/go/pkg/client/client.go (about) 1 // Package client contains a high-level remote execution client library. 2 package client 3 4 import ( 5 "context" 6 "crypto/tls" 7 "crypto/x509" 8 "fmt" 9 "net/http" 10 "os" 11 "os/user" 12 "strings" 13 "sync" 14 "time" 15 16 "github.com/bazelbuild/remote-apis-sdks/go/pkg/actas" 17 "github.com/bazelbuild/remote-apis-sdks/go/pkg/balancer" 18 "github.com/bazelbuild/remote-apis-sdks/go/pkg/chunker" 19 "github.com/bazelbuild/remote-apis-sdks/go/pkg/digest" 20 "github.com/bazelbuild/remote-apis-sdks/go/pkg/retry" 21 "github.com/bazelbuild/remote-apis-sdks/go/pkg/uploadinfo" 22 "github.com/pkg/errors" 23 "golang.org/x/oauth2" 24 "golang.org/x/sync/semaphore" 25 "google.golang.org/grpc" 26 "google.golang.org/grpc/credentials" 27 "google.golang.org/grpc/credentials/oauth" 28 "google.golang.org/grpc/status" 29 30 // Redundant imports are required for the google3 mirror. Aliases should not be changed. 31 configpb "github.com/bazelbuild/remote-apis-sdks/go/pkg/balancer/proto" 32 regrpc "github.com/bazelbuild/remote-apis/build/bazel/remote/execution/v2" 33 repb "github.com/bazelbuild/remote-apis/build/bazel/remote/execution/v2" 34 log "github.com/golang/glog" 35 bsgrpc "google.golang.org/genproto/googleapis/bytestream" 36 bspb "google.golang.org/genproto/googleapis/bytestream" 37 opgrpc "google.golang.org/genproto/googleapis/longrunning" 38 oppb "google.golang.org/genproto/googleapis/longrunning" 39 emptypb "google.golang.org/protobuf/types/known/emptypb" 40 ) 41 42 const ( 43 scopes = "https://www.googleapis.com/auth/cloud-platform" 44 45 // HomeDirMacro is replaced by the current user's home dir in the CredFile dial parameter. 46 HomeDirMacro = "${HOME}" 47 ) 48 49 // ErrEmptySegment indicates an attempt to construct a resource name with an empty segment. 50 var ErrEmptySegment = errors.New("empty segment in resoure name") 51 52 // AuthType indicates the type of authentication being used. 53 type AuthType int 54 55 const ( 56 // UnknownAuth refers to unknown authentication type. 57 UnknownAuth AuthType = iota 58 59 // NoAuth refers to no authentication when connecting to the RBE service. 60 NoAuth 61 62 // ExternalTokenAuth is used to connect to the RBE service. 63 ExternalTokenAuth 64 65 // CredsFileAuth refers to a JSON credentials file used to connect to the RBE service. 66 CredsFileAuth 67 68 // ApplicationDefaultCredsAuth refers to Google Application default credentials that is 69 // used to connect to the RBE service. 70 ApplicationDefaultCredsAuth 71 72 // GCECredsAuth refers to GCE machine credentials that is 73 // used to connect to the RBE service. 74 GCECredsAuth 75 ) 76 77 // String returns a human readable form of authentication used to connect to RBE. 78 func (a AuthType) String() string { 79 switch a { 80 case NoAuth: 81 return "no authentication" 82 case ExternalTokenAuth: 83 return "external authentication token (gcert?)" 84 case CredsFileAuth: 85 return "credentials file" 86 case ApplicationDefaultCredsAuth: 87 return "application default credentials" 88 case GCECredsAuth: 89 return "gce credentials" 90 } 91 return "unknown authentication type" 92 } 93 94 // InitError is used to wrap the error returned when initializing a new 95 // client to also indicate the type of authentication used. 96 type InitError struct { 97 // Err refers to the underlying client initialization error. 98 Err error 99 100 // AuthUsed stores the type of authentication used to connect to RBE. 101 AuthUsed AuthType 102 } 103 104 // Error returns a string error that includes information about the 105 // type of auth used to connect to RBE. 106 func (ce *InitError) Error() string { 107 return fmt.Sprintf("%v, authentication type (identity) used=%q", ce.Err.Error(), ce.AuthUsed) 108 } 109 110 // Client is a client to several services, including remote execution and services used in 111 // conjunction with remote execution. A Client must be constructed by calling Dial() or NewClient() 112 // rather than attempting to assemble it directly. 113 // 114 // Unless specified otherwise, and provided the fields are not modified, a Client is safe for 115 // concurrent use. 116 type Client struct { 117 // InstanceName is the instance name for the targeted remote execution instance; e.g. for Google 118 // RBE: "projects/<foo>/instances/default_instance". 119 // It should NOT be used to construct resource names, but rather only for reusing the instance name as is. 120 // Use the ResourceName method to create correctly formatted resource names. 121 InstanceName string 122 actionCache regrpc.ActionCacheClient 123 byteStream bsgrpc.ByteStreamClient 124 cas regrpc.ContentAddressableStorageClient 125 execution regrpc.ExecutionClient 126 operations opgrpc.OperationsClient 127 // Retrier is the Retrier that is used for RPCs made by this client. 128 // 129 // These fields are logically "protected" and are intended for use by extensions of Client. 130 Retrier *Retrier 131 Connection *grpc.ClientConn 132 CASConnection *grpc.ClientConn // Can be different from Connection a separate CAS endpoint is provided. 133 // StartupCapabilities denotes whether to load ServerCapabilities on startup. 134 StartupCapabilities StartupCapabilities 135 // LegacyExecRootRelativeOutputs denotes whether outputs are relative to the exec root. 136 LegacyExecRootRelativeOutputs LegacyExecRootRelativeOutputs 137 // ChunkMaxSize is maximum chunk size to use for CAS uploads/downloads. 138 ChunkMaxSize ChunkMaxSize 139 // CompressedBytestreamThreshold is the threshold in bytes for which blobs are read and written 140 // compressed. Use 0 for all writes being compressed, and a negative number for all operations being 141 // uncompressed. 142 CompressedBytestreamThreshold CompressedBytestreamThreshold 143 // UploadCompressionPredicate is a function called to decide whether a blob should be compressed for upload. 144 UploadCompressionPredicate UploadCompressionPredicate 145 // MaxBatchDigests is maximum amount of digests to batch in upload and download operations. 146 MaxBatchDigests MaxBatchDigests 147 // MaxQueryBatchDigests is maximum amount of digests to batch in CAS query operations. 148 MaxQueryBatchDigests MaxQueryBatchDigests 149 // MaxBatchSize is maximum size in bytes of a batch request for batch operations. 150 MaxBatchSize MaxBatchSize 151 // DirMode is mode used to create directories. 152 DirMode os.FileMode 153 // ExecutableMode is mode used to create executable files. 154 ExecutableMode os.FileMode 155 // RegularMode is mode used to create non-executable files. 156 RegularMode os.FileMode 157 // UtilizeLocality is to specify whether client downloads files utilizing disk access locality. 158 UtilizeLocality UtilizeLocality 159 // UnifiedUploads specifies whether the client uploads files in the background. 160 UnifiedUploads UnifiedUploads 161 // UnifiedUploadBufferSize specifies when the unified upload daemon flushes the pending requests. 162 UnifiedUploadBufferSize UnifiedUploadBufferSize 163 // UnifiedUploadTickDuration specifies how often the unified upload daemon flushes the pending requests. 164 UnifiedUploadTickDuration UnifiedUploadTickDuration 165 // UnifiedDownloads specifies whether the client downloads files in the background. 166 UnifiedDownloads UnifiedDownloads 167 // UnifiedDownloadBufferSize specifies when the unified download daemon flushes the pending requests. 168 UnifiedDownloadBufferSize UnifiedDownloadBufferSize 169 // UnifiedDownloadTickDuration specifies how often the unified download daemon flushes the pending requests. 170 UnifiedDownloadTickDuration UnifiedDownloadTickDuration 171 // TreeSymlinkOpts controls how symlinks are handled when constructing a tree. 172 TreeSymlinkOpts *TreeSymlinkOpts 173 174 serverCaps *repb.ServerCapabilities 175 useBatchOps UseBatchOps 176 casConcurrency int64 177 casUploaders *semaphore.Weighted 178 casUploadRequests chan *uploadRequest 179 casUploads map[digest.Digest]*uploadState 180 casDownloaders *semaphore.Weighted 181 casDownloadRequests chan *downloadRequest 182 rpcTimeouts RPCTimeouts 183 creds credentials.PerRPCCredentials 184 uploadOnce sync.Once 185 downloadOnce sync.Once 186 useBatchCompression UseBatchCompression 187 } 188 189 const ( 190 // DefaultMaxBatchSize is the maximum size of a batch to upload with BatchWriteBlobs. We set it to slightly 191 // below 4 MB, because that is the limit of a message size in gRPC 192 DefaultMaxBatchSize = 4*1024*1024 - 1024 193 194 // DefaultMaxBatchDigests is a suggested approximate limit based on current RBE implementation. 195 // Above that BatchUpdateBlobs calls start to exceed a typical minute timeout. 196 DefaultMaxBatchDigests = 4000 197 198 // DefaultMaxQueryBatchDigests is a suggested limit for the number of items for in batch for a missing blobs query. 199 DefaultMaxQueryBatchDigests = 10_000 200 201 // DefaultDirMode is mode used to create directories. 202 DefaultDirMode = 0777 203 204 // DefaultExecutableMode is mode used to create executable files. 205 DefaultExecutableMode = 0777 206 207 // DefaultRegularMode is mode used to create non-executable files. 208 DefaultRegularMode = 0644 209 ) 210 211 // Close closes the underlying gRPC connection(s). 212 func (c *Client) Close() error { 213 // Close the channels & stop background operations. 214 UnifiedUploads(false).Apply(c) 215 UnifiedDownloads(false).Apply(c) 216 err := c.Connection.Close() 217 if err != nil { 218 return err 219 } 220 if c.CASConnection != c.Connection { 221 return c.CASConnection.Close() 222 } 223 return nil 224 } 225 226 // Opt is an option that can be passed to Dial in order to configure the behaviour of the client. 227 type Opt interface { 228 Apply(*Client) 229 } 230 231 // ChunkMaxSize is maximum chunk size to use in Bytestream wrappers. 232 type ChunkMaxSize int 233 234 // Apply sets the client's maximal chunk size s. 235 func (s ChunkMaxSize) Apply(c *Client) { 236 c.ChunkMaxSize = s 237 } 238 239 // CompressedBytestreamThreshold is the threshold for compressing blobs when writing/reading. 240 // See comment in related field on the Client struct. 241 type CompressedBytestreamThreshold int64 242 243 // Apply sets the client's maximal chunk size s. 244 func (s CompressedBytestreamThreshold) Apply(c *Client) { 245 c.CompressedBytestreamThreshold = s 246 } 247 248 // An UploadCompressionPredicate determines whether to compress a blob on upload. 249 // Note that the CompressedBytestreamThreshold takes priority over this (i.e. if the blob to be uploaded 250 // is smaller than the threshold, this will not be called). 251 type UploadCompressionPredicate func(*uploadinfo.Entry) bool 252 253 // Apply sets the client's compression predicate. 254 func (cc UploadCompressionPredicate) Apply(c *Client) { 255 c.UploadCompressionPredicate = cc 256 } 257 258 // UtilizeLocality is to specify whether client downloads files utilizing disk access locality. 259 type UtilizeLocality bool 260 261 // Apply sets the client's UtilizeLocality. 262 func (s UtilizeLocality) Apply(c *Client) { 263 c.UtilizeLocality = s 264 } 265 266 // UnifiedUploads is to specify whether client uploads files in the background, unifying operations between different actions. 267 type UnifiedUploads bool 268 269 // Apply sets the client's UnifiedUploads. 270 func (s UnifiedUploads) Apply(c *Client) { 271 c.UnifiedUploads = s 272 } 273 274 // UnifiedUploadBufferSize is to tune when the daemon for UnifiedUploads flushes the pending requests. 275 type UnifiedUploadBufferSize int 276 277 // DefaultUnifiedUploadBufferSize is the default UnifiedUploadBufferSize. 278 const DefaultUnifiedUploadBufferSize = 10000 279 280 // Apply sets the client's UnifiedDownloadBufferSize. 281 func (s UnifiedUploadBufferSize) Apply(c *Client) { 282 c.UnifiedUploadBufferSize = s 283 } 284 285 // UnifiedUploadTickDuration is to tune how often the daemon for UnifiedUploads flushes the pending requests. 286 type UnifiedUploadTickDuration time.Duration 287 288 // DefaultUnifiedUploadTickDuration is the default UnifiedUploadTickDuration. 289 const DefaultUnifiedUploadTickDuration = UnifiedUploadTickDuration(50 * time.Millisecond) 290 291 // Apply sets the client's UnifiedUploadTickDuration. 292 func (s UnifiedUploadTickDuration) Apply(c *Client) { 293 c.UnifiedUploadTickDuration = s 294 } 295 296 // UnifiedDownloads is to specify whether client uploads files in the background, unifying operations between different actions. 297 type UnifiedDownloads bool 298 299 // Apply sets the client's UnifiedDownloads. 300 // Note: it is unsafe to change this property when connections are ongoing. 301 func (s UnifiedDownloads) Apply(c *Client) { 302 c.UnifiedDownloads = s 303 } 304 305 // UnifiedDownloadBufferSize is to tune when the daemon for UnifiedDownloads flushes the pending requests. 306 type UnifiedDownloadBufferSize int 307 308 // DefaultUnifiedDownloadBufferSize is the default UnifiedDownloadBufferSize. 309 const DefaultUnifiedDownloadBufferSize = 10000 310 311 // Apply sets the client's UnifiedDownloadBufferSize. 312 func (s UnifiedDownloadBufferSize) Apply(c *Client) { 313 c.UnifiedDownloadBufferSize = s 314 } 315 316 // UnifiedDownloadTickDuration is to tune how often the daemon for UnifiedDownloads flushes the pending requests. 317 type UnifiedDownloadTickDuration time.Duration 318 319 // DefaultUnifiedDownloadTickDuration is the default UnifiedDownloadTickDuration. 320 const DefaultUnifiedDownloadTickDuration = UnifiedDownloadTickDuration(50 * time.Millisecond) 321 322 // Apply sets the client's UnifiedDownloadTickDuration. 323 func (s UnifiedDownloadTickDuration) Apply(c *Client) { 324 c.UnifiedDownloadTickDuration = s 325 } 326 327 // Apply sets the client's TreeSymlinkOpts. 328 func (o *TreeSymlinkOpts) Apply(c *Client) { 329 c.TreeSymlinkOpts = o 330 } 331 332 // MaxBatchDigests is maximum amount of digests to batch in upload and download operations. 333 type MaxBatchDigests int 334 335 // Apply sets the client's maximal batch digests to s. 336 func (s MaxBatchDigests) Apply(c *Client) { 337 c.MaxBatchDigests = s 338 } 339 340 // MaxQueryBatchDigests is maximum amount of digests to batch in query operations. 341 type MaxQueryBatchDigests int 342 343 // Apply sets the client's maximal batch digests to s. 344 func (s MaxQueryBatchDigests) Apply(c *Client) { 345 c.MaxQueryBatchDigests = s 346 } 347 348 // MaxBatchSize is maximum size in bytes of a batch request for batch operations. 349 type MaxBatchSize int64 350 351 // Apply sets the client's maximum batch size to s. 352 func (s MaxBatchSize) Apply(c *Client) { 353 c.MaxBatchSize = s 354 } 355 356 // DirMode is mode used to create directories. 357 type DirMode os.FileMode 358 359 // Apply sets the client's DirMode to m. 360 func (m DirMode) Apply(c *Client) { 361 c.DirMode = os.FileMode(m) 362 } 363 364 // ExecutableMode is mode used to create executable files. 365 type ExecutableMode os.FileMode 366 367 // Apply sets the client's ExecutableMode to m. 368 func (m ExecutableMode) Apply(c *Client) { 369 c.ExecutableMode = os.FileMode(m) 370 } 371 372 // RegularMode is mode used to create non-executable files. 373 type RegularMode os.FileMode 374 375 // Apply sets the client's RegularMode to m. 376 func (m RegularMode) Apply(c *Client) { 377 c.RegularMode = os.FileMode(m) 378 } 379 380 // UseBatchOps can be set to true to use batch CAS operations when uploading multiple blobs, or 381 // false to always use individual ByteStream requests. 382 type UseBatchOps bool 383 384 // Apply sets the UseBatchOps flag on a client. 385 func (u UseBatchOps) Apply(c *Client) { 386 c.useBatchOps = u 387 } 388 389 // UseBatchCompression is currently set to true when the server has 390 // SupportedBatchUpdateCompressors capability and supports ZSTD compression. 391 type UseBatchCompression bool 392 393 // Apply sets the batchCompression flag on a client. 394 func (u UseBatchCompression) Apply(c *Client) { 395 c.useBatchCompression = u 396 } 397 398 // CASConcurrency is the number of simultaneous requests that will be issued for CAS upload and 399 // download operations. 400 type CASConcurrency int 401 402 // DefaultCASConcurrency is the default maximum number of concurrent upload and download operations. 403 const DefaultCASConcurrency = 500 404 405 // DefaultMaxConcurrentRequests specifies the default maximum number of concurrent requests on a single connection 406 // that the GRPC balancer can perform. 407 const DefaultMaxConcurrentRequests = 25 408 409 // DefaultMaxConcurrentStreams specifies the default threshold value at which the GRPC balancer should create 410 // new sub-connections. 411 const DefaultMaxConcurrentStreams = 25 412 413 // Apply sets the CASConcurrency flag on a client. 414 func (cy CASConcurrency) Apply(c *Client) { 415 c.casConcurrency = int64(cy) 416 c.casUploaders = semaphore.NewWeighted(c.casConcurrency) 417 c.casDownloaders = semaphore.NewWeighted(c.casConcurrency) 418 } 419 420 // StartupCapabilities controls whether the client should attempt to fetch the remote 421 // server capabilities on New. If set to true, some configuration such as MaxBatchSize 422 // is set according to the remote server capabilities instead of using the provided values. 423 type StartupCapabilities bool 424 425 // Apply sets the StartupCapabilities flag on a client. 426 func (s StartupCapabilities) Apply(c *Client) { 427 c.StartupCapabilities = s 428 } 429 430 // LegacyExecRootRelativeOutputs controls whether the client uses legacy behavior of 431 // treating output paths as relative to the exec root instead of the working directory. 432 type LegacyExecRootRelativeOutputs bool 433 434 // Apply sets the LegacyExecRootRelativeOutputs flag on a client. 435 func (l LegacyExecRootRelativeOutputs) Apply(c *Client) { 436 c.LegacyExecRootRelativeOutputs = l 437 } 438 439 // PerRPCCreds sets per-call options that will be set on all RPCs to the underlying connection. 440 type PerRPCCreds struct { 441 Creds credentials.PerRPCCredentials 442 } 443 444 // Apply saves the per-RPC creds in the Client. 445 func (p *PerRPCCreds) Apply(c *Client) { 446 c.creds = p.Creds 447 } 448 449 func getImpersonatedRPCCreds(ctx context.Context, actAs string, cred credentials.PerRPCCredentials) credentials.PerRPCCredentials { 450 // Wrap in a ReuseTokenSource to cache valid tokens in memory (i.e., non-nil, with a non-expired 451 // access token). 452 ts := oauth2.ReuseTokenSource( 453 nil, actas.NewTokenSource(ctx, cred, http.DefaultClient, actAs, []string{scopes})) 454 return oauth.TokenSource{ 455 TokenSource: ts, 456 } 457 } 458 459 func getRPCCreds(ctx context.Context, credFile string, useApplicationDefault bool, useComputeEngine bool) (credentials.PerRPCCredentials, AuthType, error) { 460 if useApplicationDefault { 461 c, err := oauth.NewApplicationDefault(ctx, scopes) 462 return c, ApplicationDefaultCredsAuth, err 463 } 464 if useComputeEngine { 465 return oauth.NewComputeEngine(), GCECredsAuth, nil 466 } 467 rpcCreds, err := oauth.NewServiceAccountFromFile(credFile, scopes) 468 if err != nil { 469 return nil, CredsFileAuth, fmt.Errorf("couldn't create RPC creds from %s: %v", credFile, err) 470 } 471 return rpcCreds, CredsFileAuth, nil 472 } 473 474 // DialParams contains all the parameters that Dial needs. 475 type DialParams struct { 476 // Service contains the address of remote execution service. 477 Service string 478 479 // CASService contains the address of the CAS service, if it is separate from 480 // the remote execution service. 481 CASService string 482 483 // UseApplicationDefault indicates that the default credentials should be used. 484 UseApplicationDefault bool 485 486 // UseComputeEngine indicates that the default CE credentials should be used. 487 UseComputeEngine bool 488 489 // UseExternalAuthToken indicates whether an externally specified auth token should be used. 490 // If set to true, ExternalPerRPCCreds should also be non-nil. 491 UseExternalAuthToken bool 492 493 // ExternalPerRPCCreds refers to the per RPC credentials that should be used for each RPC. 494 ExternalPerRPCCreds *PerRPCCreds 495 496 // CredFile is the JSON file that contains the credentials for RPCs. 497 CredFile string 498 499 // ActAsAccount is the service account to act as when making RPC calls. 500 ActAsAccount string 501 502 // NoSecurity is true if there is no security: no credentials are configured 503 // (NoAuth is implied) and grpc.WithInsecure() is passed in. Should only be 504 // used in test code. 505 NoSecurity bool 506 507 // NoAuth is true if TLS is enabled (NoSecurity is false) but the client does 508 // not need to authenticate with the server. 509 NoAuth bool 510 511 // TransportCredsOnly is true if it's the caller's responsibility to set per-RPC credentials 512 // on individual calls. This overrides ActAsAccount, UseApplicationDefault, and UseComputeEngine. 513 // This is not the same as NoSecurity, as transport credentials will still be set. 514 TransportCredsOnly bool 515 516 // TLSCACertFile is the PEM file that contains TLS root certificates. 517 TLSCACertFile string 518 519 // TLSServerName overrides the server name sent in TLS, if set to a non-empty string. 520 TLSServerName string 521 522 // DialOpts defines the set of gRPC DialOptions to apply, in addition to any used internally. 523 DialOpts []grpc.DialOption 524 525 // MaxConcurrentRequests specifies the maximum number of concurrent RPCs on a single connection. 526 MaxConcurrentRequests uint32 527 528 // MaxConcurrentStreams specifies the maximum number of concurrent stream RPCs on a single connection. 529 MaxConcurrentStreams uint32 530 531 // TLSClientAuthCert specifies the public key in PEM format for using mTLS auth to connect to the RBE service. 532 // 533 // If this is specified, TLSClientAuthKey must also be specified. 534 TLSClientAuthCert string 535 536 // TLSClientAuthKey specifies the private key for using mTLS auth to connect to the RBE service. 537 // 538 // If this is specified, TLSClientAuthCert must also be specified. 539 TLSClientAuthKey string 540 } 541 542 func createGRPCInterceptor(p DialParams) *balancer.GCPInterceptor { 543 apiConfig := &configpb.ApiConfig{ 544 ChannelPool: &configpb.ChannelPoolConfig{ 545 MaxSize: p.MaxConcurrentRequests, 546 MaxConcurrentStreamsLowWatermark: p.MaxConcurrentStreams, 547 }, 548 Method: []*configpb.MethodConfig{ 549 { 550 Name: []string{".*"}, 551 Affinity: &configpb.AffinityConfig{ 552 Command: configpb.AffinityConfig_BIND, 553 AffinityKey: "bind-affinity", 554 }, 555 }, 556 }, 557 } 558 return balancer.NewGCPInterceptor(apiConfig) 559 } 560 561 func createTLSConfig(params DialParams) (*tls.Config, error) { 562 var certPool *x509.CertPool 563 if params.TLSCACertFile != "" { 564 certPool = x509.NewCertPool() 565 ca, err := os.ReadFile(params.TLSCACertFile) 566 if err != nil { 567 return nil, fmt.Errorf("failed to read %s: %w", params.TLSCACertFile, err) 568 } 569 if ok := certPool.AppendCertsFromPEM(ca); !ok { 570 return nil, fmt.Errorf("failed to load TLS CA certificates from %s", params.TLSCACertFile) 571 } 572 } 573 574 var mTLSCredentials []tls.Certificate 575 if params.TLSClientAuthCert != "" || params.TLSClientAuthKey != "" { 576 if params.TLSClientAuthCert == "" || params.TLSClientAuthKey == "" { 577 return nil, fmt.Errorf("TLSClientAuthCert and TLSClientAuthKey must both be empty or both be set, got TLSClientAuthCert='%v' and TLSClientAuthKey='%v'", params.TLSClientAuthCert, params.TLSClientAuthKey) 578 } 579 580 cert, err := tls.LoadX509KeyPair(params.TLSClientAuthCert, params.TLSClientAuthKey) 581 if err != nil { 582 return nil, fmt.Errorf("failed to read mTLS cert pair ('%v', '%v'): %v", params.TLSClientAuthCert, params.TLSClientAuthKey, err) 583 } 584 mTLSCredentials = append(mTLSCredentials, cert) 585 } 586 587 c := &tls.Config{ 588 ServerName: params.TLSServerName, 589 RootCAs: certPool, 590 Certificates: mTLSCredentials, 591 } 592 return c, nil 593 } 594 595 // Dial dials a given endpoint and returns the grpc connection that is established. 596 func Dial(ctx context.Context, endpoint string, params DialParams) (*grpc.ClientConn, AuthType, error) { 597 var authUsed AuthType 598 599 var opts []grpc.DialOption 600 opts = append(opts, params.DialOpts...) 601 602 if params.MaxConcurrentRequests == 0 { 603 params.MaxConcurrentRequests = DefaultMaxConcurrentRequests 604 } 605 if params.MaxConcurrentStreams == 0 { 606 params.MaxConcurrentStreams = DefaultMaxConcurrentStreams 607 } 608 if params.NoSecurity { 609 authUsed = NoAuth 610 opts = append(opts, grpc.WithInsecure()) 611 } else if params.NoAuth { 612 authUsed = NoAuth 613 // Set the ServerName and RootCAs fields, if needed. 614 tlsConfig, err := createTLSConfig(params) 615 if err != nil { 616 return nil, authUsed, fmt.Errorf("could not create TLS config: %v", err) 617 } 618 opts = append(opts, grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig))) 619 } else if params.UseExternalAuthToken { 620 authUsed = ExternalTokenAuth 621 if params.ExternalPerRPCCreds == nil { 622 return nil, authUsed, fmt.Errorf("ExternalPerRPCCreds unspecified when using external auth token mechanism") 623 } 624 opts = append(opts, grpc.WithPerRPCCredentials(params.ExternalPerRPCCreds.Creds)) 625 // Set the ServerName and RootCAs fields, if needed. 626 tlsConfig, err := createTLSConfig(params) 627 if err != nil { 628 return nil, authUsed, fmt.Errorf("could not create TLS config: %v", err) 629 } 630 opts = append(opts, grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig))) 631 } else { 632 credFile := params.CredFile 633 if strings.Contains(credFile, HomeDirMacro) { 634 authUsed = CredsFileAuth 635 usr, err := user.Current() 636 if err != nil { 637 return nil, authUsed, fmt.Errorf("could not fetch home directory because of error determining current user: %v", err) 638 } 639 credFile = strings.Replace(credFile, HomeDirMacro, usr.HomeDir, -1 /* no limit */) 640 } 641 642 if !params.TransportCredsOnly { 643 var ( 644 rpcCreds credentials.PerRPCCredentials 645 err error 646 ) 647 rpcCreds, authUsed, err = getRPCCreds(ctx, credFile, params.UseApplicationDefault, params.UseComputeEngine) 648 if err != nil { 649 return nil, authUsed, fmt.Errorf("couldn't create RPC creds for %s: %v", scopes, err) 650 } 651 652 if params.ActAsAccount != "" { 653 rpcCreds = getImpersonatedRPCCreds(ctx, params.ActAsAccount, rpcCreds) 654 } 655 656 opts = append(opts, grpc.WithPerRPCCredentials(rpcCreds)) 657 } 658 tlsConfig, err := createTLSConfig(params) 659 if err != nil { 660 return nil, authUsed, fmt.Errorf("could not create TLS config: %v", err) 661 } 662 opts = append(opts, grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig))) 663 } 664 grpcInt := createGRPCInterceptor(params) 665 opts = append(opts, grpc.WithDisableServiceConfig()) 666 opts = append(opts, grpc.WithDefaultServiceConfig(fmt.Sprintf(`{"loadBalancingConfig": [{"%s":{}}]}`, balancer.Name))) 667 opts = append(opts, grpc.WithUnaryInterceptor(grpcInt.GCPUnaryClientInterceptor)) 668 opts = append(opts, grpc.WithStreamInterceptor(grpcInt.GCPStreamClientInterceptor)) 669 670 conn, err := grpc.Dial(endpoint, opts...) 671 if err != nil { 672 return nil, authUsed, fmt.Errorf("couldn't dial gRPC %q: %v", endpoint, err) 673 } 674 return conn, authUsed, nil 675 } 676 677 // DialRaw dials a remote execution service and returns the grpc connection that is established. 678 // TODO(olaola): remove this overload when all clients use Dial. 679 func DialRaw(ctx context.Context, params DialParams) (*grpc.ClientConn, AuthType, error) { 680 if params.Service == "" { 681 return nil, UnknownAuth, fmt.Errorf("service needs to be specified") 682 } 683 log.Infof("Connecting to remote execution service %s", params.Service) 684 return Dial(ctx, params.Service, params) 685 } 686 687 // NewClient connects to a remote execution service and returns a client suitable for higher-level 688 // functionality. 689 func NewClient(ctx context.Context, instanceName string, params DialParams, opts ...Opt) (*Client, error) { 690 if instanceName == "" { 691 log.Warning("Instance name was not specified.") 692 } 693 if params.Service == "" { 694 return nil, &InitError{Err: fmt.Errorf("service needs to be specified")} 695 } 696 log.Infof("Connecting to remote execution instance %s", instanceName) 697 log.Infof("Connecting to remote execution service %s", params.Service) 698 conn, authUsed, err := Dial(ctx, params.Service, params) 699 casConn := conn 700 if params.CASService != "" && params.CASService != params.Service { 701 log.Infof("Connecting to CAS service %s", params.CASService) 702 casConn, authUsed, err = Dial(ctx, params.CASService, params) 703 } 704 if err != nil { 705 return nil, &InitError{Err: statusWrap(err), AuthUsed: authUsed} 706 } 707 client, err := NewClientFromConnection(ctx, instanceName, conn, casConn, opts...) 708 if err != nil { 709 return nil, &InitError{Err: err, AuthUsed: authUsed} 710 } 711 return client, nil 712 } 713 714 // NewClientFromConnection creates a client from gRPC connections to a remote execution service and a cas service. 715 func NewClientFromConnection(ctx context.Context, instanceName string, conn, casConn *grpc.ClientConn, opts ...Opt) (*Client, error) { 716 if conn == nil { 717 return nil, fmt.Errorf("connection to remote execution service may not be nil") 718 } 719 if casConn == nil { 720 return nil, fmt.Errorf("connection to CAS service may not be nil") 721 } 722 client := &Client{ 723 InstanceName: instanceName, 724 actionCache: regrpc.NewActionCacheClient(casConn), 725 byteStream: bsgrpc.NewByteStreamClient(casConn), 726 cas: regrpc.NewContentAddressableStorageClient(casConn), 727 execution: regrpc.NewExecutionClient(conn), 728 operations: opgrpc.NewOperationsClient(conn), 729 rpcTimeouts: DefaultRPCTimeouts, 730 Connection: conn, 731 CASConnection: casConn, 732 CompressedBytestreamThreshold: DefaultCompressedBytestreamThreshold, 733 ChunkMaxSize: chunker.DefaultChunkSize, 734 MaxBatchDigests: DefaultMaxBatchDigests, 735 MaxQueryBatchDigests: DefaultMaxQueryBatchDigests, 736 MaxBatchSize: DefaultMaxBatchSize, 737 DirMode: DefaultDirMode, 738 ExecutableMode: DefaultExecutableMode, 739 RegularMode: DefaultRegularMode, 740 useBatchOps: true, 741 StartupCapabilities: true, 742 LegacyExecRootRelativeOutputs: false, 743 casConcurrency: DefaultCASConcurrency, 744 casUploaders: semaphore.NewWeighted(DefaultCASConcurrency), 745 casDownloaders: semaphore.NewWeighted(DefaultCASConcurrency), 746 casUploads: make(map[digest.Digest]*uploadState), 747 UnifiedUploadTickDuration: DefaultUnifiedUploadTickDuration, 748 UnifiedUploadBufferSize: DefaultUnifiedUploadBufferSize, 749 UnifiedDownloadTickDuration: DefaultUnifiedDownloadTickDuration, 750 UnifiedDownloadBufferSize: DefaultUnifiedDownloadBufferSize, 751 Retrier: RetryTransient(), 752 } 753 for _, o := range opts { 754 o.Apply(client) 755 } 756 if client.StartupCapabilities { 757 if err := client.CheckCapabilities(ctx); err != nil { 758 return nil, statusWrap(err) 759 } 760 } 761 if client.casConcurrency < 1 { 762 return nil, fmt.Errorf("CASConcurrency should be at least 1") 763 } 764 client.RunBackgroundTasks(ctx) 765 return client, nil 766 } 767 768 // RunBackgroundTasks starts background goroutines for the client. 769 func (c *Client) RunBackgroundTasks(ctx context.Context) { 770 if c.UnifiedUploads { 771 c.uploadOnce.Do(func() { 772 c.casUploadRequests = make(chan *uploadRequest, c.UnifiedUploadBufferSize) 773 go c.uploadProcessor(ctx) 774 }) 775 } 776 if c.UnifiedDownloads { 777 c.downloadOnce.Do(func() { 778 c.casDownloadRequests = make(chan *downloadRequest, c.UnifiedDownloadBufferSize) 779 go c.downloadProcessor(ctx) 780 }) 781 } 782 } 783 784 // RPCTimeouts is a Opt that sets the per-RPC deadline. 785 // The keys are RPC names. The "default" key, if present, is the default 786 // timeout. 0 values are valid and indicate no timeout. 787 type RPCTimeouts map[string]time.Duration 788 789 // Apply applies the timeouts to a Client. It overrides the provided values, 790 // but doesn't remove/alter any other present values. 791 func (d RPCTimeouts) Apply(c *Client) { 792 c.rpcTimeouts = map[string]time.Duration(d) 793 } 794 795 // DefaultRPCTimeouts contains the default timeout of various RPC calls to RBE. 796 var DefaultRPCTimeouts = map[string]time.Duration{ 797 "default": 20 * time.Second, 798 "GetCapabilities": 5 * time.Second, 799 "BatchUpdateBlobs": time.Minute, 800 "BatchReadBlobs": time.Minute, 801 "GetTree": time.Minute, 802 // Note: due to an implementation detail, WaitExecution will use the same 803 // per-RPC timeout as Execute. It is extremely ill-advised to set the Execute 804 // timeout at above 0; most users should use the Action Timeout instead. 805 "Execute": 0, 806 "WaitExecution": 0, 807 } 808 809 // ResourceName constructs a correctly formatted resource name as defined in the spec. 810 // No keyword validation is performed since the semantics of the path are defined by the server. 811 // See: https://github.com/bazelbuild/remote-apis/blob/cb8058798964f0adf6dbab2f4c2176ae2d653447/build/bazel/remote/execution/v2/remote_execution.proto#L223 812 func (c *Client) ResourceName(segments ...string) (string, error) { 813 segs := make([]string, 0, len(segments)+1) 814 if c.InstanceName != "" { 815 segs = append(segs, c.InstanceName) 816 } 817 for _, s := range segments { 818 if s == "" { 819 return "", ErrEmptySegment 820 } 821 segs = append(segs, s) 822 } 823 return strings.Join(segs, "/"), nil 824 } 825 826 // RPCOpts returns the default RPC options that should be used for calls made with this client. 827 // 828 // This method is logically "protected" and is intended for use by extensions of Client. 829 func (c *Client) RPCOpts() []grpc.CallOption { 830 // Set a high limit on receiving large messages from the server. 831 opts := []grpc.CallOption{grpc.MaxCallRecvMsgSize(100 * 1024 * 1024)} 832 if c.creds == nil { 833 return opts 834 } 835 return append(opts, grpc.PerRPCCredentials(c.creds)) 836 } 837 838 // CallWithTimeout executes the given function f with a context that times out after an RPC timeout. 839 // 840 // This method is logically "protected" and is intended for use by extensions of Client. 841 func (c *Client) CallWithTimeout(ctx context.Context, rpcName string, f func(ctx context.Context) error) error { 842 timeout, ok := c.rpcTimeouts[rpcName] 843 if !ok { 844 if timeout, ok = c.rpcTimeouts["default"]; !ok { 845 timeout = 0 846 } 847 } 848 if timeout == 0 { 849 return f(ctx) 850 } 851 childCtx, cancel := context.WithTimeout(ctx, timeout) 852 defer cancel() 853 e := f(childCtx) 854 if childCtx.Err() != nil { 855 return childCtx.Err() 856 } 857 return e 858 } 859 860 // Retrier applied to all client requests. 861 type Retrier struct { 862 Backoff retry.BackoffPolicy 863 ShouldRetry retry.ShouldRetry 864 } 865 866 // Apply sets the client's retrier function to r. 867 func (r *Retrier) Apply(c *Client) { 868 c.Retrier = r 869 } 870 871 // Do executes f() with retries. 872 // It can be called with a nil receiver; in that case no retries are done (just a passthrough call 873 // to f()). 874 func (r *Retrier) Do(ctx context.Context, f func() error) error { 875 if r == nil { 876 return f() 877 } 878 return retry.WithPolicy(ctx, r.ShouldRetry, r.Backoff, f) 879 } 880 881 // RetryTransient is a default retry policy for transient status codes. 882 func RetryTransient() *Retrier { 883 return &Retrier{ 884 Backoff: retry.ExponentialBackoff(225*time.Millisecond, 2*time.Second, retry.Attempts(6)), 885 ShouldRetry: retry.TransientOnly, 886 } 887 } 888 889 // GetActionResult wraps the underlying call with specific client options. 890 func (c *Client) GetActionResult(ctx context.Context, req *repb.GetActionResultRequest) (res *repb.ActionResult, err error) { 891 opts := c.RPCOpts() 892 err = c.Retrier.Do(ctx, func() (e error) { 893 return c.CallWithTimeout(ctx, "GetActionResult", func(ctx context.Context) (e error) { 894 res, e = c.actionCache.GetActionResult(ctx, req, opts...) 895 return e 896 }) 897 }) 898 if err != nil { 899 return nil, statusWrap(err) 900 } 901 return res, nil 902 } 903 904 // UpdateActionResult wraps the underlying call with specific client options. 905 func (c *Client) UpdateActionResult(ctx context.Context, req *repb.UpdateActionResultRequest) (res *repb.ActionResult, err error) { 906 opts := c.RPCOpts() 907 err = c.Retrier.Do(ctx, func() (e error) { 908 return c.CallWithTimeout(ctx, "UpdateActionResult", func(ctx context.Context) (e error) { 909 res, e = c.actionCache.UpdateActionResult(ctx, req, opts...) 910 return e 911 }) 912 }) 913 if err != nil { 914 return nil, statusWrap(err) 915 } 916 return res, nil 917 } 918 919 // Read wraps the underlying call with specific client options. 920 // The wrapper is here for completeness to provide access to the low-level 921 // RPCs. Prefer using higher-level functions such as ReadBlob(ToFile) instead, 922 // as they include retries/timeouts handling. 923 func (c *Client) Read(ctx context.Context, req *bspb.ReadRequest) (res bsgrpc.ByteStream_ReadClient, err error) { 924 return c.byteStream.Read(ctx, req, c.RPCOpts()...) 925 } 926 927 // Write wraps the underlying call with specific client options. 928 // The wrapper is here for completeness to provide access to the low-level 929 // RPCs. Prefer using higher-level functions such as WriteBlob(s) instead, 930 // as they include retries/timeouts handling. 931 func (c *Client) Write(ctx context.Context) (res bsgrpc.ByteStream_WriteClient, err error) { 932 return c.byteStream.Write(ctx, c.RPCOpts()...) 933 } 934 935 // QueryWriteStatus wraps the underlying call with specific client options. 936 func (c *Client) QueryWriteStatus(ctx context.Context, req *bspb.QueryWriteStatusRequest) (res *bspb.QueryWriteStatusResponse, err error) { 937 opts := c.RPCOpts() 938 err = c.Retrier.Do(ctx, func() (e error) { 939 return c.CallWithTimeout(ctx, "QueryWriteStatus", func(ctx context.Context) (e error) { 940 res, e = c.byteStream.QueryWriteStatus(ctx, req, opts...) 941 return e 942 }) 943 }) 944 if err != nil { 945 return nil, statusWrap(err) 946 } 947 return res, nil 948 } 949 950 // FindMissingBlobs wraps the underlying call with specific client options. 951 func (c *Client) FindMissingBlobs(ctx context.Context, req *repb.FindMissingBlobsRequest) (res *repb.FindMissingBlobsResponse, err error) { 952 opts := c.RPCOpts() 953 err = c.Retrier.Do(ctx, func() (e error) { 954 return c.CallWithTimeout(ctx, "FindMissingBlobs", func(ctx context.Context) (e error) { 955 res, e = c.cas.FindMissingBlobs(ctx, req, opts...) 956 return e 957 }) 958 }) 959 if err != nil { 960 return nil, statusWrap(err) 961 } 962 return res, nil 963 } 964 965 // BatchUpdateBlobs wraps the underlying call with specific client options. 966 // NOTE that its retry logic ignores the per-blob errors embedded in the response; you probably want 967 // to use BatchWriteBlobs() instead. 968 func (c *Client) BatchUpdateBlobs(ctx context.Context, req *repb.BatchUpdateBlobsRequest) (res *repb.BatchUpdateBlobsResponse, err error) { 969 opts := c.RPCOpts() 970 err = c.Retrier.Do(ctx, func() (e error) { 971 return c.CallWithTimeout(ctx, "BatchUpdateBlobs", func(ctx context.Context) (e error) { 972 res, e = c.cas.BatchUpdateBlobs(ctx, req, opts...) 973 return e 974 }) 975 }) 976 if err != nil { 977 return nil, statusWrap(err) 978 } 979 return res, nil 980 } 981 982 // BatchReadBlobs wraps the underlying call with specific client options. 983 // NOTE that its retry logic ignores the per-blob errors embedded in the response. 984 // It is recommended to use BatchDownloadBlobs instead. 985 func (c *Client) BatchReadBlobs(ctx context.Context, req *repb.BatchReadBlobsRequest) (res *repb.BatchReadBlobsResponse, err error) { 986 opts := c.RPCOpts() 987 err = c.Retrier.Do(ctx, func() (e error) { 988 return c.CallWithTimeout(ctx, "BatchReadBlobs", func(ctx context.Context) (e error) { 989 res, e = c.cas.BatchReadBlobs(ctx, req, opts...) 990 return e 991 }) 992 }) 993 if err != nil { 994 return nil, statusWrap(err) 995 } 996 return res, nil 997 } 998 999 // GetTree wraps the underlying call with specific client options. 1000 // The wrapper is here for completeness to provide access to the low-level 1001 // RPCs. Prefer using higher-level GetDirectoryTree instead, 1002 // as it includes retries/timeouts handling. 1003 func (c *Client) GetTree(ctx context.Context, req *repb.GetTreeRequest) (res regrpc.ContentAddressableStorage_GetTreeClient, err error) { 1004 return c.cas.GetTree(ctx, req, c.RPCOpts()...) 1005 } 1006 1007 // Execute wraps the underlying call with specific client options. 1008 // The wrapper is here for completeness to provide access to the low-level 1009 // RPCs. Prefer using higher-level ExecuteAndWait instead, 1010 // as it includes retries/timeouts handling. 1011 func (c *Client) Execute(ctx context.Context, req *repb.ExecuteRequest) (res regrpc.Execution_ExecuteClient, err error) { 1012 return c.execution.Execute(ctx, req, c.RPCOpts()...) 1013 } 1014 1015 // WaitExecution wraps the underlying call with specific client options. 1016 // The wrapper is here for completeness to provide access to the low-level 1017 // RPCs. Prefer using higher-level ExecuteAndWait instead, 1018 // as it includes retries/timeouts handling. 1019 func (c *Client) WaitExecution(ctx context.Context, req *repb.WaitExecutionRequest) (res regrpc.Execution_ExecuteClient, err error) { 1020 return c.execution.WaitExecution(ctx, req, c.RPCOpts()...) 1021 } 1022 1023 // GetBackendCapabilities returns the capabilities for a specific server connection 1024 // (either the main connection or the CAS connection). 1025 func (c *Client) GetBackendCapabilities(ctx context.Context, conn *grpc.ClientConn, req *repb.GetCapabilitiesRequest) (res *repb.ServerCapabilities, err error) { 1026 opts := c.RPCOpts() 1027 err = c.Retrier.Do(ctx, func() (e error) { 1028 return c.CallWithTimeout(ctx, "GetCapabilities", func(ctx context.Context) (e error) { 1029 res, e = regrpc.NewCapabilitiesClient(conn).GetCapabilities(ctx, req, opts...) 1030 return e 1031 }) 1032 }) 1033 if err != nil { 1034 return nil, err 1035 } 1036 return res, nil 1037 } 1038 1039 // GetOperation wraps the underlying call with specific client options. 1040 func (c *Client) GetOperation(ctx context.Context, req *oppb.GetOperationRequest) (res *oppb.Operation, err error) { 1041 opts := c.RPCOpts() 1042 err = c.Retrier.Do(ctx, func() (e error) { 1043 return c.CallWithTimeout(ctx, "GetOperation", func(ctx context.Context) (e error) { 1044 res, e = c.operations.GetOperation(ctx, req, opts...) 1045 return e 1046 }) 1047 }) 1048 if err != nil { 1049 return nil, statusWrap(err) 1050 } 1051 return res, nil 1052 } 1053 1054 // ListOperations wraps the underlying call with specific client options. 1055 func (c *Client) ListOperations(ctx context.Context, req *oppb.ListOperationsRequest) (res *oppb.ListOperationsResponse, err error) { 1056 opts := c.RPCOpts() 1057 err = c.Retrier.Do(ctx, func() (e error) { 1058 return c.CallWithTimeout(ctx, "ListOperations", func(ctx context.Context) (e error) { 1059 res, e = c.operations.ListOperations(ctx, req, opts...) 1060 return e 1061 }) 1062 }) 1063 if err != nil { 1064 return nil, statusWrap(err) 1065 } 1066 return res, nil 1067 } 1068 1069 // CancelOperation wraps the underlying call with specific client options. 1070 func (c *Client) CancelOperation(ctx context.Context, req *oppb.CancelOperationRequest) (res *emptypb.Empty, err error) { 1071 opts := c.RPCOpts() 1072 err = c.Retrier.Do(ctx, func() (e error) { 1073 return c.CallWithTimeout(ctx, "CancelOperation", func(ctx context.Context) (e error) { 1074 res, e = c.operations.CancelOperation(ctx, req, opts...) 1075 return e 1076 }) 1077 }) 1078 if err != nil { 1079 return nil, statusWrap(err) 1080 } 1081 return res, nil 1082 } 1083 1084 // DeleteOperation wraps the underlying call with specific client options. 1085 func (c *Client) DeleteOperation(ctx context.Context, req *oppb.DeleteOperationRequest) (res *emptypb.Empty, err error) { 1086 opts := c.RPCOpts() 1087 err = c.Retrier.Do(ctx, func() (e error) { 1088 return c.CallWithTimeout(ctx, "DeleteOperation", func(ctx context.Context) (e error) { 1089 res, e = c.operations.DeleteOperation(ctx, req, opts...) 1090 return e 1091 }) 1092 }) 1093 if err != nil { 1094 return nil, statusWrap(err) 1095 } 1096 return res, nil 1097 } 1098 1099 // gRPC errors are incompatible with simple wraps. See 1100 // https://github.com/grpc/grpc-go/issues/3115 1101 func statusWrap(err error) error { 1102 if st, ok := status.FromError(err); ok { 1103 return status.Errorf(st.Code(), errors.WithStack(err).Error()) 1104 } 1105 return errors.WithStack(err) 1106 }