github.com/rigado/snapd@v2.42.5-go-mod+incompatible/overlord/devicestate/devicestatetest/devicesvc.go (about) 1 // -*- Mode: Go; indent-tabs-mode: t -*- 2 3 /* 4 * Copyright (C) 2016-2019 Canonical Ltd 5 * 6 * This program is free software: you can redistribute it and/or modify 7 * it under the terms of the GNU General Public License version 3 as 8 * published by the Free Software Foundation. 9 * 10 * This program is distributed in the hope that it will be useful, 11 * but WITHOUT ANY WARRANTY; without even the implied warranty of 12 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 * GNU General Public License for more details. 14 * 15 * You should have received a copy of the GNU General Public License 16 * along with this program. If not, see <http://www.gnu.org/licenses/>. 17 * 18 */ 19 20 package devicestatetest 21 22 import ( 23 "bytes" 24 "fmt" 25 "io" 26 "net/http" 27 "net/http/httptest" 28 "sync" 29 "time" 30 31 . "gopkg.in/check.v1" 32 33 "github.com/snapcore/snapd/asserts" 34 "github.com/snapcore/snapd/httputil" 35 ) 36 37 type DeviceServiceBehavior struct { 38 ReqID string 39 40 RequestIDURLPath string 41 SerialURLPath string 42 ExpectedCapabilities string 43 44 Head func(c *C, bhv *DeviceServiceBehavior, w http.ResponseWriter, r *http.Request) 45 PostPreflight func(c *C, bhv *DeviceServiceBehavior, w http.ResponseWriter, r *http.Request) 46 47 SignSerial func(c *C, bhv *DeviceServiceBehavior, headers map[string]interface{}, body []byte) (serial asserts.Assertion, ancillary []asserts.Assertion, err error) 48 } 49 50 // Request IDs for hard-coded behaviors. 51 const ( 52 ReqIDFailID501 = "REQID-FAIL-ID-501" 53 ReqIDBadRequest = "REQID-BAD-REQ" 54 ReqIDPoll = "REQID-POLL" 55 ReqIDSerialWithBadModel = "REQID-SERIAL-W-BAD-MODEL" 56 ) 57 58 const ( 59 requestIDURLPath = "/api/v1/snaps/auth/request-id" 60 serialURLPath = "/api/v1/snaps/auth/devices" 61 ) 62 63 func MockDeviceService(c *C, bhv *DeviceServiceBehavior) *httptest.Server { 64 expectedUserAgent := httputil.UserAgent() 65 66 // default URL paths 67 if bhv.RequestIDURLPath == "" { 68 bhv.RequestIDURLPath = requestIDURLPath 69 bhv.SerialURLPath = serialURLPath 70 } 71 // currently supported 72 if bhv.ExpectedCapabilities == "" { 73 bhv.ExpectedCapabilities = "serial-stream" 74 } 75 76 var mu sync.Mutex 77 count := 0 78 return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 79 switch r.Method { 80 default: 81 c.Fatalf("unexpected verb %q", r.Method) 82 case "HEAD": 83 if r.URL.Path != "/" { 84 c.Fatalf("unexpected HEAD request %q", r.URL.String()) 85 } 86 if bhv.Head != nil { 87 bhv.Head(c, bhv, w, r) 88 } 89 w.WriteHeader(200) 90 return 91 case "POST": 92 // carry on 93 } 94 95 if bhv.PostPreflight != nil { 96 bhv.PostPreflight(c, bhv, w, r) 97 } 98 99 switch r.URL.Path { 100 default: 101 c.Fatalf("unexpected POST request %q", r.URL.String()) 102 case bhv.RequestIDURLPath: 103 if bhv.ReqID == ReqIDFailID501 { 104 w.WriteHeader(501) 105 return 106 } 107 w.WriteHeader(200) 108 c.Check(r.Header.Get("User-Agent"), Equals, expectedUserAgent) 109 io.WriteString(w, fmt.Sprintf(`{"request-id": "%s"}`, bhv.ReqID)) 110 case bhv.SerialURLPath: 111 c.Check(r.Header.Get("User-Agent"), Equals, expectedUserAgent) 112 c.Check(r.Header.Get("Snap-Device-Capabilities"), Equals, bhv.ExpectedCapabilities) 113 114 mu.Lock() 115 serialNum := 9999 + count 116 count++ 117 mu.Unlock() 118 119 dec := asserts.NewDecoder(r.Body) 120 121 a, err := dec.Decode() 122 c.Assert(err, IsNil) 123 serialReq, ok := a.(*asserts.SerialRequest) 124 c.Assert(ok, Equals, true) 125 extra := []asserts.Assertion{} 126 for { 127 a1, err := dec.Decode() 128 if err == io.EOF { 129 break 130 } 131 c.Assert(err, IsNil) 132 extra = append(extra, a1) 133 } 134 err = asserts.SignatureCheck(serialReq, serialReq.DeviceKey()) 135 c.Assert(err, IsNil) 136 brandID := serialReq.BrandID() 137 model := serialReq.Model() 138 reqID := serialReq.RequestID() 139 if reqID == ReqIDBadRequest { 140 w.Header().Set("Content-Type", "application/json") 141 w.WriteHeader(400) 142 w.Write([]byte(`{ 143 "error_list": [{"message": "bad serial-request"}] 144 }`)) 145 return 146 } 147 if reqID == ReqIDPoll && serialNum != 10002 { 148 w.WriteHeader(202) 149 return 150 } 151 serialStr := fmt.Sprintf("%d", serialNum) 152 if serialReq.Serial() != "" { 153 // use proposed serial 154 serialStr = serialReq.Serial() 155 } 156 if serialReq.HeaderString("original-model") != "" { 157 // re-registration 158 c.Check(extra, HasLen, 2) 159 _, ok := extra[0].(*asserts.Model) 160 c.Check(ok, Equals, true) 161 origSerial, ok := extra[1].(*asserts.Serial) 162 c.Check(ok, Equals, true) 163 c.Check(origSerial.DeviceKey(), DeepEquals, serialReq.DeviceKey()) 164 // TODO: more checks once we have Original* accessors 165 } else { 166 c.Check(extra, HasLen, 0) 167 } 168 serial, ancillary, err := bhv.SignSerial(c, bhv, map[string]interface{}{ 169 "authority-id": "canonical", 170 "brand-id": brandID, 171 "model": model, 172 "serial": serialStr, 173 "device-key": serialReq.HeaderString("device-key"), 174 "device-key-sha3-384": serialReq.SignKeyID(), 175 "timestamp": time.Now().Format(time.RFC3339), 176 }, serialReq.Body()) 177 c.Assert(err, IsNil) 178 w.Header().Set("Content-Type", asserts.MediaType) 179 w.WriteHeader(200) 180 if reqID == ReqIDSerialWithBadModel { 181 encoded := asserts.Encode(serial) 182 183 encoded = bytes.Replace(encoded, []byte("model: pc"), []byte("model: bad-model-foo"), 1) 184 w.Write(encoded) 185 return 186 } 187 enc := asserts.NewEncoder(w) 188 enc.Encode(serial) 189 for _, a := range ancillary { 190 enc.Encode(a) 191 } 192 } 193 })) 194 }