github.com/anacrolix/torrent@v1.61.0/mse/mse_test.go (about) 1 package mse 2 3 import ( 4 "bytes" 5 "context" 6 "crypto/rand" 7 "crypto/rc4" 8 "io" 9 "net" 10 "sync" 11 "testing" 12 13 _ "github.com/anacrolix/envpprof" 14 "github.com/stretchr/testify/assert" 15 "github.com/stretchr/testify/require" 16 ) 17 18 func sliceIter(skeys [][]byte) SecretKeyIter { 19 return func(callback func([]byte) bool) { 20 for _, sk := range skeys { 21 if !callback(sk) { 22 break 23 } 24 } 25 } 26 } 27 28 func TestReadUntil(t *testing.T) { 29 test := func(data, until string, leftover int, expectedErr error) { 30 r := bytes.NewReader([]byte(data)) 31 err := readUntil(r, []byte(until)) 32 if err != expectedErr { 33 t.Fatal(err) 34 } 35 if r.Len() != leftover { 36 t.Fatal(r.Len()) 37 } 38 } 39 test("feakjfeafeafegbaabc00", "abc", 2, nil) 40 test("feakjfeafeafegbaadc00", "abc", 0, io.EOF) 41 } 42 43 func TestSuffixMatchLen(t *testing.T) { 44 test := func(a, b string, expected int) { 45 actual := suffixMatchLen([]byte(a), []byte(b)) 46 if actual != expected { 47 t.Fatalf("expected %d, got %d for %q and %q", expected, actual, a, b) 48 } 49 } 50 test("hello", "world", 0) 51 test("hello", "lo", 2) 52 test("hello", "llo", 3) 53 test("hello", "hell", 0) 54 test("hello", "helloooo!", 5) 55 test("hello", "lol!", 2) 56 test("hello", "mondo", 0) 57 test("mongo", "webscale", 0) 58 test("sup", "person", 1) 59 } 60 61 func handshakeTest(t testing.TB, ia []byte, aData, bData string, cryptoProvides CryptoMethod, cryptoSelect CryptoSelector) { 62 a, b := net.Pipe() 63 wg := sync.WaitGroup{} 64 wg.Add(2) 65 go func() { 66 defer wg.Done() 67 a, cm, err := InitiateHandshake(a, []byte("yep"), ia, cryptoProvides) 68 require.NoError(t, err) 69 assert.Equal(t, cryptoSelect(cryptoProvides), cm) 70 go a.Write([]byte(aData)) 71 72 var msg [20]byte 73 n, _ := a.Read(msg[:]) 74 if n != len(bData) { 75 t.FailNow() 76 } 77 // t.Log(string(msg[:n])) 78 }() 79 go func() { 80 defer wg.Done() 81 res := ReceiveHandshakeEx( 82 context.Background(), 83 b, 84 sliceIter([][]byte{[]byte("nope"), []byte("yep"), []byte("maybe")}), 85 cryptoSelect, 86 ) 87 require.NoError(t, res.error) 88 assert.EqualValues(t, "yep", res.SecretKey) 89 b := res.ReadWriter 90 assert.Equal(t, cryptoSelect(cryptoProvides), res.CryptoMethod) 91 go b.Write([]byte(bData)) 92 // Need to be exact here, as there are several reads, and net.Pipe is most synchronous. 93 msg := make([]byte, len(ia)+len(aData)) 94 n, _ := io.ReadFull(b, msg) 95 if n != len(msg) { 96 t.FailNow() 97 } 98 // t.Log(string(msg[:n])) 99 }() 100 wg.Wait() 101 a.Close() 102 b.Close() 103 } 104 105 func allHandshakeTests(t testing.TB, provides CryptoMethod, selector CryptoSelector) { 106 handshakeTest(t, []byte("jump the gun, "), "hello world", "yo dawg", provides, selector) 107 handshakeTest(t, nil, "hello world", "yo dawg", provides, selector) 108 handshakeTest(t, []byte{}, "hello world", "yo dawg", provides, selector) 109 } 110 111 func TestHandshakeDefault(t *testing.T) { 112 allHandshakeTests(t, AllSupportedCrypto, DefaultCryptoSelector) 113 t.Logf("crypto provides encountered: %s", cryptoProvidesCount) 114 } 115 116 func TestHandshakeSelectPlaintext(t *testing.T) { 117 allHandshakeTests(t, AllSupportedCrypto, func(CryptoMethod) CryptoMethod { return CryptoMethodPlaintext }) 118 } 119 120 func BenchmarkHandshakeDefault(b *testing.B) { 121 for i := 0; i < b.N; i += 1 { 122 allHandshakeTests(b, AllSupportedCrypto, DefaultCryptoSelector) 123 } 124 } 125 126 type trackReader struct { 127 r io.Reader 128 n int64 129 } 130 131 func (tr *trackReader) Read(b []byte) (n int, err error) { 132 n, err = tr.r.Read(b) 133 tr.n += int64(n) 134 return 135 } 136 137 func TestReceiveRandomData(t *testing.T) { 138 tr := trackReader{rand.Reader, 0} 139 _, _, err := ReceiveHandshake(context.Background(), readWriter{&tr, io.Discard}, nil, DefaultCryptoSelector) 140 // No skey matches 141 require.Error(t, err) 142 // Establishing S, and then reading the maximum padding for giving up on 143 // synchronizing. 144 require.EqualValues(t, 96+532, tr.n) 145 } 146 147 func fillRand(t testing.TB, bs ...[]byte) { 148 for _, b := range bs { 149 _, err := rand.Read(b) 150 require.NoError(t, err) 151 } 152 } 153 154 func readAndWrite(rw io.ReadWriter, r, w []byte) error { 155 var wg sync.WaitGroup 156 wg.Add(1) 157 var wErr error 158 go func() { 159 defer wg.Done() 160 _, wErr = rw.Write(w) 161 }() 162 _, err := io.ReadFull(rw, r) 163 if err != nil { 164 return err 165 } 166 wg.Wait() 167 return wErr 168 } 169 170 func benchmarkStream(t *testing.B, crypto CryptoMethod) { 171 ia := make([]byte, 0x1000) 172 a := make([]byte, 1<<20) 173 b := make([]byte, 1<<20) 174 fillRand(t, ia, a, b) 175 t.StopTimer() 176 t.SetBytes(int64(len(ia) + len(a) + len(b))) 177 t.ResetTimer() 178 for i := 0; i < t.N; i += 1 { 179 ac, bc := net.Pipe() 180 ar := make([]byte, len(b)) 181 br := make([]byte, len(ia)+len(a)) 182 t.StartTimer() 183 var wg sync.WaitGroup 184 wg.Add(1) 185 go func() { 186 defer ac.Close() 187 defer wg.Done() 188 rw, _, err := InitiateHandshake(ac, []byte("cats"), ia, crypto) 189 require.NoError(t, err) 190 require.NoError(t, readAndWrite(rw, ar, a)) 191 }() 192 func() { 193 defer bc.Close() 194 rw, _, err := ReceiveHandshake( 195 context.Background(), 196 bc, 197 sliceIter([][]byte{[]byte("cats")}), 198 func(CryptoMethod) CryptoMethod { return crypto }, 199 ) 200 require.NoError(t, err) 201 require.NoError(t, readAndWrite(rw, br, b)) 202 }() 203 wg.Wait() 204 t.StopTimer() 205 if !bytes.Equal(ar, b) { 206 t.Fatalf("A read the wrong bytes") 207 } 208 if !bytes.Equal(br[:len(ia)], ia) { 209 t.Fatalf("B read the wrong IA") 210 } 211 if !bytes.Equal(br[len(ia):], a) { 212 t.Fatalf("B read the wrong A") 213 } 214 // require.Equal(t, b, ar) 215 // require.Equal(t, ia, br[:len(ia)]) 216 // require.Equal(t, a, br[len(ia):]) 217 } 218 } 219 220 func BenchmarkStreamRC4(t *testing.B) { 221 benchmarkStream(t, CryptoMethodRC4) 222 } 223 224 func BenchmarkStreamPlaintext(t *testing.B) { 225 benchmarkStream(t, CryptoMethodPlaintext) 226 } 227 228 func BenchmarkPipeRC4(t *testing.B) { 229 key := make([]byte, 20) 230 n, _ := rand.Read(key) 231 require.Equal(t, len(key), n) 232 var buf bytes.Buffer 233 c, err := rc4.NewCipher(key) 234 require.NoError(t, err) 235 r := cipherReader{ 236 c: c, 237 r: &buf, 238 } 239 c, err = rc4.NewCipher(key) 240 require.NoError(t, err) 241 w := cipherWriter{ 242 c: c, 243 w: &buf, 244 } 245 a := make([]byte, 0x1000) 246 n, _ = io.ReadFull(rand.Reader, a) 247 require.Equal(t, len(a), n) 248 b := make([]byte, len(a)) 249 t.SetBytes(int64(len(a))) 250 t.ResetTimer() 251 for i := 0; i < t.N; i += 1 { 252 n, _ = w.Write(a) 253 if n != len(a) { 254 t.FailNow() 255 } 256 n, _ = r.Read(b) 257 if n != len(b) { 258 t.FailNow() 259 } 260 if !bytes.Equal(a, b) { 261 t.FailNow() 262 } 263 } 264 } 265 266 func BenchmarkSkeysReceive(b *testing.B) { 267 var skeys [][]byte 268 for i := 0; i < 100000; i += 1 { 269 skeys = append(skeys, make([]byte, 20)) 270 } 271 fillRand(b, skeys...) 272 initSkey := skeys[len(skeys)/2] 273 // c := qt.New(b) 274 b.ReportAllocs() 275 b.ResetTimer() 276 for i := 0; i < b.N; i += 1 { 277 initiator, receiver := net.Pipe() 278 go func() { 279 _, _, err := InitiateHandshake(initiator, initSkey, nil, AllSupportedCrypto) 280 if err != nil { 281 panic(err) 282 } 283 }() 284 res := ReceiveHandshakeEx(context.Background(), receiver, sliceIter(skeys), DefaultCryptoSelector) 285 if res.error != nil { 286 panic(res.error) 287 } 288 } 289 }