github.com/Psiphon-Labs/psiphon-tunnel-core@v2.0.28+incompatible/psiphon/server/trafficRules_test.go (about)

     1  /*
     2   * Copyright (c) 2022, Psiphon Inc.
     3   * All rights reserved.
     4   *
     5   * This program is free software: you can redistribute it and/or modify
     6   * it under the terms of the GNU General Public License as published by
     7   * the Free Software Foundation, either version 3 of the License, or
     8   * (at your option) any later version.
     9   *
    10   * This program is distributed in the hope that it will be useful,
    11   * but WITHOUT ANY WARRANTY; without even the implied warranty of
    12   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    13   * GNU General Public License for more details.
    14   *
    15   * You should have received a copy of the GNU General Public License
    16   * along with this program.  If not, see <http://www.gnu.org/licenses/>.
    17   *
    18   */
    19  
    20  package server
    21  
    22  import (
    23  	"encoding/json"
    24  	"io/ioutil"
    25  	"os"
    26  	"reflect"
    27  	"testing"
    28  
    29  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
    30  )
    31  
    32  func TestTrafficRulesFilters(t *testing.T) {
    33  
    34  	trafficRulesJSON := `
    35  	{
    36        "DefaultRules" :  {
    37          "RateLimits" : {
    38            "WriteUnthrottledBytes": 1,
    39            "WriteBytesPerSecond": 2,
    40            "ReadUnthrottledBytes": 3,
    41            "ReadBytesPerSecond": 4,
    42            "UnthrottleFirstTunnelOnly": true
    43          },
    44          "AllowTCPPorts" : [5],
    45          "AllowUDPPorts" : [6]
    46        },
    47    
    48        "FilteredRules" : [
    49    
    50          {
    51            "Filter" : {
    52              "Regions" : ["R2"],
    53              "HandshakeParameters" : {
    54                  "client_version" : ["1"]
    55              }
    56            },
    57            "Rules" : {
    58              "RateLimits" : {
    59                "WriteBytesPerSecond": 7,
    60                "ReadBytesPerSecond": 8
    61              },
    62              "AllowTCPPorts" : [5,9],
    63              "AllowUDPPorts" : [6,10]
    64            }
    65          },
    66  
    67          {
    68            "Filter" : {
    69              "TunnelProtocols" : ["P2"],
    70              "Regions" : ["R3", "R4"],
    71              "HandshakeParameters" : {
    72                  "client_version" : ["1", "2"]
    73              }
    74            },
    75            "ExceptFilter" : {
    76              "ISPs" : ["I2", "I3"],
    77              "HandshakeParameters" : {
    78                  "client_version" : ["1"]
    79              }
    80            },
    81            "Rules" : {
    82              "RateLimits" : {
    83                "WriteBytesPerSecond": 11,
    84                "ReadBytesPerSecond": 12
    85              },
    86              "AllowTCPPorts" : [5,13],
    87              "AllowUDPPorts" : [6,14]
    88            }
    89          },
    90  
    91          {
    92            "Filter" : {
    93              "Regions" : ["R3", "R4"],
    94              "HandshakeParameters" : {
    95                  "client_version" : ["1", "2"]
    96              }
    97            },
    98            "ExceptFilter" : {
    99              "ISPs" : ["I2", "I3"],
   100              "HandshakeParameters" : {
   101                  "client_version" : ["1"]
   102              }
   103            },
   104            "Rules" : {
   105              "RateLimits" : {
   106                "WriteBytesPerSecond": 15,
   107                "ReadBytesPerSecond": 16
   108              },
   109              "AllowTCPPorts" : [5,17],
   110              "AllowUDPPorts" : [6,18]
   111            }
   112          }
   113        ]
   114      }
   115  	`
   116  
   117  	file, err := ioutil.TempFile("", "trafficRules.config")
   118  	if err != nil {
   119  		t.Fatalf("TempFile create failed: %s", err)
   120  	}
   121  	_, err = file.Write([]byte(trafficRulesJSON))
   122  	if err != nil {
   123  		t.Fatalf("TempFile write failed: %s", err)
   124  	}
   125  	file.Close()
   126  	configFileName := file.Name()
   127  	defer os.Remove(configFileName)
   128  
   129  	trafficRules, err := NewTrafficRulesSet(configFileName)
   130  	if err != nil {
   131  		t.Fatalf("NewTrafficRulesSet failed: %s", err)
   132  	}
   133  
   134  	err = trafficRules.Validate()
   135  	if err != nil {
   136  		t.Fatalf("TrafficRulesSet.Validate failed: %s", err)
   137  	}
   138  
   139  	makePortList := func(portsJSON string) common.PortList {
   140  		var p common.PortList
   141  		_ = json.Unmarshal([]byte(portsJSON), &p)
   142  		return p
   143  	}
   144  
   145  	testCases := []struct {
   146  		description                   string
   147  		isFirstTunnelInSession        bool
   148  		tunnelProtocol                string
   149  		geoIPData                     GeoIPData
   150  		state                         handshakeState
   151  		expectedWriteUnthrottledBytes int64
   152  		expectedWriteBytesPerSecond   int64
   153  		expectedReadUnthrottledBytes  int64
   154  		expectedReadBytesPerSecond    int64
   155  		expectedAllowTCPPorts         common.PortList
   156  		expectedAllowUDPPorts         common.PortList
   157  	}{
   158  		{
   159  			"get defaults",
   160  			true,
   161  			"P1",
   162  			GeoIPData{Country: "R1", ISP: "I1"},
   163  			handshakeState{apiParams: map[string]interface{}{"client_version": "1"}, completed: true},
   164  			1, 2, 3, 4, makePortList("[5]"), makePortList("[6]"),
   165  		},
   166  
   167  		{
   168  			"get defaults for not first tunnel in session",
   169  			false,
   170  			"P1",
   171  			GeoIPData{Country: "R1", ISP: "I1"},
   172  			handshakeState{apiParams: map[string]interface{}{"client_version": "1"}, completed: true},
   173  			0, 2, 0, 4, makePortList("[5]"), makePortList("[6]"),
   174  		},
   175  
   176  		{
   177  			"get first filtered rule",
   178  			true,
   179  			"P1",
   180  			GeoIPData{Country: "R2", ISP: "I1"},
   181  			handshakeState{apiParams: map[string]interface{}{"client_version": "1"}, completed: true},
   182  			1, 7, 3, 8, makePortList("[5,9]"), makePortList("[6,10]"),
   183  		},
   184  
   185  		{
   186  			"don't get first filtered rule with incomplete match",
   187  			true,
   188  			"P1",
   189  			GeoIPData{Country: "R2", ISP: "I1"},
   190  			handshakeState{apiParams: map[string]interface{}{"client_version": "2"}, completed: true},
   191  			1, 2, 3, 4, makePortList("[5]"), makePortList("[6]"),
   192  		},
   193  
   194  		{
   195  			"get second filtered rule",
   196  			true,
   197  			"P2",
   198  			GeoIPData{Country: "R3", ISP: "I1"},
   199  			handshakeState{apiParams: map[string]interface{}{"client_version": "2"}, completed: true},
   200  			1, 11, 3, 12, makePortList("[5,13]"), makePortList("[6,14]"),
   201  		},
   202  
   203  		{
   204  			"get second filtered rule with incomplete exception",
   205  			true,
   206  			"P2",
   207  			GeoIPData{Country: "R3", ISP: "I2"},
   208  			handshakeState{apiParams: map[string]interface{}{"client_version": "2"}, completed: true},
   209  			1, 11, 3, 12, makePortList("[5,13]"), makePortList("[6,14]"),
   210  		},
   211  
   212  		{
   213  			"don't get second filtered rule due to exception",
   214  			true,
   215  			"P2",
   216  			GeoIPData{Country: "R3", ISP: "I2"},
   217  			handshakeState{apiParams: map[string]interface{}{"client_version": "1"}, completed: true},
   218  			1, 2, 3, 4, makePortList("[5]"), makePortList("[6]"),
   219  		},
   220  
   221  		{
   222  			"get third filtered rule",
   223  			true,
   224  			"P1",
   225  			GeoIPData{Country: "R3", ISP: "I1"},
   226  			handshakeState{apiParams: map[string]interface{}{"client_version": "1"}, completed: true},
   227  			1, 15, 3, 16, makePortList("[5,17]"), makePortList("[6,18]"),
   228  		},
   229  
   230  		{
   231  			"don't get third filtered rule due to exception",
   232  			true,
   233  			"P1",
   234  			GeoIPData{Country: "R3", ISP: "I2"},
   235  			handshakeState{apiParams: map[string]interface{}{"client_version": "1"}, completed: true},
   236  			1, 2, 3, 4, makePortList("[5]"), makePortList("[6]"),
   237  		},
   238  	}
   239  	for _, testCase := range testCases {
   240  		t.Run(testCase.description, func(t *testing.T) {
   241  
   242  			rules := trafficRules.GetTrafficRules(
   243  				testCase.isFirstTunnelInSession,
   244  				testCase.tunnelProtocol,
   245  				testCase.geoIPData,
   246  				testCase.state)
   247  
   248  			if *rules.RateLimits.WriteUnthrottledBytes != testCase.expectedWriteUnthrottledBytes {
   249  				t.Errorf("unexpected rules.RateLimits.WriteUnthrottledBytes: %v != %v",
   250  					*rules.RateLimits.WriteUnthrottledBytes, testCase.expectedWriteUnthrottledBytes)
   251  			}
   252  			if *rules.RateLimits.WriteBytesPerSecond != testCase.expectedWriteBytesPerSecond {
   253  				t.Errorf("unexpected rules.RateLimits.WriteBytesPerSecond: %v != %v",
   254  					*rules.RateLimits.WriteBytesPerSecond, testCase.expectedWriteBytesPerSecond)
   255  			}
   256  			if *rules.RateLimits.ReadUnthrottledBytes != testCase.expectedReadUnthrottledBytes {
   257  				t.Errorf("unexpected rules.RateLimits.ReadUnthrottledBytes: %v != %v",
   258  					*rules.RateLimits.ReadUnthrottledBytes, testCase.expectedReadUnthrottledBytes)
   259  			}
   260  			if *rules.RateLimits.ReadBytesPerSecond != testCase.expectedReadBytesPerSecond {
   261  				t.Errorf("unexpected rules.RateLimits.ReadBytesPerSecond: %v != %v",
   262  					*rules.RateLimits.ReadBytesPerSecond, testCase.expectedReadBytesPerSecond)
   263  			}
   264  			if !reflect.DeepEqual(*rules.AllowTCPPorts, testCase.expectedAllowTCPPorts) {
   265  				t.Errorf("unexpected rules.RateLimits.AllowTCPPorts: %v != %v",
   266  					*rules.AllowTCPPorts, testCase.expectedAllowTCPPorts)
   267  			}
   268  			if !reflect.DeepEqual(*rules.AllowUDPPorts, testCase.expectedAllowUDPPorts) {
   269  				t.Errorf("unexpected rules.RateLimits.AllowUDPPorts: %v != %v",
   270  					*rules.AllowUDPPorts, testCase.expectedAllowUDPPorts)
   271  			}
   272  		})
   273  	}
   274  }