github.com/mysteriumnetwork/node@v0.0.0-20240516044423-365054f76801/router/router_test.go (about)

     1  /*
     2   * Copyright (C) 2021 The "MysteriumNetwork/node" Authors.
     3   *
     4   * This program is free software: you can redistribute it and/or modify
     5   * it under the terms of the GNU General Public License as published by
     6   * the Free Software Foundation, either version 3 of the License, or
     7   * (at your option) any later version.
     8   *
     9   * This program is distributed in the hope that it will be useful,
    10   * but WITHOUT ANY WARRANTY; without even the implied warranty of
    11   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    12   * GNU General Public License for more details.
    13   *
    14   * You should have received a copy of the GNU General Public License
    15   * along with this program.  If not, see <http://www.gnu.org/licenses/>.
    16   */
    17  
    18  package router
    19  
    20  import (
    21  	"fmt"
    22  	"net"
    23  	"sync"
    24  	"testing"
    25  	"time"
    26  
    27  	"github.com/stretchr/testify/assert"
    28  )
    29  
    30  func Test_router_ExcludeIP(t *testing.T) {
    31  	tests := []struct {
    32  		name            string
    33  		ips             []net.IP
    34  		expectedRecords int
    35  		wantErr         bool
    36  	}{
    37  		{
    38  			name:            "Adding multiple unique records",
    39  			ips:             []net.IP{net.ParseIP("1.2.3.4"), net.ParseIP("2.3.4.5"), net.ParseIP("3.4.5.6")},
    40  			expectedRecords: 3,
    41  			wantErr:         false,
    42  		},
    43  		{
    44  			name:            "Adding duplicated rules saves only once",
    45  			ips:             []net.IP{net.ParseIP("1.2.3.4"), net.ParseIP("1.2.3.4")},
    46  			expectedRecords: 1,
    47  			wantErr:         false,
    48  		},
    49  		{
    50  			name:            "No panic on empty IP, just expected error",
    51  			ips:             []net.IP{nil},
    52  			expectedRecords: 1,
    53  			wantErr:         true,
    54  		},
    55  	}
    56  
    57  	for _, tt := range tests {
    58  		t.Run(tt.name, func(t *testing.T) {
    59  			table := &mockRoutingTable{gw: net.ParseIP("1.1.1.1")}
    60  			r := &manager{
    61  				stop:         make(chan struct{}),
    62  				routingTable: table,
    63  			}
    64  
    65  			for _, ip := range tt.ips {
    66  				if err := r.ExcludeIP(ip); (err != nil) != tt.wantErr {
    67  					t.Errorf("Error = %v, wantErr %v", err, tt.wantErr)
    68  				}
    69  			}
    70  
    71  			assert.Len(t, table.rules, tt.expectedRecords, "Expected number of table rules does not match")
    72  		})
    73  	}
    74  }
    75  
    76  func Test_router_Clean(t *testing.T) {
    77  	tests := []struct {
    78  		name            string
    79  		ips             []net.IP
    80  		expectedRecords int
    81  		wantErr         bool
    82  	}{
    83  		{
    84  			name:            "Clean multiple unique records",
    85  			ips:             []net.IP{net.ParseIP("1.2.3.4"), net.ParseIP("2.3.4.5"), net.ParseIP("3.4.5.6")},
    86  			expectedRecords: 0,
    87  			wantErr:         false,
    88  		},
    89  		{
    90  			name:            "Clean duplicated rules only once",
    91  			ips:             []net.IP{net.ParseIP("1.2.3.4"), net.ParseIP("1.2.3.4")},
    92  			expectedRecords: 0,
    93  			wantErr:         false,
    94  		},
    95  		{
    96  			name:            "No panic on empty table, just expected error",
    97  			ips:             []net.IP{},
    98  			expectedRecords: 0,
    99  			wantErr:         true,
   100  		},
   101  	}
   102  
   103  	for _, tt := range tests {
   104  		t.Run(tt.name, func(t *testing.T) {
   105  			table := &mockRoutingTable{gw: net.ParseIP("1.1.1.1")}
   106  			r := &manager{
   107  				stop:         make(chan struct{}),
   108  				routingTable: table,
   109  			}
   110  
   111  			for _, ip := range tt.ips {
   112  				if err := r.ExcludeIP(ip); (err != nil) != tt.wantErr {
   113  					t.Errorf("Error = %v, wantErr %v", err, tt.wantErr)
   114  				}
   115  			}
   116  
   117  			err := r.Clean()
   118  
   119  			assert.NoError(t, err)
   120  
   121  			assert.Len(t, table.rules, tt.expectedRecords, "Expected number of table rules does not match")
   122  		})
   123  	}
   124  }
   125  
   126  func Test_router_ReplaceGW(t *testing.T) {
   127  	table := &mockRoutingTable{gw: net.ParseIP("1.1.1.1")}
   128  	r := &manager{
   129  		stop:            make(chan struct{}),
   130  		gwCheckInterval: 100 * time.Millisecond,
   131  		routingTable:    table,
   132  	}
   133  
   134  	r.ExcludeIP(net.ParseIP("2.2.2.2"))
   135  	r.ExcludeIP(net.ParseIP("3.3.3.3"))
   136  
   137  	assert.Contains(t, table.rules, "2.2.2.2:1.1.1.1")
   138  	assert.Contains(t, table.rules, "3.3.3.3:1.1.1.1")
   139  	assert.NotContains(t, table.rules, "2.2.2.2:4.4.4.4")
   140  	assert.NotContains(t, table.rules, "3.3.3.3:4.4.4.4")
   141  	assert.Len(t, table.rules, 2)
   142  
   143  	table.setGW(net.ParseIP("4.4.4.4"))
   144  
   145  	assert.Eventually(t, func() bool {
   146  		table.mu.Lock()
   147  		defer table.mu.Unlock()
   148  
   149  		_, ok1 := table.rules["2.2.2.2:4.4.4.4"]
   150  		_, ok2 := table.rules["3.3.3.3:4.4.4.4"]
   151  		_, not1 := table.rules["2.2.2.2:1.1.1.1"]
   152  		_, not2 := table.rules["3.3.3.3:1.1.1.1"]
   153  
   154  		return ok1 && ok2 && !not1 && !not2
   155  	}, time.Second, 10*time.Millisecond)
   156  
   157  	assert.Len(t, table.rules, 2)
   158  }
   159  
   160  type mockRoutingTable struct {
   161  	rules map[string]int
   162  	gw    net.IP
   163  
   164  	mu sync.Mutex
   165  }
   166  
   167  func (t *mockRoutingTable) ExcludeRule(ip, gw net.IP) error {
   168  	t.mu.Lock()
   169  	defer t.mu.Unlock()
   170  
   171  	if t.rules == nil {
   172  		t.rules = make(map[string]int)
   173  	}
   174  
   175  	t.rules[fmt.Sprintf("%s:%s", ip, gw)]++
   176  
   177  	if ip.Equal(nil) {
   178  		return fmt.Errorf("expected error")
   179  	}
   180  
   181  	return nil
   182  }
   183  
   184  func (t *mockRoutingTable) DeleteRule(ip, gw net.IP) error {
   185  	t.mu.Lock()
   186  	defer t.mu.Unlock()
   187  
   188  	if t.rules == nil {
   189  		t.rules = make(map[string]int)
   190  	}
   191  
   192  	t.rules[fmt.Sprintf("%s:%s", ip, gw)]--
   193  
   194  	if t.rules[fmt.Sprintf("%s:%s", ip, gw)] == 0 {
   195  		delete(t.rules, fmt.Sprintf("%s:%s", ip, gw))
   196  	}
   197  
   198  	return nil
   199  }
   200  
   201  func (t *mockRoutingTable) DiscoverGateway() (net.IP, error) {
   202  	t.mu.Lock()
   203  	defer t.mu.Unlock()
   204  
   205  	return t.gw, nil
   206  }
   207  
   208  func (t *mockRoutingTable) setGW(gw net.IP) {
   209  	t.mu.Lock()
   210  	defer t.mu.Unlock()
   211  
   212  	t.gw = gw
   213  }