github.com/mirantis/virtlet@v1.5.2-0.20191204181327-1659b8a48e9b/pkg/image/download.go (about)

     1  /*
     2  Copyright 2017-2018 Mirantis
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package image
    18  
    19  import (
    20  	"context"
    21  	"crypto"
    22  	"crypto/tls"
    23  	"crypto/x509"
    24  	"fmt"
    25  	"io"
    26  	"net"
    27  	"net/http"
    28  	"net/url"
    29  	"os"
    30  	"strings"
    31  	"time"
    32  
    33  	"github.com/golang/glog"
    34  )
    35  
    36  const (
    37  	copyBufferSize = 1024 * 1024
    38  )
    39  
    40  // Endpoint contains all the endpoint parameters needed to download a file
    41  type Endpoint struct {
    42  	// URL is the image URL. If protocol is omitted, the
    43  	// configured default one is used.
    44  	URL string
    45  
    46  	// MaxRedirects is the maximum number of redirects that downloader is allowed to follow. -1 for stdlib default (fails on request #10)
    47  	MaxRedirects int
    48  
    49  	// TLS is the TLS config
    50  	TLS *TLSConfig
    51  
    52  	// Timeout specifies a time limit for http(s) download request. <= 0 is no timeout (default)
    53  	Timeout time.Duration
    54  
    55  	// Proxy is the proxy server to use. Default = use proxy from HTTP_PROXY environment variable
    56  	Proxy string
    57  
    58  	// Transport profile name for this endpoint. Provided for logging/debugging
    59  	ProfileName string
    60  }
    61  
    62  // TLSConfig has the TLS transport parameters
    63  type TLSConfig struct {
    64  	// Certificates to use (both CA and for client authentication)
    65  	Certificates []TLSCertificate
    66  
    67  	// ServerName is needed when connecting to domain other that certificate was issued for
    68  	ServerName string
    69  
    70  	// Insecure skips certificate verification
    71  	Insecure bool
    72  }
    73  
    74  // TLSCertificate is a x509 certificate with optional private key
    75  type TLSCertificate struct {
    76  	// Certificate is the x509 certificate
    77  	Certificate *x509.Certificate
    78  
    79  	// PrivateKey is the private key needed for certificate-based client authentication
    80  	PrivateKey crypto.PrivateKey
    81  }
    82  
    83  // Downloader is an interface for downloading files from web
    84  type Downloader interface {
    85  	// DownloadFile downloads the specified file
    86  	DownloadFile(ctx context.Context, endpoint Endpoint, w io.Writer) error
    87  }
    88  
    89  type defaultDownloader struct {
    90  	protocol string
    91  }
    92  
    93  // NewDownloader returns the default downloader for 'protocol'.
    94  // The default downloader downloads a file via an URL constructed as
    95  // 'protocol://location' and saves it in temporary file in default
    96  // system directory for temporary files
    97  func NewDownloader(protocol string) Downloader {
    98  	return &defaultDownloader{protocol}
    99  }
   100  
   101  func buildTLSConfig(config *TLSConfig, profileName string) (*tls.Config, error) {
   102  	var certificates []tls.Certificate
   103  	roots, err := x509.SystemCertPool()
   104  	if err != nil {
   105  		roots = x509.NewCertPool()
   106  	}
   107  	for _, cert := range config.Certificates {
   108  		if cert.Certificate.IsCA {
   109  			roots.AddCert(cert.Certificate)
   110  		} else if cert.PrivateKey != nil {
   111  			certificates = append(certificates, tls.Certificate{
   112  				Certificate: [][]byte{cert.Certificate.Raw},
   113  				PrivateKey:  cert.PrivateKey,
   114  			})
   115  		} else {
   116  			glog.V(3).Infof("Skipping certificate %q because it is neither CA not has a private key", cert.Certificate.SerialNumber.String())
   117  		}
   118  	}
   119  
   120  	return &tls.Config{
   121  		Certificates:       certificates,
   122  		RootCAs:            roots,
   123  		InsecureSkipVerify: config.Insecure,
   124  		ServerName:         config.ServerName,
   125  	}, nil
   126  }
   127  
   128  func createTransport(endpoint Endpoint) (*http.Transport, error) {
   129  	var tlsConfig *tls.Config
   130  	var err error
   131  	if endpoint.TLS != nil {
   132  		tlsConfig, err = buildTLSConfig(endpoint.TLS, endpoint.ProfileName)
   133  		if err != nil {
   134  			return nil, err
   135  		}
   136  	}
   137  
   138  	proxyFunc := http.ProxyFromEnvironment
   139  	if endpoint.Proxy != "" {
   140  		proxyFunc = func(*http.Request) (*url.URL, error) {
   141  			return url.Parse(endpoint.Proxy)
   142  		}
   143  	}
   144  
   145  	return &http.Transport{
   146  		Proxy: proxyFunc,
   147  		DialContext: (&net.Dialer{
   148  			Timeout:   30 * time.Second,
   149  			KeepAlive: 30 * time.Second,
   150  			DualStack: true,
   151  		}).DialContext,
   152  		MaxIdleConns:          100,
   153  		IdleConnTimeout:       90 * time.Second,
   154  		TLSHandshakeTimeout:   10 * time.Second,
   155  		ExpectContinueTimeout: 1 * time.Second,
   156  		TLSClientConfig:       tlsConfig,
   157  	}, nil
   158  }
   159  
   160  func createHTTPClient(endpoint Endpoint) (*http.Client, error) {
   161  	transport, err := createTransport(endpoint)
   162  	if err != nil {
   163  		return nil, err
   164  	}
   165  
   166  	var checkRedirects func(req *http.Request, via []*http.Request) error
   167  	if endpoint.MaxRedirects >= 0 {
   168  		checkRedirects = func(req *http.Request, via []*http.Request) error {
   169  			if len(via) > endpoint.MaxRedirects {
   170  				return fmt.Errorf("stopped after %d redirects", endpoint.MaxRedirects)
   171  			}
   172  			return nil
   173  		}
   174  	}
   175  
   176  	return &http.Client{
   177  		Transport:     transport,
   178  		Timeout:       endpoint.Timeout,
   179  		CheckRedirect: checkRedirects,
   180  	}, nil
   181  }
   182  
   183  func (d *defaultDownloader) DownloadFile(ctx context.Context, endpoint Endpoint, w io.Writer) error {
   184  	url := endpoint.URL
   185  	if !strings.Contains(url, "://") {
   186  		url = fmt.Sprintf("%s://%s", d.protocol, url)
   187  	}
   188  
   189  	client, err := createHTTPClient(endpoint)
   190  	if err != nil {
   191  		return err
   192  	}
   193  
   194  	glog.V(2).Infof("Start downloading %s", url)
   195  
   196  	req, err := http.NewRequest("GET", url, nil)
   197  	if err != nil {
   198  		return err
   199  	}
   200  	req = req.WithContext(ctx)
   201  	resp, err := client.Do(req)
   202  	if err != nil {
   203  		return err
   204  	}
   205  	defer resp.Body.Close()
   206  
   207  	if resp.StatusCode != http.StatusOK {
   208  		return fmt.Errorf("bad http status %q", resp.Status)
   209  	}
   210  
   211  	if _, err = io.CopyBuffer(w, resp.Body, make([]byte, copyBufferSize)); err != nil {
   212  		return err
   213  	}
   214  
   215  	if f, ok := w.(*os.File); ok {
   216  		glog.V(2).Infof("Data from url %s saved as %q", url, f.Name())
   217  	}
   218  	return nil
   219  }
   220  
   221  // Note that the tests for defaultDownloader are in 'imagetranslation' package (FIXME)