github.com/blixtra/rkt@v0.8.1-0.20160204105720-ab0d1add1a43/rkt/image/resumablesession.go (about)

     1  // Copyright 2015 The rkt Authors
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package image
    16  
    17  import (
    18  	"crypto/tls"
    19  	"fmt"
    20  	"io"
    21  	"io/ioutil"
    22  	"net/http"
    23  	"net/url"
    24  	"os"
    25  	"strconv"
    26  	"strings"
    27  	"time"
    28  
    29  	"github.com/coreos/rkt/version"
    30  )
    31  
    32  // statusAcceptedError is an error returned when resumableSession
    33  // receives a 202 HTTP status. It is mostly used for deferring
    34  // signature downloads.
    35  type statusAcceptedError struct{}
    36  
    37  func (*statusAcceptedError) Error() string {
    38  	return "HTTP 202"
    39  }
    40  
    41  // cacheData holds caching-specific informations taken from various
    42  // HTTP headers.
    43  type cacheData struct {
    44  	// whether we should reuse an image from store
    45  	UseCached bool
    46  	// image ETag, used for redownloading the obsolete images
    47  	ETag string
    48  	// MaxAge is a number of seconds telling when the downloaded
    49  	// image is obsolete
    50  	MaxAge int
    51  }
    52  
    53  // resumableSession is an implementation of the downloadSession
    54  // interface, it allows sending custom headers for authentication,
    55  // resuming interrupted downloads, handling cache data.
    56  type resumableSession struct {
    57  	// InsecureSkipTLSVerify tells whether TLS certificate
    58  	// validation should be skipped.
    59  	InsecureSkipTLSVerify bool
    60  	// Headers are HTTP headers to be added to the HTTP
    61  	// request. Used for authentication.
    62  	Headers http.Header
    63  	// File possibly holds the downloaded data - it is used for
    64  	// resuming interrupted downloads.
    65  	File *os.File
    66  	// ETagFilePath is a path to a file holding an ETag of a
    67  	// downloaded file. It is used for resuming interrupted
    68  	// downloads.
    69  	ETagFilePath string
    70  	// Label is used for printing the type of the downloaded data
    71  	// when printing a pretty progress bar.
    72  	Label string
    73  
    74  	// Cd is a cache data returned by HTTP server. It is an output
    75  	// value.
    76  	Cd *cacheData
    77  
    78  	u                  *url.URL
    79  	client             *http.Client
    80  	amountAlreadyHere  int64
    81  	byteRangeSupported bool
    82  }
    83  
    84  func (s *resumableSession) GetClient() (*http.Client, error) {
    85  	s.ensureClient()
    86  	return s.client, nil
    87  }
    88  
    89  func (s *resumableSession) GetRequest(u *url.URL) (*http.Request, error) {
    90  	s.u = u
    91  	if err := s.maybeSetupDownloadResume(u); err != nil {
    92  		return nil, err
    93  	}
    94  	return s.getRequest(u), nil
    95  }
    96  
    97  func (s *resumableSession) HandleStatus(res *http.Response) (bool, error) {
    98  	switch res.StatusCode {
    99  	case http.StatusOK, http.StatusPartialContent:
   100  		fallthrough
   101  	case http.StatusNotModified:
   102  		s.Cd = &cacheData{
   103  			ETag:      res.Header.Get("ETag"),
   104  			MaxAge:    s.getMaxAge(res.Header.Get("Cache-Control")),
   105  			UseCached: res.StatusCode == http.StatusNotModified,
   106  		}
   107  		return s.Cd.UseCached, nil
   108  	case http.StatusAccepted:
   109  		// If the server returns Status Accepted (HTTP 202), we should retry
   110  		// downloading the signature later.
   111  		return false, &statusAcceptedError{}
   112  	case http.StatusRequestedRangeNotSatisfiable:
   113  		return s.handleRangeNotSatisfiable()
   114  	default:
   115  		return false, fmt.Errorf("bad HTTP status code: %d", res.StatusCode)
   116  	}
   117  }
   118  
   119  func (s *resumableSession) GetBodyReader(res *http.Response) (io.Reader, error) {
   120  	reader := getIoProgressReader(s.Label, res)
   121  	return reader, nil
   122  }
   123  
   124  type rangeStatus int
   125  
   126  const (
   127  	rangeSupported rangeStatus = iota
   128  	rangeInvalid
   129  	rangeUnsupported
   130  )
   131  
   132  func (s *resumableSession) maybeSetupDownloadResume(u *url.URL) error {
   133  	fi, err := s.File.Stat()
   134  	if err != nil {
   135  		return err
   136  	}
   137  
   138  	size := fi.Size()
   139  	if size < 1 {
   140  		return nil
   141  	}
   142  
   143  	s.ensureClient()
   144  	headReq := s.headRequest(u)
   145  	res, err := s.client.Do(headReq)
   146  	if err != nil {
   147  		return err
   148  	}
   149  	if res.StatusCode != http.StatusOK {
   150  		log.Printf("bad HTTP status code from HEAD request: %d", res.StatusCode)
   151  		return nil
   152  	}
   153  	status := s.verifyAcceptRange(res, fi.ModTime())
   154  	if status == rangeSupported {
   155  		s.byteRangeSupported = true
   156  		s.amountAlreadyHere = size
   157  	} else {
   158  		if status == rangeInvalid {
   159  			log.Printf("cannot use cached partial download, resource updated.")
   160  		} else {
   161  			log.Printf("cannot use cached partial download, range request unsupported.")
   162  		}
   163  		if err := s.reset(); err != nil {
   164  			return err
   165  		}
   166  	}
   167  	return nil
   168  }
   169  
   170  func (s *resumableSession) ensureClient() {
   171  	if s.client == nil {
   172  		s.client = s.getClient()
   173  	}
   174  }
   175  
   176  func (s *resumableSession) getRequest(u *url.URL) *http.Request {
   177  	return s.httpRequest("GET", u)
   178  }
   179  
   180  func (s *resumableSession) getMaxAge(headerValue string) int {
   181  	if headerValue == "" {
   182  		return 0
   183  	}
   184  
   185  	maxAge := 0
   186  	parts := strings.Split(headerValue, ",")
   187  
   188  	for i := 0; i < len(parts); i++ {
   189  		parts[i] = strings.TrimSpace(parts[i])
   190  		attr, val := parts[i], ""
   191  		if j := strings.Index(attr, "="); j >= 0 {
   192  			attr, val = attr[:j], attr[j+1:]
   193  		}
   194  		lowerAttr := strings.ToLower(attr)
   195  
   196  		switch lowerAttr {
   197  		case "no-store", "no-cache":
   198  			maxAge = 0
   199  			// TODO(krnowak): Just break out of the loop
   200  			// at this point.
   201  		case "max-age":
   202  			secs, err := strconv.Atoi(val)
   203  			if err != nil || secs != 0 && val[0] == '0' {
   204  				// TODO(krnowak): Set maxAge to zero.
   205  				break
   206  			}
   207  			if secs <= 0 {
   208  				maxAge = 0
   209  			} else {
   210  				maxAge = secs
   211  			}
   212  		}
   213  	}
   214  	return maxAge
   215  }
   216  
   217  func (s *resumableSession) handleRangeNotSatisfiable() (bool, error) {
   218  	if fi, err := s.File.Stat(); err != nil {
   219  		return false, err
   220  	} else if fi.Size() > 0 {
   221  		if err := s.reset(); err != nil {
   222  			return false, err
   223  		}
   224  		dl := &downloader{
   225  			Session: s,
   226  		}
   227  		if err := dl.Download(s.u, s.File); err != nil {
   228  			return false, err
   229  		}
   230  		return true, nil
   231  	}
   232  	code := http.StatusRequestedRangeNotSatisfiable
   233  	return false, fmt.Errorf("bad HTTP status code: %d", code)
   234  }
   235  
   236  func (s *resumableSession) headRequest(u *url.URL) *http.Request {
   237  	return s.httpRequest("HEAD", u)
   238  }
   239  
   240  func (s *resumableSession) verifyAcceptRange(res *http.Response, mod time.Time) rangeStatus {
   241  	acceptRanges, hasRange := res.Header["Accept-Ranges"]
   242  	if !hasRange {
   243  		return rangeUnsupported
   244  	}
   245  	if !s.modificationTimeOK(res, mod) && !s.eTagOK(res) {
   246  		return rangeInvalid
   247  	}
   248  	for _, rng := range acceptRanges {
   249  		if rng == "bytes" {
   250  			return rangeSupported
   251  		}
   252  	}
   253  	return rangeInvalid
   254  }
   255  
   256  func (s *resumableSession) reset() error {
   257  	s.amountAlreadyHere = 0
   258  	s.byteRangeSupported = false
   259  	if _, err := s.File.Seek(0, 0); err != nil {
   260  		return err
   261  	}
   262  	if err := s.File.Truncate(0); err != nil {
   263  		return err
   264  	}
   265  	if err := os.Remove(s.ETagFilePath); err != nil && !os.IsNotExist(err) {
   266  		return err
   267  	}
   268  	return nil
   269  }
   270  
   271  func (s *resumableSession) getClient() *http.Client {
   272  	transport := http.DefaultTransport
   273  	if s.InsecureSkipTLSVerify {
   274  		transport = &http.Transport{
   275  			TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
   276  		}
   277  	}
   278  
   279  	return &http.Client{
   280  		Transport: transport,
   281  		CheckRedirect: func(req *http.Request, via []*http.Request) error {
   282  			if len(via) >= 10 {
   283  				return fmt.Errorf("too many redirects")
   284  			}
   285  			s.setHTTPHeaders(req)
   286  			return nil
   287  		},
   288  	}
   289  }
   290  
   291  func (s *resumableSession) httpRequest(method string, u *url.URL) *http.Request {
   292  	req := &http.Request{
   293  		Method:     method,
   294  		URL:        u,
   295  		Proto:      "HTTP/1.1",
   296  		ProtoMajor: 1,
   297  		ProtoMinor: 1,
   298  		Header:     make(http.Header),
   299  		Host:       u.Host,
   300  	}
   301  
   302  	s.setHTTPHeaders(req)
   303  
   304  	return req
   305  }
   306  
   307  func (s *resumableSession) modificationTimeOK(res *http.Response, mod time.Time) bool {
   308  	lastModified := res.Header.Get("Last-Modified")
   309  	if lastModified != "" {
   310  		layout := "Mon, 02 Jan 2006 15:04:05 MST"
   311  		t, err := time.Parse(layout, lastModified)
   312  		if err == nil && t.Before(mod) {
   313  			return true
   314  		}
   315  	}
   316  	return false
   317  }
   318  
   319  func (s *resumableSession) eTagOK(res *http.Response) bool {
   320  	etag := res.Header.Get("ETag")
   321  	if etag != "" {
   322  		savedEtag, err := ioutil.ReadFile(s.ETagFilePath)
   323  		if err == nil && string(savedEtag) == etag {
   324  			return true
   325  		}
   326  	}
   327  	return false
   328  }
   329  
   330  func (s *resumableSession) setHTTPHeaders(req *http.Request) {
   331  	for k, v := range s.Headers {
   332  		for _, e := range v {
   333  			req.Header.Add(k, e)
   334  		}
   335  	}
   336  	req.Header.Add("User-Agent", fmt.Sprintf("rkt/%s", version.Version))
   337  	if s.amountAlreadyHere > 0 && s.byteRangeSupported {
   338  		req.Header.Add("Range", fmt.Sprintf("bytes=%d-", s.amountAlreadyHere))
   339  	}
   340  }