github.com/gravitational/teleport/api@v0.0.0-20240507183017-3110591cbafc/types/databaseserver_test.go (about)

     1  /*
     2  Copyright 2021 Gravitational, Inc.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package types
    18  
    19  import (
    20  	"fmt"
    21  	"testing"
    22  
    23  	"github.com/gravitational/trace"
    24  	"github.com/stretchr/testify/require"
    25  )
    26  
    27  func TestDatabaseServerSorter(t *testing.T) {
    28  	t.Parallel()
    29  
    30  	testValsUnordered := []string{"d", "b", "a", "c"}
    31  
    32  	// DB types are hardcoded and types are determined
    33  	// by which spec fields are set, values don't matter.
    34  	// Used to randomly assign db types.
    35  	dbSpecs := []DatabaseSpecV3{
    36  		// type redshift
    37  		{
    38  			Protocol: "_",
    39  			URI:      "_",
    40  			AWS: AWS{
    41  				Redshift: Redshift{
    42  					ClusterID: "_",
    43  				},
    44  			},
    45  		},
    46  		// type azure
    47  		{
    48  			Protocol: "_",
    49  			URI:      "_",
    50  			Azure: Azure{
    51  				Name: "_",
    52  			},
    53  		},
    54  		// type rds
    55  		{
    56  			Protocol: "_",
    57  			URI:      "_",
    58  			AWS: AWS{
    59  				Region: "_",
    60  			},
    61  		},
    62  		// type gcp
    63  		{
    64  			Protocol: "_",
    65  			URI:      "_",
    66  			GCP: GCPCloudSQL{
    67  				ProjectID:  "_",
    68  				InstanceID: "_",
    69  			},
    70  		},
    71  	}
    72  
    73  	cases := []struct {
    74  		name      string
    75  		wantErr   bool
    76  		fieldName string
    77  	}{
    78  		{
    79  			name:      "by name",
    80  			fieldName: ResourceMetadataName,
    81  		},
    82  		{
    83  			name:      "by description",
    84  			fieldName: ResourceSpecDescription,
    85  		},
    86  		{
    87  			name:      "by type",
    88  			fieldName: ResourceSpecType,
    89  		},
    90  	}
    91  
    92  	for _, c := range cases {
    93  		c := c
    94  		t.Run(fmt.Sprintf("%s desc", c.name), func(t *testing.T) {
    95  			sortBy := SortBy{Field: c.fieldName, IsDesc: true}
    96  			servers := DatabaseServers(makeServers(t, testValsUnordered, dbSpecs, c.fieldName))
    97  			require.NoError(t, servers.SortByCustom(sortBy))
    98  			targetVals, err := servers.GetFieldVals(c.fieldName)
    99  			require.NoError(t, err)
   100  			require.IsDecreasing(t, targetVals)
   101  		})
   102  
   103  		t.Run(fmt.Sprintf("%s asc", c.name), func(t *testing.T) {
   104  			sortBy := SortBy{Field: c.fieldName}
   105  			servers := DatabaseServers(makeServers(t, testValsUnordered, dbSpecs, c.fieldName))
   106  			require.NoError(t, servers.SortByCustom(sortBy))
   107  			targetVals, err := servers.GetFieldVals(c.fieldName)
   108  			require.NoError(t, err)
   109  			require.IsIncreasing(t, targetVals)
   110  		})
   111  	}
   112  
   113  	// Test error.
   114  	sortBy := SortBy{Field: "unsupported"}
   115  	servers := makeServers(t, testValsUnordered, dbSpecs, "does-not-matter")
   116  	require.True(t, trace.IsNotImplemented(DatabaseServers(servers).SortByCustom(sortBy)))
   117  }
   118  
   119  func makeServers(t *testing.T, testVals []string, dbSpecs []DatabaseSpecV3, testField string) []DatabaseServer {
   120  	t.Helper()
   121  	servers := make([]DatabaseServer, len(testVals))
   122  	for i := 0; i < len(testVals); i++ {
   123  		testVal := testVals[i]
   124  		dbSpec := dbSpecs[i%len(dbSpecs)]
   125  		var err error
   126  
   127  		servers[i], err = NewDatabaseServerV3(Metadata{
   128  			Name: "foo",
   129  		}, DatabaseServerSpecV3{
   130  			HostID:   "_",
   131  			Hostname: "_",
   132  			Database: &DatabaseV3{
   133  				Metadata: Metadata{
   134  					Name:        getTestVal(testField == ResourceMetadataName, testVal),
   135  					Description: getTestVal(testField == ResourceSpecDescription, testVal),
   136  				},
   137  				Spec: dbSpec,
   138  			},
   139  		})
   140  		require.NoError(t, err)
   141  	}
   142  	return servers
   143  }
   144  
   145  func TestDatabaseServersToDatabases(t *testing.T) {
   146  	t.Parallel()
   147  
   148  	databaseServers := []DatabaseServer{
   149  		makeDatabaseServer(t, "db1", "agent1"),
   150  		makeDatabaseServer(t, "db1", "agent2"),
   151  		makeDatabaseServer(t, "db2", "agent1"),
   152  		makeDatabaseServer(t, "db3", "agent2"),
   153  		makeDatabaseServer(t, "db3", "agent3"),
   154  	}
   155  
   156  	wantDatabases := []Database{
   157  		databaseServers[0].GetDatabase(), // db1
   158  		databaseServers[2].GetDatabase(), // db2
   159  		databaseServers[3].GetDatabase(), // db3
   160  	}
   161  
   162  	actualDatabases := DatabaseServers(databaseServers).ToDatabases()
   163  	require.Equal(t, wantDatabases, actualDatabases)
   164  }
   165  
   166  func makeDatabaseServer(t *testing.T, dbName, agentName string) DatabaseServer {
   167  	t.Helper()
   168  
   169  	databaseServer, err := NewDatabaseServerV3(Metadata{
   170  		Name: dbName,
   171  	}, DatabaseServerSpecV3{
   172  		HostID:   agentName,
   173  		Hostname: agentName,
   174  		Database: &DatabaseV3{
   175  			Metadata: Metadata{
   176  				Name: dbName,
   177  			},
   178  			Spec: DatabaseSpecV3{
   179  				Protocol: "postgres",
   180  				URI:      "localhost:5432",
   181  			},
   182  		},
   183  	})
   184  	require.NoError(t, err)
   185  	return databaseServer
   186  }