github.com/gravitational/teleport/api@v0.0.0-20240507183017-3110591cbafc/utils/route_test.go (about)

     1  // Copyright 2023 Gravitational, Inc
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package utils
    16  
    17  import (
    18  	"context"
    19  	"testing"
    20  
    21  	"github.com/google/uuid"
    22  	"github.com/stretchr/testify/require"
    23  )
    24  
    25  // NOTE: much of the details of the behavior of this type is tested in lib/proxy as part
    26  // of the main router test coverage.
    27  
    28  // TestSSHRouteMatcherHostnameMatching verifies the expected behavior of the custom ssh
    29  // hostname matching logic.
    30  func TestSSHRouteMatcherHostnameMatching(t *testing.T) {
    31  	tts := []struct {
    32  		desc        string
    33  		principal   string
    34  		target      string
    35  		insensitive bool
    36  		match       bool
    37  	}{
    38  		{
    39  			desc:        "upper-eq",
    40  			principal:   "Foo",
    41  			target:      "Foo",
    42  			insensitive: true,
    43  			match:       true,
    44  		},
    45  		{
    46  			desc:        "lower-eq",
    47  			principal:   "foo",
    48  			target:      "foo",
    49  			insensitive: true,
    50  			match:       true,
    51  		},
    52  		{
    53  			desc:        "lower-target-match",
    54  			principal:   "Foo",
    55  			target:      "foo",
    56  			insensitive: true,
    57  			match:       true,
    58  		},
    59  		{
    60  			desc:        "upper-target-mismatch",
    61  			principal:   "foo",
    62  			target:      "Foo",
    63  			insensitive: true,
    64  			match:       false,
    65  		},
    66  		{
    67  			desc:        "upper-mismatch",
    68  			principal:   "Foo",
    69  			target:      "fOO",
    70  			insensitive: true,
    71  			match:       false,
    72  		},
    73  		{
    74  			desc:        "non-ascii-match",
    75  			principal:   "🌲",
    76  			target:      "🌲",
    77  			insensitive: true,
    78  			match:       true,
    79  		},
    80  		{
    81  			desc:        "non-ascii-mismatch",
    82  			principal:   "🌲",
    83  			target:      "🔥",
    84  			insensitive: true,
    85  			match:       false,
    86  		},
    87  		{
    88  			desc:        "sensitive-match",
    89  			principal:   "Foo",
    90  			target:      "Foo",
    91  			insensitive: false,
    92  			match:       true,
    93  		},
    94  		{
    95  			desc:        "sensitive-mismatch",
    96  			principal:   "Foo",
    97  			target:      "foo",
    98  			insensitive: false,
    99  			match:       false,
   100  		},
   101  	}
   102  
   103  	for _, tt := range tts {
   104  		matcher := NewSSHRouteMatcher(tt.target, "", tt.insensitive)
   105  		require.Equal(t, tt.match, matcher.routeToHostname(tt.principal), "desc=%q", tt.desc)
   106  	}
   107  }
   108  
   109  type mockRouteableServer struct {
   110  	name       string
   111  	hostname   string
   112  	addr       string
   113  	useTunnel  bool
   114  	publicAddr []string
   115  }
   116  
   117  func (m mockRouteableServer) GetName() string {
   118  	return m.name
   119  }
   120  
   121  func (m mockRouteableServer) GetHostname() string {
   122  	return m.hostname
   123  }
   124  
   125  func (m mockRouteableServer) GetAddr() string {
   126  	return m.addr
   127  }
   128  
   129  func (m mockRouteableServer) GetUseTunnel() bool {
   130  	return m.useTunnel
   131  }
   132  
   133  func (m mockRouteableServer) GetPublicAddrs() []string {
   134  	return m.publicAddr
   135  }
   136  
   137  func TestRouteToServer(t *testing.T) {
   138  	t.Parallel()
   139  	testUUID := uuid.NewString()
   140  
   141  	matchAddrServer := mockRouteableServer{
   142  		name:       "test",
   143  		addr:       "example.com:1111",
   144  		publicAddr: []string{"node:1234", "public.example.com:1111"},
   145  	}
   146  
   147  	tests := []struct {
   148  		name    string
   149  		matcher SSHRouteMatcher
   150  		server  RouteableServer
   151  		assert  require.BoolAssertionFunc
   152  	}{
   153  		{
   154  			name:    "no match",
   155  			matcher: NewSSHRouteMatcher(testUUID, "", true),
   156  			server: mockRouteableServer{
   157  				name:       "test",
   158  				addr:       "localhost",
   159  				hostname:   "example.com",
   160  				publicAddr: []string{"example.com"},
   161  			},
   162  			assert: require.False,
   163  		},
   164  		{
   165  			name:    "match by server name",
   166  			matcher: NewSSHRouteMatcher(testUUID, "", true),
   167  			server: mockRouteableServer{
   168  				name:       testUUID,
   169  				addr:       "localhost",
   170  				hostname:   "example.com",
   171  				publicAddr: []string{"example.com"},
   172  			},
   173  			assert: require.True,
   174  		},
   175  		{
   176  			name:    "match by hostname over tunnel",
   177  			matcher: NewSSHRouteMatcher("example.com", "", true),
   178  			server: mockRouteableServer{
   179  				name:       testUUID,
   180  				addr:       "addr.example.com",
   181  				hostname:   "example.com",
   182  				publicAddr: []string{"public.example.com"},
   183  				useTunnel:  true,
   184  			},
   185  			assert: require.True,
   186  		},
   187  		{
   188  			name:    "mismatch hostname over tunnel",
   189  			matcher: NewSSHRouteMatcher("example.com", "", true),
   190  			server: mockRouteableServer{
   191  				name:       testUUID,
   192  				addr:       "example.com",
   193  				hostname:   "fake.example.com",
   194  				publicAddr: []string{"example.com"},
   195  				useTunnel:  true,
   196  			},
   197  			assert: require.False,
   198  		},
   199  		{
   200  			name:    "match addr",
   201  			matcher: NewSSHRouteMatcher("example.com", "1111", true),
   202  			server:  matchAddrServer,
   203  			assert:  require.True,
   204  		},
   205  		{
   206  			name:    "match addr with empty port",
   207  			matcher: NewSSHRouteMatcher("example.com", "", true),
   208  			server:  matchAddrServer,
   209  			assert:  require.True,
   210  		},
   211  		{
   212  			name:    "mismatch addr with wrong port",
   213  			matcher: NewSSHRouteMatcher("example.com", "2222", true),
   214  			server:  matchAddrServer,
   215  			assert:  require.False,
   216  		},
   217  		{
   218  			name:    "match first public addr",
   219  			matcher: NewSSHRouteMatcher("node", "1234", true),
   220  			server:  matchAddrServer,
   221  			assert:  require.True,
   222  		},
   223  		{
   224  			name:    "match second public addr",
   225  			matcher: NewSSHRouteMatcher("public.example.com", "1111", true),
   226  			server:  matchAddrServer,
   227  			assert:  require.True,
   228  		},
   229  		{
   230  			name:    "match public addr with empty port",
   231  			matcher: NewSSHRouteMatcher("public.example.com", "", true),
   232  			server:  matchAddrServer,
   233  			assert:  require.True,
   234  		},
   235  		{
   236  			name:    "mismatch public addr with wrong port",
   237  			matcher: NewSSHRouteMatcher("public.example.com", "2222", true),
   238  			server:  matchAddrServer,
   239  			assert:  require.False,
   240  		},
   241  	}
   242  	for _, tc := range tests {
   243  		t.Run(tc.name, func(t *testing.T) {
   244  			tc.assert(t, tc.matcher.RouteToServer(tc.server))
   245  		})
   246  	}
   247  }
   248  
   249  type mockHostResolver struct {
   250  	ips []string
   251  }
   252  
   253  func (r mockHostResolver) LookupHost(ctx context.Context, host string) (addrs []string, err error) {
   254  	return r.ips, nil
   255  }
   256  
   257  // TestSSHRouteMatcherScoring verifies the expected scoring behavior of SSHRouteMatcher.
   258  func TestSSHRouteMatcherScoring(t *testing.T) {
   259  	t.Parallel()
   260  
   261  	// set up matcher with mock resolver in order to control ips
   262  	matcher, err := NewSSHRouteMatcherFromConfig(SSHRouteMatcherConfig{
   263  		Host: "foo.example.com",
   264  		Resolver: mockHostResolver{
   265  			ips: []string{
   266  				"1.2.3.4",
   267  				"4.5.6.7",
   268  			},
   269  		},
   270  	})
   271  	require.NoError(t, err)
   272  
   273  	tts := []struct {
   274  		desc     string
   275  		hostname string
   276  		addrs    []string
   277  		score    int
   278  	}{
   279  		{
   280  			desc:     "multi factor match",
   281  			hostname: "foo.example.com",
   282  			addrs: []string{
   283  				"1.2.3.4:0",
   284  			},
   285  			score: directMatch,
   286  		},
   287  		{
   288  			desc:     "ip match only",
   289  			hostname: "bar.example.com",
   290  			addrs: []string{
   291  				"1.2.3.4:0",
   292  			},
   293  			score: indirectMatch,
   294  		},
   295  		{
   296  			desc:     "hostname match only",
   297  			hostname: "foo.example.com",
   298  			addrs: []string{
   299  				"7.7.7.7:0",
   300  			},
   301  			score: directMatch,
   302  		},
   303  		{
   304  			desc:     "not match",
   305  			hostname: "bar.example.com",
   306  			addrs: []string{
   307  				"0.0.0.0:0",
   308  				"1.1.1.1:0",
   309  			},
   310  			score: notMatch,
   311  		},
   312  	}
   313  
   314  	for _, tt := range tts {
   315  		t.Run(tt.desc, func(t *testing.T) {
   316  			score := matcher.RouteToServerScore(mockRouteableServer{
   317  				name:       uuid.NewString(),
   318  				hostname:   tt.hostname,
   319  				publicAddr: tt.addrs,
   320  			})
   321  
   322  			require.Equal(t, tt.score, score)
   323  		})
   324  	}
   325  }