github.com/davidhuie/signcryption@v0.0.0-20180606214722-1b2fa07edd39/stl/relayer_test.go (about)

     1  package stl
     2  
     3  import (
     4  	"bytes"
     5  	"io"
     6  	"math/rand"
     7  	"net"
     8  	"sync"
     9  	"testing"
    10  
    11  	"github.com/DavidHuie/signcryption"
    12  	"github.com/DavidHuie/signcryption/aai"
    13  	"github.com/pkg/errors"
    14  )
    15  
    16  type ServerConnFetcherImpl struct {
    17  	topic []byte
    18  	cert  *signcryption.Certificate
    19  	addr  net.Addr
    20  }
    21  
    22  func (s *ServerConnFetcherImpl) GetConn(topic []byte,
    23  	cert *signcryption.Certificate) (net.Conn, error) {
    24  	if !bytes.Equal(topic, s.topic) || !cert.Equal(s.cert) {
    25  		return nil, errors.New("invalid server requested")
    26  	}
    27  
    28  	return net.Dial("tcp", s.addr.String())
    29  }
    30  
    31  func getClientServerRelayer(t testing.TB, r io.Reader) (*Conn, *Conn, *Relayer, func()) {
    32  	clientCert := generateCert(t, r)
    33  	serverCert := generateCert(t, r)
    34  	relayerCert := generateCert(t, r)
    35  
    36  	verifier := &sessionVerifierImpl{
    37  		clientCert:  clientCert,
    38  		serverCert:  serverCert,
    39  		relayerCert: relayerCert,
    40  		topic:       []byte("t1"),
    41  	}
    42  
    43  	serverListener, err := net.Listen("tcp", ":")
    44  	if err != nil {
    45  		t.Fatal(err)
    46  	}
    47  
    48  	var serverConn Conn
    49  	go func() {
    50  		conn, err := serverListener.Accept()
    51  		if err != nil {
    52  			t.Logf("error accepting conn: %s", err)
    53  		}
    54  
    55  		serverConn = *NewServerConn(conn, &ServerConfig{
    56  			ServerCertificate: serverCert,
    57  			SessionVerifier:   verifier,
    58  		})
    59  		if err := serverConn.Handshake(); err != nil {
    60  			t.Error(err)
    61  			conn.Close()
    62  		}
    63  	}()
    64  
    65  	fetcher := &ServerConnFetcherImpl{
    66  		topic: []byte("t1"),
    67  		cert:  serverCert,
    68  		addr:  serverListener.Addr(),
    69  	}
    70  
    71  	relayerListener, err := net.Listen("tcp", ":")
    72  	if err != nil {
    73  		t.Fatal(err)
    74  	}
    75  
    76  	var relayer Relayer
    77  	go func() {
    78  		conn, err := relayerListener.Accept()
    79  		if err != nil {
    80  			t.Logf("error accepting conn: %s", err)
    81  		}
    82  
    83  		relayer = *NewRelayer(conn, &RelayerConfig{
    84  			Verifier:    verifier,
    85  			ConnFetcher: fetcher,
    86  			RelayerCert: relayerCert,
    87  			Signcrypter: aai.NewP256(),
    88  		})
    89  
    90  		if err := relayer.Start(); err != nil {
    91  			t.Fatal(err)
    92  		}
    93  	}()
    94  
    95  	conn, err := net.Dial("tcp", relayerListener.Addr().String())
    96  	if err != nil {
    97  		t.Fatal(err)
    98  	}
    99  	clientConn := NewConn(conn, &ClientConfig{
   100  		Topic:              []byte("t1"),
   101  		ClientCertificate:  clientCert,
   102  		ServerCertificate:  serverCert,
   103  		RelayerCeriificate: relayerCert,
   104  	})
   105  
   106  	return clientConn, &serverConn, &relayer, func() {
   107  		relayer.Close()
   108  		conn.Close()
   109  		serverListener.Close()
   110  		relayerListener.Close()
   111  	}
   112  }
   113  
   114  func TestRelayerHandshake(t *testing.T) {
   115  	r := rand.New(rand.NewSource(0))
   116  	clientConn, serverConn, relayer, cleanup := getClientServerRelayer(t, r)
   117  	defer cleanup()
   118  
   119  	if err := clientConn.Handshake(); err != nil {
   120  		t.Fatal(err)
   121  	}
   122  	if err := serverConn.Handshake(); err != nil {
   123  		t.Fatal(err)
   124  	}
   125  
   126  	if !bytes.Equal(relayer.sessionKey, clientConn.sessionKey) ||
   127  		!bytes.Equal(clientConn.sessionKey, serverConn.sessionKey) {
   128  		t.Fatal("session keys should be equal")
   129  	}
   130  }
   131  
   132  func TestRelayerBidirectionalReadWrite(t *testing.T) {
   133  	currentRand := int64(0)
   134  	getRand := func() *rand.Rand {
   135  		currentRand++
   136  		return rand.New(rand.NewSource(currentRand))
   137  	}
   138  
   139  	clientConn, serverConn, _, cleanup := getClientServerRelayer(t, getRand())
   140  	defer cleanup()
   141  
   142  	if err := clientConn.Handshake(); err != nil {
   143  		t.Fatal(err)
   144  	}
   145  	if err := serverConn.Handshake(); err != nil {
   146  		t.Fatal(err)
   147  	}
   148  
   149  	numBytes := int64(1 * 1024 * 1024)
   150  
   151  	rand1 := make([]byte, numBytes)
   152  	rand2 := make([]byte, numBytes)
   153  
   154  	clientReader := &bytes.Buffer{}
   155  	serverReader := &bytes.Buffer{}
   156  
   157  	if _, err := io.ReadFull(getRand(), rand1); err != nil {
   158  		t.Fatal(err)
   159  	}
   160  	if _, err := io.ReadFull(getRand(), rand2); err != nil {
   161  		t.Fatal(err)
   162  	}
   163  
   164  	wg := &sync.WaitGroup{}
   165  	wg.Add(4)
   166  
   167  	go func() {
   168  		defer wg.Done()
   169  		n, err := clientConn.Write(rand1)
   170  		if err != nil {
   171  			t.Fatalf("copied %d bytes, error: %s", n, err)
   172  		}
   173  	}()
   174  
   175  	go func() {
   176  		defer wg.Done()
   177  		n, err := serverConn.Write(rand2)
   178  		if err != nil {
   179  			t.Fatalf("copied %d bytes, error: %s", n, err)
   180  		}
   181  	}()
   182  
   183  	go func() {
   184  		defer wg.Done()
   185  		if _, err := io.CopyN(clientReader, clientConn, numBytes); err != nil {
   186  			t.Fatal(err)
   187  		}
   188  	}()
   189  
   190  	go func() {
   191  		defer wg.Done()
   192  		if _, err := io.CopyN(serverReader, serverConn, numBytes); err != nil {
   193  			t.Fatal(err)
   194  		}
   195  	}()
   196  
   197  	wg.Wait()
   198  
   199  	if bytes.Compare(clientReader.Bytes(), rand2) != 0 {
   200  		t.Fatal("client buffers not equal")
   201  	}
   202  	if bytes.Compare(serverReader.Bytes(), rand1) != 0 {
   203  		t.Fatal("server buffers not equal")
   204  	}
   205  }
   206  
   207  func BenchmarkRelayerBidirectionalReadWrite(t *testing.B) {
   208  	currentRand := int64(0)
   209  	getRand := func() *rand.Rand {
   210  		currentRand++
   211  		return rand.New(rand.NewSource(currentRand))
   212  	}
   213  
   214  	clientConn, serverConn, _, cleanup := getClientServerRelayer(t, getRand())
   215  	defer cleanup()
   216  
   217  	if err := clientConn.Handshake(); err != nil {
   218  		t.Fatal(err)
   219  	}
   220  	if err := serverConn.Handshake(); err != nil {
   221  		t.Fatal(err)
   222  	}
   223  
   224  	numBytes := int64(1024 * 1024 * 1024)
   225  
   226  	rand1 := make([]byte, numBytes)
   227  	rand2 := make([]byte, numBytes)
   228  
   229  	clientReader := &bytes.Buffer{}
   230  	serverReader := &bytes.Buffer{}
   231  
   232  	if _, err := io.ReadFull(getRand(), rand1); err != nil {
   233  		t.Fatal(err)
   234  	}
   235  	if _, err := io.ReadFull(getRand(), rand2); err != nil {
   236  		t.Fatal(err)
   237  	}
   238  
   239  	wg := &sync.WaitGroup{}
   240  	wg.Add(4)
   241  
   242  	t.ResetTimer()
   243  
   244  	go func() {
   245  		defer wg.Done()
   246  		n, err := clientConn.Write(rand1)
   247  		if err != nil {
   248  			t.Fatalf("copied %d bytes, error: %s", n, err)
   249  		}
   250  	}()
   251  
   252  	go func() {
   253  		defer wg.Done()
   254  		n, err := serverConn.Write(rand2)
   255  		if err != nil {
   256  			t.Fatalf("copied %d bytes, error: %s", n, err)
   257  		}
   258  	}()
   259  
   260  	go func() {
   261  		defer wg.Done()
   262  		if _, err := io.CopyN(clientReader, clientConn, numBytes); err != nil {
   263  			t.Fatal(err)
   264  		}
   265  	}()
   266  
   267  	go func() {
   268  		defer wg.Done()
   269  		if _, err := io.CopyN(serverReader, serverConn, numBytes); err != nil {
   270  			t.Fatal(err)
   271  		}
   272  	}()
   273  
   274  	wg.Wait()
   275  
   276  	if bytes.Compare(clientReader.Bytes(), rand2) != 0 {
   277  		t.Fatal("client buffers not equal")
   278  	}
   279  	if bytes.Compare(serverReader.Bytes(), rand1) != 0 {
   280  		t.Fatal("server buffers not equal")
   281  	}
   282  }