github.com/niedbalski/juju@v0.0.0-20190215020005-8ff100488e47/downloader/download.go (about)

     1  // Copyright 2012, 2013 Canonical Ltd.
     2  // Licensed under the AGPLv3, see LICENCE file for details.
     3  
     4  package downloader
     5  
     6  import (
     7  	"io"
     8  	"io/ioutil"
     9  	"net/url"
    10  	"os"
    11  
    12  	"github.com/juju/errors"
    13  	"github.com/juju/utils"
    14  )
    15  
    16  // Request holds a single download request.
    17  type Request struct {
    18  	// URL is the location from which the file will be downloaded.
    19  	URL *url.URL
    20  
    21  	// TargetDir is the directory into which the file will be downloaded.
    22  	// It defaults to os.TempDir().
    23  	TargetDir string
    24  
    25  	// Verify is used to ensure that the download result is correct. If
    26  	// the download is invalid then the func must return errors.NotValid.
    27  	// If no func is provided then no verification happens.
    28  	Verify func(*os.File) error
    29  
    30  	// Abort is a channel that will cancel the download when it is closed.
    31  	Abort <-chan struct{}
    32  }
    33  
    34  // Status represents the status of a completed download.
    35  type Status struct {
    36  	// Filename is the name of the file which holds the downloaded
    37  	// data on success.
    38  	Filename string
    39  
    40  	// Err describes any error encountered while downloading.
    41  	Err error
    42  }
    43  
    44  // StartDownload starts a new download as specified by `req` using
    45  // `openBlob` to actually pull the remote data.
    46  func StartDownload(req Request, openBlob func(*url.URL) (io.ReadCloser, error)) *Download {
    47  	if openBlob == nil {
    48  		openBlob = NewHTTPBlobOpener(utils.NoVerifySSLHostnames)
    49  	}
    50  	dl := &Download{
    51  		done:     make(chan Status, 1),
    52  		openBlob: openBlob,
    53  	}
    54  	go dl.run(req)
    55  	return dl
    56  }
    57  
    58  // Download can download a file from the network.
    59  type Download struct {
    60  	done     chan Status
    61  	openBlob func(*url.URL) (io.ReadCloser, error)
    62  }
    63  
    64  // Done returns a channel that receives a status when the download has
    65  // completed or is aborted. Exactly one Status value will be sent for
    66  // each download once it finishes (successfully or otherwise) or is
    67  // aborted.
    68  //
    69  // It is the receiver's responsibility to handle and remove the
    70  // downloaded file.
    71  func (dl *Download) Done() <-chan Status {
    72  	return dl.done
    73  }
    74  
    75  // Wait blocks until the download finishes (successfully or
    76  // otherwise), or the download is aborted. There will only be a
    77  // filename if err is nil.
    78  func (dl *Download) Wait() (string, error) {
    79  	// No select required here because each download will always
    80  	// return a value once it completes. Downloads can be aborted via
    81  	// the Abort channel provided a creation time.
    82  	status := <-dl.Done()
    83  	return status.Filename, errors.Trace(status.Err)
    84  }
    85  
    86  func (dl *Download) run(req Request) {
    87  	// TODO(dimitern) 2013-10-03 bug #1234715
    88  	// Add a testing HTTPS storage to verify the
    89  	// disableSSLHostnameVerification behavior here.
    90  	filename, err := dl.download(req)
    91  	if err != nil {
    92  		err = errors.Trace(err)
    93  	} else {
    94  		logger.Infof("download complete (%q)", req.URL)
    95  		err = verifyDownload(filename, req)
    96  		if err != nil {
    97  			os.Remove(filename)
    98  			filename = ""
    99  		}
   100  	}
   101  
   102  	// No select needed here because the channel has a size of 1 and
   103  	// will only be written to once.
   104  	dl.done <- Status{
   105  		Filename: filename,
   106  		Err:      err,
   107  	}
   108  }
   109  
   110  func (dl *Download) download(req Request) (filename string, err error) {
   111  	logger.Infof("downloading from %s", req.URL)
   112  
   113  	dir := req.TargetDir
   114  	if dir == "" {
   115  		dir = os.TempDir()
   116  	}
   117  	tempFile, err := ioutil.TempFile(dir, "inprogress-")
   118  	if err != nil {
   119  		return "", errors.Trace(err)
   120  	}
   121  	defer func() {
   122  		tempFile.Close()
   123  		if err != nil {
   124  			os.Remove(tempFile.Name())
   125  		}
   126  	}()
   127  
   128  	blobReader, err := dl.openBlob(req.URL)
   129  	if err != nil {
   130  		return "", errors.Trace(err)
   131  	}
   132  	defer blobReader.Close()
   133  
   134  	reader := &abortableReader{blobReader, req.Abort}
   135  	_, err = io.Copy(tempFile, reader)
   136  	if err != nil {
   137  		return "", errors.Trace(err)
   138  	}
   139  
   140  	return tempFile.Name(), nil
   141  }
   142  
   143  // abortableReader wraps a Reader, returning an error from Read calls
   144  // if the abort channel provided is closed.
   145  type abortableReader struct {
   146  	r     io.Reader
   147  	abort <-chan struct{}
   148  }
   149  
   150  // Read implements io.Reader.
   151  func (ar *abortableReader) Read(p []byte) (int, error) {
   152  	select {
   153  	case <-ar.abort:
   154  		return 0, errors.New("download aborted")
   155  	default:
   156  	}
   157  	return ar.r.Read(p)
   158  }
   159  
   160  func verifyDownload(filename string, req Request) error {
   161  	if req.Verify == nil {
   162  		return nil
   163  	}
   164  
   165  	file, err := os.Open(filename)
   166  	if err != nil {
   167  		return errors.Annotate(err, "opening for verify")
   168  	}
   169  	defer file.Close()
   170  
   171  	if err := req.Verify(file); err != nil {
   172  		return errors.Trace(err)
   173  	}
   174  	logger.Infof("download verified (%q)", req.URL)
   175  	return nil
   176  }