github.com/gopacket/gopacket@v1.1.0/tcpassembly/tcpreader/reader_test.go (about)

     1  // Copyright 2012 Google, Inc. All rights reserved.
     2  //
     3  // Use of this source code is governed by a BSD-style license
     4  // that can be found in the LICENSE file in the root of the source
     5  // tree.
     6  
     7  package tcpreader
     8  
     9  import (
    10  	"bytes"
    11  	"fmt"
    12  	"io"
    13  	"net"
    14  	"testing"
    15  
    16  	"github.com/gopacket/gopacket"
    17  	"github.com/gopacket/gopacket/layers"
    18  	"github.com/gopacket/gopacket/tcpassembly"
    19  )
    20  
    21  var netFlow gopacket.Flow
    22  
    23  func init() {
    24  	netFlow, _ = gopacket.FlowFromEndpoints(
    25  		layers.NewIPEndpoint(net.IP{1, 2, 3, 4}),
    26  		layers.NewIPEndpoint(net.IP{5, 6, 7, 8}))
    27  }
    28  
    29  type readReturn struct {
    30  	data []byte
    31  	err  error
    32  }
    33  type readSequence struct {
    34  	in   []layers.TCP
    35  	want []readReturn
    36  }
    37  type testReaderFactory struct {
    38  	lossErrors bool
    39  	readSize   int
    40  	ReaderStream
    41  	output chan []byte
    42  }
    43  
    44  func (t *testReaderFactory) New(a, b gopacket.Flow) tcpassembly.Stream {
    45  	return &t.ReaderStream
    46  }
    47  
    48  func testReadSequence(t *testing.T, lossErrors bool, readSize int, seq readSequence) {
    49  	f := &testReaderFactory{ReaderStream: NewReaderStream()}
    50  	f.ReaderStream.LossErrors = lossErrors
    51  	p := tcpassembly.NewStreamPool(f)
    52  	a := tcpassembly.NewAssembler(p)
    53  	buf := make([]byte, readSize)
    54  	go func() {
    55  		for i, test := range seq.in {
    56  			fmt.Println("Assembling", i)
    57  			a.Assemble(netFlow, &test)
    58  			fmt.Println("Assembly done")
    59  		}
    60  	}()
    61  	for i, test := range seq.want {
    62  		fmt.Println("Waiting for read", i)
    63  		n, err := f.Read(buf[:])
    64  		fmt.Println("Got read")
    65  		if n != len(test.data) {
    66  			t.Errorf("test %d want %d bytes, got %d bytes", i, len(test.data), n)
    67  		} else if err != test.err {
    68  			t.Errorf("test %d want err %v, got err %v", i, test.err, err)
    69  		} else if !bytes.Equal(buf[:n], test.data) {
    70  			t.Errorf("test %d\nwant: %v\n got: %v\n", i, test.data, buf[:n])
    71  		}
    72  	}
    73  	fmt.Println("All done reads")
    74  }
    75  
    76  func TestRead(t *testing.T) {
    77  	testReadSequence(t, false, 10, readSequence{
    78  		in: []layers.TCP{
    79  			{
    80  				SYN:       true,
    81  				SrcPort:   1,
    82  				DstPort:   2,
    83  				Seq:       1000,
    84  				BaseLayer: layers.BaseLayer{Payload: []byte{1, 2, 3}},
    85  			},
    86  			{
    87  				FIN:     true,
    88  				SrcPort: 1,
    89  				DstPort: 2,
    90  				Seq:     1004,
    91  			},
    92  		},
    93  		want: []readReturn{
    94  			{data: []byte{1, 2, 3}},
    95  			{err: io.EOF},
    96  		},
    97  	})
    98  }
    99  
   100  func TestReadSmallChunks(t *testing.T) {
   101  	testReadSequence(t, false, 2, readSequence{
   102  		in: []layers.TCP{
   103  			{
   104  				SYN:       true,
   105  				SrcPort:   1,
   106  				DstPort:   2,
   107  				Seq:       1000,
   108  				BaseLayer: layers.BaseLayer{Payload: []byte{1, 2, 3}},
   109  			},
   110  			{
   111  				FIN:     true,
   112  				SrcPort: 1,
   113  				DstPort: 2,
   114  				Seq:     1004,
   115  			},
   116  		},
   117  		want: []readReturn{
   118  			{data: []byte{1, 2}},
   119  			{data: []byte{3}},
   120  			{err: io.EOF},
   121  		},
   122  	})
   123  }
   124  
   125  func ExampleDiscardBytesToEOF() {
   126  	b := bytes.NewBuffer([]byte{1, 2, 3, 4, 5})
   127  	fmt.Println(DiscardBytesToEOF(b))
   128  	// Output:
   129  	// 5
   130  }