github.com/psiphon-labs/psiphon-tunnel-core@v2.0.28+incompatible/psiphon/common/parameters/parameters_test.go (about)

     1  /*
     2   * Copyright (c) 2018, Psiphon Inc.
     3   * All rights reserved.
     4   *
     5   * This program is free software: you can redistribute it and/or modify
     6   * it under the terms of the GNU General Public License as published by
     7   * the Free Software Foundation, either version 3 of the License, or
     8   * (at your option) any later version.
     9   *
    10   * This program is distributed in the hope that it will be useful,
    11   * but WITHOUT ANY WARRANTY; without even the implied warranty of
    12   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    13   * GNU General Public License for more details.
    14   *
    15   * You should have received a copy of the GNU General Public License
    16   * along with this program.  If not, see <http://www.gnu.org/licenses/>.
    17   *
    18   */
    19  
    20  package parameters
    21  
    22  import (
    23  	"encoding/json"
    24  	"net/http"
    25  	"reflect"
    26  	"testing"
    27  	"time"
    28  
    29  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
    30  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/protocol"
    31  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/transforms"
    32  )
    33  
    34  func TestGetDefaultParameters(t *testing.T) {
    35  
    36  	p, err := NewParameters(nil)
    37  	if err != nil {
    38  		t.Fatalf("NewParameters failed: %s", err)
    39  	}
    40  
    41  	for name, defaults := range defaultParameters {
    42  		switch v := defaults.value.(type) {
    43  		case string:
    44  			g := p.Get().String(name)
    45  			if v != g {
    46  				t.Fatalf("String returned %+v expected %+v", g, v)
    47  			}
    48  		case []string:
    49  			g := p.Get().Strings(name)
    50  			if !reflect.DeepEqual(v, g) {
    51  				t.Fatalf("Strings returned %+v expected %+v", g, v)
    52  			}
    53  		case int:
    54  			g := p.Get().Int(name)
    55  			if v != g {
    56  				t.Fatalf("Int returned %+v expected %+v", g, v)
    57  			}
    58  		case float64:
    59  			g := p.Get().Float(name)
    60  			if v != g {
    61  				t.Fatalf("Float returned %+v expected %+v", g, v)
    62  			}
    63  		case bool:
    64  			g := p.Get().Bool(name)
    65  			if v != g {
    66  				t.Fatalf("Bool returned %+v expected %+v", g, v)
    67  			}
    68  		case time.Duration:
    69  			g := p.Get().Duration(name)
    70  			if v != g {
    71  				t.Fatalf("Duration returned %+v expected %+v", g, v)
    72  			}
    73  		case protocol.TunnelProtocols:
    74  			g := p.Get().TunnelProtocols(name)
    75  			if !reflect.DeepEqual(v, g) {
    76  				t.Fatalf("TunnelProtocols returned %+v expected %+v", g, v)
    77  			}
    78  		case protocol.TLSProfiles:
    79  			g := p.Get().TLSProfiles(name)
    80  			if !reflect.DeepEqual(v, g) {
    81  				t.Fatalf("TLSProfiles returned %+v expected %+v", g, v)
    82  			}
    83  		case protocol.LabeledTLSProfiles:
    84  			for label, profiles := range v {
    85  				g := p.Get().LabeledTLSProfiles(name, label)
    86  				if !reflect.DeepEqual(profiles, g) {
    87  					t.Fatalf("LabeledTLSProfiles returned %+v expected %+v", g, profiles)
    88  				}
    89  			}
    90  		case protocol.QUICVersions:
    91  			g := p.Get().QUICVersions(name)
    92  			if !reflect.DeepEqual(v, g) {
    93  				t.Fatalf("QUICVersions returned %+v expected %+v", g, v)
    94  			}
    95  		case protocol.LabeledQUICVersions:
    96  			for label, versions := range v {
    97  				g := p.Get().LabeledTLSProfiles(name, label)
    98  				if !reflect.DeepEqual(versions, g) {
    99  					t.Fatalf("LabeledQUICVersions returned %+v expected %+v", g, versions)
   100  				}
   101  			}
   102  		case TransferURLs:
   103  			g := p.Get().TransferURLs(name)
   104  			if !reflect.DeepEqual(v, g) {
   105  				t.Fatalf("TransferURLs returned %+v expected %+v", g, v)
   106  			}
   107  		case common.RateLimits:
   108  			g := p.Get().RateLimits(name)
   109  			if !reflect.DeepEqual(v, g) {
   110  				t.Fatalf("RateLimits returned %+v expected %+v", g, v)
   111  			}
   112  		case http.Header:
   113  			g := p.Get().HTTPHeaders(name)
   114  			if !reflect.DeepEqual(v, g) {
   115  				t.Fatalf("HTTPHeaders returned %+v expected %+v", g, v)
   116  			}
   117  		case protocol.CustomTLSProfiles:
   118  			g := p.Get().CustomTLSProfileNames()
   119  			names := make([]string, len(v))
   120  			for i, profile := range v {
   121  				names[i] = profile.Name
   122  			}
   123  			if !reflect.DeepEqual(names, g) {
   124  				t.Fatalf("CustomTLSProfileNames returned %+v expected %+v", g, names)
   125  			}
   126  		case KeyValues:
   127  			g := p.Get().KeyValues(name)
   128  			if !reflect.DeepEqual(v, g) {
   129  				t.Fatalf("KeyValues returned %+v expected %+v", g, v)
   130  			}
   131  		case *BPFProgramSpec:
   132  			ok, name, rawInstructions := p.Get().BPFProgram(name)
   133  			if v != nil || ok || name != "" || rawInstructions != nil {
   134  				t.Fatalf(
   135  					"BPFProgramSpec returned %+v %+v %+v expected %+v",
   136  					ok, name, rawInstructions, v)
   137  			}
   138  		case PacketManipulationSpecs:
   139  			g := p.Get().PacketManipulationSpecs(name)
   140  			if !reflect.DeepEqual(v, g) {
   141  				t.Fatalf("PacketManipulationSpecs returned %+v expected %+v", g, v)
   142  			}
   143  		case ProtocolPacketManipulations:
   144  			g := p.Get().ProtocolPacketManipulations(name)
   145  			if !reflect.DeepEqual(v, g) {
   146  				t.Fatalf("ProtocolPacketManipulations returned %+v expected %+v", g, v)
   147  			}
   148  		case RegexStrings:
   149  			g := p.Get().RegexStrings(name)
   150  			if !reflect.DeepEqual(v, g) {
   151  				t.Fatalf("RegexStrings returned %+v expected %+v", g, v)
   152  			}
   153  		case FrontingSpecs:
   154  			g := p.Get().FrontingSpecs(name)
   155  			if !reflect.DeepEqual(v, g) {
   156  				t.Fatalf("FrontingSpecs returned %+v expected %+v", g, v)
   157  			}
   158  		case TunnelProtocolPortLists:
   159  			g := p.Get().TunnelProtocolPortLists(name)
   160  			if !reflect.DeepEqual(v, g) {
   161  				t.Fatalf("TunnelProtocolPortLists returned %+v expected %+v", g, v)
   162  			}
   163  		case LabeledCIDRs:
   164  			for label, CIDRs := range v {
   165  				g := p.Get().LabeledCIDRs(name, label)
   166  				if !reflect.DeepEqual(CIDRs, g) {
   167  					t.Fatalf("LabeledCIDRs returned %+v expected %+v", g, CIDRs)
   168  				}
   169  			}
   170  		case transforms.Specs:
   171  			g := p.Get().ProtocolTransformSpecs(name)
   172  			if !reflect.DeepEqual(v, g) {
   173  				t.Fatalf("ProtocolTransformSpecs returned %+v expected %+v", g, v)
   174  			}
   175  		case transforms.ScopedSpecNames:
   176  			g := p.Get().ProtocolTransformScopedSpecNames(name)
   177  			if !reflect.DeepEqual(v, g) {
   178  				t.Fatalf("ProtocolTransformScopedSpecNames returned %+v expected %+v", g, v)
   179  			}
   180  		default:
   181  			t.Fatalf("Unhandled default type: %s (%T)", name, defaults.value)
   182  		}
   183  	}
   184  }
   185  
   186  func TestGetValueLogger(t *testing.T) {
   187  
   188  	loggerCalled := false
   189  
   190  	p, err := NewParameters(
   191  		func(error) {
   192  			loggerCalled = true
   193  		})
   194  	if err != nil {
   195  		t.Fatalf("NewParameters failed: %s", err)
   196  	}
   197  
   198  	p.Get().Int("unknown-parameter-name")
   199  
   200  	if !loggerCalled {
   201  		t.Fatalf("logged not called")
   202  	}
   203  }
   204  
   205  func TestOverrides(t *testing.T) {
   206  
   207  	tag := "tag"
   208  	applyParameters := make(map[string]interface{})
   209  
   210  	// Below minimum, should not apply
   211  	defaultConnectionWorkerPoolSize := defaultParameters[ConnectionWorkerPoolSize].value.(int)
   212  	minimumConnectionWorkerPoolSize := defaultParameters[ConnectionWorkerPoolSize].minimum.(int)
   213  	newConnectionWorkerPoolSize := minimumConnectionWorkerPoolSize - 1
   214  	applyParameters[ConnectionWorkerPoolSize] = newConnectionWorkerPoolSize
   215  
   216  	// Above minimum, should apply
   217  	defaultInitialLimitTunnelProtocolsCandidateCount := defaultParameters[InitialLimitTunnelProtocolsCandidateCount].value.(int)
   218  	minimumInitialLimitTunnelProtocolsCandidateCount := defaultParameters[InitialLimitTunnelProtocolsCandidateCount].minimum.(int)
   219  	newInitialLimitTunnelProtocolsCandidateCount := minimumInitialLimitTunnelProtocolsCandidateCount + 1
   220  	applyParameters[InitialLimitTunnelProtocolsCandidateCount] = newInitialLimitTunnelProtocolsCandidateCount
   221  
   222  	p, err := NewParameters(nil)
   223  	if err != nil {
   224  		t.Fatalf("NewParameters failed: %s", err)
   225  	}
   226  
   227  	// No skip on error; should fail and not apply any changes
   228  
   229  	_, err = p.Set(tag, false, applyParameters)
   230  	if err == nil {
   231  		t.Fatalf("Set succeeded unexpectedly")
   232  	}
   233  
   234  	if p.Get().Tag() != "" {
   235  		t.Fatalf("GetTag returned unexpected value")
   236  	}
   237  
   238  	v := p.Get().Int(ConnectionWorkerPoolSize)
   239  	if v != defaultConnectionWorkerPoolSize {
   240  		t.Fatalf("GetInt returned unexpected ConnectionWorkerPoolSize: %d", v)
   241  	}
   242  
   243  	v = p.Get().Int(InitialLimitTunnelProtocolsCandidateCount)
   244  	if v != defaultInitialLimitTunnelProtocolsCandidateCount {
   245  		t.Fatalf("GetInt returned unexpected InitialLimitTunnelProtocolsCandidateCount: %d", v)
   246  	}
   247  
   248  	// Skip on error; should skip ConnectionWorkerPoolSize and apply InitialLimitTunnelProtocolsCandidateCount
   249  
   250  	counts, err := p.Set(tag, true, applyParameters)
   251  	if err != nil {
   252  		t.Fatalf("Set failed: %s", err)
   253  	}
   254  
   255  	if counts[0] != 1 {
   256  		t.Fatalf("Apply returned unexpected count: %d", counts[0])
   257  	}
   258  
   259  	v = p.Get().Int(ConnectionWorkerPoolSize)
   260  	if v != defaultConnectionWorkerPoolSize {
   261  		t.Fatalf("GetInt returned unexpected ConnectionWorkerPoolSize: %d", v)
   262  	}
   263  
   264  	v = p.Get().Int(InitialLimitTunnelProtocolsCandidateCount)
   265  	if v != newInitialLimitTunnelProtocolsCandidateCount {
   266  		t.Fatalf("GetInt returned unexpected InitialLimitTunnelProtocolsCandidateCount: %d", v)
   267  	}
   268  }
   269  
   270  func TestNetworkLatencyMultiplier(t *testing.T) {
   271  	p, err := NewParameters(nil)
   272  	if err != nil {
   273  		t.Fatalf("NewParameters failed: %s", err)
   274  	}
   275  
   276  	timeout1 := p.Get().Duration(TunnelConnectTimeout)
   277  
   278  	applyParameters := map[string]interface{}{"NetworkLatencyMultiplier": 2.0}
   279  
   280  	_, err = p.Set("", false, applyParameters)
   281  	if err != nil {
   282  		t.Fatalf("Set failed: %s", err)
   283  	}
   284  
   285  	timeout2 := p.Get().Duration(TunnelConnectTimeout)
   286  
   287  	if 2*timeout1 != timeout2 {
   288  		t.Fatalf("Unexpected timeouts: 2 * %s != %s", timeout1, timeout2)
   289  	}
   290  }
   291  
   292  func TestCustomNetworkLatencyMultiplier(t *testing.T) {
   293  	p, err := NewParameters(nil)
   294  	if err != nil {
   295  		t.Fatalf("NewParameters failed: %s", err)
   296  	}
   297  
   298  	timeout1 := p.Get().Duration(TunnelConnectTimeout)
   299  
   300  	applyParameters := map[string]interface{}{"NetworkLatencyMultiplier": 2.0}
   301  
   302  	_, err = p.Set("", false, applyParameters)
   303  	if err != nil {
   304  		t.Fatalf("Set failed: %s", err)
   305  	}
   306  
   307  	timeout2 := p.GetCustom(4.0).Duration(TunnelConnectTimeout)
   308  
   309  	if 4*timeout1 != timeout2 {
   310  		t.Fatalf("Unexpected timeouts: 4 * %s != %s", timeout1, timeout2)
   311  	}
   312  }
   313  
   314  func TestLimitTunnelProtocolProbability(t *testing.T) {
   315  	p, err := NewParameters(nil)
   316  	if err != nil {
   317  		t.Fatalf("NewParameters failed: %s", err)
   318  	}
   319  
   320  	// Default probability should be 1.0 and always return tunnelProtocols
   321  
   322  	tunnelProtocols := protocol.TunnelProtocols{"OSSH", "SSH"}
   323  
   324  	applyParameters := map[string]interface{}{
   325  		"LimitTunnelProtocols": tunnelProtocols,
   326  	}
   327  
   328  	_, err = p.Set("", false, applyParameters)
   329  	if err != nil {
   330  		t.Fatalf("Set failed: %s", err)
   331  	}
   332  
   333  	for i := 0; i < 1000; i++ {
   334  		l := p.Get().TunnelProtocols(LimitTunnelProtocols)
   335  		if !reflect.DeepEqual(l, tunnelProtocols) {
   336  			t.Fatalf("unexpected %+v != %+v", l, tunnelProtocols)
   337  		}
   338  	}
   339  
   340  	// With probability set to 0.5, should return tunnelProtocols ~50%
   341  
   342  	defaultLimitTunnelProtocols := protocol.TunnelProtocols{}
   343  
   344  	applyParameters = map[string]interface{}{
   345  		"LimitTunnelProtocolsProbability": 0.5,
   346  		"LimitTunnelProtocols":            tunnelProtocols,
   347  	}
   348  
   349  	_, err = p.Set("", false, applyParameters)
   350  	if err != nil {
   351  		t.Fatalf("Set failed: %s", err)
   352  	}
   353  
   354  	matchCount := 0
   355  
   356  	for i := 0; i < 1000; i++ {
   357  		l := p.Get().TunnelProtocols(LimitTunnelProtocols)
   358  		if reflect.DeepEqual(l, tunnelProtocols) {
   359  			matchCount += 1
   360  		} else if !reflect.DeepEqual(l, defaultLimitTunnelProtocols) {
   361  			t.Fatalf("unexpected %+v != %+v", l, defaultLimitTunnelProtocols)
   362  		}
   363  	}
   364  
   365  	if matchCount < 250 || matchCount > 750 {
   366  		t.Fatalf("Unexpected probability result: %d", matchCount)
   367  	}
   368  }
   369  
   370  func TestLabeledLists(t *testing.T) {
   371  	p, err := NewParameters(nil)
   372  	if err != nil {
   373  		t.Fatalf("NewParameters failed: %s", err)
   374  	}
   375  
   376  	tlsProfiles := make(protocol.TLSProfiles, 0)
   377  	for i, tlsProfile := range protocol.SupportedTLSProfiles {
   378  		if i%2 == 0 {
   379  			tlsProfiles = append(tlsProfiles, tlsProfile)
   380  		}
   381  	}
   382  
   383  	quicVersions := make(protocol.QUICVersions, 0)
   384  	for i, quicVersion := range protocol.SupportedQUICVersions {
   385  		if i%2 == 0 {
   386  			quicVersions = append(quicVersions, quicVersion)
   387  		}
   388  	}
   389  
   390  	applyParameters := map[string]interface{}{
   391  		"DisableFrontingProviderTLSProfiles":  protocol.LabeledTLSProfiles{"validLabel": tlsProfiles},
   392  		"DisableFrontingProviderQUICVersions": protocol.LabeledQUICVersions{"validLabel": quicVersions},
   393  	}
   394  
   395  	_, err = p.Set("", false, applyParameters)
   396  	if err != nil {
   397  		t.Fatalf("Set failed: %s", err)
   398  	}
   399  
   400  	disableTLSProfiles := p.Get().LabeledTLSProfiles(DisableFrontingProviderTLSProfiles, "validLabel")
   401  	if !reflect.DeepEqual(disableTLSProfiles, tlsProfiles) {
   402  		t.Fatalf("LabeledTLSProfiles returned %+v expected %+v", disableTLSProfiles, tlsProfiles)
   403  	}
   404  
   405  	disableTLSProfiles = p.Get().LabeledTLSProfiles(DisableFrontingProviderTLSProfiles, "invalidLabel")
   406  	if disableTLSProfiles != nil {
   407  		t.Fatalf("LabeledTLSProfiles returned unexpected non-empty list %+v", disableTLSProfiles)
   408  	}
   409  
   410  	disableQUICVersions := p.Get().LabeledQUICVersions(DisableFrontingProviderQUICVersions, "validLabel")
   411  	if !reflect.DeepEqual(disableQUICVersions, quicVersions) {
   412  		t.Fatalf("LabeledQUICVersions returned %+v expected %+v", disableQUICVersions, quicVersions)
   413  	}
   414  
   415  	disableQUICVersions = p.Get().LabeledQUICVersions(DisableFrontingProviderQUICVersions, "invalidLabel")
   416  	if disableQUICVersions != nil {
   417  		t.Fatalf("LabeledQUICVersions returned unexpected non-empty list %+v", disableQUICVersions)
   418  	}
   419  }
   420  
   421  func TestCustomTLSProfiles(t *testing.T) {
   422  	p, err := NewParameters(nil)
   423  	if err != nil {
   424  		t.Fatalf("NewParameters failed: %s", err)
   425  	}
   426  
   427  	customTLSProfiles := protocol.CustomTLSProfiles{
   428  		&protocol.CustomTLSProfile{Name: "Profile1", UTLSSpec: &protocol.UTLSSpec{}},
   429  		&protocol.CustomTLSProfile{Name: "Profile2", UTLSSpec: &protocol.UTLSSpec{}},
   430  	}
   431  
   432  	applyParameters := map[string]interface{}{
   433  		"CustomTLSProfiles": customTLSProfiles}
   434  
   435  	_, err = p.Set("", false, applyParameters)
   436  	if err != nil {
   437  		t.Fatalf("Set failed: %s", err)
   438  	}
   439  
   440  	names := p.Get().CustomTLSProfileNames()
   441  
   442  	if len(names) != 2 || names[0] != "Profile1" || names[1] != "Profile2" {
   443  		t.Fatalf("Unexpected CustomTLSProfileNames: %+v", names)
   444  	}
   445  
   446  	profile := p.Get().CustomTLSProfile("Profile1")
   447  	if profile == nil || profile.Name != "Profile1" {
   448  		t.Fatalf("Unexpected profile")
   449  	}
   450  
   451  	profile = p.Get().CustomTLSProfile("Profile2")
   452  	if profile == nil || profile.Name != "Profile2" {
   453  		t.Fatalf("Unexpected profile")
   454  	}
   455  
   456  	profile = p.Get().CustomTLSProfile("Profile3")
   457  	if profile != nil {
   458  		t.Fatalf("Unexpected profile")
   459  	}
   460  }
   461  
   462  func TestApplicationParameters(t *testing.T) {
   463  
   464  	parametersJSON := []byte(`
   465      {
   466         "ApplicationParameters" : {
   467           "AppFlag1" : true,
   468           "AppConfig1" : {"Option1" : "A", "Option2" : "B"},
   469           "AppSwitches1" : [1, 2, 3, 4]
   470         }
   471      }
   472      `)
   473  
   474  	validators := map[string]func(v interface{}) bool{
   475  		"AppFlag1": func(v interface{}) bool { return reflect.DeepEqual(v, true) },
   476  		"AppConfig1": func(v interface{}) bool {
   477  			return reflect.DeepEqual(v, map[string]interface{}{"Option1": "A", "Option2": "B"})
   478  		},
   479  		"AppSwitches1": func(v interface{}) bool {
   480  			return reflect.DeepEqual(v, []interface{}{float64(1), float64(2), float64(3), float64(4)})
   481  		},
   482  	}
   483  
   484  	var applyParameters map[string]interface{}
   485  	err := json.Unmarshal(parametersJSON, &applyParameters)
   486  	if err != nil {
   487  		t.Fatalf("Unmarshal failed: %s", err)
   488  	}
   489  
   490  	p, err := NewParameters(nil)
   491  	if err != nil {
   492  		t.Fatalf("NewParameters failed: %s", err)
   493  	}
   494  
   495  	_, err = p.Set("", false, applyParameters)
   496  	if err != nil {
   497  		t.Fatalf("Set failed: %s", err)
   498  	}
   499  
   500  	keyValues := p.Get().KeyValues(ApplicationParameters)
   501  
   502  	if len(keyValues) != len(validators) {
   503  		t.Fatalf("Unexpected key value count")
   504  	}
   505  
   506  	for key, value := range keyValues {
   507  
   508  		validator, ok := validators[key]
   509  		if !ok {
   510  			t.Fatalf("Unexpected key: %s", key)
   511  		}
   512  
   513  		var unmarshaledValue interface{}
   514  		err := json.Unmarshal(value, &unmarshaledValue)
   515  		if err != nil {
   516  			t.Fatalf("Unmarshal failed: %s", err)
   517  		}
   518  
   519  		if !validator(unmarshaledValue) {
   520  			t.Fatalf("Invalid value: %s, %T: %+v",
   521  				key, unmarshaledValue, unmarshaledValue)
   522  		}
   523  	}
   524  }