github.com/gopacket/gopacket@v1.1.0/tcpassembly/tcpreader/reader_test.go (about) 1 // Copyright 2012 Google, Inc. All rights reserved. 2 // 3 // Use of this source code is governed by a BSD-style license 4 // that can be found in the LICENSE file in the root of the source 5 // tree. 6 7 package tcpreader 8 9 import ( 10 "bytes" 11 "fmt" 12 "io" 13 "net" 14 "testing" 15 16 "github.com/gopacket/gopacket" 17 "github.com/gopacket/gopacket/layers" 18 "github.com/gopacket/gopacket/tcpassembly" 19 ) 20 21 var netFlow gopacket.Flow 22 23 func init() { 24 netFlow, _ = gopacket.FlowFromEndpoints( 25 layers.NewIPEndpoint(net.IP{1, 2, 3, 4}), 26 layers.NewIPEndpoint(net.IP{5, 6, 7, 8})) 27 } 28 29 type readReturn struct { 30 data []byte 31 err error 32 } 33 type readSequence struct { 34 in []layers.TCP 35 want []readReturn 36 } 37 type testReaderFactory struct { 38 lossErrors bool 39 readSize int 40 ReaderStream 41 output chan []byte 42 } 43 44 func (t *testReaderFactory) New(a, b gopacket.Flow) tcpassembly.Stream { 45 return &t.ReaderStream 46 } 47 48 func testReadSequence(t *testing.T, lossErrors bool, readSize int, seq readSequence) { 49 f := &testReaderFactory{ReaderStream: NewReaderStream()} 50 f.ReaderStream.LossErrors = lossErrors 51 p := tcpassembly.NewStreamPool(f) 52 a := tcpassembly.NewAssembler(p) 53 buf := make([]byte, readSize) 54 go func() { 55 for i, test := range seq.in { 56 fmt.Println("Assembling", i) 57 a.Assemble(netFlow, &test) 58 fmt.Println("Assembly done") 59 } 60 }() 61 for i, test := range seq.want { 62 fmt.Println("Waiting for read", i) 63 n, err := f.Read(buf[:]) 64 fmt.Println("Got read") 65 if n != len(test.data) { 66 t.Errorf("test %d want %d bytes, got %d bytes", i, len(test.data), n) 67 } else if err != test.err { 68 t.Errorf("test %d want err %v, got err %v", i, test.err, err) 69 } else if !bytes.Equal(buf[:n], test.data) { 70 t.Errorf("test %d\nwant: %v\n got: %v\n", i, test.data, buf[:n]) 71 } 72 } 73 fmt.Println("All done reads") 74 } 75 76 func TestRead(t *testing.T) { 77 testReadSequence(t, false, 10, readSequence{ 78 in: []layers.TCP{ 79 { 80 SYN: true, 81 SrcPort: 1, 82 DstPort: 2, 83 Seq: 1000, 84 BaseLayer: layers.BaseLayer{Payload: []byte{1, 2, 3}}, 85 }, 86 { 87 FIN: true, 88 SrcPort: 1, 89 DstPort: 2, 90 Seq: 1004, 91 }, 92 }, 93 want: []readReturn{ 94 {data: []byte{1, 2, 3}}, 95 {err: io.EOF}, 96 }, 97 }) 98 } 99 100 func TestReadSmallChunks(t *testing.T) { 101 testReadSequence(t, false, 2, readSequence{ 102 in: []layers.TCP{ 103 { 104 SYN: true, 105 SrcPort: 1, 106 DstPort: 2, 107 Seq: 1000, 108 BaseLayer: layers.BaseLayer{Payload: []byte{1, 2, 3}}, 109 }, 110 { 111 FIN: true, 112 SrcPort: 1, 113 DstPort: 2, 114 Seq: 1004, 115 }, 116 }, 117 want: []readReturn{ 118 {data: []byte{1, 2}}, 119 {data: []byte{3}}, 120 {err: io.EOF}, 121 }, 122 }) 123 } 124 125 func ExampleDiscardBytesToEOF() { 126 b := bytes.NewBuffer([]byte{1, 2, 3, 4, 5}) 127 fmt.Println(DiscardBytesToEOF(b)) 128 // Output: 129 // 5 130 }