github.com/david-imola/snapd@v0.0.0-20210611180407-2de8ddeece6d/overlord/devicestate/devicestatetest/devicesvc.go (about)

     1  // -*- Mode: Go; indent-tabs-mode: t -*-
     2  
     3  /*
     4   * Copyright (C) 2016-2019 Canonical Ltd
     5   *
     6   * This program is free software: you can redistribute it and/or modify
     7   * it under the terms of the GNU General Public License version 3 as
     8   * published by the Free Software Foundation.
     9   *
    10   * This program is distributed in the hope that it will be useful,
    11   * but WITHOUT ANY WARRANTY; without even the implied warranty of
    12   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    13   * GNU General Public License for more details.
    14   *
    15   * You should have received a copy of the GNU General Public License
    16   * along with this program.  If not, see <http://www.gnu.org/licenses/>.
    17   *
    18   */
    19  
    20  package devicestatetest
    21  
    22  import (
    23  	"bytes"
    24  	"fmt"
    25  	"io"
    26  	"net/http"
    27  	"net/http/httptest"
    28  	"sync"
    29  	"time"
    30  
    31  	. "gopkg.in/check.v1"
    32  
    33  	"github.com/snapcore/snapd/asserts"
    34  	"github.com/snapcore/snapd/snapdenv"
    35  )
    36  
    37  type DeviceServiceBehavior struct {
    38  	ReqID string
    39  
    40  	RequestIDURLPath     string
    41  	SerialURLPath        string
    42  	ExpectedCapabilities string
    43  
    44  	Head          func(c *C, bhv *DeviceServiceBehavior, w http.ResponseWriter, r *http.Request)
    45  	PostPreflight func(c *C, bhv *DeviceServiceBehavior, w http.ResponseWriter, r *http.Request)
    46  
    47  	SignSerial func(c *C, bhv *DeviceServiceBehavior, headers map[string]interface{}, body []byte) (serial asserts.Assertion, ancillary []asserts.Assertion, err error)
    48  }
    49  
    50  // Request IDs for hard-coded behaviors.
    51  const (
    52  	ReqIDFailID501          = "REQID-FAIL-ID-501"
    53  	ReqIDBadRequest         = "REQID-BAD-REQ"
    54  	ReqIDPoll               = "REQID-POLL"
    55  	ReqIDSerialWithBadModel = "REQID-SERIAL-W-BAD-MODEL"
    56  )
    57  
    58  const (
    59  	requestIDURLPath = "/api/v1/snaps/auth/request-id"
    60  	serialURLPath    = "/api/v1/snaps/auth/devices"
    61  )
    62  
    63  func MockDeviceService(c *C, bhv *DeviceServiceBehavior) *httptest.Server {
    64  	expectedUserAgent := snapdenv.UserAgent()
    65  
    66  	// default URL paths
    67  	if bhv.RequestIDURLPath == "" {
    68  		bhv.RequestIDURLPath = requestIDURLPath
    69  		bhv.SerialURLPath = serialURLPath
    70  	}
    71  	// currently supported
    72  	if bhv.ExpectedCapabilities == "" {
    73  		bhv.ExpectedCapabilities = "serial-stream"
    74  	}
    75  
    76  	var mu sync.Mutex
    77  	count := 0
    78  	return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    79  		// check.Assert here will produce harder to understand failure
    80  		// modes
    81  
    82  		switch r.Method {
    83  		default:
    84  			c.Errorf("unexpected verb %q", r.Method)
    85  			w.WriteHeader(500)
    86  			return
    87  		case "HEAD":
    88  			if r.URL.Path != "/" {
    89  				c.Errorf("unexpected HEAD request %q", r.URL.String())
    90  				w.WriteHeader(500)
    91  				return
    92  			}
    93  			if bhv.Head != nil {
    94  				bhv.Head(c, bhv, w, r)
    95  			}
    96  			w.WriteHeader(200)
    97  			return
    98  		case "POST":
    99  			// carry on
   100  		}
   101  
   102  		if bhv.PostPreflight != nil {
   103  			bhv.PostPreflight(c, bhv, w, r)
   104  		}
   105  
   106  		switch r.URL.Path {
   107  		default:
   108  			c.Errorf("unexpected POST request %q", r.URL.String())
   109  			w.WriteHeader(500)
   110  			return
   111  		case bhv.RequestIDURLPath:
   112  			if bhv.ReqID == ReqIDFailID501 {
   113  				w.WriteHeader(501)
   114  				return
   115  			}
   116  			w.WriteHeader(200)
   117  			c.Check(r.Header.Get("User-Agent"), Equals, expectedUserAgent)
   118  			io.WriteString(w, fmt.Sprintf(`{"request-id": "%s"}`, bhv.ReqID))
   119  		case bhv.SerialURLPath:
   120  			c.Check(r.Header.Get("User-Agent"), Equals, expectedUserAgent)
   121  			c.Check(r.Header.Get("Snap-Device-Capabilities"), Equals, bhv.ExpectedCapabilities)
   122  
   123  			mu.Lock()
   124  			serialNum := 9999 + count
   125  			count++
   126  			mu.Unlock()
   127  
   128  			dec := asserts.NewDecoder(r.Body)
   129  
   130  			a, err := dec.Decode()
   131  			if err != nil {
   132  				w.WriteHeader(400)
   133  				return
   134  			}
   135  			serialReq, ok := a.(*asserts.SerialRequest)
   136  			if !ok {
   137  				w.WriteHeader(400)
   138  				w.Write([]byte(`{
   139    "error_list": [{"message": "expected serial-request"}]
   140  }`))
   141  				return
   142  			}
   143  			extra := []asserts.Assertion{}
   144  			for {
   145  				a1, err := dec.Decode()
   146  				if err == io.EOF {
   147  					break
   148  				}
   149  				if err != nil {
   150  					w.WriteHeader(400)
   151  					return
   152  				}
   153  				extra = append(extra, a1)
   154  			}
   155  			err = asserts.SignatureCheck(serialReq, serialReq.DeviceKey())
   156  			c.Check(err, IsNil)
   157  			if err != nil {
   158  				// also return response to client
   159  				w.WriteHeader(400)
   160  				w.Write([]byte(`{
   161    "error_list": [{"message": "invalid serial-request self-signature"}]
   162  }`))
   163  				return
   164  			}
   165  			brandID := serialReq.BrandID()
   166  			model := serialReq.Model()
   167  			reqID := serialReq.RequestID()
   168  			if reqID == ReqIDBadRequest {
   169  				w.Header().Set("Content-Type", "application/json")
   170  				w.WriteHeader(400)
   171  				w.Write([]byte(`{
   172    "error_list": [{"message": "bad serial-request"}]
   173  }`))
   174  				return
   175  			}
   176  			if reqID == ReqIDPoll && serialNum != 10002 {
   177  				w.WriteHeader(202)
   178  				return
   179  			}
   180  			serialStr := fmt.Sprintf("%d", serialNum)
   181  			if serialReq.Serial() != "" {
   182  				// use proposed serial
   183  				serialStr = serialReq.Serial()
   184  			}
   185  			if serialReq.HeaderString("original-model") != "" {
   186  				// re-registration
   187  				if len(extra) != 2 {
   188  					w.WriteHeader(400)
   189  					w.Write([]byte(`{
   190    "error_list": [{"message": "expected model and original serial"}]
   191  }`))
   192  					return
   193  				}
   194  				_, ok := extra[0].(*asserts.Model)
   195  				if !ok {
   196  					w.WriteHeader(400)
   197  					w.Write([]byte(`{
   198    "error_list": [{"message": "expected model"}]
   199  }`))
   200  					return
   201  				}
   202  				origSerial, ok := extra[1].(*asserts.Serial)
   203  				if !ok {
   204  					w.WriteHeader(400)
   205  					w.Write([]byte(`{
   206    "error_list": [{"message": "expected model"}]
   207  }`))
   208  				}
   209  				c.Check(origSerial.DeviceKey(), DeepEquals, serialReq.DeviceKey())
   210  				// TODO: more checks once we have Original* accessors
   211  			} else {
   212  
   213  				mod, ok := extra[0].(*asserts.Model)
   214  				if !ok {
   215  					w.WriteHeader(400)
   216  					w.Write([]byte(`{
   217    "error_list": [{"message": "expected model"}]
   218  }`))
   219  					return
   220  				}
   221  				c.Check(mod.BrandID(), Equals, brandID)
   222  				c.Check(mod.Model(), Equals, model)
   223  			}
   224  			serial, ancillary, err := bhv.SignSerial(c, bhv, map[string]interface{}{
   225  				"authority-id":        "canonical",
   226  				"brand-id":            brandID,
   227  				"model":               model,
   228  				"serial":              serialStr,
   229  				"device-key":          serialReq.HeaderString("device-key"),
   230  				"device-key-sha3-384": serialReq.SignKeyID(),
   231  				"timestamp":           time.Now().Format(time.RFC3339),
   232  			}, serialReq.Body())
   233  			c.Check(err, IsNil)
   234  			if err != nil {
   235  				// also return response to client
   236  				w.WriteHeader(500)
   237  				return
   238  			}
   239  			w.Header().Set("Content-Type", asserts.MediaType)
   240  			w.WriteHeader(200)
   241  			if reqID == ReqIDSerialWithBadModel {
   242  				encoded := asserts.Encode(serial)
   243  
   244  				encoded = bytes.Replace(encoded, []byte("model: pc"), []byte("model: bad-model-foo"), 1)
   245  				w.Write(encoded)
   246  				return
   247  			}
   248  			enc := asserts.NewEncoder(w)
   249  			enc.Encode(serial)
   250  			for _, a := range ancillary {
   251  				enc.Encode(a)
   252  			}
   253  		}
   254  	}))
   255  }