github.com/coreos/rocket@v1.30.1-0.20200224141603-171c416fac02/rkt/image/resumablesession.go (about)

     1  // Copyright 2015 The rkt Authors
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package image
    16  
    17  import (
    18  	"crypto/tls"
    19  	"fmt"
    20  	"io"
    21  	"io/ioutil"
    22  	"net/http"
    23  	"net/url"
    24  	"os"
    25  	"strconv"
    26  	"strings"
    27  	"time"
    28  
    29  	"github.com/rkt/rkt/rkt/config"
    30  	"github.com/rkt/rkt/version"
    31  )
    32  
    33  // statusAcceptedError is an error returned when resumableSession
    34  // receives a 202 HTTP status. It is mostly used for deferring
    35  // signature downloads.
    36  type statusAcceptedError struct{}
    37  
    38  func (*statusAcceptedError) Error() string {
    39  	return "HTTP 202"
    40  }
    41  
    42  // cacheData holds caching-specific information taken from various
    43  // HTTP headers.
    44  type cacheData struct {
    45  	// whether we should reuse an image from store
    46  	UseCached bool
    47  	// image ETag, used for redownloading the obsolete images
    48  	ETag string
    49  	// MaxAge is a number of seconds telling when the downloaded
    50  	// image is obsolete
    51  	MaxAge int
    52  }
    53  
    54  // resumableSession is an implementation of the downloadSession
    55  // interface, it allows sending custom headers for authentication,
    56  // resuming interrupted downloads, handling cache data.
    57  type resumableSession struct {
    58  	// InsecureSkipTLSVerify tells whether TLS certificate
    59  	// validation should be skipped.
    60  	InsecureSkipTLSVerify bool
    61  	// Headers are HTTP headers to be added to the HTTP
    62  	// request. Used for ETAG.
    63  	Headers http.Header
    64  	// Headerers used for authentication.
    65  	Headerers map[string]config.Headerer
    66  	// File possibly holds the downloaded data - it is used for
    67  	// resuming interrupted downloads.
    68  	File *os.File
    69  	// ETagFilePath is a path to a file holding an ETag of a
    70  	// downloaded file. It is used for resuming interrupted
    71  	// downloads.
    72  	ETagFilePath string
    73  	// Label is used for printing the type of the downloaded data
    74  	// when printing a pretty progress bar.
    75  	Label string
    76  
    77  	// Cd is a cache data returned by HTTP server. It is an output
    78  	// value.
    79  	Cd *cacheData
    80  
    81  	u                  *url.URL
    82  	client             *http.Client
    83  	amountAlreadyHere  int64
    84  	byteRangeSupported bool
    85  }
    86  
    87  func (s *resumableSession) Client() (*http.Client, error) {
    88  	s.ensureClient()
    89  	return s.client, nil
    90  }
    91  
    92  func (s *resumableSession) Request(u *url.URL) (*http.Request, error) {
    93  	s.u = u
    94  	if err := s.maybeSetupDownloadResume(u); err != nil {
    95  		return nil, err
    96  	}
    97  	return s.getRequest(u), nil
    98  }
    99  
   100  func (s *resumableSession) HandleStatus(res *http.Response) (bool, error) {
   101  	switch res.StatusCode {
   102  	case http.StatusOK, http.StatusPartialContent:
   103  		fallthrough
   104  	case http.StatusNotModified:
   105  		s.Cd = &cacheData{
   106  			ETag:      res.Header.Get("ETag"),
   107  			MaxAge:    s.getMaxAge(res.Header.Get("Cache-Control")),
   108  			UseCached: res.StatusCode == http.StatusNotModified,
   109  		}
   110  		return s.Cd.UseCached, nil
   111  	case http.StatusAccepted:
   112  		// If the server returns Status Accepted (HTTP 202), we should retry
   113  		// downloading the signature later.
   114  		return false, &statusAcceptedError{}
   115  	case http.StatusRequestedRangeNotSatisfiable:
   116  		return s.handleRangeNotSatisfiable()
   117  	default:
   118  		return false, fmt.Errorf("bad HTTP status code: %d", res.StatusCode)
   119  	}
   120  }
   121  
   122  func (s *resumableSession) BodyReader(res *http.Response) (io.Reader, error) {
   123  	reader := getIoProgressReader(s.Label, res)
   124  	return reader, nil
   125  }
   126  
   127  type rangeStatus int
   128  
   129  const (
   130  	rangeSupported rangeStatus = iota
   131  	rangeInvalid
   132  	rangeUnsupported
   133  )
   134  
   135  func (s *resumableSession) maybeSetupDownloadResume(u *url.URL) error {
   136  	fi, err := s.File.Stat()
   137  	if err != nil {
   138  		return err
   139  	}
   140  
   141  	size := fi.Size()
   142  	if size < 1 {
   143  		return nil
   144  	}
   145  
   146  	s.ensureClient()
   147  	headReq := s.headRequest(u)
   148  	res, err := s.client.Do(headReq)
   149  	if err != nil {
   150  		return err
   151  	}
   152  	if res.StatusCode != http.StatusOK {
   153  		log.Printf("bad HTTP status code from HEAD request: %d", res.StatusCode)
   154  		if err := s.reset(); err != nil {
   155  			return err
   156  		}
   157  		return nil
   158  	}
   159  	status := s.verifyAcceptRange(res, fi.ModTime())
   160  	if status == rangeSupported {
   161  		s.byteRangeSupported = true
   162  		s.amountAlreadyHere = size
   163  	} else {
   164  		if status == rangeInvalid {
   165  			log.Printf("cannot use cached partial download, resource updated.")
   166  		} else {
   167  			log.Printf("cannot use cached partial download, range request unsupported.")
   168  		}
   169  		if err := s.reset(); err != nil {
   170  			return err
   171  		}
   172  	}
   173  	return nil
   174  }
   175  
   176  func (s *resumableSession) ensureClient() {
   177  	if s.client == nil {
   178  		s.client = s.getClient()
   179  	}
   180  }
   181  
   182  func (s *resumableSession) getRequest(u *url.URL) *http.Request {
   183  	return s.httpRequest("GET", u)
   184  }
   185  
   186  func (s *resumableSession) getMaxAge(headerValue string) int {
   187  	if headerValue == "" {
   188  		return 0
   189  	}
   190  
   191  	maxAge := 0
   192  	parts := strings.Split(headerValue, ",")
   193  
   194  	for i := 0; i < len(parts); i++ {
   195  		parts[i] = strings.TrimSpace(parts[i])
   196  		attr, val := parts[i], ""
   197  		if j := strings.Index(attr, "="); j >= 0 {
   198  			attr, val = attr[:j], attr[j+1:]
   199  		}
   200  		lowerAttr := strings.ToLower(attr)
   201  
   202  		switch lowerAttr {
   203  		case "no-store", "no-cache":
   204  			maxAge = 0
   205  			// TODO(krnowak): Just break out of the loop
   206  			// at this point.
   207  		case "max-age":
   208  			secs, err := strconv.Atoi(val)
   209  			if err != nil || secs != 0 && val[0] == '0' {
   210  				// TODO(krnowak): Set maxAge to zero.
   211  				break
   212  			}
   213  			if secs <= 0 {
   214  				maxAge = 0
   215  			} else {
   216  				maxAge = secs
   217  			}
   218  		}
   219  	}
   220  	return maxAge
   221  }
   222  
   223  func (s *resumableSession) handleRangeNotSatisfiable() (bool, error) {
   224  	if fi, err := s.File.Stat(); err != nil {
   225  		return false, err
   226  	} else if fi.Size() > 0 {
   227  		if err := s.reset(); err != nil {
   228  			return false, err
   229  		}
   230  		dl := &downloader{
   231  			Session: s,
   232  		}
   233  		if err := dl.Download(s.u, s.File); err != nil {
   234  			return false, err
   235  		}
   236  		return true, nil
   237  	}
   238  	code := http.StatusRequestedRangeNotSatisfiable
   239  	return false, fmt.Errorf("bad HTTP status code: %d", code)
   240  }
   241  
   242  func (s *resumableSession) headRequest(u *url.URL) *http.Request {
   243  	return s.httpRequest("HEAD", u)
   244  }
   245  
   246  func (s *resumableSession) verifyAcceptRange(res *http.Response, mod time.Time) rangeStatus {
   247  	acceptRanges, hasRange := res.Header["Accept-Ranges"]
   248  	if !hasRange {
   249  		return rangeUnsupported
   250  	}
   251  	if !s.modificationTimeOK(res, mod) && !s.eTagOK(res) {
   252  		return rangeInvalid
   253  	}
   254  	for _, rng := range acceptRanges {
   255  		if rng == "bytes" {
   256  			return rangeSupported
   257  		}
   258  	}
   259  	return rangeInvalid
   260  }
   261  
   262  func (s *resumableSession) reset() error {
   263  	s.amountAlreadyHere = 0
   264  	s.byteRangeSupported = false
   265  	if _, err := s.File.Seek(0, 0); err != nil {
   266  		return err
   267  	}
   268  	if err := s.File.Truncate(0); err != nil {
   269  		return err
   270  	}
   271  	if err := os.Remove(s.ETagFilePath); err != nil && !os.IsNotExist(err) {
   272  		return err
   273  	}
   274  	return nil
   275  }
   276  
   277  func (s *resumableSession) getClient() *http.Client {
   278  	transport := http.DefaultTransport
   279  	if s.InsecureSkipTLSVerify {
   280  		transport = &http.Transport{
   281  			Proxy:           http.ProxyFromEnvironment,
   282  			TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
   283  		}
   284  	}
   285  
   286  	return &http.Client{
   287  		Transport: transport,
   288  		CheckRedirect: func(req *http.Request, via []*http.Request) error {
   289  			if len(via) >= 10 {
   290  				return fmt.Errorf("too many redirects")
   291  			}
   292  			stripAuth := false
   293  			// don't propagate "Authorization" if the redirect is to a
   294  			// different host
   295  			previousHost := via[len(via)-1].URL.Host
   296  			if previousHost != req.URL.Host {
   297  				stripAuth = true
   298  			}
   299  			s.setHTTPHeaders(req, stripAuth)
   300  			return nil
   301  		},
   302  	}
   303  }
   304  
   305  func (s *resumableSession) httpRequest(method string, u *url.URL) *http.Request {
   306  	req := &http.Request{
   307  		Method:     method,
   308  		URL:        u,
   309  		Proto:      "HTTP/1.1",
   310  		ProtoMajor: 1,
   311  		ProtoMinor: 1,
   312  		Header:     make(http.Header),
   313  		Host:       u.Host,
   314  	}
   315  
   316  	s.setHTTPHeaders(req, false)
   317  
   318  	// Send credentials only over secure channel
   319  	// TODO(krnowak): This could be controlled with another
   320  	// insecure flag.
   321  	if req.URL.Scheme != "https" {
   322  		return req
   323  	}
   324  
   325  	if hostOpts, ok := s.Headerers[req.URL.Host]; ok {
   326  		req = hostOpts.SignRequest(req)
   327  		if req == nil {
   328  			panic("Req is nil!")
   329  		}
   330  	}
   331  
   332  	return req
   333  }
   334  
   335  func (s *resumableSession) modificationTimeOK(res *http.Response, mod time.Time) bool {
   336  	lastModified := res.Header.Get("Last-Modified")
   337  	if lastModified != "" {
   338  		layout := "Mon, 02 Jan 2006 15:04:05 MST"
   339  		t, err := time.Parse(layout, lastModified)
   340  		if err == nil && t.Before(mod) {
   341  			return true
   342  		}
   343  	}
   344  	return false
   345  }
   346  
   347  func (s *resumableSession) eTagOK(res *http.Response) bool {
   348  	etag := res.Header.Get("ETag")
   349  	if etag != "" {
   350  		savedEtag, err := ioutil.ReadFile(s.ETagFilePath)
   351  		if err == nil && string(savedEtag) == etag {
   352  			return true
   353  		}
   354  	}
   355  	return false
   356  }
   357  
   358  func (s *resumableSession) setHTTPHeaders(req *http.Request, stripAuth bool) {
   359  	for k, v := range s.Headers {
   360  		if stripAuth && k == "Authorization" {
   361  			continue
   362  		}
   363  		for _, e := range v {
   364  			req.Header.Set(k, e)
   365  		}
   366  	}
   367  	req.Header.Add("User-Agent", fmt.Sprintf("rkt/%s", version.Version))
   368  	if s.amountAlreadyHere > 0 && s.byteRangeSupported {
   369  		req.Header.Add("Range", fmt.Sprintf("bytes=%d-", s.amountAlreadyHere))
   370  	}
   371  }