github.com/juju/juju@v0.0.0-20240327075706-a90865de2538/container/kvm/sync.go (about)

     1  // Copyright 2016 Canonical Ltd.
     2  // Licensed under the AGPLv3, see LICENCE file for details.
     3  
     4  package kvm
     5  
     6  import (
     7  	"crypto/sha256"
     8  	"fmt"
     9  	"io"
    10  	"net/http"
    11  	"os"
    12  	"path"
    13  	"path/filepath"
    14  	"time"
    15  
    16  	"github.com/dustin/go-humanize"
    17  	"github.com/juju/clock"
    18  	"github.com/juju/errors"
    19  
    20  	"github.com/juju/juju/core/paths"
    21  	"github.com/juju/juju/environs/imagedownloads"
    22  	"github.com/juju/juju/environs/imagemetadata"
    23  	"github.com/juju/juju/environs/simplestreams"
    24  )
    25  
    26  // DiskImageType is the file type we want to fetch and use for kvm instances.
    27  const DiskImageType = "disk1.img"
    28  
    29  // Oner gets the one matching item from simplestreams.
    30  type Oner interface {
    31  	One() (*imagedownloads.Metadata, error)
    32  }
    33  
    34  // syncParams conveys the information necessary for calling imagedownloads.One.
    35  type syncParams struct {
    36  	fetcher                      imagemetadata.SimplestreamsFetcher
    37  	arch, version, stream, fType string
    38  	srcFunc                      func() simplestreams.DataSource
    39  }
    40  
    41  // One implements Oner.
    42  func (p syncParams) One() (*imagedownloads.Metadata, error) {
    43  	if err := p.exists(); err != nil {
    44  		return nil, errors.Trace(err)
    45  	}
    46  	return imagedownloads.One(p.fetcher, p.arch, p.version, p.stream, p.fType, p.srcFunc)
    47  }
    48  
    49  func (p syncParams) exists() error {
    50  	fName := backingFileName(p.version, p.arch)
    51  	baseDir := paths.DataDir(paths.CurrentOS())
    52  	imagePath := filepath.Join(baseDir, kvm, guestDir, fName)
    53  
    54  	if _, err := os.Stat(imagePath); err == nil {
    55  		return errors.AlreadyExistsf("%q %q image for exists at %q", p.version, p.arch, imagePath)
    56  	}
    57  	return nil
    58  }
    59  
    60  // Validate that our types fulfill their implementations.
    61  var _ Oner = (*syncParams)(nil)
    62  var _ Fetcher = (*fetcher)(nil)
    63  
    64  // Fetcher is an interface to permit faking input in tests. The default
    65  // implementation is updater, defined in this file.
    66  type Fetcher interface {
    67  	Fetch() error
    68  	Close()
    69  }
    70  
    71  type fetcher struct {
    72  	metadata         *imagedownloads.Metadata
    73  	req              *http.Request
    74  	client           *http.Client
    75  	image            *Image
    76  	imageDownloadURL string
    77  }
    78  
    79  // Fetch implements Fetcher. It fetches the image file from simplestreams and
    80  // delegates writing it out and creating the qcow3 backing file to Image.write.
    81  func (f *fetcher) Fetch() error {
    82  	resp, err := f.client.Do(f.req)
    83  	if err != nil {
    84  		return errors.Trace(err)
    85  	}
    86  
    87  	defer func() {
    88  		err = resp.Body.Close()
    89  		if err != nil {
    90  			logger.Debugf("failed defer %q", errors.Trace(err))
    91  		}
    92  	}()
    93  
    94  	if resp.StatusCode != 200 {
    95  		f.image.cleanup()
    96  		return errors.NotFoundf(
    97  			"got %d fetching image %q", resp.StatusCode, path.Base(
    98  				f.req.URL.String()))
    99  	}
   100  	err = f.image.write(resp.Body, f.metadata, f.imageDownloadURL)
   101  	if err != nil {
   102  		return errors.Trace(err)
   103  	}
   104  	return nil
   105  }
   106  
   107  // Close calls images cleanup method for deferred closing of the image tmpFile.
   108  func (f *fetcher) Close() {
   109  	f.image.cleanup()
   110  }
   111  
   112  type ProgressCallback func(message string)
   113  
   114  // Sync updates the local cached images by reading the simplestreams data and
   115  // caching if an image matching the constraints doesn't exist. It retrieves
   116  // metadata information from Oner and updates local cache via Fetcher.
   117  // A ProgressCallback can optionally be passed which will get update messages
   118  // as data is copied.
   119  func Sync(o Oner, f Fetcher, imageDownloadURL string, progress ProgressCallback) error {
   120  	md, err := o.One()
   121  	if err != nil {
   122  		if errors.IsAlreadyExists(err) {
   123  			// We've already got a backing file for this series/architecture.
   124  			return nil
   125  		}
   126  		return errors.Trace(err)
   127  	}
   128  	if f == nil {
   129  		f, err = newDefaultFetcher(md, imageDownloadURL, paths.DataDir, progress)
   130  		if err != nil {
   131  			return errors.Trace(err)
   132  		}
   133  		defer f.Close()
   134  	}
   135  	err = f.Fetch()
   136  	if err != nil {
   137  		return errors.Trace(err)
   138  	}
   139  	return nil
   140  }
   141  
   142  // Image represents a server image.
   143  type Image struct {
   144  	FilePath string
   145  	progress ProgressCallback
   146  	tmpFile  *os.File
   147  	runCmd   runFunc
   148  }
   149  
   150  type progressWriter struct {
   151  	callback    ProgressCallback
   152  	url         string
   153  	total       uint64
   154  	maxBytes    uint64
   155  	startTime   *time.Time
   156  	lastPercent int
   157  	clock       clock.Clock
   158  }
   159  
   160  var _ io.Writer = (*progressWriter)(nil)
   161  
   162  func (p *progressWriter) Write(content []byte) (n int, err error) {
   163  	if p.clock == nil {
   164  		p.clock = clock.WallClock
   165  	}
   166  	p.total += uint64(len(content))
   167  	if p.startTime == nil {
   168  		now := p.clock.Now()
   169  		p.startTime = &now
   170  		return len(content), nil
   171  	}
   172  	if p.callback != nil {
   173  		elapsed := p.clock.Now().Sub(*p.startTime)
   174  		// Avoid measurements that aren't interesting
   175  		if elapsed > time.Millisecond {
   176  			percent := (float64(p.total) * 100.0) / float64(p.maxBytes)
   177  			intPercent := int(percent + 0.5)
   178  			if p.lastPercent != intPercent {
   179  				bps := uint64((float64(p.total) / elapsed.Seconds()) + 0.5)
   180  				p.callback(fmt.Sprintf("copying %s %d%% (%s/s)", p.url, intPercent, humanize.Bytes(bps)))
   181  				p.lastPercent = intPercent
   182  			}
   183  		}
   184  	}
   185  	return len(content), nil
   186  }
   187  
   188  // write saves the stream to disk and updates the metadata file.
   189  func (i *Image) write(r io.Reader, md *imagedownloads.Metadata, imageDownloadURL string) error {
   190  	tmpPath := i.tmpFile.Name()
   191  	defer func() {
   192  		err := i.tmpFile.Close()
   193  		if err != nil {
   194  			logger.Errorf("failed to close %q %s", tmpPath, err)
   195  		}
   196  		err = os.Remove(tmpPath)
   197  		if err != nil {
   198  			logger.Errorf("failed to remove %q after use %s", tmpPath, err)
   199  		}
   200  
   201  	}()
   202  
   203  	hash := sha256.New()
   204  	var writer io.Writer
   205  	if i.progress == nil {
   206  		writer = io.MultiWriter(i.tmpFile, hash)
   207  	} else {
   208  		dlURL, _ := md.DownloadURL(imageDownloadURL)
   209  		progWriter := &progressWriter{
   210  			url:      dlURL.String(),
   211  			callback: i.progress,
   212  			maxBytes: uint64(md.Size),
   213  			total:    0,
   214  		}
   215  		writer = io.MultiWriter(i.tmpFile, hash, progWriter)
   216  	}
   217  	_, err := io.Copy(writer, r)
   218  	if err != nil {
   219  		i.cleanup()
   220  		return errors.Trace(err)
   221  	}
   222  
   223  	result := fmt.Sprintf("%x", hash.Sum(nil))
   224  	if result != md.SHA256 {
   225  		i.cleanup()
   226  		return errors.Errorf(
   227  			"hash sum mismatch for %s: %s != %s", i.tmpFile.Name(), result, md.SHA256)
   228  	}
   229  
   230  	// TODO(jam): 2017-03-19 If this is slow, maybe we want to add a progress step for it, rather than only
   231  	// indicating download progress.
   232  	output, err := i.runCmd(
   233  		"", "qemu-img", "convert", "-f", "qcow2", tmpPath, i.FilePath)
   234  	logger.Debugf("qemu-image convert output: %s", output)
   235  	if err != nil {
   236  		i.cleanupAll()
   237  		return errors.Trace(err)
   238  	}
   239  	return nil
   240  }
   241  
   242  // cleanup attempts to close and remove the tempfile download image. It can be
   243  // called if things don't work out. E.g. sha256 mismatch, incorrect size...
   244  func (i *Image) cleanup() {
   245  	if err := i.tmpFile.Close(); err != nil {
   246  		logger.Debugf("%s", err.Error())
   247  	}
   248  
   249  	if err := os.Remove(i.tmpFile.Name()); err != nil {
   250  		logger.Debugf("got %q removing %q", err.Error(), i.tmpFile.Name())
   251  	}
   252  }
   253  
   254  // cleanupAll cleans up the possible backing file as well.
   255  func (i *Image) cleanupAll() {
   256  	i.cleanup()
   257  	err := os.Remove(i.FilePath)
   258  	if err != nil {
   259  		logger.Debugf("got %q removing %q", err.Error(), i.FilePath)
   260  	}
   261  }
   262  
   263  func newDefaultFetcher(md *imagedownloads.Metadata, imageDownloadURL string, pathfinder pathfinderFunc, callback ProgressCallback) (*fetcher, error) {
   264  	i, err := newImage(md, imageDownloadURL, pathfinder, callback)
   265  	if err != nil {
   266  		return nil, errors.Trace(err)
   267  	}
   268  	dlURL, err := md.DownloadURL(imageDownloadURL)
   269  	if err != nil {
   270  		return nil, errors.Trace(err)
   271  	}
   272  	req, err := http.NewRequest("GET", dlURL.String(), nil)
   273  	if err != nil {
   274  		return nil, errors.Trace(err)
   275  	}
   276  	client := &http.Client{}
   277  	return &fetcher{metadata: md, image: i, client: client, req: req, imageDownloadURL: imageDownloadURL}, nil
   278  }
   279  
   280  func newImage(md *imagedownloads.Metadata, imageDownloadURL string, pathfinder pathfinderFunc, callback ProgressCallback) (*Image, error) {
   281  	// Setup names and paths.
   282  	dlURL, err := md.DownloadURL(imageDownloadURL)
   283  	if err != nil {
   284  		return nil, errors.Trace(err)
   285  	}
   286  	baseDir := pathfinder(paths.CurrentOS())
   287  
   288  	// Closing this is deferred in Image.write.
   289  	fh, err := os.CreateTemp("", fmt.Sprintf("juju-kvm-%s-", path.Base(dlURL.String())))
   290  	if err != nil {
   291  		return nil, errors.Trace(err)
   292  	}
   293  
   294  	return &Image{
   295  		FilePath: filepath.Join(
   296  			baseDir, kvm, guestDir, backingFileName(md.Version, md.Arch)),
   297  		tmpFile:  fh,
   298  		runCmd:   run,
   299  		progress: callback,
   300  	}, nil
   301  }
   302  
   303  func backingFileName(version, arch string) string {
   304  	return fmt.Sprintf("%s-%s-backing-file.qcow", version, arch)
   305  }