github.com/david-imola/snapd@v0.0.0-20210611180407-2de8ddeece6d/overlord/devicestate/devicestatetest/devicesvc.go (about) 1 // -*- Mode: Go; indent-tabs-mode: t -*- 2 3 /* 4 * Copyright (C) 2016-2019 Canonical Ltd 5 * 6 * This program is free software: you can redistribute it and/or modify 7 * it under the terms of the GNU General Public License version 3 as 8 * published by the Free Software Foundation. 9 * 10 * This program is distributed in the hope that it will be useful, 11 * but WITHOUT ANY WARRANTY; without even the implied warranty of 12 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 * GNU General Public License for more details. 14 * 15 * You should have received a copy of the GNU General Public License 16 * along with this program. If not, see <http://www.gnu.org/licenses/>. 17 * 18 */ 19 20 package devicestatetest 21 22 import ( 23 "bytes" 24 "fmt" 25 "io" 26 "net/http" 27 "net/http/httptest" 28 "sync" 29 "time" 30 31 . "gopkg.in/check.v1" 32 33 "github.com/snapcore/snapd/asserts" 34 "github.com/snapcore/snapd/snapdenv" 35 ) 36 37 type DeviceServiceBehavior struct { 38 ReqID string 39 40 RequestIDURLPath string 41 SerialURLPath string 42 ExpectedCapabilities string 43 44 Head func(c *C, bhv *DeviceServiceBehavior, w http.ResponseWriter, r *http.Request) 45 PostPreflight func(c *C, bhv *DeviceServiceBehavior, w http.ResponseWriter, r *http.Request) 46 47 SignSerial func(c *C, bhv *DeviceServiceBehavior, headers map[string]interface{}, body []byte) (serial asserts.Assertion, ancillary []asserts.Assertion, err error) 48 } 49 50 // Request IDs for hard-coded behaviors. 51 const ( 52 ReqIDFailID501 = "REQID-FAIL-ID-501" 53 ReqIDBadRequest = "REQID-BAD-REQ" 54 ReqIDPoll = "REQID-POLL" 55 ReqIDSerialWithBadModel = "REQID-SERIAL-W-BAD-MODEL" 56 ) 57 58 const ( 59 requestIDURLPath = "/api/v1/snaps/auth/request-id" 60 serialURLPath = "/api/v1/snaps/auth/devices" 61 ) 62 63 func MockDeviceService(c *C, bhv *DeviceServiceBehavior) *httptest.Server { 64 expectedUserAgent := snapdenv.UserAgent() 65 66 // default URL paths 67 if bhv.RequestIDURLPath == "" { 68 bhv.RequestIDURLPath = requestIDURLPath 69 bhv.SerialURLPath = serialURLPath 70 } 71 // currently supported 72 if bhv.ExpectedCapabilities == "" { 73 bhv.ExpectedCapabilities = "serial-stream" 74 } 75 76 var mu sync.Mutex 77 count := 0 78 return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 79 // check.Assert here will produce harder to understand failure 80 // modes 81 82 switch r.Method { 83 default: 84 c.Errorf("unexpected verb %q", r.Method) 85 w.WriteHeader(500) 86 return 87 case "HEAD": 88 if r.URL.Path != "/" { 89 c.Errorf("unexpected HEAD request %q", r.URL.String()) 90 w.WriteHeader(500) 91 return 92 } 93 if bhv.Head != nil { 94 bhv.Head(c, bhv, w, r) 95 } 96 w.WriteHeader(200) 97 return 98 case "POST": 99 // carry on 100 } 101 102 if bhv.PostPreflight != nil { 103 bhv.PostPreflight(c, bhv, w, r) 104 } 105 106 switch r.URL.Path { 107 default: 108 c.Errorf("unexpected POST request %q", r.URL.String()) 109 w.WriteHeader(500) 110 return 111 case bhv.RequestIDURLPath: 112 if bhv.ReqID == ReqIDFailID501 { 113 w.WriteHeader(501) 114 return 115 } 116 w.WriteHeader(200) 117 c.Check(r.Header.Get("User-Agent"), Equals, expectedUserAgent) 118 io.WriteString(w, fmt.Sprintf(`{"request-id": "%s"}`, bhv.ReqID)) 119 case bhv.SerialURLPath: 120 c.Check(r.Header.Get("User-Agent"), Equals, expectedUserAgent) 121 c.Check(r.Header.Get("Snap-Device-Capabilities"), Equals, bhv.ExpectedCapabilities) 122 123 mu.Lock() 124 serialNum := 9999 + count 125 count++ 126 mu.Unlock() 127 128 dec := asserts.NewDecoder(r.Body) 129 130 a, err := dec.Decode() 131 if err != nil { 132 w.WriteHeader(400) 133 return 134 } 135 serialReq, ok := a.(*asserts.SerialRequest) 136 if !ok { 137 w.WriteHeader(400) 138 w.Write([]byte(`{ 139 "error_list": [{"message": "expected serial-request"}] 140 }`)) 141 return 142 } 143 extra := []asserts.Assertion{} 144 for { 145 a1, err := dec.Decode() 146 if err == io.EOF { 147 break 148 } 149 if err != nil { 150 w.WriteHeader(400) 151 return 152 } 153 extra = append(extra, a1) 154 } 155 err = asserts.SignatureCheck(serialReq, serialReq.DeviceKey()) 156 c.Check(err, IsNil) 157 if err != nil { 158 // also return response to client 159 w.WriteHeader(400) 160 w.Write([]byte(`{ 161 "error_list": [{"message": "invalid serial-request self-signature"}] 162 }`)) 163 return 164 } 165 brandID := serialReq.BrandID() 166 model := serialReq.Model() 167 reqID := serialReq.RequestID() 168 if reqID == ReqIDBadRequest { 169 w.Header().Set("Content-Type", "application/json") 170 w.WriteHeader(400) 171 w.Write([]byte(`{ 172 "error_list": [{"message": "bad serial-request"}] 173 }`)) 174 return 175 } 176 if reqID == ReqIDPoll && serialNum != 10002 { 177 w.WriteHeader(202) 178 return 179 } 180 serialStr := fmt.Sprintf("%d", serialNum) 181 if serialReq.Serial() != "" { 182 // use proposed serial 183 serialStr = serialReq.Serial() 184 } 185 if serialReq.HeaderString("original-model") != "" { 186 // re-registration 187 if len(extra) != 2 { 188 w.WriteHeader(400) 189 w.Write([]byte(`{ 190 "error_list": [{"message": "expected model and original serial"}] 191 }`)) 192 return 193 } 194 _, ok := extra[0].(*asserts.Model) 195 if !ok { 196 w.WriteHeader(400) 197 w.Write([]byte(`{ 198 "error_list": [{"message": "expected model"}] 199 }`)) 200 return 201 } 202 origSerial, ok := extra[1].(*asserts.Serial) 203 if !ok { 204 w.WriteHeader(400) 205 w.Write([]byte(`{ 206 "error_list": [{"message": "expected model"}] 207 }`)) 208 } 209 c.Check(origSerial.DeviceKey(), DeepEquals, serialReq.DeviceKey()) 210 // TODO: more checks once we have Original* accessors 211 } else { 212 213 mod, ok := extra[0].(*asserts.Model) 214 if !ok { 215 w.WriteHeader(400) 216 w.Write([]byte(`{ 217 "error_list": [{"message": "expected model"}] 218 }`)) 219 return 220 } 221 c.Check(mod.BrandID(), Equals, brandID) 222 c.Check(mod.Model(), Equals, model) 223 } 224 serial, ancillary, err := bhv.SignSerial(c, bhv, map[string]interface{}{ 225 "authority-id": "canonical", 226 "brand-id": brandID, 227 "model": model, 228 "serial": serialStr, 229 "device-key": serialReq.HeaderString("device-key"), 230 "device-key-sha3-384": serialReq.SignKeyID(), 231 "timestamp": time.Now().Format(time.RFC3339), 232 }, serialReq.Body()) 233 c.Check(err, IsNil) 234 if err != nil { 235 // also return response to client 236 w.WriteHeader(500) 237 return 238 } 239 w.Header().Set("Content-Type", asserts.MediaType) 240 w.WriteHeader(200) 241 if reqID == ReqIDSerialWithBadModel { 242 encoded := asserts.Encode(serial) 243 244 encoded = bytes.Replace(encoded, []byte("model: pc"), []byte("model: bad-model-foo"), 1) 245 w.Write(encoded) 246 return 247 } 248 enc := asserts.NewEncoder(w) 249 enc.Encode(serial) 250 for _, a := range ancillary { 251 enc.Encode(a) 252 } 253 } 254 })) 255 }