code.pfad.fr/gohmekit@v0.2.1/pairing/pair_setup_verify_test.go (about)

     1  package pairing_test
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/ed25519"
     6  	"io"
     7  	"net"
     8  	"net/http"
     9  	"testing"
    10  	"time"
    11  
    12  	"code.pfad.fr/gohmekit/pairing"
    13  	"code.pfad.fr/gohmekit/storage"
    14  	"code.pfad.fr/gohmekit/tlv8"
    15  	"github.com/brutella/hc/db"
    16  	"github.com/brutella/hc/hap"
    17  	"github.com/brutella/hc/hap/pair"
    18  	"github.com/brutella/hc/util"
    19  	"gotest.tools/v3/assert"
    20  )
    21  
    22  // Table 5-3.
    23  const (
    24  	kTLVMethod_PairSetup     = 0
    25  	kTLVMethod_AddPairing    = 3
    26  	kTLVMethod_RemovePairing = 4
    27  	kTLVMethod_ListPairings  = 5
    28  )
    29  
    30  // Table 5-5.
    31  const (
    32  	kTLVError_Unavailable = 0x06
    33  )
    34  
    35  func tlv8Marshal(t *testing.T, v any) io.Reader {
    36  	m, err := tlv8.Marshal(v)
    37  	if err != nil {
    38  		t.Fatal(err)
    39  	}
    40  	return bytes.NewReader(m)
    41  }
    42  
    43  func TestPairSetupVerifyIntegration(t *testing.T) {
    44  	type request struct {
    45  		State      byte   `tlv8:"kTLVType_State"`
    46  		Method     byte   `tlv8:"kTLVType_Method"`
    47  		Identifier []byte `tlv8:"kTLVType_Identifier"`
    48  		PublicKey  []byte `tlv8:"kTLVType_PublicKey"`
    49  	}
    50  	type response struct {
    51  		State byte `tlv8:"kTLVType_State"`
    52  		Error byte `tlv8:"kTLVType_Error"`
    53  	}
    54  
    55  	deviceDatabase := storage.NewMemDatabase()
    56  	_, privateKey, err := ed25519.GenerateKey(nil)
    57  	assert.NilError(t, err)
    58  	device, err := pairing.NewDeviceWithPin(pairing.NewRandomPairingID(), "001-02-003", privateKey)
    59  	assert.NilError(t, err)
    60  
    61  	s := pairing.NewServer(&http.Server{
    62  		ReadTimeout: 5 * time.Second,
    63  	}, device, deviceDatabase)
    64  
    65  	addr := startServer(t, s)
    66  
    67  	clientStorage, err := util.NewFileStorage(t.TempDir())
    68  	assert.NilError(t, err)
    69  	clientDatabase := db.NewDatabaseWithStorage(clientStorage)
    70  
    71  	client, err := hap.NewDevice("HomeKit Client", clientDatabase)
    72  	assert.NilError(t, err)
    73  
    74  	httpClient, encryptClientConn := newHTTPClient()
    75  
    76  	unauthorizedPairings, err := httpClient.Post(addr+"pairings", pairing.ContentType, nil)
    77  	assert.NilError(t, err)
    78  	assert.Equal(t, 470, unauthorizedPairings.StatusCode)
    79  	_, err = io.ReadAll(unauthorizedPairings.Body)
    80  	assert.NilError(t, err)
    81  	assert.NilError(t, unauthorizedPairings.Body.Close())
    82  
    83  	//////////////////////////
    84  	// setup of the pairing //
    85  	//////////////////////////
    86  	setupController := pair.NewSetupClientController("001-02-003", client, clientDatabase)
    87  
    88  	// 5.6.1) C -> S
    89  	srpStartResponse, err := httpClient.Post(addr+"pair-setup", pairing.ContentType, setupController.InitialPairingRequest())
    90  	assert.NilError(t, err)
    91  	assert.Equal(t, 200, srpStartResponse.StatusCode)
    92  
    93  	// 5.6.2) S -> C
    94  	srpVerifyRequest, err := pair.HandleReaderForHandler(srpStartResponse.Body, setupController)
    95  	assert.NilError(t, err)
    96  	assert.NilError(t, srpStartResponse.Body.Close())
    97  
    98  	// 5.6.3) C -> S
    99  	srpVerifyResponse, err := httpClient.Post(addr+"pair-setup", pairing.ContentType, srpVerifyRequest)
   100  	assert.NilError(t, err)
   101  	assert.Equal(t, 200, srpVerifyResponse.StatusCode)
   102  
   103  	// 5.6.4) S -> C
   104  	exchangeRequest, err := pair.HandleReaderForHandler(srpVerifyResponse.Body, setupController)
   105  	assert.NilError(t, err)
   106  	assert.NilError(t, srpVerifyResponse.Body.Close())
   107  
   108  	// 5.6.3) C -> S
   109  	exchangeResponse, err := httpClient.Post(addr+"pair-setup", pairing.ContentType, exchangeRequest)
   110  	assert.NilError(t, err)
   111  	assert.Equal(t, 200, exchangeResponse.StatusCode)
   112  
   113  	setupCompleted, err := pair.HandleReaderForHandler(exchangeResponse.Body, setupController)
   114  	assert.NilError(t, err)
   115  	assert.NilError(t, exchangeResponse.Body.Close())
   116  	assert.Equal(t, setupCompleted, nil)
   117  
   118  	////////////////////////
   119  	// verify the pairing //
   120  	////////////////////////
   121  	verify, err := pairing.NewVerifyClientController(
   122  		HCDevice{Device: client},
   123  		HCDatabase{Database: clientDatabase},
   124  	)
   125  	assert.NilError(t, err)
   126  
   127  	// 5.7.1) C -> S
   128  	verifyStartResponse, err := httpClient.Post(addr+"pair-verify", pairing.ContentType, bytes.NewReader(verify.StartRequest()))
   129  	assert.NilError(t, err)
   130  	assert.Equal(t, 200, verifyStartResponse.StatusCode)
   131  
   132  	// 2) S -> C
   133  	verifyFinishRequest, sharedSecret, err := verify.FinishRequest(verifyStartResponse.Body)
   134  	assert.NilError(t, err)
   135  
   136  	// 3) C -> S
   137  	verifyFinishResponse, err := httpClient.Post(addr+"pair-verify", pairing.ContentType, bytes.NewReader(verifyFinishRequest))
   138  	assert.NilError(t, err)
   139  	assert.Equal(t, 200, verifyFinishResponse.StatusCode)
   140  
   141  	// // 4) S -> C
   142  	err = verify.FinishResponse(verifyFinishResponse.Body)
   143  	assert.NilError(t, err)
   144  
   145  	// encrypt the connection
   146  	var sharedSecretArray [32]byte
   147  	copy(sharedSecretArray[:], sharedSecret)
   148  	err = encryptClientConn(sharedSecretArray)
   149  	assert.NilError(t, err)
   150  
   151  	t.Log("list the pairing")
   152  	////////////////////////
   153  	// list the pairings  //
   154  	////////////////////////
   155  	pairingList, err := httpClient.Post(addr+"pairings", pairing.ContentType, tlv8Marshal(t, request{
   156  		State:  1,
   157  		Method: kTLVMethod_ListPairings,
   158  	}))
   159  	assert.NilError(t, err)
   160  	assert.Equal(t, 200, pairingList.StatusCode)
   161  
   162  	type controller struct {
   163  		Identifier  []byte   `tlv8:"kTLVType_Identifier"`
   164  		PublicKey   []byte   `tlv8:"kTLVType_PublicKey"`
   165  		Permissions byte     `tlv8:"kTLVType_Permissions"`
   166  		Separator   struct{} `tlv8:"kTLVType_Separator"`
   167  	}
   168  	type listResponse struct {
   169  		State       byte         `tlv8:"kTLVType_State"`
   170  		Controllers []controller `tlv8:""`
   171  	}
   172  	var listed listResponse
   173  	err = tlv8.NewDecoder(pairingList.Body).Decode(&listed)
   174  	assert.NilError(t, err)
   175  	assert.Equal(t, 1, len(listed.Controllers))
   176  	assert.DeepEqual(t, []byte("HomeKit Client"), listed.Controllers[0].Identifier)
   177  
   178  	////////////////////////
   179  	// add pairing  //
   180  	////////////////////////
   181  	newPublic, _, err := ed25519.GenerateKey(nil)
   182  	assert.NilError(t, err)
   183  	pairingAdd, err := httpClient.Post(addr+"pairings", pairing.ContentType, tlv8Marshal(t, request{
   184  		State:      1,
   185  		Method:     kTLVMethod_AddPairing,
   186  		Identifier: []byte("new controller"),
   187  		PublicKey:  newPublic,
   188  	}))
   189  	assert.NilError(t, err)
   190  	assert.Equal(t, 200, pairingList.StatusCode)
   191  
   192  	var added response
   193  	err = tlv8.NewDecoder(pairingAdd.Body).Decode(&added)
   194  	assert.NilError(t, err)
   195  	assert.Equal(t, byte(0), added.Error)
   196  
   197  	////////////////////////
   198  	// list the pairings  //
   199  	////////////////////////
   200  	pairingList, err = httpClient.Post(addr+"pairings", pairing.ContentType, tlv8Marshal(t, request{
   201  		State:  1,
   202  		Method: kTLVMethod_ListPairings,
   203  	}))
   204  	assert.NilError(t, err)
   205  	assert.Equal(t, 200, pairingList.StatusCode)
   206  
   207  	listed = listResponse{}
   208  	err = tlv8.NewDecoder(pairingList.Body).Decode(&listed)
   209  	assert.NilError(t, err)
   210  	assert.Equal(t, 2, len(listed.Controllers))
   211  
   212  	////////////////////////
   213  	// remove pairing  //
   214  	////////////////////////
   215  	pairingRemove, err := httpClient.Post(addr+"pairings", pairing.ContentType, tlv8Marshal(t, request{
   216  		State:      1,
   217  		Method:     kTLVMethod_RemovePairing,
   218  		Identifier: []byte("new controller"),
   219  	}))
   220  	assert.NilError(t, err)
   221  	assert.Equal(t, 200, pairingRemove.StatusCode)
   222  
   223  	var removed response
   224  	err = tlv8.NewDecoder(pairingRemove.Body).Decode(&removed)
   225  	assert.NilError(t, err)
   226  	assert.Equal(t, byte(0), removed.Error)
   227  
   228  	////////////////////////
   229  	// list the pairings  //
   230  	////////////////////////
   231  	pairingList, err = httpClient.Post(addr+"pairings", pairing.ContentType, tlv8Marshal(t, request{
   232  		State:  1,
   233  		Method: kTLVMethod_ListPairings,
   234  	}))
   235  	assert.NilError(t, err)
   236  	assert.Equal(t, 200, pairingList.StatusCode)
   237  
   238  	listed = listResponse{}
   239  	err = tlv8.NewDecoder(pairingList.Body).Decode(&listed)
   240  	assert.NilError(t, err)
   241  	assert.Equal(t, 1, len(listed.Controllers))
   242  	assert.DeepEqual(t, []byte("HomeKit Client"), listed.Controllers[0].Identifier)
   243  }
   244  
   245  func newHTTPClient() (*http.Client, func(sharedKey [32]byte) error) {
   246  	dial, encrypt := pairing.NewEncryptableDialer((&net.Dialer{
   247  		Timeout:   5 * time.Second,
   248  		KeepAlive: 5 * time.Second,
   249  	}).DialContext)
   250  	httpClient := http.Client{
   251  		Transport: &http.Transport{
   252  			Proxy:                 http.ProxyFromEnvironment,
   253  			DialContext:           dial,
   254  			ForceAttemptHTTP2:     false,
   255  			MaxIdleConns:          1,
   256  			IdleConnTimeout:       5 * time.Second,
   257  			TLSHandshakeTimeout:   5 * time.Second,
   258  			ExpectContinueTimeout: 5 * time.Second,
   259  		},
   260  	}
   261  	return &httpClient, encrypt
   262  }
   263  
   264  func ExampleNewEncryptableDialer() {
   265  	dial, encrypt := pairing.NewEncryptableDialer((&net.Dialer{
   266  		Timeout:   5 * time.Second,
   267  		KeepAlive: 5 * time.Second,
   268  	}).DialContext)
   269  	httpClient := http.Client{
   270  		Transport: &http.Transport{
   271  			Proxy:                 http.ProxyFromEnvironment,
   272  			DialContext:           dial,
   273  			ForceAttemptHTTP2:     false,
   274  			MaxIdleConns:          1,
   275  			IdleConnTimeout:       5 * time.Second,
   276  			TLSHandshakeTimeout:   5 * time.Second,
   277  			ExpectContinueTimeout: 5 * time.Second,
   278  		},
   279  	}
   280  	// do whatever you need with the httpClient
   281  	// call encrypt(sharedKey) to encrypt further communications.
   282  	_ = encrypt
   283  	_ = httpClient
   284  }
   285  
   286  func TestPairSetupAlreadyPaired(t *testing.T) {
   287  	deviceDatabase := storage.NewMemDatabase()
   288  	_, privateKey, err := ed25519.GenerateKey(nil)
   289  	assert.NilError(t, err)
   290  	device, err := pairing.NewDeviceWithPin(pairing.NewRandomPairingID(), "001-02-003", privateKey)
   291  	assert.NilError(t, err)
   292  
   293  	s := pairing.NewServer(&http.Server{
   294  		ReadTimeout: 5 * time.Second,
   295  	}, device, deviceDatabase)
   296  
   297  	addr := startServer(t, s)
   298  
   299  	err = deviceDatabase.AddLongTermPublicKey(pairing.Controller{PairingID: []byte("HomeKit client"), LongTermPublicKey: []byte{0x01}})
   300  	assert.NilError(t, err)
   301  
   302  	httpClient, _ := newHTTPClient()
   303  
   304  	srpStartResponse, err := httpClient.Post(addr+"pair-setup", pairing.ContentType, tlv8Marshal(t, struct {
   305  		State  byte `tlv8:"kTLVType_State"`
   306  		Method byte `tlv8:"kTLVType_Method"`
   307  	}{
   308  		State:  1,
   309  		Method: kTLVMethod_PairSetup,
   310  	}))
   311  	assert.NilError(t, err)
   312  	assert.Equal(t, 200, srpStartResponse.StatusCode)
   313  
   314  	type response struct {
   315  		State byte `tlv8:"kTLVType_State"`
   316  		Error byte `tlv8:"kTLVType_Error"`
   317  	}
   318  	var resp response
   319  	err = tlv8.NewDecoder(srpStartResponse.Body).Decode(&resp)
   320  	assert.NilError(t, err)
   321  	assert.Equal(t, resp.Error, byte(kTLVError_Unavailable)) // kTLVError_Unavailable
   322  }