github.com/ethanhsieh/snapd@v0.0.0-20210615102523-3db9b8e4edc5/usersession/client/client.go (about)

     1  // -*- Mode: Go; indent-tabs-mode: t -*-
     2  
     3  /*
     4   * Copyright (C) 2015-2020 Canonical Ltd
     5   *
     6   * This program is free software: you can redistribute it and/or modify
     7   * it under the terms of the GNU General Public License version 3 as
     8   * published by the Free Software Foundation.
     9   *
    10   * This program is distributed in the hope that it will be useful,
    11   * but WITHOUT ANY WARRANTY; without even the implied warranty of
    12   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    13   * GNU General Public License for more details.
    14   *
    15   * You should have received a copy of the GNU General Public License
    16   * along with this program.  If not, see <http://www.gnu.org/licenses/>.
    17   *
    18   */
    19  
    20  package client
    21  
    22  import (
    23  	"bytes"
    24  	"context"
    25  	"encoding/json"
    26  	"fmt"
    27  	"io"
    28  	"io/ioutil"
    29  	"net"
    30  	"net/http"
    31  	"net/url"
    32  	"path/filepath"
    33  	"strconv"
    34  	"sync"
    35  	"time"
    36  
    37  	"github.com/snapcore/snapd/dirs"
    38  )
    39  
    40  // dialSessionAgent connects to a user's session agent
    41  //
    42  // The host portion of the address is interpreted as the numeric user
    43  // ID of the target user.
    44  func dialSessionAgent(network, address string) (net.Conn, error) {
    45  	host, _, err := net.SplitHostPort(address)
    46  	if err != nil {
    47  		return nil, err
    48  	}
    49  	socket := filepath.Join(dirs.XdgRuntimeDirBase, host, "snapd-session-agent.socket")
    50  	return net.Dial("unix", socket)
    51  }
    52  
    53  type Client struct {
    54  	doer *http.Client
    55  }
    56  
    57  func New() *Client {
    58  	transport := &http.Transport{Dial: dialSessionAgent, DisableKeepAlives: true}
    59  	return &Client{
    60  		doer: &http.Client{Transport: transport},
    61  	}
    62  }
    63  
    64  type Error struct {
    65  	Kind    string      `json:"kind"`
    66  	Value   interface{} `json:"value"`
    67  	Message string      `json:"message"`
    68  }
    69  
    70  func (e *Error) Error() string {
    71  	return e.Message
    72  }
    73  
    74  type response struct {
    75  	// Not from JSON
    76  	uid        int
    77  	err        error
    78  	statusCode int
    79  
    80  	Result json.RawMessage `json:"result"`
    81  	Type   string          `json:"type"`
    82  }
    83  
    84  func (resp *response) checkError() {
    85  	if resp.Type != "error" {
    86  		return
    87  	}
    88  	var resultErr Error
    89  	err := json.Unmarshal(resp.Result, &resultErr)
    90  	if err != nil || resultErr.Message == "" {
    91  		resp.err = fmt.Errorf("server error: %q", http.StatusText(resp.statusCode))
    92  	} else {
    93  		resp.err = &resultErr
    94  	}
    95  }
    96  
    97  func (client *Client) doMany(ctx context.Context, method, urlpath string, query url.Values, headers map[string]string, body []byte) ([]*response, error) {
    98  	sockets, err := filepath.Glob(filepath.Join(dirs.XdgRuntimeDirGlob, "snapd-session-agent.socket"))
    99  	if err != nil {
   100  		return nil, err
   101  	}
   102  	var (
   103  		wg        sync.WaitGroup
   104  		mu        sync.Mutex
   105  		responses []*response
   106  	)
   107  	for _, socket := range sockets {
   108  		wg.Add(1)
   109  		go func(socket string) {
   110  			defer wg.Done()
   111  			uidStr := filepath.Base(filepath.Dir(socket))
   112  			uid, err := strconv.Atoi(uidStr)
   113  			if err != nil {
   114  				// Ignore directories that do not
   115  				// appear to be valid XDG runtime dirs
   116  				// (i.e. /run/user/NNNN).
   117  				return
   118  			}
   119  			response := response{uid: uid}
   120  			defer func() {
   121  				mu.Lock()
   122  				defer mu.Unlock()
   123  				responses = append(responses, &response)
   124  			}()
   125  
   126  			u := url.URL{
   127  				Scheme:   "http",
   128  				Host:     uidStr,
   129  				Path:     urlpath,
   130  				RawQuery: query.Encode(),
   131  			}
   132  			req, err := http.NewRequest(method, u.String(), bytes.NewBuffer(body))
   133  			if err != nil {
   134  				response.err = fmt.Errorf("internal error: %v", err)
   135  				return
   136  			}
   137  			req = req.WithContext(ctx)
   138  			for key, value := range headers {
   139  				req.Header.Set(key, value)
   140  			}
   141  			httpResp, err := client.doer.Do(req)
   142  			if err != nil {
   143  				response.err = err
   144  				return
   145  			}
   146  			defer httpResp.Body.Close()
   147  			response.statusCode = httpResp.StatusCode
   148  			response.err = decodeInto(httpResp.Body, &response)
   149  			response.checkError()
   150  		}(socket)
   151  	}
   152  	wg.Wait()
   153  	return responses, nil
   154  }
   155  
   156  func decodeInto(reader io.Reader, v interface{}) error {
   157  	dec := json.NewDecoder(reader)
   158  	if err := dec.Decode(v); err != nil {
   159  		r := dec.Buffered()
   160  		buf, err1 := ioutil.ReadAll(r)
   161  		if err1 != nil {
   162  			buf = []byte(fmt.Sprintf("error reading buffered response body: %s", err1))
   163  		}
   164  		return fmt.Errorf("cannot decode %q: %s", buf, err)
   165  	}
   166  	return nil
   167  }
   168  
   169  type SessionInfo struct {
   170  	Version string `json:"version"`
   171  }
   172  
   173  func (client *Client) SessionInfo(ctx context.Context) (info map[int]SessionInfo, err error) {
   174  	responses, err := client.doMany(ctx, "GET", "/v1/session-info", nil, nil, nil)
   175  	if err != nil {
   176  		return nil, err
   177  	}
   178  
   179  	info = make(map[int]SessionInfo)
   180  	for _, resp := range responses {
   181  		if resp.err != nil {
   182  			if err == nil {
   183  				err = resp.err
   184  			}
   185  			continue
   186  		}
   187  		var si SessionInfo
   188  		if decodeErr := json.Unmarshal(resp.Result, &si); decodeErr != nil {
   189  			if err == nil {
   190  				err = decodeErr
   191  			}
   192  			continue
   193  		}
   194  		info[resp.uid] = si
   195  	}
   196  	return info, err
   197  }
   198  
   199  type ServiceFailure struct {
   200  	Uid     int
   201  	Service string
   202  	Error   string
   203  }
   204  
   205  func decodeServiceErrors(uid int, errorValue map[string]interface{}, kind string) ([]ServiceFailure, error) {
   206  	if errorValue[kind] == nil {
   207  		return nil, nil
   208  	}
   209  	errors, ok := errorValue[kind].(map[string]interface{})
   210  	if !ok {
   211  		return nil, fmt.Errorf("cannot decode %s failures: expected a map, got %T", kind, errorValue[kind])
   212  	}
   213  	var failures []ServiceFailure
   214  	var err error
   215  	for service, reason := range errors {
   216  		if reasonString, ok := reason.(string); ok {
   217  			failures = append(failures, ServiceFailure{
   218  				Uid:     uid,
   219  				Service: service,
   220  				Error:   reasonString,
   221  			})
   222  		} else if err == nil {
   223  			err = fmt.Errorf("cannot decode %s failure for %q: expected string, but got %T", kind, service, reason)
   224  		}
   225  	}
   226  	return failures, err
   227  }
   228  
   229  func (client *Client) serviceControlCall(ctx context.Context, action string, services []string) (startFailures, stopFailures []ServiceFailure, err error) {
   230  	headers := map[string]string{"Content-Type": "application/json"}
   231  	reqBody, err := json.Marshal(map[string]interface{}{
   232  		"action":   action,
   233  		"services": services,
   234  	})
   235  	if err != nil {
   236  		return nil, nil, err
   237  	}
   238  	responses, err := client.doMany(ctx, "POST", "/v1/service-control", nil, headers, reqBody)
   239  	if err != nil {
   240  		return nil, nil, err
   241  	}
   242  	for _, resp := range responses {
   243  		if agentErr, ok := resp.err.(*Error); ok && agentErr.Kind == "service-control" {
   244  			if errorValue, ok := agentErr.Value.(map[string]interface{}); ok {
   245  				failures, _ := decodeServiceErrors(resp.uid, errorValue, "start-errors")
   246  				startFailures = append(startFailures, failures...)
   247  				failures, _ = decodeServiceErrors(resp.uid, errorValue, "stop-errors")
   248  				stopFailures = append(stopFailures, failures...)
   249  			}
   250  		}
   251  		if resp.err != nil && err == nil {
   252  			err = resp.err
   253  		}
   254  	}
   255  	return startFailures, stopFailures, err
   256  }
   257  
   258  func (client *Client) ServicesDaemonReload(ctx context.Context) error {
   259  	_, _, err := client.serviceControlCall(ctx, "daemon-reload", nil)
   260  	return err
   261  }
   262  
   263  func (client *Client) ServicesStart(ctx context.Context, services []string) (startFailures, stopFailures []ServiceFailure, err error) {
   264  	return client.serviceControlCall(ctx, "start", services)
   265  }
   266  
   267  func (client *Client) ServicesStop(ctx context.Context, services []string) (stopFailures []ServiceFailure, err error) {
   268  	_, stopFailures, err = client.serviceControlCall(ctx, "stop", services)
   269  	return stopFailures, err
   270  }
   271  
   272  // PendingSnapRefreshInfo holds information about pending snap refresh provided to userd.
   273  type PendingSnapRefreshInfo struct {
   274  	InstanceName        string        `json:"instance-name"`
   275  	TimeRemaining       time.Duration `json:"time-remaining,omitempty"`
   276  	BusyAppName         string        `json:"busy-app-name,omitempty"`
   277  	BusyAppDesktopEntry string        `json:"busy-app-desktop-entry,omitempty"`
   278  }
   279  
   280  // PendingRefreshNotification broadcasts information about a refresh.
   281  func (client *Client) PendingRefreshNotification(ctx context.Context, refreshInfo *PendingSnapRefreshInfo) error {
   282  	headers := map[string]string{"Content-Type": "application/json"}
   283  	reqBody, err := json.Marshal(refreshInfo)
   284  	if err != nil {
   285  		return err
   286  	}
   287  	_, err = client.doMany(ctx, "POST", "/v1/notifications/pending-refresh", nil, headers, reqBody)
   288  	return err
   289  }