github.com/amnezia-vpn/amnezia-wg@v0.1.8/conn/bind_std_test.go (about)

     1  package conn
     2  
     3  import (
     4  	"encoding/binary"
     5  	"net"
     6  	"testing"
     7  
     8  	"golang.org/x/net/ipv6"
     9  )
    10  
    11  func TestStdNetBindReceiveFuncAfterClose(t *testing.T) {
    12  	bind := NewStdNetBind().(*StdNetBind)
    13  	fns, _, err := bind.Open(0)
    14  	if err != nil {
    15  		t.Fatal(err)
    16  	}
    17  	bind.Close()
    18  	bufs := make([][]byte, 1)
    19  	bufs[0] = make([]byte, 1)
    20  	sizes := make([]int, 1)
    21  	eps := make([]Endpoint, 1)
    22  	for _, fn := range fns {
    23  		// The ReceiveFuncs must not access conn-related fields on StdNetBind
    24  		// unguarded. Close() nils the conn-related fields resulting in a panic
    25  		// if they violate the mutex.
    26  		fn(bufs, sizes, eps)
    27  	}
    28  }
    29  
    30  func mockSetGSOSize(control *[]byte, gsoSize uint16) {
    31  	*control = (*control)[:cap(*control)]
    32  	binary.LittleEndian.PutUint16(*control, gsoSize)
    33  }
    34  
    35  func Test_coalesceMessages(t *testing.T) {
    36  	cases := []struct {
    37  		name     string
    38  		buffs    [][]byte
    39  		wantLens []int
    40  		wantGSO  []int
    41  	}{
    42  		{
    43  			name: "one message no coalesce",
    44  			buffs: [][]byte{
    45  				make([]byte, 1, 1),
    46  			},
    47  			wantLens: []int{1},
    48  			wantGSO:  []int{0},
    49  		},
    50  		{
    51  			name: "two messages equal len coalesce",
    52  			buffs: [][]byte{
    53  				make([]byte, 1, 2),
    54  				make([]byte, 1, 1),
    55  			},
    56  			wantLens: []int{2},
    57  			wantGSO:  []int{1},
    58  		},
    59  		{
    60  			name: "two messages unequal len coalesce",
    61  			buffs: [][]byte{
    62  				make([]byte, 2, 3),
    63  				make([]byte, 1, 1),
    64  			},
    65  			wantLens: []int{3},
    66  			wantGSO:  []int{2},
    67  		},
    68  		{
    69  			name: "three messages second unequal len coalesce",
    70  			buffs: [][]byte{
    71  				make([]byte, 2, 3),
    72  				make([]byte, 1, 1),
    73  				make([]byte, 2, 2),
    74  			},
    75  			wantLens: []int{3, 2},
    76  			wantGSO:  []int{2, 0},
    77  		},
    78  		{
    79  			name: "three messages limited cap coalesce",
    80  			buffs: [][]byte{
    81  				make([]byte, 2, 4),
    82  				make([]byte, 2, 2),
    83  				make([]byte, 2, 2),
    84  			},
    85  			wantLens: []int{4, 2},
    86  			wantGSO:  []int{2, 0},
    87  		},
    88  	}
    89  
    90  	for _, tt := range cases {
    91  		t.Run(tt.name, func(t *testing.T) {
    92  			addr := &net.UDPAddr{
    93  				IP:   net.ParseIP("127.0.0.1").To4(),
    94  				Port: 1,
    95  			}
    96  			msgs := make([]ipv6.Message, len(tt.buffs))
    97  			for i := range msgs {
    98  				msgs[i].Buffers = make([][]byte, 1)
    99  				msgs[i].OOB = make([]byte, 0, 2)
   100  			}
   101  			got := coalesceMessages(addr, &StdNetEndpoint{AddrPort: addr.AddrPort()}, tt.buffs, msgs, mockSetGSOSize)
   102  			if got != len(tt.wantLens) {
   103  				t.Fatalf("got len %d want: %d", got, len(tt.wantLens))
   104  			}
   105  			for i := 0; i < got; i++ {
   106  				if msgs[i].Addr != addr {
   107  					t.Errorf("msgs[%d].Addr != passed addr", i)
   108  				}
   109  				gotLen := len(msgs[i].Buffers[0])
   110  				if gotLen != tt.wantLens[i] {
   111  					t.Errorf("len(msgs[%d].Buffers[0]) %d != %d", i, gotLen, tt.wantLens[i])
   112  				}
   113  				gotGSO, err := mockGetGSOSize(msgs[i].OOB)
   114  				if err != nil {
   115  					t.Fatalf("msgs[%d] getGSOSize err: %v", i, err)
   116  				}
   117  				if gotGSO != tt.wantGSO[i] {
   118  					t.Errorf("msgs[%d] gsoSize %d != %d", i, gotGSO, tt.wantGSO[i])
   119  				}
   120  			}
   121  		})
   122  	}
   123  }
   124  
   125  func mockGetGSOSize(control []byte) (int, error) {
   126  	if len(control) < 2 {
   127  		return 0, nil
   128  	}
   129  	return int(binary.LittleEndian.Uint16(control)), nil
   130  }
   131  
   132  func Test_splitCoalescedMessages(t *testing.T) {
   133  	newMsg := func(n, gso int) ipv6.Message {
   134  		msg := ipv6.Message{
   135  			Buffers: [][]byte{make([]byte, 1<<16-1)},
   136  			N:       n,
   137  			OOB:     make([]byte, 2),
   138  		}
   139  		binary.LittleEndian.PutUint16(msg.OOB, uint16(gso))
   140  		if gso > 0 {
   141  			msg.NN = 2
   142  		}
   143  		return msg
   144  	}
   145  
   146  	cases := []struct {
   147  		name        string
   148  		msgs        []ipv6.Message
   149  		firstMsgAt  int
   150  		wantNumEval int
   151  		wantMsgLens []int
   152  		wantErr     bool
   153  	}{
   154  		{
   155  			name: "second last split last empty",
   156  			msgs: []ipv6.Message{
   157  				newMsg(0, 0),
   158  				newMsg(0, 0),
   159  				newMsg(3, 1),
   160  				newMsg(0, 0),
   161  			},
   162  			firstMsgAt:  2,
   163  			wantNumEval: 3,
   164  			wantMsgLens: []int{1, 1, 1, 0},
   165  			wantErr:     false,
   166  		},
   167  		{
   168  			name: "second last no split last empty",
   169  			msgs: []ipv6.Message{
   170  				newMsg(0, 0),
   171  				newMsg(0, 0),
   172  				newMsg(1, 0),
   173  				newMsg(0, 0),
   174  			},
   175  			firstMsgAt:  2,
   176  			wantNumEval: 1,
   177  			wantMsgLens: []int{1, 0, 0, 0},
   178  			wantErr:     false,
   179  		},
   180  		{
   181  			name: "second last no split last no split",
   182  			msgs: []ipv6.Message{
   183  				newMsg(0, 0),
   184  				newMsg(0, 0),
   185  				newMsg(1, 0),
   186  				newMsg(1, 0),
   187  			},
   188  			firstMsgAt:  2,
   189  			wantNumEval: 2,
   190  			wantMsgLens: []int{1, 1, 0, 0},
   191  			wantErr:     false,
   192  		},
   193  		{
   194  			name: "second last no split last split",
   195  			msgs: []ipv6.Message{
   196  				newMsg(0, 0),
   197  				newMsg(0, 0),
   198  				newMsg(1, 0),
   199  				newMsg(3, 1),
   200  			},
   201  			firstMsgAt:  2,
   202  			wantNumEval: 4,
   203  			wantMsgLens: []int{1, 1, 1, 1},
   204  			wantErr:     false,
   205  		},
   206  		{
   207  			name: "second last split last split",
   208  			msgs: []ipv6.Message{
   209  				newMsg(0, 0),
   210  				newMsg(0, 0),
   211  				newMsg(2, 1),
   212  				newMsg(2, 1),
   213  			},
   214  			firstMsgAt:  2,
   215  			wantNumEval: 4,
   216  			wantMsgLens: []int{1, 1, 1, 1},
   217  			wantErr:     false,
   218  		},
   219  		{
   220  			name: "second last no split last split overflow",
   221  			msgs: []ipv6.Message{
   222  				newMsg(0, 0),
   223  				newMsg(0, 0),
   224  				newMsg(1, 0),
   225  				newMsg(4, 1),
   226  			},
   227  			firstMsgAt:  2,
   228  			wantNumEval: 4,
   229  			wantMsgLens: []int{1, 1, 1, 1},
   230  			wantErr:     true,
   231  		},
   232  	}
   233  
   234  	for _, tt := range cases {
   235  		t.Run(tt.name, func(t *testing.T) {
   236  			got, err := splitCoalescedMessages(tt.msgs, 2, mockGetGSOSize)
   237  			if err != nil && !tt.wantErr {
   238  				t.Fatalf("err: %v", err)
   239  			}
   240  			if got != tt.wantNumEval {
   241  				t.Fatalf("got to eval: %d want: %d", got, tt.wantNumEval)
   242  			}
   243  			for i, msg := range tt.msgs {
   244  				if msg.N != tt.wantMsgLens[i] {
   245  					t.Fatalf("msg[%d].N: %d want: %d", i, msg.N, tt.wantMsgLens[i])
   246  				}
   247  			}
   248  		})
   249  	}
   250  }