github.com/davidhuie/signcryption@v0.0.0-20180606214722-1b2fa07edd39/stl/conn_test.go (about)

     1  package stl
     2  
     3  import (
     4  	"bytes"
     5  	"io"
     6  	"math/rand"
     7  	"net"
     8  	"sync"
     9  	"testing"
    10  )
    11  
    12  func getClientServer(t testing.TB, r io.Reader) (*Conn, *Conn, func()) {
    13  	clientCert := generateCert(t, r)
    14  	serverCert := generateCert(t, r)
    15  	relayerCert := generateCert(t, r)
    16  
    17  	verifier := &sessionVerifierImpl{
    18  		clientCert:  clientCert,
    19  		serverCert:  serverCert,
    20  		relayerCert: relayerCert,
    21  		topic:       []byte("t1"),
    22  	}
    23  
    24  	listener, err := net.Listen("tcp", ":")
    25  	if err != nil {
    26  		t.Fatal(err)
    27  	}
    28  
    29  	var serverConn Conn
    30  
    31  	go func() {
    32  		conn, err := listener.Accept()
    33  		if err != nil {
    34  			t.Logf("error accepting conn: %s", err)
    35  		}
    36  
    37  		serverConn = *NewServerConn(conn, &ServerConfig{
    38  			ServerCertificate: serverCert,
    39  			SessionVerifier:   verifier,
    40  		})
    41  		if err := serverConn.Handshake(); err != nil {
    42  			t.Error(err)
    43  			conn.Close()
    44  		}
    45  	}()
    46  
    47  	conn, err := net.Dial("tcp", listener.Addr().String())
    48  	if err != nil {
    49  		t.Fatal(err)
    50  	}
    51  	clientConn := NewConn(conn, &ClientConfig{
    52  		Topic:              []byte("t1"),
    53  		ClientCertificate:  clientCert,
    54  		ServerCertificate:  serverCert,
    55  		RelayerCeriificate: relayerCert,
    56  	})
    57  
    58  	return clientConn, &serverConn, func() {
    59  		conn.Close()
    60  		listener.Close()
    61  	}
    62  }
    63  
    64  func TestConnIntegration(t *testing.T) {
    65  	clientConn, serverConn, cleanup := getClientServer(t, rand.New(rand.NewSource(0)))
    66  	defer cleanup()
    67  
    68  	if err := clientConn.Handshake(); err != nil {
    69  		t.Fatal(err)
    70  	}
    71  	if err := serverConn.Handshake(); err != nil {
    72  		t.Fatal(err)
    73  	}
    74  
    75  	if bytes.Compare(clientConn.sessionKey, serverConn.sessionKey) != 0 {
    76  		t.Fatal("session keys must match")
    77  	}
    78  }
    79  
    80  func TestBidirectionalReadWrite(t *testing.T) {
    81  	currentRand := int64(0)
    82  	getRand := func() *rand.Rand {
    83  		currentRand++
    84  		return rand.New(rand.NewSource(currentRand))
    85  	}
    86  
    87  	clientConn, serverConn, cleanup := getClientServer(t, getRand())
    88  	defer cleanup()
    89  
    90  	if err := clientConn.Handshake(); err != nil {
    91  		t.Fatal(err)
    92  	}
    93  	if err := serverConn.Handshake(); err != nil {
    94  		t.Fatal(err)
    95  	}
    96  
    97  	numBytes := int64(10 * 1024 * 1024)
    98  
    99  	rand1 := make([]byte, numBytes)
   100  	rand2 := make([]byte, numBytes)
   101  
   102  	clientReader := &bytes.Buffer{}
   103  	serverReader := &bytes.Buffer{}
   104  
   105  	if _, err := io.ReadFull(getRand(), rand1); err != nil {
   106  		t.Fatal(err)
   107  	}
   108  	if _, err := io.ReadFull(getRand(), rand2); err != nil {
   109  		t.Fatal(err)
   110  	}
   111  
   112  	wg := &sync.WaitGroup{}
   113  	wg.Add(4)
   114  
   115  	go func() {
   116  		defer wg.Done()
   117  		n, err := clientConn.Write(rand1)
   118  		if err != nil {
   119  			t.Fatalf("copied %d bytes, error: %s", n, err)
   120  		}
   121  	}()
   122  
   123  	go func() {
   124  		defer wg.Done()
   125  		n, err := serverConn.Write(rand2)
   126  		if err != nil {
   127  			t.Fatalf("copied %d bytes, error: %s", n, err)
   128  		}
   129  	}()
   130  
   131  	go func() {
   132  		defer wg.Done()
   133  		if _, err := io.CopyN(clientReader, clientConn, numBytes); err != nil {
   134  			panic(err)
   135  		}
   136  	}()
   137  
   138  	go func() {
   139  		defer wg.Done()
   140  		if _, err := io.CopyN(serverReader, serverConn, numBytes); err != nil {
   141  			panic(err)
   142  		}
   143  	}()
   144  
   145  	wg.Wait()
   146  
   147  	if bytes.Compare(clientReader.Bytes(), rand2) != 0 {
   148  		t.Fatal("client buffers not equal")
   149  	}
   150  	if bytes.Compare(serverReader.Bytes(), rand1) != 0 {
   151  		t.Fatal("server buffers not equal")
   152  	}
   153  }
   154  
   155  func BenchmarkBidirectionalReadWrite(t *testing.B) {
   156  	currentRand := int64(0)
   157  	getRand := func() *rand.Rand {
   158  		currentRand++
   159  		return rand.New(rand.NewSource(currentRand))
   160  	}
   161  
   162  	clientConn, serverConn, cleanup := getClientServer(t, getRand())
   163  	defer cleanup()
   164  
   165  	if err := clientConn.Handshake(); err != nil {
   166  		t.Fatal(err)
   167  	}
   168  	if err := serverConn.Handshake(); err != nil {
   169  		t.Fatal(err)
   170  	}
   171  
   172  	numBytes := int64(1 * 1024 * 1024)
   173  
   174  	rand1 := make([]byte, numBytes)
   175  	rand2 := make([]byte, numBytes)
   176  
   177  	clientReader := &bytes.Buffer{}
   178  	serverReader := &bytes.Buffer{}
   179  
   180  	if _, err := io.ReadFull(getRand(), rand1); err != nil {
   181  		t.Fatal(err)
   182  	}
   183  	if _, err := io.ReadFull(getRand(), rand2); err != nil {
   184  		t.Fatal(err)
   185  	}
   186  
   187  	wg := &sync.WaitGroup{}
   188  	wg.Add(4)
   189  
   190  	t.ResetTimer()
   191  
   192  	go func() {
   193  		defer wg.Done()
   194  		n, err := clientConn.Write(rand1)
   195  		if err != nil {
   196  			t.Fatalf("copied %d bytes, error: %s", n, err)
   197  		}
   198  	}()
   199  
   200  	go func() {
   201  		defer wg.Done()
   202  		n, err := serverConn.Write(rand2)
   203  		if err != nil {
   204  			t.Fatalf("copied %d bytes, error: %s", n, err)
   205  		}
   206  	}()
   207  
   208  	go func() {
   209  		defer wg.Done()
   210  		if _, err := io.CopyN(clientReader, clientConn, numBytes); err != nil {
   211  			t.Fatal(err)
   212  		}
   213  	}()
   214  
   215  	go func() {
   216  		defer wg.Done()
   217  		if _, err := io.CopyN(serverReader, serverConn, numBytes); err != nil {
   218  			t.Fatal(err)
   219  		}
   220  	}()
   221  
   222  	wg.Wait()
   223  
   224  	if bytes.Compare(clientReader.Bytes(), rand2) != 0 {
   225  		t.Fatal("client buffers not equal")
   226  	}
   227  	if bytes.Compare(serverReader.Bytes(), rand1) != 0 {
   228  		t.Fatal("server buffers not equal")
   229  	}
   230  }