github.com/psiphon-labs/psiphon-tunnel-core@v2.0.28+incompatible/psiphon/common/parameters/parameters_test.go (about) 1 /* 2 * Copyright (c) 2018, Psiphon Inc. 3 * All rights reserved. 4 * 5 * This program is free software: you can redistribute it and/or modify 6 * it under the terms of the GNU General Public License as published by 7 * the Free Software Foundation, either version 3 of the License, or 8 * (at your option) any later version. 9 * 10 * This program is distributed in the hope that it will be useful, 11 * but WITHOUT ANY WARRANTY; without even the implied warranty of 12 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 * GNU General Public License for more details. 14 * 15 * You should have received a copy of the GNU General Public License 16 * along with this program. If not, see <http://www.gnu.org/licenses/>. 17 * 18 */ 19 20 package parameters 21 22 import ( 23 "encoding/json" 24 "net/http" 25 "reflect" 26 "testing" 27 "time" 28 29 "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common" 30 "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/protocol" 31 "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/transforms" 32 ) 33 34 func TestGetDefaultParameters(t *testing.T) { 35 36 p, err := NewParameters(nil) 37 if err != nil { 38 t.Fatalf("NewParameters failed: %s", err) 39 } 40 41 for name, defaults := range defaultParameters { 42 switch v := defaults.value.(type) { 43 case string: 44 g := p.Get().String(name) 45 if v != g { 46 t.Fatalf("String returned %+v expected %+v", g, v) 47 } 48 case []string: 49 g := p.Get().Strings(name) 50 if !reflect.DeepEqual(v, g) { 51 t.Fatalf("Strings returned %+v expected %+v", g, v) 52 } 53 case int: 54 g := p.Get().Int(name) 55 if v != g { 56 t.Fatalf("Int returned %+v expected %+v", g, v) 57 } 58 case float64: 59 g := p.Get().Float(name) 60 if v != g { 61 t.Fatalf("Float returned %+v expected %+v", g, v) 62 } 63 case bool: 64 g := p.Get().Bool(name) 65 if v != g { 66 t.Fatalf("Bool returned %+v expected %+v", g, v) 67 } 68 case time.Duration: 69 g := p.Get().Duration(name) 70 if v != g { 71 t.Fatalf("Duration returned %+v expected %+v", g, v) 72 } 73 case protocol.TunnelProtocols: 74 g := p.Get().TunnelProtocols(name) 75 if !reflect.DeepEqual(v, g) { 76 t.Fatalf("TunnelProtocols returned %+v expected %+v", g, v) 77 } 78 case protocol.TLSProfiles: 79 g := p.Get().TLSProfiles(name) 80 if !reflect.DeepEqual(v, g) { 81 t.Fatalf("TLSProfiles returned %+v expected %+v", g, v) 82 } 83 case protocol.LabeledTLSProfiles: 84 for label, profiles := range v { 85 g := p.Get().LabeledTLSProfiles(name, label) 86 if !reflect.DeepEqual(profiles, g) { 87 t.Fatalf("LabeledTLSProfiles returned %+v expected %+v", g, profiles) 88 } 89 } 90 case protocol.QUICVersions: 91 g := p.Get().QUICVersions(name) 92 if !reflect.DeepEqual(v, g) { 93 t.Fatalf("QUICVersions returned %+v expected %+v", g, v) 94 } 95 case protocol.LabeledQUICVersions: 96 for label, versions := range v { 97 g := p.Get().LabeledTLSProfiles(name, label) 98 if !reflect.DeepEqual(versions, g) { 99 t.Fatalf("LabeledQUICVersions returned %+v expected %+v", g, versions) 100 } 101 } 102 case TransferURLs: 103 g := p.Get().TransferURLs(name) 104 if !reflect.DeepEqual(v, g) { 105 t.Fatalf("TransferURLs returned %+v expected %+v", g, v) 106 } 107 case common.RateLimits: 108 g := p.Get().RateLimits(name) 109 if !reflect.DeepEqual(v, g) { 110 t.Fatalf("RateLimits returned %+v expected %+v", g, v) 111 } 112 case http.Header: 113 g := p.Get().HTTPHeaders(name) 114 if !reflect.DeepEqual(v, g) { 115 t.Fatalf("HTTPHeaders returned %+v expected %+v", g, v) 116 } 117 case protocol.CustomTLSProfiles: 118 g := p.Get().CustomTLSProfileNames() 119 names := make([]string, len(v)) 120 for i, profile := range v { 121 names[i] = profile.Name 122 } 123 if !reflect.DeepEqual(names, g) { 124 t.Fatalf("CustomTLSProfileNames returned %+v expected %+v", g, names) 125 } 126 case KeyValues: 127 g := p.Get().KeyValues(name) 128 if !reflect.DeepEqual(v, g) { 129 t.Fatalf("KeyValues returned %+v expected %+v", g, v) 130 } 131 case *BPFProgramSpec: 132 ok, name, rawInstructions := p.Get().BPFProgram(name) 133 if v != nil || ok || name != "" || rawInstructions != nil { 134 t.Fatalf( 135 "BPFProgramSpec returned %+v %+v %+v expected %+v", 136 ok, name, rawInstructions, v) 137 } 138 case PacketManipulationSpecs: 139 g := p.Get().PacketManipulationSpecs(name) 140 if !reflect.DeepEqual(v, g) { 141 t.Fatalf("PacketManipulationSpecs returned %+v expected %+v", g, v) 142 } 143 case ProtocolPacketManipulations: 144 g := p.Get().ProtocolPacketManipulations(name) 145 if !reflect.DeepEqual(v, g) { 146 t.Fatalf("ProtocolPacketManipulations returned %+v expected %+v", g, v) 147 } 148 case RegexStrings: 149 g := p.Get().RegexStrings(name) 150 if !reflect.DeepEqual(v, g) { 151 t.Fatalf("RegexStrings returned %+v expected %+v", g, v) 152 } 153 case FrontingSpecs: 154 g := p.Get().FrontingSpecs(name) 155 if !reflect.DeepEqual(v, g) { 156 t.Fatalf("FrontingSpecs returned %+v expected %+v", g, v) 157 } 158 case TunnelProtocolPortLists: 159 g := p.Get().TunnelProtocolPortLists(name) 160 if !reflect.DeepEqual(v, g) { 161 t.Fatalf("TunnelProtocolPortLists returned %+v expected %+v", g, v) 162 } 163 case LabeledCIDRs: 164 for label, CIDRs := range v { 165 g := p.Get().LabeledCIDRs(name, label) 166 if !reflect.DeepEqual(CIDRs, g) { 167 t.Fatalf("LabeledCIDRs returned %+v expected %+v", g, CIDRs) 168 } 169 } 170 case transforms.Specs: 171 g := p.Get().ProtocolTransformSpecs(name) 172 if !reflect.DeepEqual(v, g) { 173 t.Fatalf("ProtocolTransformSpecs returned %+v expected %+v", g, v) 174 } 175 case transforms.ScopedSpecNames: 176 g := p.Get().ProtocolTransformScopedSpecNames(name) 177 if !reflect.DeepEqual(v, g) { 178 t.Fatalf("ProtocolTransformScopedSpecNames returned %+v expected %+v", g, v) 179 } 180 default: 181 t.Fatalf("Unhandled default type: %s (%T)", name, defaults.value) 182 } 183 } 184 } 185 186 func TestGetValueLogger(t *testing.T) { 187 188 loggerCalled := false 189 190 p, err := NewParameters( 191 func(error) { 192 loggerCalled = true 193 }) 194 if err != nil { 195 t.Fatalf("NewParameters failed: %s", err) 196 } 197 198 p.Get().Int("unknown-parameter-name") 199 200 if !loggerCalled { 201 t.Fatalf("logged not called") 202 } 203 } 204 205 func TestOverrides(t *testing.T) { 206 207 tag := "tag" 208 applyParameters := make(map[string]interface{}) 209 210 // Below minimum, should not apply 211 defaultConnectionWorkerPoolSize := defaultParameters[ConnectionWorkerPoolSize].value.(int) 212 minimumConnectionWorkerPoolSize := defaultParameters[ConnectionWorkerPoolSize].minimum.(int) 213 newConnectionWorkerPoolSize := minimumConnectionWorkerPoolSize - 1 214 applyParameters[ConnectionWorkerPoolSize] = newConnectionWorkerPoolSize 215 216 // Above minimum, should apply 217 defaultInitialLimitTunnelProtocolsCandidateCount := defaultParameters[InitialLimitTunnelProtocolsCandidateCount].value.(int) 218 minimumInitialLimitTunnelProtocolsCandidateCount := defaultParameters[InitialLimitTunnelProtocolsCandidateCount].minimum.(int) 219 newInitialLimitTunnelProtocolsCandidateCount := minimumInitialLimitTunnelProtocolsCandidateCount + 1 220 applyParameters[InitialLimitTunnelProtocolsCandidateCount] = newInitialLimitTunnelProtocolsCandidateCount 221 222 p, err := NewParameters(nil) 223 if err != nil { 224 t.Fatalf("NewParameters failed: %s", err) 225 } 226 227 // No skip on error; should fail and not apply any changes 228 229 _, err = p.Set(tag, false, applyParameters) 230 if err == nil { 231 t.Fatalf("Set succeeded unexpectedly") 232 } 233 234 if p.Get().Tag() != "" { 235 t.Fatalf("GetTag returned unexpected value") 236 } 237 238 v := p.Get().Int(ConnectionWorkerPoolSize) 239 if v != defaultConnectionWorkerPoolSize { 240 t.Fatalf("GetInt returned unexpected ConnectionWorkerPoolSize: %d", v) 241 } 242 243 v = p.Get().Int(InitialLimitTunnelProtocolsCandidateCount) 244 if v != defaultInitialLimitTunnelProtocolsCandidateCount { 245 t.Fatalf("GetInt returned unexpected InitialLimitTunnelProtocolsCandidateCount: %d", v) 246 } 247 248 // Skip on error; should skip ConnectionWorkerPoolSize and apply InitialLimitTunnelProtocolsCandidateCount 249 250 counts, err := p.Set(tag, true, applyParameters) 251 if err != nil { 252 t.Fatalf("Set failed: %s", err) 253 } 254 255 if counts[0] != 1 { 256 t.Fatalf("Apply returned unexpected count: %d", counts[0]) 257 } 258 259 v = p.Get().Int(ConnectionWorkerPoolSize) 260 if v != defaultConnectionWorkerPoolSize { 261 t.Fatalf("GetInt returned unexpected ConnectionWorkerPoolSize: %d", v) 262 } 263 264 v = p.Get().Int(InitialLimitTunnelProtocolsCandidateCount) 265 if v != newInitialLimitTunnelProtocolsCandidateCount { 266 t.Fatalf("GetInt returned unexpected InitialLimitTunnelProtocolsCandidateCount: %d", v) 267 } 268 } 269 270 func TestNetworkLatencyMultiplier(t *testing.T) { 271 p, err := NewParameters(nil) 272 if err != nil { 273 t.Fatalf("NewParameters failed: %s", err) 274 } 275 276 timeout1 := p.Get().Duration(TunnelConnectTimeout) 277 278 applyParameters := map[string]interface{}{"NetworkLatencyMultiplier": 2.0} 279 280 _, err = p.Set("", false, applyParameters) 281 if err != nil { 282 t.Fatalf("Set failed: %s", err) 283 } 284 285 timeout2 := p.Get().Duration(TunnelConnectTimeout) 286 287 if 2*timeout1 != timeout2 { 288 t.Fatalf("Unexpected timeouts: 2 * %s != %s", timeout1, timeout2) 289 } 290 } 291 292 func TestCustomNetworkLatencyMultiplier(t *testing.T) { 293 p, err := NewParameters(nil) 294 if err != nil { 295 t.Fatalf("NewParameters failed: %s", err) 296 } 297 298 timeout1 := p.Get().Duration(TunnelConnectTimeout) 299 300 applyParameters := map[string]interface{}{"NetworkLatencyMultiplier": 2.0} 301 302 _, err = p.Set("", false, applyParameters) 303 if err != nil { 304 t.Fatalf("Set failed: %s", err) 305 } 306 307 timeout2 := p.GetCustom(4.0).Duration(TunnelConnectTimeout) 308 309 if 4*timeout1 != timeout2 { 310 t.Fatalf("Unexpected timeouts: 4 * %s != %s", timeout1, timeout2) 311 } 312 } 313 314 func TestLimitTunnelProtocolProbability(t *testing.T) { 315 p, err := NewParameters(nil) 316 if err != nil { 317 t.Fatalf("NewParameters failed: %s", err) 318 } 319 320 // Default probability should be 1.0 and always return tunnelProtocols 321 322 tunnelProtocols := protocol.TunnelProtocols{"OSSH", "SSH"} 323 324 applyParameters := map[string]interface{}{ 325 "LimitTunnelProtocols": tunnelProtocols, 326 } 327 328 _, err = p.Set("", false, applyParameters) 329 if err != nil { 330 t.Fatalf("Set failed: %s", err) 331 } 332 333 for i := 0; i < 1000; i++ { 334 l := p.Get().TunnelProtocols(LimitTunnelProtocols) 335 if !reflect.DeepEqual(l, tunnelProtocols) { 336 t.Fatalf("unexpected %+v != %+v", l, tunnelProtocols) 337 } 338 } 339 340 // With probability set to 0.5, should return tunnelProtocols ~50% 341 342 defaultLimitTunnelProtocols := protocol.TunnelProtocols{} 343 344 applyParameters = map[string]interface{}{ 345 "LimitTunnelProtocolsProbability": 0.5, 346 "LimitTunnelProtocols": tunnelProtocols, 347 } 348 349 _, err = p.Set("", false, applyParameters) 350 if err != nil { 351 t.Fatalf("Set failed: %s", err) 352 } 353 354 matchCount := 0 355 356 for i := 0; i < 1000; i++ { 357 l := p.Get().TunnelProtocols(LimitTunnelProtocols) 358 if reflect.DeepEqual(l, tunnelProtocols) { 359 matchCount += 1 360 } else if !reflect.DeepEqual(l, defaultLimitTunnelProtocols) { 361 t.Fatalf("unexpected %+v != %+v", l, defaultLimitTunnelProtocols) 362 } 363 } 364 365 if matchCount < 250 || matchCount > 750 { 366 t.Fatalf("Unexpected probability result: %d", matchCount) 367 } 368 } 369 370 func TestLabeledLists(t *testing.T) { 371 p, err := NewParameters(nil) 372 if err != nil { 373 t.Fatalf("NewParameters failed: %s", err) 374 } 375 376 tlsProfiles := make(protocol.TLSProfiles, 0) 377 for i, tlsProfile := range protocol.SupportedTLSProfiles { 378 if i%2 == 0 { 379 tlsProfiles = append(tlsProfiles, tlsProfile) 380 } 381 } 382 383 quicVersions := make(protocol.QUICVersions, 0) 384 for i, quicVersion := range protocol.SupportedQUICVersions { 385 if i%2 == 0 { 386 quicVersions = append(quicVersions, quicVersion) 387 } 388 } 389 390 applyParameters := map[string]interface{}{ 391 "DisableFrontingProviderTLSProfiles": protocol.LabeledTLSProfiles{"validLabel": tlsProfiles}, 392 "DisableFrontingProviderQUICVersions": protocol.LabeledQUICVersions{"validLabel": quicVersions}, 393 } 394 395 _, err = p.Set("", false, applyParameters) 396 if err != nil { 397 t.Fatalf("Set failed: %s", err) 398 } 399 400 disableTLSProfiles := p.Get().LabeledTLSProfiles(DisableFrontingProviderTLSProfiles, "validLabel") 401 if !reflect.DeepEqual(disableTLSProfiles, tlsProfiles) { 402 t.Fatalf("LabeledTLSProfiles returned %+v expected %+v", disableTLSProfiles, tlsProfiles) 403 } 404 405 disableTLSProfiles = p.Get().LabeledTLSProfiles(DisableFrontingProviderTLSProfiles, "invalidLabel") 406 if disableTLSProfiles != nil { 407 t.Fatalf("LabeledTLSProfiles returned unexpected non-empty list %+v", disableTLSProfiles) 408 } 409 410 disableQUICVersions := p.Get().LabeledQUICVersions(DisableFrontingProviderQUICVersions, "validLabel") 411 if !reflect.DeepEqual(disableQUICVersions, quicVersions) { 412 t.Fatalf("LabeledQUICVersions returned %+v expected %+v", disableQUICVersions, quicVersions) 413 } 414 415 disableQUICVersions = p.Get().LabeledQUICVersions(DisableFrontingProviderQUICVersions, "invalidLabel") 416 if disableQUICVersions != nil { 417 t.Fatalf("LabeledQUICVersions returned unexpected non-empty list %+v", disableQUICVersions) 418 } 419 } 420 421 func TestCustomTLSProfiles(t *testing.T) { 422 p, err := NewParameters(nil) 423 if err != nil { 424 t.Fatalf("NewParameters failed: %s", err) 425 } 426 427 customTLSProfiles := protocol.CustomTLSProfiles{ 428 &protocol.CustomTLSProfile{Name: "Profile1", UTLSSpec: &protocol.UTLSSpec{}}, 429 &protocol.CustomTLSProfile{Name: "Profile2", UTLSSpec: &protocol.UTLSSpec{}}, 430 } 431 432 applyParameters := map[string]interface{}{ 433 "CustomTLSProfiles": customTLSProfiles} 434 435 _, err = p.Set("", false, applyParameters) 436 if err != nil { 437 t.Fatalf("Set failed: %s", err) 438 } 439 440 names := p.Get().CustomTLSProfileNames() 441 442 if len(names) != 2 || names[0] != "Profile1" || names[1] != "Profile2" { 443 t.Fatalf("Unexpected CustomTLSProfileNames: %+v", names) 444 } 445 446 profile := p.Get().CustomTLSProfile("Profile1") 447 if profile == nil || profile.Name != "Profile1" { 448 t.Fatalf("Unexpected profile") 449 } 450 451 profile = p.Get().CustomTLSProfile("Profile2") 452 if profile == nil || profile.Name != "Profile2" { 453 t.Fatalf("Unexpected profile") 454 } 455 456 profile = p.Get().CustomTLSProfile("Profile3") 457 if profile != nil { 458 t.Fatalf("Unexpected profile") 459 } 460 } 461 462 func TestApplicationParameters(t *testing.T) { 463 464 parametersJSON := []byte(` 465 { 466 "ApplicationParameters" : { 467 "AppFlag1" : true, 468 "AppConfig1" : {"Option1" : "A", "Option2" : "B"}, 469 "AppSwitches1" : [1, 2, 3, 4] 470 } 471 } 472 `) 473 474 validators := map[string]func(v interface{}) bool{ 475 "AppFlag1": func(v interface{}) bool { return reflect.DeepEqual(v, true) }, 476 "AppConfig1": func(v interface{}) bool { 477 return reflect.DeepEqual(v, map[string]interface{}{"Option1": "A", "Option2": "B"}) 478 }, 479 "AppSwitches1": func(v interface{}) bool { 480 return reflect.DeepEqual(v, []interface{}{float64(1), float64(2), float64(3), float64(4)}) 481 }, 482 } 483 484 var applyParameters map[string]interface{} 485 err := json.Unmarshal(parametersJSON, &applyParameters) 486 if err != nil { 487 t.Fatalf("Unmarshal failed: %s", err) 488 } 489 490 p, err := NewParameters(nil) 491 if err != nil { 492 t.Fatalf("NewParameters failed: %s", err) 493 } 494 495 _, err = p.Set("", false, applyParameters) 496 if err != nil { 497 t.Fatalf("Set failed: %s", err) 498 } 499 500 keyValues := p.Get().KeyValues(ApplicationParameters) 501 502 if len(keyValues) != len(validators) { 503 t.Fatalf("Unexpected key value count") 504 } 505 506 for key, value := range keyValues { 507 508 validator, ok := validators[key] 509 if !ok { 510 t.Fatalf("Unexpected key: %s", key) 511 } 512 513 var unmarshaledValue interface{} 514 err := json.Unmarshal(value, &unmarshaledValue) 515 if err != nil { 516 t.Fatalf("Unmarshal failed: %s", err) 517 } 518 519 if !validator(unmarshaledValue) { 520 t.Fatalf("Invalid value: %s, %T: %+v", 521 key, unmarshaledValue, unmarshaledValue) 522 } 523 } 524 }