github.com/kubiko/snapd@v0.0.0-20201013125620-d4f3094d9ddf/client/client.go (about)

     1  // -*- Mode: Go; indent-tabs-mode: t -*-
     2  
     3  /*
     4   * Copyright (C) 2015-2018 Canonical Ltd
     5   *
     6   * This program is free software: you can redistribute it and/or modify
     7   * it under the terms of the GNU General Public License version 3 as
     8   * published by the Free Software Foundation.
     9   *
    10   * This program is distributed in the hope that it will be useful,
    11   * but WITHOUT ANY WARRANTY; without even the implied warranty of
    12   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    13   * GNU General Public License for more details.
    14   *
    15   * You should have received a copy of the GNU General Public License
    16   * along with this program.  If not, see <http://www.gnu.org/licenses/>.
    17   *
    18   */
    19  
    20  package client
    21  
    22  import (
    23  	"bytes"
    24  	"context"
    25  	"encoding/json"
    26  	"fmt"
    27  	"io"
    28  	"io/ioutil"
    29  	"net"
    30  	"net/http"
    31  	"net/url"
    32  	"os"
    33  	"path"
    34  	"time"
    35  
    36  	"github.com/snapcore/snapd/dirs"
    37  	"github.com/snapcore/snapd/jsonutil"
    38  )
    39  
    40  func unixDialer(socketPath string) func(string, string) (net.Conn, error) {
    41  	if socketPath == "" {
    42  		socketPath = dirs.SnapdSocket
    43  	}
    44  	return func(_, _ string) (net.Conn, error) {
    45  		return net.Dial("unix", socketPath)
    46  	}
    47  }
    48  
    49  type doer interface {
    50  	Do(*http.Request) (*http.Response, error)
    51  }
    52  
    53  // Config allows to customize client behavior.
    54  type Config struct {
    55  	// BaseURL contains the base URL where snappy daemon is expected to be.
    56  	// It can be empty for a default behavior of talking over a unix socket.
    57  	BaseURL string
    58  
    59  	// DisableAuth controls whether the client should send an
    60  	// Authorization header from reading the auth.json data.
    61  	DisableAuth bool
    62  
    63  	// Interactive controls whether the client runs in interactive mode.
    64  	// At present, this only affects whether interactive polkit
    65  	// authorisation is requested.
    66  	Interactive bool
    67  
    68  	// Socket is the path to the unix socket to use
    69  	Socket string
    70  
    71  	// DisableKeepAlive indicates whether the connections should not be kept
    72  	// alive for later reuse
    73  	DisableKeepAlive bool
    74  
    75  	// User-Agent to sent to the snapd daemon
    76  	UserAgent string
    77  }
    78  
    79  // A Client knows how to talk to the snappy daemon.
    80  type Client struct {
    81  	baseURL url.URL
    82  	doer    doer
    83  
    84  	disableAuth bool
    85  	interactive bool
    86  
    87  	maintenance error
    88  
    89  	warningCount     int
    90  	warningTimestamp time.Time
    91  
    92  	userAgent string
    93  }
    94  
    95  // New returns a new instance of Client
    96  func New(config *Config) *Client {
    97  	if config == nil {
    98  		config = &Config{}
    99  	}
   100  
   101  	// By default talk over an UNIX socket.
   102  	if config.BaseURL == "" {
   103  		transport := &http.Transport{Dial: unixDialer(config.Socket), DisableKeepAlives: config.DisableKeepAlive}
   104  		return &Client{
   105  			baseURL: url.URL{
   106  				Scheme: "http",
   107  				Host:   "localhost",
   108  			},
   109  			doer:        &http.Client{Transport: transport},
   110  			disableAuth: config.DisableAuth,
   111  			interactive: config.Interactive,
   112  			userAgent:   config.UserAgent,
   113  		}
   114  	}
   115  
   116  	baseURL, err := url.Parse(config.BaseURL)
   117  	if err != nil {
   118  		panic(fmt.Sprintf("cannot parse server base URL: %q (%v)", config.BaseURL, err))
   119  	}
   120  	return &Client{
   121  		baseURL:     *baseURL,
   122  		doer:        &http.Client{Transport: &http.Transport{DisableKeepAlives: config.DisableKeepAlive}},
   123  		disableAuth: config.DisableAuth,
   124  		interactive: config.Interactive,
   125  		userAgent:   config.UserAgent,
   126  	}
   127  }
   128  
   129  // Maintenance returns an error reflecting the daemon maintenance status or nil.
   130  func (client *Client) Maintenance() error {
   131  	return client.maintenance
   132  }
   133  
   134  // WarningsSummary returns the number of warnings that are ready to be shown to
   135  // the user, and the timestamp of the most recently added warning (useful for
   136  // silencing the warning alerts, and OKing the returned warnings).
   137  func (client *Client) WarningsSummary() (count int, timestamp time.Time) {
   138  	return client.warningCount, client.warningTimestamp
   139  }
   140  
   141  func (client *Client) WhoAmI() (string, error) {
   142  	user, err := readAuthData()
   143  	if os.IsNotExist(err) {
   144  		return "", nil
   145  	}
   146  	if err != nil {
   147  		return "", err
   148  	}
   149  
   150  	return user.Email, nil
   151  }
   152  
   153  func (client *Client) setAuthorization(req *http.Request) error {
   154  	user, err := readAuthData()
   155  	if os.IsNotExist(err) {
   156  		return nil
   157  	}
   158  	if err != nil {
   159  		return err
   160  	}
   161  
   162  	var buf bytes.Buffer
   163  	fmt.Fprintf(&buf, `Macaroon root="%s"`, user.Macaroon)
   164  	for _, discharge := range user.Discharges {
   165  		fmt.Fprintf(&buf, `, discharge="%s"`, discharge)
   166  	}
   167  	req.Header.Set("Authorization", buf.String())
   168  	return nil
   169  }
   170  
   171  type RequestError struct{ error }
   172  
   173  func (e RequestError) Error() string {
   174  	return fmt.Sprintf("cannot build request: %v", e.error)
   175  }
   176  
   177  type AuthorizationError struct{ error }
   178  
   179  func (e AuthorizationError) Error() string {
   180  	return fmt.Sprintf("cannot add authorization: %v", e.error)
   181  }
   182  
   183  type ConnectionError struct{ Err error }
   184  
   185  func (e ConnectionError) Error() string {
   186  	var errStr string
   187  	switch e.Err {
   188  	case context.DeadlineExceeded:
   189  		errStr = "timeout exceeded while waiting for response"
   190  	case context.Canceled:
   191  		errStr = "request canceled"
   192  	default:
   193  		errStr = e.Err.Error()
   194  	}
   195  	return fmt.Sprintf("cannot communicate with server: %s", errStr)
   196  }
   197  
   198  func (e ConnectionError) Unwrap() error {
   199  	return e.Err
   200  }
   201  
   202  // AllowInteractionHeader is the HTTP request header used to indicate
   203  // that the client is willing to allow interaction.
   204  const AllowInteractionHeader = "X-Allow-Interaction"
   205  
   206  // raw performs a request and returns the resulting http.Response and
   207  // error. You usually only need to call this directly if you expect the
   208  // response to not be JSON, otherwise you'd call Do(...) instead.
   209  func (client *Client) raw(ctx context.Context, method, urlpath string, query url.Values, headers map[string]string, body io.Reader) (*http.Response, error) {
   210  	// fake a url to keep http.Client happy
   211  	u := client.baseURL
   212  	u.Path = path.Join(client.baseURL.Path, urlpath)
   213  	u.RawQuery = query.Encode()
   214  	req, err := http.NewRequest(method, u.String(), body)
   215  	if err != nil {
   216  		return nil, RequestError{err}
   217  	}
   218  	if client.userAgent != "" {
   219  		req.Header.Set("User-Agent", client.userAgent)
   220  	}
   221  
   222  	for key, value := range headers {
   223  		req.Header.Set(key, value)
   224  	}
   225  
   226  	if !client.disableAuth {
   227  		// set Authorization header if there are user's credentials
   228  		err = client.setAuthorization(req)
   229  		if err != nil {
   230  			return nil, AuthorizationError{err}
   231  		}
   232  	}
   233  
   234  	if client.interactive {
   235  		req.Header.Set(AllowInteractionHeader, "true")
   236  	}
   237  
   238  	if ctx != nil {
   239  		req = req.WithContext(ctx)
   240  	}
   241  
   242  	rsp, err := client.doer.Do(req)
   243  	if err != nil {
   244  		return nil, ConnectionError{err}
   245  	}
   246  
   247  	return rsp, nil
   248  }
   249  
   250  // rawWithTimeout is like raw(), but sets a timeout based on opts for
   251  // the whole of request and response (including rsp.Body() read) round
   252  // trip. If opts is nil the default doTimeout is used.
   253  // The caller is responsible for canceling the internal context
   254  // to release the resources associated with the request by calling the
   255  // returned cancel function.
   256  func (client *Client) rawWithTimeout(ctx context.Context, method, urlpath string, query url.Values, headers map[string]string, body io.Reader, opts *doOptions) (*http.Response, context.CancelFunc, error) {
   257  	opts = ensureDoOpts(opts)
   258  	if opts.Timeout <= 0 {
   259  		return nil, nil, fmt.Errorf("internal error: timeout not set in options for rawWithTimeout")
   260  	}
   261  
   262  	ctx, cancel := context.WithTimeout(ctx, opts.Timeout)
   263  	rsp, err := client.raw(ctx, method, urlpath, query, headers, body)
   264  	if err != nil && ctx.Err() != nil {
   265  		cancel()
   266  		return nil, nil, ConnectionError{ctx.Err()}
   267  	}
   268  
   269  	return rsp, cancel, err
   270  }
   271  
   272  var (
   273  	doRetry = 250 * time.Millisecond
   274  	// snapd may need to reach out to the store, where it uses a fixed 10s
   275  	// timeout for the whole of a single request to complete, requests are
   276  	// retried for up to 38s in total, make sure that the client timeout is
   277  	// not shorter than that
   278  	doTimeout = 120 * time.Second
   279  )
   280  
   281  // MockDoTimings mocks the delay used by the do retry loop and request timeout.
   282  func MockDoTimings(retry, timeout time.Duration) (restore func()) {
   283  	oldRetry := doRetry
   284  	oldTimeout := doTimeout
   285  	doRetry = retry
   286  	doTimeout = timeout
   287  	return func() {
   288  		doRetry = oldRetry
   289  		doTimeout = oldTimeout
   290  	}
   291  }
   292  
   293  type hijacked struct {
   294  	do func(*http.Request) (*http.Response, error)
   295  }
   296  
   297  func (h hijacked) Do(req *http.Request) (*http.Response, error) {
   298  	return h.do(req)
   299  }
   300  
   301  // Hijack lets the caller take over the raw http request
   302  func (client *Client) Hijack(f func(*http.Request) (*http.Response, error)) {
   303  	client.doer = hijacked{f}
   304  }
   305  
   306  type doOptions struct {
   307  	// Timeout is the overall request timeout
   308  	Timeout time.Duration
   309  	// Retry interval
   310  	Retry time.Duration
   311  }
   312  
   313  func ensureDoOpts(opts *doOptions) *doOptions {
   314  	if opts == nil {
   315  		// defaults
   316  		opts = &doOptions{
   317  			Timeout: doTimeout,
   318  			Retry:   doRetry,
   319  		}
   320  	}
   321  	return opts
   322  }
   323  
   324  // doNoTimeoutAndRetry can be passed to the do family to not have timeout
   325  // nor retries.
   326  var doNoTimeoutAndRetry = &doOptions{
   327  	Timeout: time.Duration(-1),
   328  }
   329  
   330  // do performs a request and decodes the resulting json into the given
   331  // value. It's low-level, for testing/experimenting only; you should
   332  // usually use a higher level interface that builds on this.
   333  func (client *Client) do(method, path string, query url.Values, headers map[string]string, body io.Reader, v interface{}, opts *doOptions) (statusCode int, err error) {
   334  	opts = ensureDoOpts(opts)
   335  
   336  	var rsp *http.Response
   337  	var ctx context.Context = context.Background()
   338  	if opts.Timeout <= 0 {
   339  		// no timeout and retries
   340  		rsp, err = client.raw(ctx, method, path, query, headers, body)
   341  	} else {
   342  		if opts.Retry <= 0 {
   343  			return 0, fmt.Errorf("internal error: retry setting %s invalid", opts.Retry)
   344  		}
   345  		retry := time.NewTicker(opts.Retry)
   346  		defer retry.Stop()
   347  		timeout := time.NewTimer(opts.Timeout)
   348  		defer timeout.Stop()
   349  
   350  		for {
   351  			var cancel context.CancelFunc
   352  			// use the same timeout as for the whole of the retry
   353  			// loop to error out the whole do() call when a single
   354  			// request exceeds the deadline
   355  			rsp, cancel, err = client.rawWithTimeout(ctx, method, path, query, headers, body, opts)
   356  			if err == nil {
   357  				defer cancel()
   358  			}
   359  			if err == nil || method != "GET" {
   360  				break
   361  			}
   362  			select {
   363  			case <-retry.C:
   364  				continue
   365  			case <-timeout.C:
   366  			}
   367  			break
   368  		}
   369  	}
   370  	if err != nil {
   371  		return 0, err
   372  	}
   373  	defer rsp.Body.Close()
   374  
   375  	if v != nil {
   376  		if err := decodeInto(rsp.Body, v); err != nil {
   377  			return rsp.StatusCode, err
   378  		}
   379  	}
   380  
   381  	return rsp.StatusCode, nil
   382  }
   383  
   384  func decodeInto(reader io.Reader, v interface{}) error {
   385  	dec := json.NewDecoder(reader)
   386  	if err := dec.Decode(v); err != nil {
   387  		r := dec.Buffered()
   388  		buf, err1 := ioutil.ReadAll(r)
   389  		if err1 != nil {
   390  			buf = []byte(fmt.Sprintf("error reading buffered response body: %s", err1))
   391  		}
   392  		return fmt.Errorf("cannot decode %q: %s", buf, err)
   393  	}
   394  	return nil
   395  }
   396  
   397  // doSync performs a request to the given path using the specified HTTP method.
   398  // It expects a "sync" response from the API and on success decodes the JSON
   399  // response payload into the given value using the "UseNumber" json decoding
   400  // which produces json.Numbers instead of float64 types for numbers.
   401  func (client *Client) doSync(method, path string, query url.Values, headers map[string]string, body io.Reader, v interface{}) (*ResultInfo, error) {
   402  	return client.doSyncWithOpts(method, path, query, headers, body, v, nil)
   403  }
   404  
   405  func (client *Client) doSyncWithOpts(method, path string, query url.Values, headers map[string]string, body io.Reader, v interface{}, opts *doOptions) (*ResultInfo, error) {
   406  	var rsp response
   407  	statusCode, err := client.do(method, path, query, headers, body, &rsp, opts)
   408  	if err != nil {
   409  		return nil, err
   410  	}
   411  	if err := rsp.err(client, statusCode); err != nil {
   412  		return nil, err
   413  	}
   414  	if rsp.Type != "sync" {
   415  		return nil, fmt.Errorf("expected sync response, got %q", rsp.Type)
   416  	}
   417  
   418  	if v != nil {
   419  		if err := jsonutil.DecodeWithNumber(bytes.NewReader(rsp.Result), v); err != nil {
   420  			return nil, fmt.Errorf("cannot unmarshal: %v", err)
   421  		}
   422  	}
   423  
   424  	client.warningCount = rsp.WarningCount
   425  	client.warningTimestamp = rsp.WarningTimestamp
   426  
   427  	return &rsp.ResultInfo, nil
   428  }
   429  
   430  func (client *Client) doAsync(method, path string, query url.Values, headers map[string]string, body io.Reader) (changeID string, err error) {
   431  	_, changeID, err = client.doAsyncFull(method, path, query, headers, body, nil)
   432  	return
   433  }
   434  
   435  func (client *Client) doAsyncFull(method, path string, query url.Values, headers map[string]string, body io.Reader, opts *doOptions) (result json.RawMessage, changeID string, err error) {
   436  	var rsp response
   437  	statusCode, err := client.do(method, path, query, headers, body, &rsp, opts)
   438  	if err != nil {
   439  		return nil, "", err
   440  	}
   441  	if err := rsp.err(client, statusCode); err != nil {
   442  		return nil, "", err
   443  	}
   444  	if rsp.Type != "async" {
   445  		return nil, "", fmt.Errorf("expected async response for %q on %q, got %q", method, path, rsp.Type)
   446  	}
   447  	if statusCode != 202 {
   448  		return nil, "", fmt.Errorf("operation not accepted")
   449  	}
   450  	if rsp.Change == "" {
   451  		return nil, "", fmt.Errorf("async response without change reference")
   452  	}
   453  
   454  	return rsp.Result, rsp.Change, nil
   455  }
   456  
   457  type ServerVersion struct {
   458  	Version     string
   459  	Series      string
   460  	OSID        string
   461  	OSVersionID string
   462  	OnClassic   bool
   463  
   464  	KernelVersion  string
   465  	Architecture   string
   466  	Virtualization string
   467  }
   468  
   469  func (client *Client) ServerVersion() (*ServerVersion, error) {
   470  	sysInfo, err := client.SysInfo()
   471  	if err != nil {
   472  		return nil, err
   473  	}
   474  
   475  	return &ServerVersion{
   476  		Version:     sysInfo.Version,
   477  		Series:      sysInfo.Series,
   478  		OSID:        sysInfo.OSRelease.ID,
   479  		OSVersionID: sysInfo.OSRelease.VersionID,
   480  		OnClassic:   sysInfo.OnClassic,
   481  
   482  		KernelVersion:  sysInfo.KernelVersion,
   483  		Architecture:   sysInfo.Architecture,
   484  		Virtualization: sysInfo.Virtualization,
   485  	}, nil
   486  }
   487  
   488  // A response produced by the REST API will usually fit in this
   489  // (exceptions are the icons/ endpoints obvs)
   490  type response struct {
   491  	Result json.RawMessage `json:"result"`
   492  	Type   string          `json:"type"`
   493  	Change string          `json:"change"`
   494  
   495  	WarningCount     int       `json:"warning-count"`
   496  	WarningTimestamp time.Time `json:"warning-timestamp"`
   497  
   498  	ResultInfo
   499  
   500  	Maintenance *Error `json:"maintenance"`
   501  }
   502  
   503  // Error is the real value of response.Result when an error occurs.
   504  type Error struct {
   505  	Kind    ErrorKind   `json:"kind"`
   506  	Value   interface{} `json:"value"`
   507  	Message string      `json:"message"`
   508  
   509  	StatusCode int
   510  }
   511  
   512  func (e *Error) Error() string {
   513  	return e.Message
   514  }
   515  
   516  // IsRetryable returns true if the given error is an error
   517  // that can be retried later.
   518  func IsRetryable(err error) bool {
   519  	switch e := err.(type) {
   520  	case *Error:
   521  		return e.Kind == ErrorKindSnapChangeConflict
   522  	}
   523  	return false
   524  }
   525  
   526  // IsTwoFactorError returns whether the given error is due to problems
   527  // in two-factor authentication.
   528  func IsTwoFactorError(err error) bool {
   529  	e, ok := err.(*Error)
   530  	if !ok || e == nil {
   531  		return false
   532  	}
   533  
   534  	return e.Kind == ErrorKindTwoFactorFailed || e.Kind == ErrorKindTwoFactorRequired
   535  }
   536  
   537  // IsInterfacesUnchangedError returns whether the given error means the requested
   538  // change to interfaces was not made, because there was nothing to do.
   539  func IsInterfacesUnchangedError(err error) bool {
   540  	e, ok := err.(*Error)
   541  	if !ok || e == nil {
   542  		return false
   543  	}
   544  	return e.Kind == ErrorKindInterfacesUnchanged
   545  }
   546  
   547  // IsAssertionNotFoundError returns whether the given error means that the
   548  // assertion wasn't found and thus the device isn't ready/seeded.
   549  func IsAssertionNotFoundError(err error) bool {
   550  	e, ok := err.(*Error)
   551  	if !ok || e == nil {
   552  		return false
   553  	}
   554  
   555  	return e.Kind == ErrorKindAssertionNotFound
   556  }
   557  
   558  // OSRelease contains information about the system extracted from /etc/os-release.
   559  type OSRelease struct {
   560  	ID        string `json:"id"`
   561  	VersionID string `json:"version-id,omitempty"`
   562  }
   563  
   564  // RefreshInfo contains information about refreshes.
   565  type RefreshInfo struct {
   566  	// Timer contains the refresh.timer setting.
   567  	Timer string `json:"timer,omitempty"`
   568  	// Schedule contains the legacy refresh.schedule setting.
   569  	Schedule string `json:"schedule,omitempty"`
   570  	Last     string `json:"last,omitempty"`
   571  	Hold     string `json:"hold,omitempty"`
   572  	Next     string `json:"next,omitempty"`
   573  }
   574  
   575  // SysInfo holds system information
   576  type SysInfo struct {
   577  	Series    string    `json:"series,omitempty"`
   578  	Version   string    `json:"version,omitempty"`
   579  	BuildID   string    `json:"build-id"`
   580  	OSRelease OSRelease `json:"os-release"`
   581  	OnClassic bool      `json:"on-classic"`
   582  	Managed   bool      `json:"managed"`
   583  
   584  	KernelVersion  string `json:"kernel-version,omitempty"`
   585  	Architecture   string `json:"architecture,omitempty"`
   586  	Virtualization string `json:"virtualization,omitempty"`
   587  
   588  	Refresh         RefreshInfo         `json:"refresh,omitempty"`
   589  	Confinement     string              `json:"confinement"`
   590  	SandboxFeatures map[string][]string `json:"sandbox-features,omitempty"`
   591  }
   592  
   593  func (rsp *response) err(cli *Client, statusCode int) error {
   594  	if cli != nil {
   595  		maintErr := rsp.Maintenance
   596  		// avoid setting to (*client.Error)(nil)
   597  		if maintErr != nil {
   598  			cli.maintenance = maintErr
   599  		} else {
   600  			cli.maintenance = nil
   601  		}
   602  	}
   603  	if rsp.Type != "error" {
   604  		return nil
   605  	}
   606  	var resultErr Error
   607  	err := json.Unmarshal(rsp.Result, &resultErr)
   608  	if err != nil || resultErr.Message == "" {
   609  		return fmt.Errorf("server error: %q", http.StatusText(statusCode))
   610  	}
   611  	resultErr.StatusCode = statusCode
   612  
   613  	return &resultErr
   614  }
   615  
   616  func parseError(r *http.Response) error {
   617  	var rsp response
   618  	if r.Header.Get("Content-Type") != "application/json" {
   619  		return fmt.Errorf("server error: %q", r.Status)
   620  	}
   621  
   622  	dec := json.NewDecoder(r.Body)
   623  	if err := dec.Decode(&rsp); err != nil {
   624  		return fmt.Errorf("cannot unmarshal error: %v", err)
   625  	}
   626  
   627  	err := rsp.err(nil, r.StatusCode)
   628  	if err == nil {
   629  		return fmt.Errorf("server error: %q", r.Status)
   630  	}
   631  	return err
   632  }
   633  
   634  // SysInfo gets system information from the REST API.
   635  func (client *Client) SysInfo() (*SysInfo, error) {
   636  	var sysInfo SysInfo
   637  
   638  	if _, err := client.doSync("GET", "/v2/system-info", nil, nil, nil, &sysInfo); err != nil {
   639  		return nil, fmt.Errorf("cannot obtain system details: %v", err)
   640  	}
   641  
   642  	return &sysInfo, nil
   643  }
   644  
   645  type debugAction struct {
   646  	Action string      `json:"action"`
   647  	Params interface{} `json:"params,omitempty"`
   648  }
   649  
   650  // Debug is only useful when writing test code, it will trigger
   651  // an internal action with the given parameters.
   652  func (client *Client) Debug(action string, params interface{}, result interface{}) error {
   653  	body, err := json.Marshal(debugAction{
   654  		Action: action,
   655  		Params: params,
   656  	})
   657  	if err != nil {
   658  		return err
   659  	}
   660  
   661  	_, err = client.doSync("POST", "/v2/debug", nil, nil, bytes.NewReader(body), result)
   662  	return err
   663  }
   664  
   665  func (client *Client) DebugGet(aspect string, result interface{}, params map[string]string) error {
   666  	urlParams := url.Values{"aspect": []string{aspect}}
   667  	for k, v := range params {
   668  		urlParams.Set(k, v)
   669  	}
   670  	_, err := client.doSync("GET", "/v2/debug", urlParams, nil, nil, &result)
   671  	return err
   672  }