github.com/gravitational/teleport/api@v0.0.0-20240507183017-3110591cbafc/types/databaseserver_test.go (about) 1 /* 2 Copyright 2021 Gravitational, Inc. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package types 18 19 import ( 20 "fmt" 21 "testing" 22 23 "github.com/gravitational/trace" 24 "github.com/stretchr/testify/require" 25 ) 26 27 func TestDatabaseServerSorter(t *testing.T) { 28 t.Parallel() 29 30 testValsUnordered := []string{"d", "b", "a", "c"} 31 32 // DB types are hardcoded and types are determined 33 // by which spec fields are set, values don't matter. 34 // Used to randomly assign db types. 35 dbSpecs := []DatabaseSpecV3{ 36 // type redshift 37 { 38 Protocol: "_", 39 URI: "_", 40 AWS: AWS{ 41 Redshift: Redshift{ 42 ClusterID: "_", 43 }, 44 }, 45 }, 46 // type azure 47 { 48 Protocol: "_", 49 URI: "_", 50 Azure: Azure{ 51 Name: "_", 52 }, 53 }, 54 // type rds 55 { 56 Protocol: "_", 57 URI: "_", 58 AWS: AWS{ 59 Region: "_", 60 }, 61 }, 62 // type gcp 63 { 64 Protocol: "_", 65 URI: "_", 66 GCP: GCPCloudSQL{ 67 ProjectID: "_", 68 InstanceID: "_", 69 }, 70 }, 71 } 72 73 cases := []struct { 74 name string 75 wantErr bool 76 fieldName string 77 }{ 78 { 79 name: "by name", 80 fieldName: ResourceMetadataName, 81 }, 82 { 83 name: "by description", 84 fieldName: ResourceSpecDescription, 85 }, 86 { 87 name: "by type", 88 fieldName: ResourceSpecType, 89 }, 90 } 91 92 for _, c := range cases { 93 c := c 94 t.Run(fmt.Sprintf("%s desc", c.name), func(t *testing.T) { 95 sortBy := SortBy{Field: c.fieldName, IsDesc: true} 96 servers := DatabaseServers(makeServers(t, testValsUnordered, dbSpecs, c.fieldName)) 97 require.NoError(t, servers.SortByCustom(sortBy)) 98 targetVals, err := servers.GetFieldVals(c.fieldName) 99 require.NoError(t, err) 100 require.IsDecreasing(t, targetVals) 101 }) 102 103 t.Run(fmt.Sprintf("%s asc", c.name), func(t *testing.T) { 104 sortBy := SortBy{Field: c.fieldName} 105 servers := DatabaseServers(makeServers(t, testValsUnordered, dbSpecs, c.fieldName)) 106 require.NoError(t, servers.SortByCustom(sortBy)) 107 targetVals, err := servers.GetFieldVals(c.fieldName) 108 require.NoError(t, err) 109 require.IsIncreasing(t, targetVals) 110 }) 111 } 112 113 // Test error. 114 sortBy := SortBy{Field: "unsupported"} 115 servers := makeServers(t, testValsUnordered, dbSpecs, "does-not-matter") 116 require.True(t, trace.IsNotImplemented(DatabaseServers(servers).SortByCustom(sortBy))) 117 } 118 119 func makeServers(t *testing.T, testVals []string, dbSpecs []DatabaseSpecV3, testField string) []DatabaseServer { 120 t.Helper() 121 servers := make([]DatabaseServer, len(testVals)) 122 for i := 0; i < len(testVals); i++ { 123 testVal := testVals[i] 124 dbSpec := dbSpecs[i%len(dbSpecs)] 125 var err error 126 127 servers[i], err = NewDatabaseServerV3(Metadata{ 128 Name: "foo", 129 }, DatabaseServerSpecV3{ 130 HostID: "_", 131 Hostname: "_", 132 Database: &DatabaseV3{ 133 Metadata: Metadata{ 134 Name: getTestVal(testField == ResourceMetadataName, testVal), 135 Description: getTestVal(testField == ResourceSpecDescription, testVal), 136 }, 137 Spec: dbSpec, 138 }, 139 }) 140 require.NoError(t, err) 141 } 142 return servers 143 } 144 145 func TestDatabaseServersToDatabases(t *testing.T) { 146 t.Parallel() 147 148 databaseServers := []DatabaseServer{ 149 makeDatabaseServer(t, "db1", "agent1"), 150 makeDatabaseServer(t, "db1", "agent2"), 151 makeDatabaseServer(t, "db2", "agent1"), 152 makeDatabaseServer(t, "db3", "agent2"), 153 makeDatabaseServer(t, "db3", "agent3"), 154 } 155 156 wantDatabases := []Database{ 157 databaseServers[0].GetDatabase(), // db1 158 databaseServers[2].GetDatabase(), // db2 159 databaseServers[3].GetDatabase(), // db3 160 } 161 162 actualDatabases := DatabaseServers(databaseServers).ToDatabases() 163 require.Equal(t, wantDatabases, actualDatabases) 164 } 165 166 func makeDatabaseServer(t *testing.T, dbName, agentName string) DatabaseServer { 167 t.Helper() 168 169 databaseServer, err := NewDatabaseServerV3(Metadata{ 170 Name: dbName, 171 }, DatabaseServerSpecV3{ 172 HostID: agentName, 173 Hostname: agentName, 174 Database: &DatabaseV3{ 175 Metadata: Metadata{ 176 Name: dbName, 177 }, 178 Spec: DatabaseSpecV3{ 179 Protocol: "postgres", 180 URI: "localhost:5432", 181 }, 182 }, 183 }) 184 require.NoError(t, err) 185 return databaseServer 186 }