github.com/vmware/govmomi@v0.37.1/vim25/soap/client.go (about) 1 /* 2 Copyright (c) 2014-2023 VMware, Inc. All Rights Reserved. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package soap 18 19 import ( 20 "bufio" 21 "bytes" 22 "context" 23 "crypto/sha1" 24 "crypto/sha256" 25 "crypto/tls" 26 "crypto/x509" 27 "encoding/json" 28 "errors" 29 "fmt" 30 "io" 31 "log" 32 "net" 33 "net/http" 34 "net/http/cookiejar" 35 "net/url" 36 "os" 37 "path" 38 "path/filepath" 39 "reflect" 40 "regexp" 41 "runtime" 42 "strings" 43 "sync" 44 45 "github.com/vmware/govmomi/internal/version" 46 "github.com/vmware/govmomi/vim25/progress" 47 "github.com/vmware/govmomi/vim25/types" 48 "github.com/vmware/govmomi/vim25/xml" 49 ) 50 51 type HasFault interface { 52 Fault() *Fault 53 } 54 55 type RoundTripper interface { 56 RoundTrip(ctx context.Context, req, res HasFault) error 57 } 58 59 const ( 60 SessionCookieName = "vmware_soap_session" 61 ) 62 63 // defaultUserAgent is the default user agent string, e.g. 64 // "govc govmomi/0.28.0 (go1.18.3;linux;amd64)" 65 var defaultUserAgent = fmt.Sprintf( 66 "%s %s/%s (%s)", 67 execName(), 68 version.ClientName, 69 version.ClientVersion, 70 strings.Join([]string{runtime.Version(), runtime.GOOS, runtime.GOARCH}, ";"), 71 ) 72 73 type Client struct { 74 http.Client 75 76 u *url.URL 77 k bool // Named after curl's -k flag 78 d *debugContainer 79 t *http.Transport 80 81 hostsMu sync.Mutex 82 hosts map[string]string 83 84 Namespace string `json:"namespace"` // Vim namespace 85 Version string `json:"version"` // Vim version 86 Types types.Func `json:"types"` 87 UserAgent string `json:"userAgent"` 88 89 cookie string 90 insecureCookies bool 91 92 useJSON bool 93 } 94 95 var schemeMatch = regexp.MustCompile(`^\w+://`) 96 97 type errInvalidCACertificate struct { 98 File string 99 } 100 101 func (e errInvalidCACertificate) Error() string { 102 return fmt.Sprintf( 103 "invalid certificate '%s', cannot be used as a trusted CA certificate", 104 e.File, 105 ) 106 } 107 108 // ParseURL is wrapper around url.Parse, where Scheme defaults to "https" and Path defaults to "/sdk" 109 func ParseURL(s string) (*url.URL, error) { 110 var err error 111 var u *url.URL 112 113 if s != "" { 114 // Default the scheme to https 115 if !schemeMatch.MatchString(s) { 116 s = "https://" + s 117 } 118 119 s := strings.TrimSuffix(s, "/") 120 u, err = url.Parse(s) 121 if err != nil { 122 return nil, err 123 } 124 125 // Default the path to /sdk 126 if u.Path == "" { 127 u.Path = "/sdk" 128 } 129 130 if u.User == nil { 131 u.User = url.UserPassword("", "") 132 } 133 } 134 135 return u, nil 136 } 137 138 func NewClient(u *url.URL, insecure bool) *Client { 139 var t *http.Transport 140 141 if d, ok := http.DefaultTransport.(*http.Transport); ok { 142 t = d.Clone() 143 } else { 144 t = new(http.Transport) 145 } 146 147 if insecure { 148 if t.TLSClientConfig == nil { 149 t.TLSClientConfig = new(tls.Config) 150 } 151 t.TLSClientConfig.InsecureSkipVerify = insecure 152 } 153 154 c := newClientWithTransport(u, insecure, t) 155 156 // Always set DialTLS and DialTLSContext, even if InsecureSkipVerify=true, 157 // because of how certificate verification has been delegated to the host's 158 // PKI framework in Go 1.18. Please see the following links for more info: 159 // 160 // * https://tip.golang.org/doc/go1.18 (search for "Certificate.Verify") 161 // * https://github.com/square/certigo/issues/264 162 t.DialTLSContext = c.dialTLSContext 163 164 return c 165 } 166 167 func newClientWithTransport(u *url.URL, insecure bool, t *http.Transport) *Client { 168 c := Client{ 169 u: u, 170 k: insecure, 171 d: newDebug(), 172 t: t, 173 174 Types: types.TypeFunc(), 175 } 176 177 c.hosts = make(map[string]string) 178 179 c.Client.Transport = c.t 180 c.Client.Jar, _ = cookiejar.New(nil) 181 182 // Remove user information from a copy of the URL 183 c.u = c.URL() 184 c.u.User = nil 185 186 if c.u.Scheme == "http" { 187 c.insecureCookies = os.Getenv("GOVMOMI_INSECURE_COOKIES") == "true" 188 } 189 190 return &c 191 } 192 193 func (c *Client) DefaultTransport() *http.Transport { 194 return c.t 195 } 196 197 // NewServiceClient creates a NewClient with the given URL.Path and namespace. 198 func (c *Client) NewServiceClient(path string, namespace string) *Client { 199 return c.newServiceClientWithTransport(path, namespace, c.t) 200 } 201 202 func (c *Client) newServiceClientWithTransport(path string, namespace string, t *http.Transport) *Client { 203 vc := c.URL() 204 u, err := url.Parse(path) 205 if err != nil { 206 log.Panicf("url.Parse(%q): %s", path, err) 207 } 208 if u.Host == "" { 209 u.Scheme = vc.Scheme 210 u.Host = vc.Host 211 } 212 213 client := newClientWithTransport(u, c.k, t) 214 client.Namespace = "urn:" + namespace 215 216 // Copy the trusted thumbprints 217 c.hostsMu.Lock() 218 for k, v := range c.hosts { 219 client.hosts[k] = v 220 } 221 c.hostsMu.Unlock() 222 223 // Copy the cookies 224 client.Client.Jar.SetCookies(u, c.Client.Jar.Cookies(u)) 225 226 // Set SOAP Header cookie 227 for _, cookie := range client.Jar.Cookies(u) { 228 if cookie.Name == SessionCookieName { 229 client.cookie = cookie.Value 230 break 231 } 232 } 233 234 // Copy any query params (e.g. GOVMOMI_TUNNEL_PROXY_PORT used in testing) 235 client.u.RawQuery = vc.RawQuery 236 237 client.UserAgent = c.UserAgent 238 239 vimTypes := c.Types 240 client.Types = func(name string) (reflect.Type, bool) { 241 kind, ok := vimTypes(name) 242 if ok { 243 return kind, ok 244 } 245 // vim25/xml typeToString() does not have an option to include namespace prefix. 246 // Workaround this by re-trying the lookup with the namespace prefix. 247 return vimTypes(namespace + ":" + name) 248 } 249 250 return client 251 } 252 253 // UseJSON changes the protocol between SOAP and JSON. Starting with vCenter 254 // 8.0.1 JSON over HTTP can be used. Note this method has no locking and clients 255 // should be careful to not interfere with concurrent use of the client 256 // instance. 257 func (c *Client) UseJSON(useJSON bool) { 258 c.useJSON = useJSON 259 } 260 261 // SetRootCAs defines the set of PEM-encoded file locations of root certificate 262 // authorities the client uses when verifying server certificates instead of the 263 // TLS defaults which uses the host's root CA set. Multiple PEM file locations 264 // can be specified using the OS-specific PathListSeparator. 265 // 266 // See: http.Client.Transport.TLSClientConfig.RootCAs and 267 // https://pkg.go.dev/os#PathListSeparator 268 func (c *Client) SetRootCAs(pemPaths string) error { 269 pool := x509.NewCertPool() 270 271 for _, name := range filepath.SplitList(pemPaths) { 272 pem, err := os.ReadFile(filepath.Clean(name)) 273 if err != nil { 274 return err 275 } 276 277 if ok := pool.AppendCertsFromPEM(pem); !ok { 278 return errInvalidCACertificate{ 279 File: name, 280 } 281 } 282 } 283 284 c.t.TLSClientConfig.RootCAs = pool 285 286 return nil 287 } 288 289 // Add default https port if missing 290 func hostAddr(addr string) string { 291 _, port := splitHostPort(addr) 292 if port == "" { 293 return addr + ":443" 294 } 295 return addr 296 } 297 298 // SetThumbprint sets the known certificate thumbprint for the given host. 299 // A custom DialTLS function is used to support thumbprint based verification. 300 // We first try tls.Dial with the default tls.Config, only falling back to thumbprint verification 301 // if it fails with an x509.UnknownAuthorityError or x509.HostnameError 302 // 303 // See: http.Client.Transport.DialTLS 304 func (c *Client) SetThumbprint(host string, thumbprint string) { 305 host = hostAddr(host) 306 307 c.hostsMu.Lock() 308 if thumbprint == "" { 309 delete(c.hosts, host) 310 } else { 311 c.hosts[host] = thumbprint 312 } 313 c.hostsMu.Unlock() 314 } 315 316 // Thumbprint returns the certificate thumbprint for the given host if known to this client. 317 func (c *Client) Thumbprint(host string) string { 318 host = hostAddr(host) 319 c.hostsMu.Lock() 320 defer c.hostsMu.Unlock() 321 return c.hosts[host] 322 } 323 324 // KnownThumbprint checks whether the provided thumbprint is known to this client. 325 func (c *Client) KnownThumbprint(tp string) bool { 326 c.hostsMu.Lock() 327 defer c.hostsMu.Unlock() 328 329 for _, v := range c.hosts { 330 if v == tp { 331 return true 332 } 333 } 334 335 return false 336 } 337 338 // LoadThumbprints from file with the give name. 339 // If name is empty or name does not exist this function will return nil. 340 func (c *Client) LoadThumbprints(file string) error { 341 if file == "" { 342 return nil 343 } 344 345 for _, name := range filepath.SplitList(file) { 346 err := c.loadThumbprints(name) 347 if err != nil { 348 return err 349 } 350 } 351 352 return nil 353 } 354 355 func (c *Client) loadThumbprints(name string) error { 356 f, err := os.Open(filepath.Clean(name)) 357 if err != nil { 358 if os.IsNotExist(err) { 359 return nil 360 } 361 return err 362 } 363 364 scanner := bufio.NewScanner(f) 365 366 for scanner.Scan() { 367 e := strings.SplitN(scanner.Text(), " ", 2) 368 if len(e) != 2 { 369 continue 370 } 371 372 c.SetThumbprint(e[0], e[1]) 373 } 374 375 _ = f.Close() 376 377 return scanner.Err() 378 } 379 380 // ThumbprintSHA1 returns the thumbprint of the given cert in the same format used by the SDK and Client.SetThumbprint. 381 // 382 // See: SSLVerifyFault.Thumbprint, SessionManagerGenericServiceTicket.Thumbprint, HostConnectSpec.SslThumbprint 383 func ThumbprintSHA1(cert *x509.Certificate) string { 384 sum := sha1.Sum(cert.Raw) 385 hex := make([]string, len(sum)) 386 for i, b := range sum { 387 hex[i] = fmt.Sprintf("%02X", b) 388 } 389 return strings.Join(hex, ":") 390 } 391 392 // ThumbprintSHA256 returns the sha256 thumbprint of the given cert. 393 func ThumbprintSHA256(cert *x509.Certificate) string { 394 sum := sha256.Sum256(cert.Raw) 395 hex := make([]string, len(sum)) 396 for i, b := range sum { 397 hex[i] = fmt.Sprintf("%02X", b) 398 } 399 return strings.Join(hex, ":") 400 } 401 402 func thumbprintMatches(thumbprint string, cert *x509.Certificate) bool { 403 return thumbprint == ThumbprintSHA256(cert) || thumbprint == ThumbprintSHA1(cert) 404 } 405 406 func (c *Client) dialTLSContext( 407 ctx context.Context, 408 network, addr string) (net.Conn, error) { 409 410 // Would be nice if there was a tls.Config.Verify func, 411 // see tls.clientHandshakeState.doFullHandshake 412 413 conn, err := tls.Dial(network, addr, c.t.TLSClientConfig) 414 415 if err == nil { 416 return conn, nil 417 } 418 419 // Allow a thumbprint verification attempt if the error indicates 420 // the failure was due to lack of trust. 421 if !IsCertificateUntrusted(err) { 422 return nil, err 423 } 424 425 thumbprint := c.Thumbprint(addr) 426 if thumbprint == "" { 427 return nil, err 428 } 429 430 config := &tls.Config{InsecureSkipVerify: true} 431 conn, err = tls.Dial(network, addr, config) 432 if err != nil { 433 return nil, err 434 } 435 436 cert := conn.ConnectionState().PeerCertificates[0] 437 if thumbprintMatches(thumbprint, cert) { 438 return conn, nil 439 } 440 441 _ = conn.Close() 442 443 return nil, fmt.Errorf("host %q thumbprint does not match %q", addr, thumbprint) 444 } 445 446 // splitHostPort is similar to net.SplitHostPort, 447 // but rather than return error if there isn't a ':port', 448 // return an empty string for the port. 449 func splitHostPort(host string) (string, string) { 450 ix := strings.LastIndex(host, ":") 451 452 if ix <= strings.LastIndex(host, "]") { 453 return host, "" 454 } 455 456 name := host[:ix] 457 port := host[ix+1:] 458 459 return name, port 460 } 461 462 const sdkTunnel = "sdkTunnel:8089" 463 464 // Certificate returns the current TLS certificate. 465 func (c *Client) Certificate() *tls.Certificate { 466 certs := c.t.TLSClientConfig.Certificates 467 if len(certs) == 0 { 468 return nil 469 } 470 return &certs[0] 471 } 472 473 // SetCertificate st a certificate for TLS use. 474 func (c *Client) SetCertificate(cert tls.Certificate) { 475 t := c.Client.Transport.(*http.Transport) 476 477 // Extension or HoK certificate 478 t.TLSClientConfig.Certificates = []tls.Certificate{cert} 479 } 480 481 // UseServiceVersion sets Client.Version to the current version of the service endpoint via /sdk/vimServiceVersions.xml 482 func (c *Client) UseServiceVersion(kind ...string) error { 483 ns := "vim" 484 if len(kind) != 0 { 485 ns = kind[0] 486 } 487 488 u := c.URL() 489 u.Path = path.Join("/sdk", ns+"ServiceVersions.xml") 490 491 res, err := c.Get(u.String()) 492 if err != nil { 493 return err 494 } 495 496 if res.StatusCode != http.StatusOK { 497 return fmt.Errorf("http.Get(%s): %s", u.Path, res.Status) 498 } 499 500 v := struct { 501 Namespace *string `xml:"namespace>name"` 502 Version *string `xml:"namespace>version"` 503 }{ 504 &c.Namespace, 505 &c.Version, 506 } 507 508 err = xml.NewDecoder(res.Body).Decode(&v) 509 _ = res.Body.Close() 510 if err != nil { 511 return fmt.Errorf("xml.Decode(%s): %s", u.Path, err) 512 } 513 514 return nil 515 } 516 517 // Tunnel returns a Client configured to proxy requests through vCenter's http port 80, 518 // to the SDK tunnel virtual host. Use of the SDK tunnel is required by LoginExtensionByCertificate() 519 // and optional for other methods. 520 func (c *Client) Tunnel() *Client { 521 tunnel := c.newServiceClientWithTransport(c.u.Path, c.Namespace, c.DefaultTransport().Clone()) 522 523 t := tunnel.Client.Transport.(*http.Transport) 524 // Proxy to vCenter host on port 80 525 host := tunnel.u.Hostname() 526 // Should be no reason to change the default port other than testing 527 key := "GOVMOMI_TUNNEL_PROXY_PORT" 528 529 port := tunnel.URL().Query().Get(key) 530 if port == "" { 531 port = os.Getenv(key) 532 } 533 534 if port != "" { 535 host += ":" + port 536 } 537 538 t.Proxy = http.ProxyURL(&url.URL{ 539 Scheme: "http", 540 Host: host, 541 }) 542 543 // Rewrite url Host to use the sdk tunnel, required for a certificate request. 544 tunnel.u.Host = sdkTunnel 545 return tunnel 546 } 547 548 // URL returns the URL to which the client is configured 549 func (c *Client) URL() *url.URL { 550 urlCopy := *c.u 551 return &urlCopy 552 } 553 554 type marshaledClient struct { 555 Cookies []*http.Cookie `json:"cookies"` 556 URL *url.URL `json:"url"` 557 Insecure bool `json:"insecure"` 558 Version string `json:"version"` 559 UseJSON bool `json:"useJSON"` 560 } 561 562 // MarshalJSON writes the Client configuration to JSON. 563 func (c *Client) MarshalJSON() ([]byte, error) { 564 m := marshaledClient{ 565 Cookies: c.Jar.Cookies(c.u), 566 URL: c.u, 567 Insecure: c.k, 568 Version: c.Version, 569 UseJSON: c.useJSON, 570 } 571 572 return json.Marshal(m) 573 } 574 575 // UnmarshalJSON rads Client configuration from JSON. 576 func (c *Client) UnmarshalJSON(b []byte) error { 577 var m marshaledClient 578 579 err := json.Unmarshal(b, &m) 580 if err != nil { 581 return err 582 } 583 584 *c = *NewClient(m.URL, m.Insecure) 585 c.Version = m.Version 586 c.Jar.SetCookies(m.URL, m.Cookies) 587 c.useJSON = m.UseJSON 588 589 return nil 590 } 591 592 func (c *Client) setInsecureCookies(res *http.Response) { 593 cookies := res.Cookies() 594 if len(cookies) != 0 { 595 for _, cookie := range cookies { 596 cookie.Secure = false 597 } 598 c.Jar.SetCookies(c.u, cookies) 599 } 600 } 601 602 // Do is equivalent to http.Client.Do and takes care of API specifics including 603 // logging, user-agent header, handling cookies, measuring responsiveness of the 604 // API 605 func (c *Client) Do(ctx context.Context, req *http.Request, f func(*http.Response) error) error { 606 if ctx == nil { 607 ctx = context.Background() 608 } 609 // Create debugging context for this round trip 610 d := c.d.newRoundTrip() 611 if d.enabled() { 612 defer d.done() 613 } 614 615 // use default 616 if c.UserAgent == "" { 617 c.UserAgent = defaultUserAgent 618 } 619 620 req.Header.Set(`User-Agent`, c.UserAgent) 621 622 ext := "" 623 if d.enabled() { 624 ext = d.debugRequest(req) 625 } 626 627 res, err := c.Client.Do(req.WithContext(ctx)) 628 if err != nil { 629 return err 630 } 631 632 if d.enabled() { 633 d.debugResponse(res, ext) 634 } 635 636 if c.insecureCookies { 637 c.setInsecureCookies(res) 638 } 639 640 defer res.Body.Close() 641 642 return f(res) 643 } 644 645 // Signer can be implemented by soap.Header.Security to sign requests. 646 // If the soap.Header.Security field is set to an implementation of Signer via WithHeader(), 647 // then Client.RoundTrip will call Sign() to marshal the SOAP request. 648 type Signer interface { 649 Sign(Envelope) ([]byte, error) 650 } 651 652 type headerContext struct{} 653 654 // WithHeader can be used to modify the outgoing request soap.Header fields. 655 func (c *Client) WithHeader(ctx context.Context, header Header) context.Context { 656 return context.WithValue(ctx, headerContext{}, header) 657 } 658 659 type statusError struct { 660 res *http.Response 661 } 662 663 // Temporary returns true for HTTP response codes that can be retried 664 // See vim25.IsTemporaryNetworkError 665 func (e *statusError) Temporary() bool { 666 switch e.res.StatusCode { 667 case http.StatusBadGateway: 668 return true 669 } 670 return false 671 } 672 673 func (e *statusError) Error() string { 674 return e.res.Status 675 } 676 677 func newStatusError(res *http.Response) error { 678 return &url.Error{ 679 Op: res.Request.Method, 680 URL: res.Request.URL.Path, 681 Err: &statusError{res}, 682 } 683 } 684 685 // RoundTrip executes an API request to VMOMI server. 686 func (c *Client) RoundTrip(ctx context.Context, reqBody, resBody HasFault) error { 687 if !c.useJSON { 688 return c.soapRoundTrip(ctx, reqBody, resBody) 689 } 690 return c.jsonRoundTrip(ctx, reqBody, resBody) 691 } 692 693 func (c *Client) soapRoundTrip(ctx context.Context, reqBody, resBody HasFault) error { 694 var err error 695 var b []byte 696 697 reqEnv := Envelope{Body: reqBody} 698 resEnv := Envelope{Body: resBody} 699 700 h, ok := ctx.Value(headerContext{}).(Header) 701 if !ok { 702 h = Header{} 703 } 704 705 // We added support for OperationID before soap.Header was exported. 706 if id, ok := ctx.Value(types.ID{}).(string); ok { 707 h.ID = id 708 } 709 710 h.Cookie = c.cookie 711 if h.Cookie != "" || h.ID != "" || h.Security != nil { 712 reqEnv.Header = &h // XML marshal header only if a field is set 713 } 714 715 if signer, ok := h.Security.(Signer); ok { 716 b, err = signer.Sign(reqEnv) 717 if err != nil { 718 return err 719 } 720 } else { 721 b, err = xml.Marshal(reqEnv) 722 if err != nil { 723 panic(err) 724 } 725 } 726 727 rawReqBody := io.MultiReader(strings.NewReader(xml.Header), bytes.NewReader(b)) 728 req, err := http.NewRequest("POST", c.u.String(), rawReqBody) 729 if err != nil { 730 panic(err) 731 } 732 733 req.Header.Set(`Content-Type`, `text/xml; charset="utf-8"`) 734 735 action := h.Action 736 if action == "" { 737 action = fmt.Sprintf("%s/%s", c.Namespace, c.Version) 738 } 739 req.Header.Set(`SOAPAction`, action) 740 741 return c.Do(ctx, req, func(res *http.Response) error { 742 switch res.StatusCode { 743 case http.StatusOK: 744 // OK 745 case http.StatusInternalServerError: 746 // Error, but typically includes a body explaining the error 747 default: 748 return newStatusError(res) 749 } 750 751 dec := xml.NewDecoder(res.Body) 752 dec.TypeFunc = c.Types 753 err = dec.Decode(&resEnv) 754 if err != nil { 755 return err 756 } 757 758 if f := resBody.Fault(); f != nil { 759 return WrapSoapFault(f) 760 } 761 762 return err 763 }) 764 } 765 766 func (c *Client) CloseIdleConnections() { 767 c.t.CloseIdleConnections() 768 } 769 770 // ParseURL wraps url.Parse to rewrite the URL.Host field 771 // In the case of VM guest uploads or NFC lease URLs, a Host 772 // field with a value of "*" is rewritten to the Client's URL.Host. 773 func (c *Client) ParseURL(urlStr string) (*url.URL, error) { 774 u, err := url.Parse(urlStr) 775 if err != nil { 776 return nil, err 777 } 778 779 host, _ := splitHostPort(u.Host) 780 if host == "*" { 781 // Also use Client's port, to support port forwarding 782 u.Host = c.URL().Host 783 } 784 785 return u, nil 786 } 787 788 type Upload struct { 789 Type string 790 Method string 791 ContentLength int64 792 Headers map[string]string 793 Ticket *http.Cookie 794 Progress progress.Sinker 795 Close bool 796 } 797 798 var DefaultUpload = Upload{ 799 Type: "application/octet-stream", 800 Method: "PUT", 801 } 802 803 // Upload PUTs the local file to the given URL 804 func (c *Client) Upload(ctx context.Context, f io.Reader, u *url.URL, param *Upload) error { 805 var err error 806 807 if param.Progress != nil { 808 pr := progress.NewReader(ctx, param.Progress, f, param.ContentLength) 809 f = pr 810 811 // Mark progress reader as done when returning from this function. 812 defer func() { 813 pr.Done(err) 814 }() 815 } 816 817 req, err := http.NewRequest(param.Method, u.String(), f) 818 if err != nil { 819 return err 820 } 821 822 req = req.WithContext(ctx) 823 req.Close = param.Close 824 req.ContentLength = param.ContentLength 825 req.Header.Set("Content-Type", param.Type) 826 827 for k, v := range param.Headers { 828 req.Header.Add(k, v) 829 } 830 831 if param.Ticket != nil { 832 req.AddCookie(param.Ticket) 833 } 834 835 res, err := c.Client.Do(req) 836 if err != nil { 837 return err 838 } 839 840 defer res.Body.Close() 841 842 switch res.StatusCode { 843 case http.StatusOK: 844 case http.StatusCreated: 845 default: 846 err = errors.New(res.Status) 847 } 848 849 return err 850 } 851 852 // UploadFile PUTs the local file to the given URL 853 func (c *Client) UploadFile(ctx context.Context, file string, u *url.URL, param *Upload) error { 854 if param == nil { 855 p := DefaultUpload // Copy since we set ContentLength 856 param = &p 857 } 858 859 s, err := os.Stat(file) 860 if err != nil { 861 return err 862 } 863 864 f, err := os.Open(filepath.Clean(file)) 865 if err != nil { 866 return err 867 } 868 defer f.Close() 869 870 param.ContentLength = s.Size() 871 872 return c.Upload(ctx, f, u, param) 873 } 874 875 type Download struct { 876 Method string 877 Headers map[string]string 878 Ticket *http.Cookie 879 Progress progress.Sinker 880 Writer io.Writer 881 Close bool 882 } 883 884 var DefaultDownload = Download{ 885 Method: "GET", 886 } 887 888 // DownloadRequest wraps http.Client.Do, returning the http.Response without checking its StatusCode 889 func (c *Client) DownloadRequest(ctx context.Context, u *url.URL, param *Download) (*http.Response, error) { 890 req, err := http.NewRequest(param.Method, u.String(), nil) 891 if err != nil { 892 return nil, err 893 } 894 895 req = req.WithContext(ctx) 896 req.Close = param.Close 897 898 for k, v := range param.Headers { 899 req.Header.Add(k, v) 900 } 901 902 if param.Ticket != nil { 903 req.AddCookie(param.Ticket) 904 } 905 906 return c.Client.Do(req) 907 } 908 909 // Download GETs the remote file from the given URL 910 func (c *Client) Download(ctx context.Context, u *url.URL, param *Download) (io.ReadCloser, int64, error) { 911 res, err := c.DownloadRequest(ctx, u, param) 912 if err != nil { 913 return nil, 0, err 914 } 915 916 switch res.StatusCode { 917 case http.StatusOK: 918 default: 919 err = fmt.Errorf("download(%s): %s", u, res.Status) 920 } 921 922 if err != nil { 923 return nil, 0, err 924 } 925 926 r := res.Body 927 928 return r, res.ContentLength, nil 929 } 930 931 func (c *Client) WriteFile(ctx context.Context, file string, src io.Reader, size int64, s progress.Sinker, w io.Writer) error { 932 var err error 933 934 r := src 935 936 fh, err := os.Create(file) 937 if err != nil { 938 return err 939 } 940 941 if s != nil { 942 pr := progress.NewReader(ctx, s, src, size) 943 r = pr 944 945 // Mark progress reader as done when returning from this function. 946 defer func() { 947 pr.Done(err) 948 }() 949 } 950 951 if w == nil { 952 w = fh 953 } else { 954 w = io.MultiWriter(w, fh) 955 } 956 957 _, err = io.Copy(w, r) 958 959 cerr := fh.Close() 960 961 if err == nil { 962 err = cerr 963 } 964 965 return err 966 } 967 968 // DownloadFile GETs the given URL to a local file 969 func (c *Client) DownloadFile(ctx context.Context, file string, u *url.URL, param *Download) error { 970 var err error 971 if param == nil { 972 param = &DefaultDownload 973 } 974 975 rc, contentLength, err := c.Download(ctx, u, param) 976 if err != nil { 977 return err 978 } 979 980 return c.WriteFile(ctx, file, rc, contentLength, param.Progress, param.Writer) 981 } 982 983 // execName gets the name of the executable for the current process 984 func execName() string { 985 name, err := os.Executable() 986 if err != nil { 987 return "N/A" 988 } 989 return strings.TrimSuffix(filepath.Base(name), ".exe") 990 }