github.com/davidhuie/signcryption@v0.0.0-20180606214722-1b2fa07edd39/stl/conn_test.go (about) 1 package stl 2 3 import ( 4 "bytes" 5 "io" 6 "math/rand" 7 "net" 8 "sync" 9 "testing" 10 ) 11 12 func getClientServer(t testing.TB, r io.Reader) (*Conn, *Conn, func()) { 13 clientCert := generateCert(t, r) 14 serverCert := generateCert(t, r) 15 relayerCert := generateCert(t, r) 16 17 verifier := &sessionVerifierImpl{ 18 clientCert: clientCert, 19 serverCert: serverCert, 20 relayerCert: relayerCert, 21 topic: []byte("t1"), 22 } 23 24 listener, err := net.Listen("tcp", ":") 25 if err != nil { 26 t.Fatal(err) 27 } 28 29 var serverConn Conn 30 31 go func() { 32 conn, err := listener.Accept() 33 if err != nil { 34 t.Logf("error accepting conn: %s", err) 35 } 36 37 serverConn = *NewServerConn(conn, &ServerConfig{ 38 ServerCertificate: serverCert, 39 SessionVerifier: verifier, 40 }) 41 if err := serverConn.Handshake(); err != nil { 42 t.Error(err) 43 conn.Close() 44 } 45 }() 46 47 conn, err := net.Dial("tcp", listener.Addr().String()) 48 if err != nil { 49 t.Fatal(err) 50 } 51 clientConn := NewConn(conn, &ClientConfig{ 52 Topic: []byte("t1"), 53 ClientCertificate: clientCert, 54 ServerCertificate: serverCert, 55 RelayerCeriificate: relayerCert, 56 }) 57 58 return clientConn, &serverConn, func() { 59 conn.Close() 60 listener.Close() 61 } 62 } 63 64 func TestConnIntegration(t *testing.T) { 65 clientConn, serverConn, cleanup := getClientServer(t, rand.New(rand.NewSource(0))) 66 defer cleanup() 67 68 if err := clientConn.Handshake(); err != nil { 69 t.Fatal(err) 70 } 71 if err := serverConn.Handshake(); err != nil { 72 t.Fatal(err) 73 } 74 75 if bytes.Compare(clientConn.sessionKey, serverConn.sessionKey) != 0 { 76 t.Fatal("session keys must match") 77 } 78 } 79 80 func TestBidirectionalReadWrite(t *testing.T) { 81 currentRand := int64(0) 82 getRand := func() *rand.Rand { 83 currentRand++ 84 return rand.New(rand.NewSource(currentRand)) 85 } 86 87 clientConn, serverConn, cleanup := getClientServer(t, getRand()) 88 defer cleanup() 89 90 if err := clientConn.Handshake(); err != nil { 91 t.Fatal(err) 92 } 93 if err := serverConn.Handshake(); err != nil { 94 t.Fatal(err) 95 } 96 97 numBytes := int64(10 * 1024 * 1024) 98 99 rand1 := make([]byte, numBytes) 100 rand2 := make([]byte, numBytes) 101 102 clientReader := &bytes.Buffer{} 103 serverReader := &bytes.Buffer{} 104 105 if _, err := io.ReadFull(getRand(), rand1); err != nil { 106 t.Fatal(err) 107 } 108 if _, err := io.ReadFull(getRand(), rand2); err != nil { 109 t.Fatal(err) 110 } 111 112 wg := &sync.WaitGroup{} 113 wg.Add(4) 114 115 go func() { 116 defer wg.Done() 117 n, err := clientConn.Write(rand1) 118 if err != nil { 119 t.Fatalf("copied %d bytes, error: %s", n, err) 120 } 121 }() 122 123 go func() { 124 defer wg.Done() 125 n, err := serverConn.Write(rand2) 126 if err != nil { 127 t.Fatalf("copied %d bytes, error: %s", n, err) 128 } 129 }() 130 131 go func() { 132 defer wg.Done() 133 if _, err := io.CopyN(clientReader, clientConn, numBytes); err != nil { 134 panic(err) 135 } 136 }() 137 138 go func() { 139 defer wg.Done() 140 if _, err := io.CopyN(serverReader, serverConn, numBytes); err != nil { 141 panic(err) 142 } 143 }() 144 145 wg.Wait() 146 147 if bytes.Compare(clientReader.Bytes(), rand2) != 0 { 148 t.Fatal("client buffers not equal") 149 } 150 if bytes.Compare(serverReader.Bytes(), rand1) != 0 { 151 t.Fatal("server buffers not equal") 152 } 153 } 154 155 func BenchmarkBidirectionalReadWrite(t *testing.B) { 156 currentRand := int64(0) 157 getRand := func() *rand.Rand { 158 currentRand++ 159 return rand.New(rand.NewSource(currentRand)) 160 } 161 162 clientConn, serverConn, cleanup := getClientServer(t, getRand()) 163 defer cleanup() 164 165 if err := clientConn.Handshake(); err != nil { 166 t.Fatal(err) 167 } 168 if err := serverConn.Handshake(); err != nil { 169 t.Fatal(err) 170 } 171 172 numBytes := int64(1 * 1024 * 1024) 173 174 rand1 := make([]byte, numBytes) 175 rand2 := make([]byte, numBytes) 176 177 clientReader := &bytes.Buffer{} 178 serverReader := &bytes.Buffer{} 179 180 if _, err := io.ReadFull(getRand(), rand1); err != nil { 181 t.Fatal(err) 182 } 183 if _, err := io.ReadFull(getRand(), rand2); err != nil { 184 t.Fatal(err) 185 } 186 187 wg := &sync.WaitGroup{} 188 wg.Add(4) 189 190 t.ResetTimer() 191 192 go func() { 193 defer wg.Done() 194 n, err := clientConn.Write(rand1) 195 if err != nil { 196 t.Fatalf("copied %d bytes, error: %s", n, err) 197 } 198 }() 199 200 go func() { 201 defer wg.Done() 202 n, err := serverConn.Write(rand2) 203 if err != nil { 204 t.Fatalf("copied %d bytes, error: %s", n, err) 205 } 206 }() 207 208 go func() { 209 defer wg.Done() 210 if _, err := io.CopyN(clientReader, clientConn, numBytes); err != nil { 211 t.Fatal(err) 212 } 213 }() 214 215 go func() { 216 defer wg.Done() 217 if _, err := io.CopyN(serverReader, serverConn, numBytes); err != nil { 218 t.Fatal(err) 219 } 220 }() 221 222 wg.Wait() 223 224 if bytes.Compare(clientReader.Bytes(), rand2) != 0 { 225 t.Fatal("client buffers not equal") 226 } 227 if bytes.Compare(serverReader.Bytes(), rand1) != 0 { 228 t.Fatal("server buffers not equal") 229 } 230 }