github.com/noisysockets/noisysockets@v0.21.2-0.20240515114641-7f467e651c90/internal/conn/bind_std_test.go (about)

     1  // SPDX-License-Identifier: MPL-2.0
     2  /*
     3   * Copyright (C) 2024 The Noisy Sockets Authors.
     4   *
     5   * This Source Code Form is subject to the terms of the Mozilla Public
     6   * License, v. 2.0. If a copy of the MPL was not distributed with this
     7   * file, You can obtain one at http://mozilla.org/MPL/2.0/.
     8   *
     9   * Portions of this file are based on code originally from wireguard-go,
    10   *
    11   * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved.
    12   *
    13   * Permission is hereby granted, free of charge, to any person obtaining a copy of
    14   * this software and associated documentation files (the "Software"), to deal in
    15   * the Software without restriction, including without limitation the rights to
    16   * use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
    17   * of the Software, and to permit persons to whom the Software is furnished to do
    18   * so, subject to the following conditions:
    19   *
    20   * The above copyright notice and this permission notice shall be included in all
    21   * copies or substantial portions of the Software.
    22   *
    23   * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    24   * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    25   * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    26   * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    27   * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    28   * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
    29   * SOFTWARE.
    30   */
    31  
    32  package conn
    33  
    34  import (
    35  	"encoding/binary"
    36  	"errors"
    37  	"net"
    38  	"testing"
    39  
    40  	"golang.org/x/net/ipv6"
    41  )
    42  
    43  func TestStdNetBindReceiveFuncAfterClose(t *testing.T) {
    44  	bind := NewStdNetBind().(*StdNetBind)
    45  	fns, _, err := bind.Open(0)
    46  	if err != nil {
    47  		t.Fatal(err)
    48  	}
    49  	bind.Close()
    50  	bufs := make([][]byte, 1)
    51  	bufs[0] = make([]byte, 1)
    52  	sizes := make([]int, 1)
    53  	eps := make([]Endpoint, 1)
    54  	for _, fn := range fns {
    55  		// The ReceiveFuncs must not access conn-related fields on StdNetBind
    56  		// unguarded. Close() nils the conn-related fields resulting in a panic
    57  		// if they violate the mutex.
    58  		if _, err := fn(bufs, sizes, eps); err != nil && !errors.Is(err, net.ErrClosed) {
    59  			t.Fatal(err)
    60  		}
    61  	}
    62  }
    63  
    64  func mockSetGSOSize(control *[]byte, gsoSize uint16) {
    65  	*control = (*control)[:cap(*control)]
    66  	binary.LittleEndian.PutUint16(*control, gsoSize)
    67  }
    68  
    69  func Test_coalesceMessages(t *testing.T) {
    70  	cases := []struct {
    71  		name     string
    72  		buffs    [][]byte
    73  		wantLens []int
    74  		wantGSO  []int
    75  	}{
    76  		{
    77  			name: "one message no coalesce",
    78  			buffs: [][]byte{
    79  				make([]byte, 1),
    80  			},
    81  			wantLens: []int{1},
    82  			wantGSO:  []int{0},
    83  		},
    84  		{
    85  			name: "two messages equal len coalesce",
    86  			buffs: [][]byte{
    87  				make([]byte, 1, 2),
    88  				make([]byte, 1),
    89  			},
    90  			wantLens: []int{2},
    91  			wantGSO:  []int{1},
    92  		},
    93  		{
    94  			name: "two messages unequal len coalesce",
    95  			buffs: [][]byte{
    96  				make([]byte, 2, 3),
    97  				make([]byte, 1),
    98  			},
    99  			wantLens: []int{3},
   100  			wantGSO:  []int{2},
   101  		},
   102  		{
   103  			name: "three messages second unequal len coalesce",
   104  			buffs: [][]byte{
   105  				make([]byte, 2, 3),
   106  				make([]byte, 1),
   107  				make([]byte, 2),
   108  			},
   109  			wantLens: []int{3, 2},
   110  			wantGSO:  []int{2, 0},
   111  		},
   112  		{
   113  			name: "three messages limited cap coalesce",
   114  			buffs: [][]byte{
   115  				make([]byte, 2, 4),
   116  				make([]byte, 2),
   117  				make([]byte, 2),
   118  			},
   119  			wantLens: []int{4, 2},
   120  			wantGSO:  []int{2, 0},
   121  		},
   122  	}
   123  
   124  	for _, tt := range cases {
   125  		t.Run(tt.name, func(t *testing.T) {
   126  			addr := &net.UDPAddr{
   127  				IP:   net.ParseIP("127.0.0.1").To4(),
   128  				Port: 1,
   129  			}
   130  			msgs := make([]ipv6.Message, len(tt.buffs))
   131  			for i := range msgs {
   132  				msgs[i].Buffers = make([][]byte, 1)
   133  				msgs[i].OOB = make([]byte, 0, 2)
   134  			}
   135  			got := coalesceMessages(addr, &StdNetEndpoint{AddrPort: addr.AddrPort()}, tt.buffs, msgs, mockSetGSOSize)
   136  			if got != len(tt.wantLens) {
   137  				t.Fatalf("got len %d want: %d", got, len(tt.wantLens))
   138  			}
   139  			for i := 0; i < got; i++ {
   140  				if msgs[i].Addr != addr {
   141  					t.Errorf("msgs[%d].Addr != passed addr", i)
   142  				}
   143  				gotLen := len(msgs[i].Buffers[0])
   144  				if gotLen != tt.wantLens[i] {
   145  					t.Errorf("len(msgs[%d].Buffers[0]) %d != %d", i, gotLen, tt.wantLens[i])
   146  				}
   147  				gotGSO, err := mockGetGSOSize(msgs[i].OOB)
   148  				if err != nil {
   149  					t.Fatalf("msgs[%d] getGSOSize err: %v", i, err)
   150  				}
   151  				if gotGSO != tt.wantGSO[i] {
   152  					t.Errorf("msgs[%d] gsoSize %d != %d", i, gotGSO, tt.wantGSO[i])
   153  				}
   154  			}
   155  		})
   156  	}
   157  }
   158  
   159  func mockGetGSOSize(control []byte) (int, error) {
   160  	if len(control) < 2 {
   161  		return 0, nil
   162  	}
   163  	return int(binary.LittleEndian.Uint16(control)), nil
   164  }
   165  
   166  func Test_splitCoalescedMessages(t *testing.T) {
   167  	newMsg := func(n, gso int) ipv6.Message {
   168  		msg := ipv6.Message{
   169  			Buffers: [][]byte{make([]byte, 1<<16-1)},
   170  			N:       n,
   171  			OOB:     make([]byte, 2),
   172  		}
   173  		binary.LittleEndian.PutUint16(msg.OOB, uint16(gso))
   174  		if gso > 0 {
   175  			msg.NN = 2
   176  		}
   177  		return msg
   178  	}
   179  
   180  	cases := []struct {
   181  		name        string
   182  		msgs        []ipv6.Message
   183  		firstMsgAt  int
   184  		wantNumEval int
   185  		wantMsgLens []int
   186  		wantErr     bool
   187  	}{
   188  		{
   189  			name: "second last split last empty",
   190  			msgs: []ipv6.Message{
   191  				newMsg(0, 0),
   192  				newMsg(0, 0),
   193  				newMsg(3, 1),
   194  				newMsg(0, 0),
   195  			},
   196  			firstMsgAt:  2,
   197  			wantNumEval: 3,
   198  			wantMsgLens: []int{1, 1, 1, 0},
   199  			wantErr:     false,
   200  		},
   201  		{
   202  			name: "second last no split last empty",
   203  			msgs: []ipv6.Message{
   204  				newMsg(0, 0),
   205  				newMsg(0, 0),
   206  				newMsg(1, 0),
   207  				newMsg(0, 0),
   208  			},
   209  			firstMsgAt:  2,
   210  			wantNumEval: 1,
   211  			wantMsgLens: []int{1, 0, 0, 0},
   212  			wantErr:     false,
   213  		},
   214  		{
   215  			name: "second last no split last no split",
   216  			msgs: []ipv6.Message{
   217  				newMsg(0, 0),
   218  				newMsg(0, 0),
   219  				newMsg(1, 0),
   220  				newMsg(1, 0),
   221  			},
   222  			firstMsgAt:  2,
   223  			wantNumEval: 2,
   224  			wantMsgLens: []int{1, 1, 0, 0},
   225  			wantErr:     false,
   226  		},
   227  		{
   228  			name: "second last no split last split",
   229  			msgs: []ipv6.Message{
   230  				newMsg(0, 0),
   231  				newMsg(0, 0),
   232  				newMsg(1, 0),
   233  				newMsg(3, 1),
   234  			},
   235  			firstMsgAt:  2,
   236  			wantNumEval: 4,
   237  			wantMsgLens: []int{1, 1, 1, 1},
   238  			wantErr:     false,
   239  		},
   240  		{
   241  			name: "second last split last split",
   242  			msgs: []ipv6.Message{
   243  				newMsg(0, 0),
   244  				newMsg(0, 0),
   245  				newMsg(2, 1),
   246  				newMsg(2, 1),
   247  			},
   248  			firstMsgAt:  2,
   249  			wantNumEval: 4,
   250  			wantMsgLens: []int{1, 1, 1, 1},
   251  			wantErr:     false,
   252  		},
   253  		{
   254  			name: "second last no split last split overflow",
   255  			msgs: []ipv6.Message{
   256  				newMsg(0, 0),
   257  				newMsg(0, 0),
   258  				newMsg(1, 0),
   259  				newMsg(4, 1),
   260  			},
   261  			firstMsgAt:  2,
   262  			wantNumEval: 4,
   263  			wantMsgLens: []int{1, 1, 1, 1},
   264  			wantErr:     true,
   265  		},
   266  	}
   267  
   268  	for _, tt := range cases {
   269  		t.Run(tt.name, func(t *testing.T) {
   270  			got, err := splitCoalescedMessages(tt.msgs, 2, mockGetGSOSize)
   271  			if err != nil && !tt.wantErr {
   272  				t.Fatalf("err: %v", err)
   273  			}
   274  			if got != tt.wantNumEval {
   275  				t.Fatalf("got to eval: %d want: %d", got, tt.wantNumEval)
   276  			}
   277  			for i, msg := range tt.msgs {
   278  				if msg.N != tt.wantMsgLens[i] {
   279  					t.Fatalf("msg[%d].N: %d want: %d", i, msg.N, tt.wantMsgLens[i])
   280  				}
   281  			}
   282  		})
   283  	}
   284  }