github.com/rigado/snapd@v2.42.5-go-mod+incompatible/overlord/devicestate/devicestatetest/devicesvc.go (about)

     1  // -*- Mode: Go; indent-tabs-mode: t -*-
     2  
     3  /*
     4   * Copyright (C) 2016-2019 Canonical Ltd
     5   *
     6   * This program is free software: you can redistribute it and/or modify
     7   * it under the terms of the GNU General Public License version 3 as
     8   * published by the Free Software Foundation.
     9   *
    10   * This program is distributed in the hope that it will be useful,
    11   * but WITHOUT ANY WARRANTY; without even the implied warranty of
    12   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    13   * GNU General Public License for more details.
    14   *
    15   * You should have received a copy of the GNU General Public License
    16   * along with this program.  If not, see <http://www.gnu.org/licenses/>.
    17   *
    18   */
    19  
    20  package devicestatetest
    21  
    22  import (
    23  	"bytes"
    24  	"fmt"
    25  	"io"
    26  	"net/http"
    27  	"net/http/httptest"
    28  	"sync"
    29  	"time"
    30  
    31  	. "gopkg.in/check.v1"
    32  
    33  	"github.com/snapcore/snapd/asserts"
    34  	"github.com/snapcore/snapd/httputil"
    35  )
    36  
    37  type DeviceServiceBehavior struct {
    38  	ReqID string
    39  
    40  	RequestIDURLPath     string
    41  	SerialURLPath        string
    42  	ExpectedCapabilities string
    43  
    44  	Head          func(c *C, bhv *DeviceServiceBehavior, w http.ResponseWriter, r *http.Request)
    45  	PostPreflight func(c *C, bhv *DeviceServiceBehavior, w http.ResponseWriter, r *http.Request)
    46  
    47  	SignSerial func(c *C, bhv *DeviceServiceBehavior, headers map[string]interface{}, body []byte) (serial asserts.Assertion, ancillary []asserts.Assertion, err error)
    48  }
    49  
    50  // Request IDs for hard-coded behaviors.
    51  const (
    52  	ReqIDFailID501          = "REQID-FAIL-ID-501"
    53  	ReqIDBadRequest         = "REQID-BAD-REQ"
    54  	ReqIDPoll               = "REQID-POLL"
    55  	ReqIDSerialWithBadModel = "REQID-SERIAL-W-BAD-MODEL"
    56  )
    57  
    58  const (
    59  	requestIDURLPath = "/api/v1/snaps/auth/request-id"
    60  	serialURLPath    = "/api/v1/snaps/auth/devices"
    61  )
    62  
    63  func MockDeviceService(c *C, bhv *DeviceServiceBehavior) *httptest.Server {
    64  	expectedUserAgent := httputil.UserAgent()
    65  
    66  	// default URL paths
    67  	if bhv.RequestIDURLPath == "" {
    68  		bhv.RequestIDURLPath = requestIDURLPath
    69  		bhv.SerialURLPath = serialURLPath
    70  	}
    71  	// currently supported
    72  	if bhv.ExpectedCapabilities == "" {
    73  		bhv.ExpectedCapabilities = "serial-stream"
    74  	}
    75  
    76  	var mu sync.Mutex
    77  	count := 0
    78  	return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    79  		switch r.Method {
    80  		default:
    81  			c.Fatalf("unexpected verb %q", r.Method)
    82  		case "HEAD":
    83  			if r.URL.Path != "/" {
    84  				c.Fatalf("unexpected HEAD request %q", r.URL.String())
    85  			}
    86  			if bhv.Head != nil {
    87  				bhv.Head(c, bhv, w, r)
    88  			}
    89  			w.WriteHeader(200)
    90  			return
    91  		case "POST":
    92  			// carry on
    93  		}
    94  
    95  		if bhv.PostPreflight != nil {
    96  			bhv.PostPreflight(c, bhv, w, r)
    97  		}
    98  
    99  		switch r.URL.Path {
   100  		default:
   101  			c.Fatalf("unexpected POST request %q", r.URL.String())
   102  		case bhv.RequestIDURLPath:
   103  			if bhv.ReqID == ReqIDFailID501 {
   104  				w.WriteHeader(501)
   105  				return
   106  			}
   107  			w.WriteHeader(200)
   108  			c.Check(r.Header.Get("User-Agent"), Equals, expectedUserAgent)
   109  			io.WriteString(w, fmt.Sprintf(`{"request-id": "%s"}`, bhv.ReqID))
   110  		case bhv.SerialURLPath:
   111  			c.Check(r.Header.Get("User-Agent"), Equals, expectedUserAgent)
   112  			c.Check(r.Header.Get("Snap-Device-Capabilities"), Equals, bhv.ExpectedCapabilities)
   113  
   114  			mu.Lock()
   115  			serialNum := 9999 + count
   116  			count++
   117  			mu.Unlock()
   118  
   119  			dec := asserts.NewDecoder(r.Body)
   120  
   121  			a, err := dec.Decode()
   122  			c.Assert(err, IsNil)
   123  			serialReq, ok := a.(*asserts.SerialRequest)
   124  			c.Assert(ok, Equals, true)
   125  			extra := []asserts.Assertion{}
   126  			for {
   127  				a1, err := dec.Decode()
   128  				if err == io.EOF {
   129  					break
   130  				}
   131  				c.Assert(err, IsNil)
   132  				extra = append(extra, a1)
   133  			}
   134  			err = asserts.SignatureCheck(serialReq, serialReq.DeviceKey())
   135  			c.Assert(err, IsNil)
   136  			brandID := serialReq.BrandID()
   137  			model := serialReq.Model()
   138  			reqID := serialReq.RequestID()
   139  			if reqID == ReqIDBadRequest {
   140  				w.Header().Set("Content-Type", "application/json")
   141  				w.WriteHeader(400)
   142  				w.Write([]byte(`{
   143    "error_list": [{"message": "bad serial-request"}]
   144  }`))
   145  				return
   146  			}
   147  			if reqID == ReqIDPoll && serialNum != 10002 {
   148  				w.WriteHeader(202)
   149  				return
   150  			}
   151  			serialStr := fmt.Sprintf("%d", serialNum)
   152  			if serialReq.Serial() != "" {
   153  				// use proposed serial
   154  				serialStr = serialReq.Serial()
   155  			}
   156  			if serialReq.HeaderString("original-model") != "" {
   157  				// re-registration
   158  				c.Check(extra, HasLen, 2)
   159  				_, ok := extra[0].(*asserts.Model)
   160  				c.Check(ok, Equals, true)
   161  				origSerial, ok := extra[1].(*asserts.Serial)
   162  				c.Check(ok, Equals, true)
   163  				c.Check(origSerial.DeviceKey(), DeepEquals, serialReq.DeviceKey())
   164  				// TODO: more checks once we have Original* accessors
   165  			} else {
   166  				c.Check(extra, HasLen, 0)
   167  			}
   168  			serial, ancillary, err := bhv.SignSerial(c, bhv, map[string]interface{}{
   169  				"authority-id":        "canonical",
   170  				"brand-id":            brandID,
   171  				"model":               model,
   172  				"serial":              serialStr,
   173  				"device-key":          serialReq.HeaderString("device-key"),
   174  				"device-key-sha3-384": serialReq.SignKeyID(),
   175  				"timestamp":           time.Now().Format(time.RFC3339),
   176  			}, serialReq.Body())
   177  			c.Assert(err, IsNil)
   178  			w.Header().Set("Content-Type", asserts.MediaType)
   179  			w.WriteHeader(200)
   180  			if reqID == ReqIDSerialWithBadModel {
   181  				encoded := asserts.Encode(serial)
   182  
   183  				encoded = bytes.Replace(encoded, []byte("model: pc"), []byte("model: bad-model-foo"), 1)
   184  				w.Write(encoded)
   185  				return
   186  			}
   187  			enc := asserts.NewEncoder(w)
   188  			enc.Encode(serial)
   189  			for _, a := range ancillary {
   190  				enc.Encode(a)
   191  			}
   192  		}
   193  	}))
   194  }