github.com/ethanhsieh/snapd@v0.0.0-20210615102523-3db9b8e4edc5/store/store_asserts.go (about)

     1  // -*- Mode: Go; indent-tabs-mode: t -*-
     2  
     3  /*
     4   * Copyright (C) 2014-2020 Canonical Ltd
     5   *
     6   * This program is free software: you can redistribute it and/or modify
     7   * it under the terms of the GNU General Public License version 3 as
     8   * published by the Free Software Foundation.
     9   *
    10   * This program is distributed in the hope that it will be useful,
    11   * but WITHOUT ANY WARRANTY; without even the implied warranty of
    12   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    13   * GNU General Public License for more details.
    14   *
    15   * You should have received a copy of the GNU General Public License
    16   * along with this program.  If not, see <http://www.gnu.org/licenses/>.
    17   *
    18   */
    19  
    20  // Package store has support to use the Ubuntu Store for querying and downloading of snaps, and the related services.
    21  package store
    22  
    23  import (
    24  	"context"
    25  	"encoding/json"
    26  	"fmt"
    27  	"io"
    28  	"net/http"
    29  	"net/url"
    30  	"path"
    31  	"strconv"
    32  
    33  	"github.com/snapcore/snapd/asserts"
    34  	"github.com/snapcore/snapd/httputil"
    35  	"github.com/snapcore/snapd/overlord/auth"
    36  )
    37  
    38  func (s *Store) assertionsEndpointURL(p string, query url.Values) *url.URL {
    39  	defBaseURL := s.cfg.StoreBaseURL
    40  	// can be overridden separately!
    41  	if s.cfg.AssertionsBaseURL != nil {
    42  		defBaseURL = s.cfg.AssertionsBaseURL
    43  	}
    44  	return endpointURL(s.baseURL(defBaseURL), path.Join(assertionsPath, p), query)
    45  }
    46  
    47  type assertionSvcError struct {
    48  	// v1 error fields
    49  	// XXX: remove once switched to v2 API request.
    50  	Status int    `json:"status"`
    51  	Type   string `json:"type"`
    52  	Title  string `json:"title"`
    53  	Detail string `json:"detail"`
    54  
    55  	// v2 error list - the only field included in v2 error response.
    56  	// XXX: there is an overlap with searchV2Results (and partially with
    57  	// errorListEntry), we could share the definition.
    58  	ErrorList []struct {
    59  		Code    string `json:"code"`
    60  		Message string `json:"message"`
    61  	} `json:"error-list"`
    62  }
    63  
    64  func (e *assertionSvcError) isNotFound() bool {
    65  	return (len(e.ErrorList) > 0 && e.ErrorList[0].Code == "not-found" /* v2 error */) || e.Status == 404
    66  }
    67  
    68  func (e *assertionSvcError) toError() error {
    69  	// is it v2 error?
    70  	if len(e.ErrorList) > 0 {
    71  		return fmt.Errorf("assertion service error: %q", e.ErrorList[0].Message)
    72  	}
    73  	// v1 error
    74  	return fmt.Errorf("assertion service error: [%s] %q", e.Title, e.Detail)
    75  }
    76  
    77  // Assertion retrieves the assertion for the given type and primary key.
    78  func (s *Store) Assertion(assertType *asserts.AssertionType, primaryKey []string, user *auth.UserState) (asserts.Assertion, error) {
    79  	v := url.Values{}
    80  	v.Set("max-format", strconv.Itoa(assertType.MaxSupportedFormat()))
    81  	u := s.assertionsEndpointURL(path.Join(assertType.Name, path.Join(primaryKey...)), v)
    82  
    83  	var asrt asserts.Assertion
    84  
    85  	err := s.downloadAssertions(u, func(r io.Reader) error {
    86  		// decode assertion
    87  		dec := asserts.NewDecoder(r)
    88  		var e error
    89  		asrt, e = dec.Decode()
    90  		return e
    91  	}, func(svcErr *assertionSvcError) error {
    92  		// error-list indicates v2 error response.
    93  		if svcErr.isNotFound() {
    94  			// best-effort
    95  			headers, _ := asserts.HeadersFromPrimaryKey(assertType, primaryKey)
    96  			return &asserts.NotFoundError{
    97  				Type:    assertType,
    98  				Headers: headers,
    99  			}
   100  		}
   101  		// default error
   102  		return nil
   103  	}, "fetch assertion", user)
   104  	if err != nil {
   105  		return nil, err
   106  	}
   107  	return asrt, nil
   108  }
   109  
   110  // SeqFormingAssertion retrieves the sequence-forming assertion for the given
   111  // type (currently validation-set only). For sequence <= 0 we query for the
   112  // latest sequence, otherwise the latest revision of the given sequence is
   113  // requested.
   114  func (s *Store) SeqFormingAssertion(assertType *asserts.AssertionType, sequenceKey []string, sequence int, user *auth.UserState) (asserts.Assertion, error) {
   115  	if !assertType.SequenceForming() {
   116  		return nil, fmt.Errorf("internal error: requested non sequence-forming assertion type %q", assertType.Name)
   117  	}
   118  	v := url.Values{}
   119  	v.Set("max-format", strconv.Itoa(assertType.MaxSupportedFormat()))
   120  
   121  	hasSequenceNumber := sequence > 0
   122  	if hasSequenceNumber {
   123  		// full primary key passed, query specific sequence number.
   124  		v.Set("sequence", fmt.Sprintf("%d", sequence))
   125  	} else {
   126  		// query for the latest sequence.
   127  		v.Set("sequence", "latest")
   128  	}
   129  	u := s.assertionsEndpointURL(path.Join(assertType.Name, path.Join(sequenceKey...)), v)
   130  
   131  	var asrt asserts.Assertion
   132  
   133  	err := s.downloadAssertions(u, func(r io.Reader) error {
   134  		// decode assertion
   135  		dec := asserts.NewDecoder(r)
   136  		var e error
   137  		asrt, e = dec.Decode()
   138  		return e
   139  	}, func(svcErr *assertionSvcError) error {
   140  		// error-list indicates v2 error response.
   141  		if svcErr.isNotFound() {
   142  			// XXX: this re-implements asserts.HeadersFromPrimaryKey() but is
   143  			// more relaxed about key length, making sequence optional. Should
   144  			// we make it a helper on its own in store for the not-found-error
   145  			// handling?
   146  			if len(sequenceKey) != len(assertType.PrimaryKey)-1 {
   147  				return fmt.Errorf("sequence key has wrong length for %q assertion", assertType.Name)
   148  			}
   149  			headers := make(map[string]string)
   150  			for i, keyVal := range sequenceKey {
   151  				name := assertType.PrimaryKey[i]
   152  				if keyVal == "" {
   153  					return fmt.Errorf("sequence key %q header cannot be empty", name)
   154  				}
   155  				headers[name] = keyVal
   156  			}
   157  			if hasSequenceNumber {
   158  				headers[assertType.PrimaryKey[len(assertType.PrimaryKey)-1]] = fmt.Sprintf("%d", sequence)
   159  			}
   160  			return &asserts.NotFoundError{
   161  				Type:    assertType,
   162  				Headers: headers,
   163  			}
   164  		}
   165  		// default error
   166  		return nil
   167  	}, "fetch assertion", user)
   168  	if err != nil {
   169  		return nil, err
   170  	}
   171  	return asrt, nil
   172  }
   173  
   174  func (s *Store) downloadAssertions(u *url.URL, decodeBody func(io.Reader) error, handleSvcErr func(*assertionSvcError) error, what string, user *auth.UserState) error {
   175  	reqOptions := &requestOptions{
   176  		Method: "GET",
   177  		URL:    u,
   178  		Accept: asserts.MediaType,
   179  	}
   180  
   181  	resp, err := httputil.RetryRequest(reqOptions.URL.String(), func() (*http.Response, error) {
   182  		return s.doRequest(context.TODO(), s.client, reqOptions, user)
   183  	}, func(resp *http.Response) error {
   184  		var e error
   185  		if resp.StatusCode == 200 {
   186  			e = decodeBody(resp.Body)
   187  		} else {
   188  			contentType := resp.Header.Get("Content-Type")
   189  			if contentType == jsonContentType || contentType == "application/problem+json" {
   190  				var svcErr assertionSvcError
   191  				dec := json.NewDecoder(resp.Body)
   192  				if e = dec.Decode(&svcErr); e != nil {
   193  					return fmt.Errorf("cannot decode assertion service error with HTTP status code %d: %v", resp.StatusCode, e)
   194  				}
   195  				if handleSvcErr != nil {
   196  					if e := handleSvcErr(&svcErr); e != nil {
   197  						return e
   198  					}
   199  				}
   200  				// default error handling
   201  				return svcErr.toError()
   202  			}
   203  		}
   204  		return e
   205  	}, defaultRetryStrategy)
   206  
   207  	if err != nil {
   208  		return err
   209  	}
   210  
   211  	if resp.StatusCode != 200 {
   212  		return respToError(resp, what)
   213  	}
   214  
   215  	return nil
   216  }
   217  
   218  // DownloadAssertions download the assertion streams at the given URLs
   219  // and adds their assertions to the given asserts.Batch.
   220  func (s *Store) DownloadAssertions(streamURLs []string, b *asserts.Batch, user *auth.UserState) error {
   221  	for _, ustr := range streamURLs {
   222  		u, err := url.Parse(ustr)
   223  		if err != nil {
   224  			return fmt.Errorf("invalid assertions stream URL: %v", err)
   225  		}
   226  
   227  		err = s.downloadAssertions(u, func(r io.Reader) error {
   228  			// decode stream
   229  			_, e := b.AddStream(r)
   230  			return e
   231  		}, nil, "download assertion stream", user)
   232  		if err != nil {
   233  			return err
   234  		}
   235  
   236  	}
   237  	return nil
   238  }