github.com/hashicorp/vault/sdk@v0.11.0/helper/ldaputil/client_test.go (about)

     1  // Copyright (c) HashiCorp, Inc.
     2  // SPDX-License-Identifier: MPL-2.0
     3  
     4  package ldaputil
     5  
     6  import (
     7  	"testing"
     8  
     9  	"github.com/hashicorp/go-hclog"
    10  	"github.com/stretchr/testify/assert"
    11  	"github.com/stretchr/testify/require"
    12  )
    13  
    14  // TestDialLDAP duplicates a potential panic that was
    15  // present in the previous version of TestDialLDAP,
    16  // then confirms its fix by passing.
    17  func TestDialLDAP(t *testing.T) {
    18  	ldapClient := Client{
    19  		Logger: hclog.NewNullLogger(),
    20  		LDAP:   NewLDAP(),
    21  	}
    22  
    23  	ce := &ConfigEntry{
    24  		Url:            "ldap://localhost:384654786",
    25  		RequestTimeout: 3,
    26  	}
    27  	if _, err := ldapClient.DialLDAP(ce); err == nil {
    28  		t.Fatal("expected error")
    29  	}
    30  }
    31  
    32  func TestLDAPEscape(t *testing.T) {
    33  	testcases := map[string]string{
    34  		"#test":                      "\\#test",
    35  		"test,hello":                 "test\\,hello",
    36  		"test,hel+lo":                "test\\,hel\\+lo",
    37  		"test\\hello":                "test\\\\hello",
    38  		"  test  ":                   "\\  test \\ ",
    39  		"":                           "",
    40  		`\`:                          `\\`,
    41  		"trailing\000":               `trailing\00`,
    42  		"mid\000dle":                 `mid\00dle`,
    43  		"\000":                       `\00`,
    44  		"multiple\000\000":           `multiple\00\00`,
    45  		"backlash-before-null\\\000": `backlash-before-null\\\00`,
    46  		"trailing\\":                 `trailing\\`,
    47  		"double-escaping\\>":         `double-escaping\\\>`,
    48  	}
    49  
    50  	for test, answer := range testcases {
    51  		res := EscapeLDAPValue(test)
    52  		if res != answer {
    53  			t.Errorf("Failed to escape %s: %s != %s\n", test, res, answer)
    54  		}
    55  	}
    56  }
    57  
    58  func TestGetTLSConfigs(t *testing.T) {
    59  	config := testConfig(t)
    60  	if err := config.Validate(); err != nil {
    61  		t.Fatal(err)
    62  	}
    63  	tlsConfig, err := getTLSConfig(config, "138.91.247.105")
    64  	if err != nil {
    65  		t.Fatal(err)
    66  	}
    67  	if tlsConfig == nil {
    68  		t.Fatal("expected 1 TLS config because there's 1 url")
    69  	}
    70  	if tlsConfig.InsecureSkipVerify {
    71  		t.Fatal("InsecureSkipVerify should be false because we should default to the most secure connection")
    72  	}
    73  	if tlsConfig.ServerName != "138.91.247.105" {
    74  		t.Fatalf("expected ServerName of \"138.91.247.105\" but received %q", tlsConfig.ServerName)
    75  	}
    76  	expected := uint16(771)
    77  	if tlsConfig.MinVersion != expected || tlsConfig.MaxVersion != expected {
    78  		t.Fatal("expected TLS min and max version of 771 which corresponds with TLS 1.2 since TLS 1.1 and 1.0 have known vulnerabilities")
    79  	}
    80  }
    81  
    82  func TestSIDBytesToString(t *testing.T) {
    83  	testcases := map[string][]byte{
    84  		"S-1-5-21-2127521184-1604012920-1887927527-72713": {0x01, 0x05, 0x00, 0x00, 0x00, 0x00, 0x00, 0x05, 0x15, 0x00, 0x00, 0x00, 0xA0, 0x65, 0xCF, 0x7E, 0x78, 0x4B, 0x9B, 0x5F, 0xE7, 0x7C, 0x87, 0x70, 0x09, 0x1C, 0x01, 0x00},
    85  		"S-1-1-0": {0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00},
    86  		"S-1-5":   {0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x05},
    87  	}
    88  
    89  	for answer, test := range testcases {
    90  		res, err := sidBytesToString(test)
    91  		if err != nil {
    92  			t.Errorf("Failed to conver %#v: %s", test, err)
    93  		} else if answer != res {
    94  			t.Errorf("Failed to convert %#v: %s != %s", test, res, answer)
    95  		}
    96  	}
    97  }
    98  
    99  func TestClient_renderUserSearchFilter(t *testing.T) {
   100  	t.Parallel()
   101  	tests := []struct {
   102  		name        string
   103  		conf        *ConfigEntry
   104  		username    string
   105  		want        string
   106  		errContains string
   107  	}{
   108  		{
   109  			name:     "valid-default",
   110  			username: "alice",
   111  			conf: &ConfigEntry{
   112  				UserAttr: "cn",
   113  			},
   114  			want: "(cn=alice)",
   115  		},
   116  		{
   117  			name:     "escaped-malicious-filter",
   118  			username: "foo@example.com)((((((((((((((((((((((((((((((((((((((userPrincipalName=foo",
   119  			conf: &ConfigEntry{
   120  				UPNDomain:  "example.com",
   121  				UserFilter: "(&({{.UserAttr}}={{.Username}})({{.UserAttr}}=admin@example.com))",
   122  			},
   123  			want: "(&(userPrincipalName=foo@example.com\\29\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28\\28userPrincipalName=foo@example.com)(userPrincipalName=admin@example.com))",
   124  		},
   125  		{
   126  			name:     "bad-filter-unclosed-action",
   127  			username: "alice",
   128  			conf: &ConfigEntry{
   129  				UserFilter: "hello{{range",
   130  			},
   131  			errContains: "search failed due to template compilation error",
   132  		},
   133  	}
   134  	for _, tc := range tests {
   135  		t.Run(tc.name, func(t *testing.T) {
   136  			c := Client{
   137  				Logger: hclog.NewNullLogger(),
   138  				LDAP:   NewLDAP(),
   139  			}
   140  
   141  			f, err := c.RenderUserSearchFilter(tc.conf, tc.username)
   142  			if tc.errContains != "" {
   143  				require.Error(t, err)
   144  				assert.ErrorContains(t, err, tc.errContains)
   145  				return
   146  			}
   147  			require.NoError(t, err)
   148  			assert.NotEmpty(t, f)
   149  			assert.Equal(t, tc.want, f)
   150  		})
   151  	}
   152  }