sigs.k8s.io/external-dns@v0.14.1/provider/coredns/coredns_test.go (about) 1 /* 2 Copyright 2017 The Kubernetes Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package coredns 18 19 import ( 20 "context" 21 "strings" 22 "testing" 23 24 "sigs.k8s.io/external-dns/endpoint" 25 "sigs.k8s.io/external-dns/plan" 26 27 "github.com/stretchr/testify/require" 28 ) 29 30 const defaultCoreDNSPrefix = "/skydns/" 31 32 type fakeETCDClient struct { 33 services map[string]Service 34 } 35 36 func (c fakeETCDClient) GetServices(prefix string) ([]*Service, error) { 37 var result []*Service 38 for key, value := range c.services { 39 if strings.HasPrefix(key, prefix) { 40 valueCopy := value 41 valueCopy.Key = key 42 result = append(result, &valueCopy) 43 } 44 } 45 return result, nil 46 } 47 48 func (c fakeETCDClient) SaveService(service *Service) error { 49 c.services[service.Key] = *service 50 return nil 51 } 52 53 func (c fakeETCDClient) DeleteService(key string) error { 54 delete(c.services, key) 55 return nil 56 } 57 58 func TestAServiceTranslation(t *testing.T) { 59 expectedTarget := "1.2.3.4" 60 expectedDNSName := "example.com" 61 expectedRecordType := endpoint.RecordTypeA 62 63 client := fakeETCDClient{ 64 map[string]Service{ 65 "/skydns/com/example": {Host: expectedTarget}, 66 }, 67 } 68 provider := coreDNSProvider{ 69 client: client, 70 coreDNSPrefix: defaultCoreDNSPrefix, 71 } 72 endpoints, err := provider.Records(context.Background()) 73 require.NoError(t, err) 74 if len(endpoints) != 1 { 75 t.Fatalf("got unexpected number of endpoints: %d", len(endpoints)) 76 } 77 if endpoints[0].DNSName != expectedDNSName { 78 t.Errorf("got unexpected DNS name: %s != %s", endpoints[0].DNSName, expectedDNSName) 79 } 80 if endpoints[0].Targets[0] != expectedTarget { 81 t.Errorf("got unexpected DNS target: %s != %s", endpoints[0].Targets[0], expectedTarget) 82 } 83 if endpoints[0].RecordType != expectedRecordType { 84 t.Errorf("got unexpected DNS record type: %s != %s", endpoints[0].RecordType, expectedRecordType) 85 } 86 } 87 88 func TestCNAMEServiceTranslation(t *testing.T) { 89 expectedTarget := "example.net" 90 expectedDNSName := "example.com" 91 expectedRecordType := endpoint.RecordTypeCNAME 92 93 client := fakeETCDClient{ 94 map[string]Service{ 95 "/skydns/com/example": {Host: expectedTarget}, 96 }, 97 } 98 provider := coreDNSProvider{ 99 client: client, 100 coreDNSPrefix: defaultCoreDNSPrefix, 101 } 102 endpoints, err := provider.Records(context.Background()) 103 require.NoError(t, err) 104 if len(endpoints) != 1 { 105 t.Fatalf("got unexpected number of endpoints: %d", len(endpoints)) 106 } 107 if endpoints[0].DNSName != expectedDNSName { 108 t.Errorf("got unexpected DNS name: %s != %s", endpoints[0].DNSName, expectedDNSName) 109 } 110 if endpoints[0].Targets[0] != expectedTarget { 111 t.Errorf("got unexpected DNS target: %s != %s", endpoints[0].Targets[0], expectedTarget) 112 } 113 if endpoints[0].RecordType != expectedRecordType { 114 t.Errorf("got unexpected DNS record type: %s != %s", endpoints[0].RecordType, expectedRecordType) 115 } 116 } 117 118 func TestTXTServiceTranslation(t *testing.T) { 119 expectedTarget := "string" 120 expectedDNSName := "example.com" 121 expectedRecordType := endpoint.RecordTypeTXT 122 123 client := fakeETCDClient{ 124 map[string]Service{ 125 "/skydns/com/example": {Text: expectedTarget}, 126 }, 127 } 128 provider := coreDNSProvider{ 129 client: client, 130 coreDNSPrefix: defaultCoreDNSPrefix, 131 } 132 endpoints, err := provider.Records(context.Background()) 133 require.NoError(t, err) 134 if len(endpoints) != 1 { 135 t.Fatalf("got unexpected number of endpoints: %d", len(endpoints)) 136 } 137 if endpoints[0].DNSName != expectedDNSName { 138 t.Errorf("got unexpected DNS name: %s != %s", endpoints[0].DNSName, expectedDNSName) 139 } 140 if endpoints[0].Targets[0] != expectedTarget { 141 t.Errorf("got unexpected DNS target: %s != %s", endpoints[0].Targets[0], expectedTarget) 142 } 143 if endpoints[0].RecordType != expectedRecordType { 144 t.Errorf("got unexpected DNS record type: %s != %s", endpoints[0].RecordType, expectedRecordType) 145 } 146 } 147 148 func TestAWithTXTServiceTranslation(t *testing.T) { 149 expectedTargets := map[string]string{ 150 endpoint.RecordTypeA: "1.2.3.4", 151 endpoint.RecordTypeTXT: "string", 152 } 153 expectedDNSName := "example.com" 154 155 client := fakeETCDClient{ 156 map[string]Service{ 157 "/skydns/com/example": {Host: "1.2.3.4", Text: "string"}, 158 }, 159 } 160 provider := coreDNSProvider{ 161 client: client, 162 coreDNSPrefix: defaultCoreDNSPrefix, 163 } 164 endpoints, err := provider.Records(context.Background()) 165 require.NoError(t, err) 166 if len(endpoints) != len(expectedTargets) { 167 t.Fatalf("got unexpected number of endpoints: %d", len(endpoints)) 168 } 169 170 for _, ep := range endpoints { 171 expectedTarget := expectedTargets[ep.RecordType] 172 if expectedTarget == "" { 173 t.Errorf("got unexpected DNS record type: %s", ep.RecordType) 174 continue 175 } 176 delete(expectedTargets, ep.RecordType) 177 178 if ep.DNSName != expectedDNSName { 179 t.Errorf("got unexpected DNS name: %s != %s", ep.DNSName, expectedDNSName) 180 } 181 182 if ep.Targets[0] != expectedTarget { 183 t.Errorf("got unexpected DNS target: %s != %s", ep.Targets[0], expectedTarget) 184 } 185 } 186 } 187 188 func TestCNAMEWithTXTServiceTranslation(t *testing.T) { 189 expectedTargets := map[string]string{ 190 endpoint.RecordTypeCNAME: "example.net", 191 endpoint.RecordTypeTXT: "string", 192 } 193 expectedDNSName := "example.com" 194 195 client := fakeETCDClient{ 196 map[string]Service{ 197 "/skydns/com/example": {Host: "example.net", Text: "string"}, 198 }, 199 } 200 provider := coreDNSProvider{ 201 client: client, 202 coreDNSPrefix: defaultCoreDNSPrefix, 203 } 204 endpoints, err := provider.Records(context.Background()) 205 require.NoError(t, err) 206 if len(endpoints) != len(expectedTargets) { 207 t.Fatalf("got unexpected number of endpoints: %d", len(endpoints)) 208 } 209 210 for _, ep := range endpoints { 211 expectedTarget := expectedTargets[ep.RecordType] 212 if expectedTarget == "" { 213 t.Errorf("got unexpected DNS record type: %s", ep.RecordType) 214 continue 215 } 216 delete(expectedTargets, ep.RecordType) 217 218 if ep.DNSName != expectedDNSName { 219 t.Errorf("got unexpected DNS name: %s != %s", ep.DNSName, expectedDNSName) 220 } 221 222 if ep.Targets[0] != expectedTarget { 223 t.Errorf("got unexpected DNS target: %s != %s", ep.Targets[0], expectedTarget) 224 } 225 } 226 } 227 228 func TestCoreDNSApplyChanges(t *testing.T) { 229 client := fakeETCDClient{ 230 map[string]Service{}, 231 } 232 coredns := coreDNSProvider{ 233 client: client, 234 coreDNSPrefix: defaultCoreDNSPrefix, 235 } 236 237 changes1 := &plan.Changes{ 238 Create: []*endpoint.Endpoint{ 239 endpoint.NewEndpoint("domain1.local", endpoint.RecordTypeA, "5.5.5.5"), 240 endpoint.NewEndpoint("domain1.local", endpoint.RecordTypeTXT, "string1"), 241 endpoint.NewEndpoint("domain2.local", endpoint.RecordTypeCNAME, "site.local"), 242 }, 243 } 244 err := coredns.ApplyChanges(context.Background(), changes1) 245 require.NoError(t, err) 246 247 expectedServices1 := map[string][]*Service{ 248 "/skydns/local/domain1": {{Host: "5.5.5.5", Text: "string1"}}, 249 "/skydns/local/domain2": {{Host: "site.local"}}, 250 } 251 validateServices(client.services, expectedServices1, t, 1) 252 253 changes2 := &plan.Changes{ 254 Create: []*endpoint.Endpoint{ 255 endpoint.NewEndpoint("domain3.local", endpoint.RecordTypeA, "7.7.7.7"), 256 }, 257 UpdateNew: []*endpoint.Endpoint{ 258 endpoint.NewEndpoint("domain1.local", "A", "6.6.6.6"), 259 }, 260 } 261 records, _ := coredns.Records(context.Background()) 262 for _, ep := range records { 263 if ep.DNSName == "domain1.local" { 264 changes2.UpdateOld = append(changes2.UpdateOld, ep) 265 } 266 } 267 err = applyServiceChanges(coredns, changes2) 268 require.NoError(t, err) 269 270 expectedServices2 := map[string][]*Service{ 271 "/skydns/local/domain1": {{Host: "6.6.6.6", Text: "string1"}}, 272 "/skydns/local/domain2": {{Host: "site.local"}}, 273 "/skydns/local/domain3": {{Host: "7.7.7.7"}}, 274 } 275 validateServices(client.services, expectedServices2, t, 2) 276 277 changes3 := &plan.Changes{ 278 Delete: []*endpoint.Endpoint{ 279 endpoint.NewEndpoint("domain1.local", endpoint.RecordTypeA, "6.6.6.6"), 280 endpoint.NewEndpoint("domain1.local", endpoint.RecordTypeTXT, "string"), 281 endpoint.NewEndpoint("domain3.local", endpoint.RecordTypeA, "7.7.7.7"), 282 }, 283 } 284 285 err = applyServiceChanges(coredns, changes3) 286 require.NoError(t, err) 287 288 expectedServices3 := map[string][]*Service{ 289 "/skydns/local/domain2": {{Host: "site.local"}}, 290 } 291 validateServices(client.services, expectedServices3, t, 3) 292 293 // Test for multiple A records for the same FQDN 294 changes4 := &plan.Changes{ 295 Create: []*endpoint.Endpoint{ 296 endpoint.NewEndpoint("domain1.local", endpoint.RecordTypeA, "5.5.5.5"), 297 endpoint.NewEndpoint("domain1.local", endpoint.RecordTypeA, "6.6.6.6"), 298 endpoint.NewEndpoint("domain1.local", endpoint.RecordTypeA, "7.7.7.7"), 299 }, 300 } 301 err = coredns.ApplyChanges(context.Background(), changes4) 302 require.NoError(t, err) 303 304 expectedServices4 := map[string][]*Service{ 305 "/skydns/local/domain2": {{Host: "site.local"}}, 306 "/skydns/local/domain1": {{Host: "5.5.5.5"}, {Host: "6.6.6.6"}, {Host: "7.7.7.7"}}, 307 } 308 validateServices(client.services, expectedServices4, t, 4) 309 } 310 311 func applyServiceChanges(provider coreDNSProvider, changes *plan.Changes) error { 312 ctx := context.Background() 313 records, _ := provider.Records(ctx) 314 for _, col := range [][]*endpoint.Endpoint{changes.Create, changes.UpdateNew, changes.Delete} { 315 for _, record := range col { 316 for _, existingRecord := range records { 317 if existingRecord.DNSName == record.DNSName && existingRecord.RecordType == record.RecordType { 318 mergeLabels(record, existingRecord.Labels) 319 } 320 } 321 } 322 } 323 return provider.ApplyChanges(ctx, changes) 324 } 325 326 func validateServices(services map[string]Service, expectedServices map[string][]*Service, t *testing.T, step int) { 327 t.Helper() 328 for key, value := range services { 329 keyParts := strings.Split(key, "/") 330 expectedKey := strings.Join(keyParts[:len(keyParts)-value.TargetStrip], "/") 331 expectedServiceEntries := expectedServices[expectedKey] 332 if expectedServiceEntries == nil { 333 t.Errorf("unexpected service %s", key) 334 continue 335 } 336 found := false 337 for i, expectedServiceEntry := range expectedServiceEntries { 338 if value.Host == expectedServiceEntry.Host && value.Text == expectedServiceEntry.Text { 339 expectedServiceEntries = append(expectedServiceEntries[:i], expectedServiceEntries[i+1:]...) 340 found = true 341 break 342 } 343 } 344 if !found { 345 t.Errorf("unexpected service %s: %s on step %d", key, value.Host, step) 346 } 347 if len(expectedServiceEntries) == 0 { 348 delete(expectedServices, expectedKey) 349 } else { 350 expectedServices[expectedKey] = expectedServiceEntries 351 } 352 } 353 if len(expectedServices) != 0 { 354 t.Errorf("unmatched expected services: %+v on step %d", expectedServices, step) 355 } 356 } 357 358 // mergeLabels adds keys to labels if not defined for the endpoint 359 func mergeLabels(e *endpoint.Endpoint, labels map[string]string) { 360 for k, v := range labels { 361 if e.Labels[k] == "" { 362 e.Labels[k] = v 363 } 364 } 365 }