github.com/davidhuie/signcryption@v0.0.0-20180606214722-1b2fa07edd39/stl/relayer_test.go (about) 1 package stl 2 3 import ( 4 "bytes" 5 "io" 6 "math/rand" 7 "net" 8 "sync" 9 "testing" 10 11 "github.com/DavidHuie/signcryption" 12 "github.com/DavidHuie/signcryption/aai" 13 "github.com/pkg/errors" 14 ) 15 16 type ServerConnFetcherImpl struct { 17 topic []byte 18 cert *signcryption.Certificate 19 addr net.Addr 20 } 21 22 func (s *ServerConnFetcherImpl) GetConn(topic []byte, 23 cert *signcryption.Certificate) (net.Conn, error) { 24 if !bytes.Equal(topic, s.topic) || !cert.Equal(s.cert) { 25 return nil, errors.New("invalid server requested") 26 } 27 28 return net.Dial("tcp", s.addr.String()) 29 } 30 31 func getClientServerRelayer(t testing.TB, r io.Reader) (*Conn, *Conn, *Relayer, func()) { 32 clientCert := generateCert(t, r) 33 serverCert := generateCert(t, r) 34 relayerCert := generateCert(t, r) 35 36 verifier := &sessionVerifierImpl{ 37 clientCert: clientCert, 38 serverCert: serverCert, 39 relayerCert: relayerCert, 40 topic: []byte("t1"), 41 } 42 43 serverListener, err := net.Listen("tcp", ":") 44 if err != nil { 45 t.Fatal(err) 46 } 47 48 var serverConn Conn 49 go func() { 50 conn, err := serverListener.Accept() 51 if err != nil { 52 t.Logf("error accepting conn: %s", err) 53 } 54 55 serverConn = *NewServerConn(conn, &ServerConfig{ 56 ServerCertificate: serverCert, 57 SessionVerifier: verifier, 58 }) 59 if err := serverConn.Handshake(); err != nil { 60 t.Error(err) 61 conn.Close() 62 } 63 }() 64 65 fetcher := &ServerConnFetcherImpl{ 66 topic: []byte("t1"), 67 cert: serverCert, 68 addr: serverListener.Addr(), 69 } 70 71 relayerListener, err := net.Listen("tcp", ":") 72 if err != nil { 73 t.Fatal(err) 74 } 75 76 var relayer Relayer 77 go func() { 78 conn, err := relayerListener.Accept() 79 if err != nil { 80 t.Logf("error accepting conn: %s", err) 81 } 82 83 relayer = *NewRelayer(conn, &RelayerConfig{ 84 Verifier: verifier, 85 ConnFetcher: fetcher, 86 RelayerCert: relayerCert, 87 Signcrypter: aai.NewP256(), 88 }) 89 90 if err := relayer.Start(); err != nil { 91 t.Fatal(err) 92 } 93 }() 94 95 conn, err := net.Dial("tcp", relayerListener.Addr().String()) 96 if err != nil { 97 t.Fatal(err) 98 } 99 clientConn := NewConn(conn, &ClientConfig{ 100 Topic: []byte("t1"), 101 ClientCertificate: clientCert, 102 ServerCertificate: serverCert, 103 RelayerCeriificate: relayerCert, 104 }) 105 106 return clientConn, &serverConn, &relayer, func() { 107 relayer.Close() 108 conn.Close() 109 serverListener.Close() 110 relayerListener.Close() 111 } 112 } 113 114 func TestRelayerHandshake(t *testing.T) { 115 r := rand.New(rand.NewSource(0)) 116 clientConn, serverConn, relayer, cleanup := getClientServerRelayer(t, r) 117 defer cleanup() 118 119 if err := clientConn.Handshake(); err != nil { 120 t.Fatal(err) 121 } 122 if err := serverConn.Handshake(); err != nil { 123 t.Fatal(err) 124 } 125 126 if !bytes.Equal(relayer.sessionKey, clientConn.sessionKey) || 127 !bytes.Equal(clientConn.sessionKey, serverConn.sessionKey) { 128 t.Fatal("session keys should be equal") 129 } 130 } 131 132 func TestRelayerBidirectionalReadWrite(t *testing.T) { 133 currentRand := int64(0) 134 getRand := func() *rand.Rand { 135 currentRand++ 136 return rand.New(rand.NewSource(currentRand)) 137 } 138 139 clientConn, serverConn, _, cleanup := getClientServerRelayer(t, getRand()) 140 defer cleanup() 141 142 if err := clientConn.Handshake(); err != nil { 143 t.Fatal(err) 144 } 145 if err := serverConn.Handshake(); err != nil { 146 t.Fatal(err) 147 } 148 149 numBytes := int64(1 * 1024 * 1024) 150 151 rand1 := make([]byte, numBytes) 152 rand2 := make([]byte, numBytes) 153 154 clientReader := &bytes.Buffer{} 155 serverReader := &bytes.Buffer{} 156 157 if _, err := io.ReadFull(getRand(), rand1); err != nil { 158 t.Fatal(err) 159 } 160 if _, err := io.ReadFull(getRand(), rand2); err != nil { 161 t.Fatal(err) 162 } 163 164 wg := &sync.WaitGroup{} 165 wg.Add(4) 166 167 go func() { 168 defer wg.Done() 169 n, err := clientConn.Write(rand1) 170 if err != nil { 171 t.Fatalf("copied %d bytes, error: %s", n, err) 172 } 173 }() 174 175 go func() { 176 defer wg.Done() 177 n, err := serverConn.Write(rand2) 178 if err != nil { 179 t.Fatalf("copied %d bytes, error: %s", n, err) 180 } 181 }() 182 183 go func() { 184 defer wg.Done() 185 if _, err := io.CopyN(clientReader, clientConn, numBytes); err != nil { 186 t.Fatal(err) 187 } 188 }() 189 190 go func() { 191 defer wg.Done() 192 if _, err := io.CopyN(serverReader, serverConn, numBytes); err != nil { 193 t.Fatal(err) 194 } 195 }() 196 197 wg.Wait() 198 199 if bytes.Compare(clientReader.Bytes(), rand2) != 0 { 200 t.Fatal("client buffers not equal") 201 } 202 if bytes.Compare(serverReader.Bytes(), rand1) != 0 { 203 t.Fatal("server buffers not equal") 204 } 205 } 206 207 func BenchmarkRelayerBidirectionalReadWrite(t *testing.B) { 208 currentRand := int64(0) 209 getRand := func() *rand.Rand { 210 currentRand++ 211 return rand.New(rand.NewSource(currentRand)) 212 } 213 214 clientConn, serverConn, _, cleanup := getClientServerRelayer(t, getRand()) 215 defer cleanup() 216 217 if err := clientConn.Handshake(); err != nil { 218 t.Fatal(err) 219 } 220 if err := serverConn.Handshake(); err != nil { 221 t.Fatal(err) 222 } 223 224 numBytes := int64(1024 * 1024 * 1024) 225 226 rand1 := make([]byte, numBytes) 227 rand2 := make([]byte, numBytes) 228 229 clientReader := &bytes.Buffer{} 230 serverReader := &bytes.Buffer{} 231 232 if _, err := io.ReadFull(getRand(), rand1); err != nil { 233 t.Fatal(err) 234 } 235 if _, err := io.ReadFull(getRand(), rand2); err != nil { 236 t.Fatal(err) 237 } 238 239 wg := &sync.WaitGroup{} 240 wg.Add(4) 241 242 t.ResetTimer() 243 244 go func() { 245 defer wg.Done() 246 n, err := clientConn.Write(rand1) 247 if err != nil { 248 t.Fatalf("copied %d bytes, error: %s", n, err) 249 } 250 }() 251 252 go func() { 253 defer wg.Done() 254 n, err := serverConn.Write(rand2) 255 if err != nil { 256 t.Fatalf("copied %d bytes, error: %s", n, err) 257 } 258 }() 259 260 go func() { 261 defer wg.Done() 262 if _, err := io.CopyN(clientReader, clientConn, numBytes); err != nil { 263 t.Fatal(err) 264 } 265 }() 266 267 go func() { 268 defer wg.Done() 269 if _, err := io.CopyN(serverReader, serverConn, numBytes); err != nil { 270 t.Fatal(err) 271 } 272 }() 273 274 wg.Wait() 275 276 if bytes.Compare(clientReader.Bytes(), rand2) != 0 { 277 t.Fatal("client buffers not equal") 278 } 279 if bytes.Compare(serverReader.Bytes(), rand1) != 0 { 280 t.Fatal("server buffers not equal") 281 } 282 }