github.com/kubiko/snapd@v0.0.0-20201013125620-d4f3094d9ddf/usersession/client/client.go (about) 1 // -*- Mode: Go; indent-tabs-mode: t -*- 2 3 /* 4 * Copyright (C) 2015-2020 Canonical Ltd 5 * 6 * This program is free software: you can redistribute it and/or modify 7 * it under the terms of the GNU General Public License version 3 as 8 * published by the Free Software Foundation. 9 * 10 * This program is distributed in the hope that it will be useful, 11 * but WITHOUT ANY WARRANTY; without even the implied warranty of 12 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 * GNU General Public License for more details. 14 * 15 * You should have received a copy of the GNU General Public License 16 * along with this program. If not, see <http://www.gnu.org/licenses/>. 17 * 18 */ 19 20 package client 21 22 import ( 23 "bytes" 24 "context" 25 "encoding/json" 26 "fmt" 27 "io" 28 "io/ioutil" 29 "net" 30 "net/http" 31 "net/url" 32 "path/filepath" 33 "strconv" 34 "sync" 35 36 "github.com/snapcore/snapd/dirs" 37 ) 38 39 // dialSessionAgent connects to a user's session agent 40 // 41 // The host portion of the address is interpreted as the numeric user 42 // ID of the target user. 43 func dialSessionAgent(network, address string) (net.Conn, error) { 44 host, _, err := net.SplitHostPort(address) 45 if err != nil { 46 return nil, err 47 } 48 socket := filepath.Join(dirs.XdgRuntimeDirBase, host, "snapd-session-agent.socket") 49 return net.Dial("unix", socket) 50 } 51 52 type Client struct { 53 doer *http.Client 54 } 55 56 func New() *Client { 57 transport := &http.Transport{Dial: dialSessionAgent, DisableKeepAlives: true} 58 return &Client{ 59 doer: &http.Client{Transport: transport}, 60 } 61 } 62 63 type Error struct { 64 Kind string `json:"kind"` 65 Value interface{} `json:"value"` 66 Message string `json:"message"` 67 } 68 69 func (e *Error) Error() string { 70 return e.Message 71 } 72 73 type response struct { 74 // Not from JSON 75 uid int 76 err error 77 statusCode int 78 79 Result json.RawMessage `json:"result"` 80 Type string `json:"type"` 81 } 82 83 func (resp *response) checkError() { 84 if resp.Type != "error" { 85 return 86 } 87 var resultErr Error 88 err := json.Unmarshal(resp.Result, &resultErr) 89 if err != nil || resultErr.Message == "" { 90 resp.err = fmt.Errorf("server error: %q", http.StatusText(resp.statusCode)) 91 } else { 92 resp.err = &resultErr 93 } 94 } 95 96 func (client *Client) doMany(ctx context.Context, method, urlpath string, query url.Values, headers map[string]string, body []byte) ([]*response, error) { 97 sockets, err := filepath.Glob(filepath.Join(dirs.XdgRuntimeDirGlob, "snapd-session-agent.socket")) 98 if err != nil { 99 return nil, err 100 } 101 var ( 102 wg sync.WaitGroup 103 mu sync.Mutex 104 responses []*response 105 ) 106 for _, socket := range sockets { 107 wg.Add(1) 108 go func(socket string) { 109 defer wg.Done() 110 uidStr := filepath.Base(filepath.Dir(socket)) 111 uid, err := strconv.Atoi(uidStr) 112 if err != nil { 113 // Ignore directories that do not 114 // appear to be valid XDG runtime dirs 115 // (i.e. /run/user/NNNN). 116 return 117 } 118 response := response{uid: uid} 119 defer func() { 120 mu.Lock() 121 defer mu.Unlock() 122 responses = append(responses, &response) 123 }() 124 125 u := url.URL{ 126 Scheme: "http", 127 Host: uidStr, 128 Path: urlpath, 129 RawQuery: query.Encode(), 130 } 131 req, err := http.NewRequest(method, u.String(), bytes.NewBuffer(body)) 132 if err != nil { 133 response.err = fmt.Errorf("internal error: %v", err) 134 return 135 } 136 req = req.WithContext(ctx) 137 for key, value := range headers { 138 req.Header.Set(key, value) 139 } 140 httpResp, err := client.doer.Do(req) 141 if err != nil { 142 response.err = err 143 return 144 } 145 defer httpResp.Body.Close() 146 response.statusCode = httpResp.StatusCode 147 response.err = decodeInto(httpResp.Body, &response) 148 response.checkError() 149 }(socket) 150 } 151 wg.Wait() 152 return responses, nil 153 } 154 155 func decodeInto(reader io.Reader, v interface{}) error { 156 dec := json.NewDecoder(reader) 157 if err := dec.Decode(v); err != nil { 158 r := dec.Buffered() 159 buf, err1 := ioutil.ReadAll(r) 160 if err1 != nil { 161 buf = []byte(fmt.Sprintf("error reading buffered response body: %s", err1)) 162 } 163 return fmt.Errorf("cannot decode %q: %s", buf, err) 164 } 165 return nil 166 } 167 168 type SessionInfo struct { 169 Version string `json:"version"` 170 } 171 172 func (client *Client) SessionInfo(ctx context.Context) (info map[int]SessionInfo, err error) { 173 responses, err := client.doMany(ctx, "GET", "/v1/session-info", nil, nil, nil) 174 if err != nil { 175 return nil, err 176 } 177 178 info = make(map[int]SessionInfo) 179 for _, resp := range responses { 180 if resp.err != nil { 181 if err == nil { 182 err = resp.err 183 } 184 continue 185 } 186 var si SessionInfo 187 if decodeErr := json.Unmarshal(resp.Result, &si); decodeErr != nil { 188 if err == nil { 189 err = decodeErr 190 } 191 continue 192 } 193 info[resp.uid] = si 194 } 195 return info, err 196 } 197 198 type ServiceFailure struct { 199 Uid int 200 Service string 201 Error string 202 } 203 204 func decodeServiceErrors(uid int, errorValue map[string]interface{}, kind string) ([]ServiceFailure, error) { 205 if errorValue[kind] == nil { 206 return nil, nil 207 } 208 errors, ok := errorValue[kind].(map[string]interface{}) 209 if !ok { 210 return nil, fmt.Errorf("cannot decode %s failures: expected a map, got %T", kind, errorValue[kind]) 211 } 212 var failures []ServiceFailure 213 var err error 214 for service, reason := range errors { 215 if reasonString, ok := reason.(string); ok { 216 failures = append(failures, ServiceFailure{ 217 Uid: uid, 218 Service: service, 219 Error: reasonString, 220 }) 221 } else if err == nil { 222 err = fmt.Errorf("cannot decode %s failure for %q: expected string, but got %T", kind, service, reason) 223 } 224 } 225 return failures, err 226 } 227 228 func (client *Client) serviceControlCall(ctx context.Context, action string, services []string) (startFailures, stopFailures []ServiceFailure, err error) { 229 headers := map[string]string{"Content-Type": "application/json"} 230 reqBody, err := json.Marshal(map[string]interface{}{ 231 "action": action, 232 "services": services, 233 }) 234 if err != nil { 235 return nil, nil, err 236 } 237 responses, err := client.doMany(ctx, "POST", "/v1/service-control", nil, headers, reqBody) 238 if err != nil { 239 return nil, nil, err 240 } 241 for _, resp := range responses { 242 if agentErr, ok := resp.err.(*Error); ok && agentErr.Kind == "service-control" { 243 if errorValue, ok := agentErr.Value.(map[string]interface{}); ok { 244 failures, _ := decodeServiceErrors(resp.uid, errorValue, "start-errors") 245 startFailures = append(startFailures, failures...) 246 failures, _ = decodeServiceErrors(resp.uid, errorValue, "stop-errors") 247 stopFailures = append(stopFailures, failures...) 248 } 249 } 250 if resp.err != nil && err == nil { 251 err = resp.err 252 } 253 } 254 return startFailures, stopFailures, err 255 } 256 257 func (client *Client) ServicesDaemonReload(ctx context.Context) error { 258 _, _, err := client.serviceControlCall(ctx, "daemon-reload", nil) 259 return err 260 } 261 262 func (client *Client) ServicesStart(ctx context.Context, services []string) (startFailures, stopFailures []ServiceFailure, err error) { 263 return client.serviceControlCall(ctx, "start", services) 264 } 265 266 func (client *Client) ServicesStop(ctx context.Context, services []string) (stopFailures []ServiceFailure, err error) { 267 _, stopFailures, err = client.serviceControlCall(ctx, "stop", services) 268 return stopFailures, err 269 }