github.com/noisysockets/noisysockets@v0.21.2-0.20240515114641-7f467e651c90/internal/transport/noise_test.go (about)

     1  // SPDX-License-Identifier: MPL-2.0
     2  /*
     3   * Copyright (C) 2024 The Noisy Sockets Authors.
     4   *
     5   * This Source Code Form is subject to the terms of the Mozilla Public
     6   * License, v. 2.0. If a copy of the MPL was not distributed with this
     7   * file, You can obtain one at http://mozilla.org/MPL/2.0/.
     8   *
     9   * Portions of this file are based on code originally from wireguard-go,
    10   *
    11   * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
    12   *
    13   * Permission is hereby granted, free of charge, to any person obtaining a copy of
    14   * this software and associated documentation files (the "Software"), to deal in
    15   * the Software without restriction, including without limitation the rights to
    16   * use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
    17   * of the Software, and to permit persons to whom the Software is furnished to do
    18   * so, subject to the following conditions:
    19   *
    20   * The above copyright notice and this permission notice shall be included in all
    21   * copies or substantial portions of the Software.
    22   *
    23   * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    24   * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    25   * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    26   * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    27   * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    28   * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
    29   * SOFTWARE.
    30   */
    31  
    32  package transport
    33  
    34  import (
    35  	"bytes"
    36  	"encoding/binary"
    37  	"net"
    38  	"testing"
    39  	"time"
    40  
    41  	"github.com/neilotoole/slogt"
    42  	"github.com/noisysockets/noisysockets/internal/conn"
    43  	"github.com/noisysockets/noisysockets/types"
    44  	"github.com/stretchr/testify/require"
    45  )
    46  
    47  func TestCurveWrappers(t *testing.T) {
    48  	sk1, err := types.NewPrivateKey()
    49  	assertNil(t, err)
    50  
    51  	sk2, err := types.NewPrivateKey()
    52  	assertNil(t, err)
    53  
    54  	pk1 := sk1.Public()
    55  	pk2 := sk2.Public()
    56  
    57  	ss1, err1 := sharedSecret(sk1, pk2)
    58  	ss2, err2 := sharedSecret(sk2, pk1)
    59  
    60  	if ss1 != ss2 || err1 != nil || err2 != nil {
    61  		t.Fatal("Failed to compute shared secet")
    62  	}
    63  }
    64  
    65  func randTransport(t *testing.T) *Transport {
    66  	sk, err := types.NewPrivateKey()
    67  	if err != nil {
    68  		t.Fatal(err)
    69  	}
    70  	logger := slogt.New(t)
    71  	transport := NewTransport(logger, &discardingSink{}, conn.NewStdNetBind())
    72  	transport.SetPrivateKey(sk)
    73  	return transport
    74  }
    75  
    76  func assertNil(t *testing.T, err error) {
    77  	if err != nil {
    78  		t.Fatal(err)
    79  	}
    80  }
    81  
    82  func assertEqual(t *testing.T, a, b []byte) {
    83  	if !bytes.Equal(a, b) {
    84  		t.Fatal(a, "!=", b)
    85  	}
    86  }
    87  
    88  func TestNoiseHandshake(t *testing.T) {
    89  	trans1 := randTransport(t)
    90  	trans2 := randTransport(t)
    91  
    92  	t.Cleanup(func() {
    93  		require.NoError(t, trans1.Close())
    94  		require.NoError(t, trans2.Close())
    95  
    96  		// Time for the workers to finish.
    97  		time.Sleep(100 * time.Millisecond)
    98  	})
    99  
   100  	peer1, err := trans2.NewPeer(trans1.staticIdentity.privateKey.Public())
   101  	if err != nil {
   102  		t.Fatal(err)
   103  	}
   104  	peer2, err := trans1.NewPeer(trans2.staticIdentity.privateKey.Public())
   105  	if err != nil {
   106  		t.Fatal(err)
   107  	}
   108  	peer1.Start()
   109  	peer2.Start()
   110  
   111  	assertEqual(
   112  		t,
   113  		peer1.handshake.precomputedStaticStatic[:],
   114  		peer2.handshake.precomputedStaticStatic[:],
   115  	)
   116  
   117  	/* simulate handshake */
   118  
   119  	// initiation message
   120  
   121  	t.Log("exchange initiation message")
   122  
   123  	msg1, err := trans1.CreateMessageInitiation(peer2)
   124  	assertNil(t, err)
   125  
   126  	packet := make([]byte, 0, 256)
   127  	writer := bytes.NewBuffer(packet)
   128  	err = binary.Write(writer, binary.LittleEndian, msg1)
   129  	assertNil(t, err)
   130  	peer := trans2.ConsumeMessageInitiation(msg1)
   131  	if peer == nil {
   132  		t.Fatal("handshake failed at initiation message")
   133  	}
   134  
   135  	assertEqual(
   136  		t,
   137  		peer1.handshake.chainKey[:],
   138  		peer2.handshake.chainKey[:],
   139  	)
   140  
   141  	assertEqual(
   142  		t,
   143  		peer1.handshake.hash[:],
   144  		peer2.handshake.hash[:],
   145  	)
   146  
   147  	// response message
   148  
   149  	t.Log("exchange response message")
   150  
   151  	msg2, err := trans2.CreateMessageResponse(peer1)
   152  	assertNil(t, err)
   153  
   154  	peer = trans1.ConsumeMessageResponse(msg2)
   155  	if peer == nil {
   156  		t.Fatal("handshake failed at response message")
   157  	}
   158  
   159  	assertEqual(
   160  		t,
   161  		peer1.handshake.chainKey[:],
   162  		peer2.handshake.chainKey[:],
   163  	)
   164  
   165  	assertEqual(
   166  		t,
   167  		peer1.handshake.hash[:],
   168  		peer2.handshake.hash[:],
   169  	)
   170  
   171  	// key pairs
   172  
   173  	t.Log("deriving keys")
   174  
   175  	err = peer1.BeginSymmetricSession()
   176  	if err != nil {
   177  		t.Fatal("failed to derive keypair for peer 1", err)
   178  	}
   179  
   180  	err = peer2.BeginSymmetricSession()
   181  	if err != nil {
   182  		t.Fatal("failed to derive keypair for peer 2", err)
   183  	}
   184  
   185  	key1 := peer1.keypairs.next.Load()
   186  	key2 := peer2.keypairs.current
   187  
   188  	// encrypting / decryption test
   189  
   190  	t.Log("test key pairs")
   191  
   192  	func() {
   193  		testMsg := []byte("test message 1")
   194  		var err error
   195  		var out []byte
   196  		var nonce [12]byte
   197  		out = key1.send.Seal(out, nonce[:], testMsg, nil)
   198  		out, err = key2.receive.Open(out[:0], nonce[:], out, nil)
   199  		assertNil(t, err)
   200  		assertEqual(t, out, testMsg)
   201  	}()
   202  
   203  	func() {
   204  		testMsg := []byte("test message 2")
   205  		var err error
   206  		var out []byte
   207  		var nonce [12]byte
   208  		out = key2.send.Seal(out, nonce[:], testMsg, nil)
   209  		out, err = key1.receive.Open(out[:0], nonce[:], out, nil)
   210  		assertNil(t, err)
   211  		assertEqual(t, out, testMsg)
   212  	}()
   213  }
   214  
   215  type discardingSink struct {
   216  	closed bool
   217  }
   218  
   219  func (ss *discardingSink) Close() error {
   220  	ss.closed = true
   221  	return nil
   222  }
   223  
   224  func (ss *discardingSink) Read(bufs [][]byte, sizes []int, destinations []types.NoisePublicKey, offset int) (int, error) {
   225  	if ss.closed {
   226  		return 0, net.ErrClosed
   227  	}
   228  
   229  	time.Sleep(10 * time.Millisecond)
   230  
   231  	return 0, nil
   232  }
   233  
   234  func (discardingSink) Write(bufs [][]byte, sources []types.NoisePublicKey, offset int) (int, error) {
   235  	return 0, nil
   236  }
   237  
   238  func (discardingSink) MTU() int {
   239  	return 1420
   240  }
   241  
   242  func (discardingSink) BatchSize() int {
   243  	return 1
   244  }