github.com/kubiko/snapd@v0.0.0-20201013125620-d4f3094d9ddf/usersession/client/client.go (about)

     1  // -*- Mode: Go; indent-tabs-mode: t -*-
     2  
     3  /*
     4   * Copyright (C) 2015-2020 Canonical Ltd
     5   *
     6   * This program is free software: you can redistribute it and/or modify
     7   * it under the terms of the GNU General Public License version 3 as
     8   * published by the Free Software Foundation.
     9   *
    10   * This program is distributed in the hope that it will be useful,
    11   * but WITHOUT ANY WARRANTY; without even the implied warranty of
    12   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    13   * GNU General Public License for more details.
    14   *
    15   * You should have received a copy of the GNU General Public License
    16   * along with this program.  If not, see <http://www.gnu.org/licenses/>.
    17   *
    18   */
    19  
    20  package client
    21  
    22  import (
    23  	"bytes"
    24  	"context"
    25  	"encoding/json"
    26  	"fmt"
    27  	"io"
    28  	"io/ioutil"
    29  	"net"
    30  	"net/http"
    31  	"net/url"
    32  	"path/filepath"
    33  	"strconv"
    34  	"sync"
    35  
    36  	"github.com/snapcore/snapd/dirs"
    37  )
    38  
    39  // dialSessionAgent connects to a user's session agent
    40  //
    41  // The host portion of the address is interpreted as the numeric user
    42  // ID of the target user.
    43  func dialSessionAgent(network, address string) (net.Conn, error) {
    44  	host, _, err := net.SplitHostPort(address)
    45  	if err != nil {
    46  		return nil, err
    47  	}
    48  	socket := filepath.Join(dirs.XdgRuntimeDirBase, host, "snapd-session-agent.socket")
    49  	return net.Dial("unix", socket)
    50  }
    51  
    52  type Client struct {
    53  	doer *http.Client
    54  }
    55  
    56  func New() *Client {
    57  	transport := &http.Transport{Dial: dialSessionAgent, DisableKeepAlives: true}
    58  	return &Client{
    59  		doer: &http.Client{Transport: transport},
    60  	}
    61  }
    62  
    63  type Error struct {
    64  	Kind    string      `json:"kind"`
    65  	Value   interface{} `json:"value"`
    66  	Message string      `json:"message"`
    67  }
    68  
    69  func (e *Error) Error() string {
    70  	return e.Message
    71  }
    72  
    73  type response struct {
    74  	// Not from JSON
    75  	uid        int
    76  	err        error
    77  	statusCode int
    78  
    79  	Result json.RawMessage `json:"result"`
    80  	Type   string          `json:"type"`
    81  }
    82  
    83  func (resp *response) checkError() {
    84  	if resp.Type != "error" {
    85  		return
    86  	}
    87  	var resultErr Error
    88  	err := json.Unmarshal(resp.Result, &resultErr)
    89  	if err != nil || resultErr.Message == "" {
    90  		resp.err = fmt.Errorf("server error: %q", http.StatusText(resp.statusCode))
    91  	} else {
    92  		resp.err = &resultErr
    93  	}
    94  }
    95  
    96  func (client *Client) doMany(ctx context.Context, method, urlpath string, query url.Values, headers map[string]string, body []byte) ([]*response, error) {
    97  	sockets, err := filepath.Glob(filepath.Join(dirs.XdgRuntimeDirGlob, "snapd-session-agent.socket"))
    98  	if err != nil {
    99  		return nil, err
   100  	}
   101  	var (
   102  		wg        sync.WaitGroup
   103  		mu        sync.Mutex
   104  		responses []*response
   105  	)
   106  	for _, socket := range sockets {
   107  		wg.Add(1)
   108  		go func(socket string) {
   109  			defer wg.Done()
   110  			uidStr := filepath.Base(filepath.Dir(socket))
   111  			uid, err := strconv.Atoi(uidStr)
   112  			if err != nil {
   113  				// Ignore directories that do not
   114  				// appear to be valid XDG runtime dirs
   115  				// (i.e. /run/user/NNNN).
   116  				return
   117  			}
   118  			response := response{uid: uid}
   119  			defer func() {
   120  				mu.Lock()
   121  				defer mu.Unlock()
   122  				responses = append(responses, &response)
   123  			}()
   124  
   125  			u := url.URL{
   126  				Scheme:   "http",
   127  				Host:     uidStr,
   128  				Path:     urlpath,
   129  				RawQuery: query.Encode(),
   130  			}
   131  			req, err := http.NewRequest(method, u.String(), bytes.NewBuffer(body))
   132  			if err != nil {
   133  				response.err = fmt.Errorf("internal error: %v", err)
   134  				return
   135  			}
   136  			req = req.WithContext(ctx)
   137  			for key, value := range headers {
   138  				req.Header.Set(key, value)
   139  			}
   140  			httpResp, err := client.doer.Do(req)
   141  			if err != nil {
   142  				response.err = err
   143  				return
   144  			}
   145  			defer httpResp.Body.Close()
   146  			response.statusCode = httpResp.StatusCode
   147  			response.err = decodeInto(httpResp.Body, &response)
   148  			response.checkError()
   149  		}(socket)
   150  	}
   151  	wg.Wait()
   152  	return responses, nil
   153  }
   154  
   155  func decodeInto(reader io.Reader, v interface{}) error {
   156  	dec := json.NewDecoder(reader)
   157  	if err := dec.Decode(v); err != nil {
   158  		r := dec.Buffered()
   159  		buf, err1 := ioutil.ReadAll(r)
   160  		if err1 != nil {
   161  			buf = []byte(fmt.Sprintf("error reading buffered response body: %s", err1))
   162  		}
   163  		return fmt.Errorf("cannot decode %q: %s", buf, err)
   164  	}
   165  	return nil
   166  }
   167  
   168  type SessionInfo struct {
   169  	Version string `json:"version"`
   170  }
   171  
   172  func (client *Client) SessionInfo(ctx context.Context) (info map[int]SessionInfo, err error) {
   173  	responses, err := client.doMany(ctx, "GET", "/v1/session-info", nil, nil, nil)
   174  	if err != nil {
   175  		return nil, err
   176  	}
   177  
   178  	info = make(map[int]SessionInfo)
   179  	for _, resp := range responses {
   180  		if resp.err != nil {
   181  			if err == nil {
   182  				err = resp.err
   183  			}
   184  			continue
   185  		}
   186  		var si SessionInfo
   187  		if decodeErr := json.Unmarshal(resp.Result, &si); decodeErr != nil {
   188  			if err == nil {
   189  				err = decodeErr
   190  			}
   191  			continue
   192  		}
   193  		info[resp.uid] = si
   194  	}
   195  	return info, err
   196  }
   197  
   198  type ServiceFailure struct {
   199  	Uid     int
   200  	Service string
   201  	Error   string
   202  }
   203  
   204  func decodeServiceErrors(uid int, errorValue map[string]interface{}, kind string) ([]ServiceFailure, error) {
   205  	if errorValue[kind] == nil {
   206  		return nil, nil
   207  	}
   208  	errors, ok := errorValue[kind].(map[string]interface{})
   209  	if !ok {
   210  		return nil, fmt.Errorf("cannot decode %s failures: expected a map, got %T", kind, errorValue[kind])
   211  	}
   212  	var failures []ServiceFailure
   213  	var err error
   214  	for service, reason := range errors {
   215  		if reasonString, ok := reason.(string); ok {
   216  			failures = append(failures, ServiceFailure{
   217  				Uid:     uid,
   218  				Service: service,
   219  				Error:   reasonString,
   220  			})
   221  		} else if err == nil {
   222  			err = fmt.Errorf("cannot decode %s failure for %q: expected string, but got %T", kind, service, reason)
   223  		}
   224  	}
   225  	return failures, err
   226  }
   227  
   228  func (client *Client) serviceControlCall(ctx context.Context, action string, services []string) (startFailures, stopFailures []ServiceFailure, err error) {
   229  	headers := map[string]string{"Content-Type": "application/json"}
   230  	reqBody, err := json.Marshal(map[string]interface{}{
   231  		"action":   action,
   232  		"services": services,
   233  	})
   234  	if err != nil {
   235  		return nil, nil, err
   236  	}
   237  	responses, err := client.doMany(ctx, "POST", "/v1/service-control", nil, headers, reqBody)
   238  	if err != nil {
   239  		return nil, nil, err
   240  	}
   241  	for _, resp := range responses {
   242  		if agentErr, ok := resp.err.(*Error); ok && agentErr.Kind == "service-control" {
   243  			if errorValue, ok := agentErr.Value.(map[string]interface{}); ok {
   244  				failures, _ := decodeServiceErrors(resp.uid, errorValue, "start-errors")
   245  				startFailures = append(startFailures, failures...)
   246  				failures, _ = decodeServiceErrors(resp.uid, errorValue, "stop-errors")
   247  				stopFailures = append(stopFailures, failures...)
   248  			}
   249  		}
   250  		if resp.err != nil && err == nil {
   251  			err = resp.err
   252  		}
   253  	}
   254  	return startFailures, stopFailures, err
   255  }
   256  
   257  func (client *Client) ServicesDaemonReload(ctx context.Context) error {
   258  	_, _, err := client.serviceControlCall(ctx, "daemon-reload", nil)
   259  	return err
   260  }
   261  
   262  func (client *Client) ServicesStart(ctx context.Context, services []string) (startFailures, stopFailures []ServiceFailure, err error) {
   263  	return client.serviceControlCall(ctx, "start", services)
   264  }
   265  
   266  func (client *Client) ServicesStop(ctx context.Context, services []string) (stopFailures []ServiceFailure, err error) {
   267  	_, stopFailures, err = client.serviceControlCall(ctx, "stop", services)
   268  	return stopFailures, err
   269  }