sigs.k8s.io/external-dns@v0.14.1/provider/coredns/coredns_test.go (about)

     1  /*
     2  Copyright 2017 The Kubernetes Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package coredns
    18  
    19  import (
    20  	"context"
    21  	"strings"
    22  	"testing"
    23  
    24  	"sigs.k8s.io/external-dns/endpoint"
    25  	"sigs.k8s.io/external-dns/plan"
    26  
    27  	"github.com/stretchr/testify/require"
    28  )
    29  
    30  const defaultCoreDNSPrefix = "/skydns/"
    31  
    32  type fakeETCDClient struct {
    33  	services map[string]Service
    34  }
    35  
    36  func (c fakeETCDClient) GetServices(prefix string) ([]*Service, error) {
    37  	var result []*Service
    38  	for key, value := range c.services {
    39  		if strings.HasPrefix(key, prefix) {
    40  			valueCopy := value
    41  			valueCopy.Key = key
    42  			result = append(result, &valueCopy)
    43  		}
    44  	}
    45  	return result, nil
    46  }
    47  
    48  func (c fakeETCDClient) SaveService(service *Service) error {
    49  	c.services[service.Key] = *service
    50  	return nil
    51  }
    52  
    53  func (c fakeETCDClient) DeleteService(key string) error {
    54  	delete(c.services, key)
    55  	return nil
    56  }
    57  
    58  func TestAServiceTranslation(t *testing.T) {
    59  	expectedTarget := "1.2.3.4"
    60  	expectedDNSName := "example.com"
    61  	expectedRecordType := endpoint.RecordTypeA
    62  
    63  	client := fakeETCDClient{
    64  		map[string]Service{
    65  			"/skydns/com/example": {Host: expectedTarget},
    66  		},
    67  	}
    68  	provider := coreDNSProvider{
    69  		client:        client,
    70  		coreDNSPrefix: defaultCoreDNSPrefix,
    71  	}
    72  	endpoints, err := provider.Records(context.Background())
    73  	require.NoError(t, err)
    74  	if len(endpoints) != 1 {
    75  		t.Fatalf("got unexpected number of endpoints: %d", len(endpoints))
    76  	}
    77  	if endpoints[0].DNSName != expectedDNSName {
    78  		t.Errorf("got unexpected DNS name: %s != %s", endpoints[0].DNSName, expectedDNSName)
    79  	}
    80  	if endpoints[0].Targets[0] != expectedTarget {
    81  		t.Errorf("got unexpected DNS target: %s != %s", endpoints[0].Targets[0], expectedTarget)
    82  	}
    83  	if endpoints[0].RecordType != expectedRecordType {
    84  		t.Errorf("got unexpected DNS record type: %s != %s", endpoints[0].RecordType, expectedRecordType)
    85  	}
    86  }
    87  
    88  func TestCNAMEServiceTranslation(t *testing.T) {
    89  	expectedTarget := "example.net"
    90  	expectedDNSName := "example.com"
    91  	expectedRecordType := endpoint.RecordTypeCNAME
    92  
    93  	client := fakeETCDClient{
    94  		map[string]Service{
    95  			"/skydns/com/example": {Host: expectedTarget},
    96  		},
    97  	}
    98  	provider := coreDNSProvider{
    99  		client:        client,
   100  		coreDNSPrefix: defaultCoreDNSPrefix,
   101  	}
   102  	endpoints, err := provider.Records(context.Background())
   103  	require.NoError(t, err)
   104  	if len(endpoints) != 1 {
   105  		t.Fatalf("got unexpected number of endpoints: %d", len(endpoints))
   106  	}
   107  	if endpoints[0].DNSName != expectedDNSName {
   108  		t.Errorf("got unexpected DNS name: %s != %s", endpoints[0].DNSName, expectedDNSName)
   109  	}
   110  	if endpoints[0].Targets[0] != expectedTarget {
   111  		t.Errorf("got unexpected DNS target: %s != %s", endpoints[0].Targets[0], expectedTarget)
   112  	}
   113  	if endpoints[0].RecordType != expectedRecordType {
   114  		t.Errorf("got unexpected DNS record type: %s != %s", endpoints[0].RecordType, expectedRecordType)
   115  	}
   116  }
   117  
   118  func TestTXTServiceTranslation(t *testing.T) {
   119  	expectedTarget := "string"
   120  	expectedDNSName := "example.com"
   121  	expectedRecordType := endpoint.RecordTypeTXT
   122  
   123  	client := fakeETCDClient{
   124  		map[string]Service{
   125  			"/skydns/com/example": {Text: expectedTarget},
   126  		},
   127  	}
   128  	provider := coreDNSProvider{
   129  		client:        client,
   130  		coreDNSPrefix: defaultCoreDNSPrefix,
   131  	}
   132  	endpoints, err := provider.Records(context.Background())
   133  	require.NoError(t, err)
   134  	if len(endpoints) != 1 {
   135  		t.Fatalf("got unexpected number of endpoints: %d", len(endpoints))
   136  	}
   137  	if endpoints[0].DNSName != expectedDNSName {
   138  		t.Errorf("got unexpected DNS name: %s != %s", endpoints[0].DNSName, expectedDNSName)
   139  	}
   140  	if endpoints[0].Targets[0] != expectedTarget {
   141  		t.Errorf("got unexpected DNS target: %s != %s", endpoints[0].Targets[0], expectedTarget)
   142  	}
   143  	if endpoints[0].RecordType != expectedRecordType {
   144  		t.Errorf("got unexpected DNS record type: %s != %s", endpoints[0].RecordType, expectedRecordType)
   145  	}
   146  }
   147  
   148  func TestAWithTXTServiceTranslation(t *testing.T) {
   149  	expectedTargets := map[string]string{
   150  		endpoint.RecordTypeA:   "1.2.3.4",
   151  		endpoint.RecordTypeTXT: "string",
   152  	}
   153  	expectedDNSName := "example.com"
   154  
   155  	client := fakeETCDClient{
   156  		map[string]Service{
   157  			"/skydns/com/example": {Host: "1.2.3.4", Text: "string"},
   158  		},
   159  	}
   160  	provider := coreDNSProvider{
   161  		client:        client,
   162  		coreDNSPrefix: defaultCoreDNSPrefix,
   163  	}
   164  	endpoints, err := provider.Records(context.Background())
   165  	require.NoError(t, err)
   166  	if len(endpoints) != len(expectedTargets) {
   167  		t.Fatalf("got unexpected number of endpoints: %d", len(endpoints))
   168  	}
   169  
   170  	for _, ep := range endpoints {
   171  		expectedTarget := expectedTargets[ep.RecordType]
   172  		if expectedTarget == "" {
   173  			t.Errorf("got unexpected DNS record type: %s", ep.RecordType)
   174  			continue
   175  		}
   176  		delete(expectedTargets, ep.RecordType)
   177  
   178  		if ep.DNSName != expectedDNSName {
   179  			t.Errorf("got unexpected DNS name: %s != %s", ep.DNSName, expectedDNSName)
   180  		}
   181  
   182  		if ep.Targets[0] != expectedTarget {
   183  			t.Errorf("got unexpected DNS target: %s != %s", ep.Targets[0], expectedTarget)
   184  		}
   185  	}
   186  }
   187  
   188  func TestCNAMEWithTXTServiceTranslation(t *testing.T) {
   189  	expectedTargets := map[string]string{
   190  		endpoint.RecordTypeCNAME: "example.net",
   191  		endpoint.RecordTypeTXT:   "string",
   192  	}
   193  	expectedDNSName := "example.com"
   194  
   195  	client := fakeETCDClient{
   196  		map[string]Service{
   197  			"/skydns/com/example": {Host: "example.net", Text: "string"},
   198  		},
   199  	}
   200  	provider := coreDNSProvider{
   201  		client:        client,
   202  		coreDNSPrefix: defaultCoreDNSPrefix,
   203  	}
   204  	endpoints, err := provider.Records(context.Background())
   205  	require.NoError(t, err)
   206  	if len(endpoints) != len(expectedTargets) {
   207  		t.Fatalf("got unexpected number of endpoints: %d", len(endpoints))
   208  	}
   209  
   210  	for _, ep := range endpoints {
   211  		expectedTarget := expectedTargets[ep.RecordType]
   212  		if expectedTarget == "" {
   213  			t.Errorf("got unexpected DNS record type: %s", ep.RecordType)
   214  			continue
   215  		}
   216  		delete(expectedTargets, ep.RecordType)
   217  
   218  		if ep.DNSName != expectedDNSName {
   219  			t.Errorf("got unexpected DNS name: %s != %s", ep.DNSName, expectedDNSName)
   220  		}
   221  
   222  		if ep.Targets[0] != expectedTarget {
   223  			t.Errorf("got unexpected DNS target: %s != %s", ep.Targets[0], expectedTarget)
   224  		}
   225  	}
   226  }
   227  
   228  func TestCoreDNSApplyChanges(t *testing.T) {
   229  	client := fakeETCDClient{
   230  		map[string]Service{},
   231  	}
   232  	coredns := coreDNSProvider{
   233  		client:        client,
   234  		coreDNSPrefix: defaultCoreDNSPrefix,
   235  	}
   236  
   237  	changes1 := &plan.Changes{
   238  		Create: []*endpoint.Endpoint{
   239  			endpoint.NewEndpoint("domain1.local", endpoint.RecordTypeA, "5.5.5.5"),
   240  			endpoint.NewEndpoint("domain1.local", endpoint.RecordTypeTXT, "string1"),
   241  			endpoint.NewEndpoint("domain2.local", endpoint.RecordTypeCNAME, "site.local"),
   242  		},
   243  	}
   244  	err := coredns.ApplyChanges(context.Background(), changes1)
   245  	require.NoError(t, err)
   246  
   247  	expectedServices1 := map[string][]*Service{
   248  		"/skydns/local/domain1": {{Host: "5.5.5.5", Text: "string1"}},
   249  		"/skydns/local/domain2": {{Host: "site.local"}},
   250  	}
   251  	validateServices(client.services, expectedServices1, t, 1)
   252  
   253  	changes2 := &plan.Changes{
   254  		Create: []*endpoint.Endpoint{
   255  			endpoint.NewEndpoint("domain3.local", endpoint.RecordTypeA, "7.7.7.7"),
   256  		},
   257  		UpdateNew: []*endpoint.Endpoint{
   258  			endpoint.NewEndpoint("domain1.local", "A", "6.6.6.6"),
   259  		},
   260  	}
   261  	records, _ := coredns.Records(context.Background())
   262  	for _, ep := range records {
   263  		if ep.DNSName == "domain1.local" {
   264  			changes2.UpdateOld = append(changes2.UpdateOld, ep)
   265  		}
   266  	}
   267  	err = applyServiceChanges(coredns, changes2)
   268  	require.NoError(t, err)
   269  
   270  	expectedServices2 := map[string][]*Service{
   271  		"/skydns/local/domain1": {{Host: "6.6.6.6", Text: "string1"}},
   272  		"/skydns/local/domain2": {{Host: "site.local"}},
   273  		"/skydns/local/domain3": {{Host: "7.7.7.7"}},
   274  	}
   275  	validateServices(client.services, expectedServices2, t, 2)
   276  
   277  	changes3 := &plan.Changes{
   278  		Delete: []*endpoint.Endpoint{
   279  			endpoint.NewEndpoint("domain1.local", endpoint.RecordTypeA, "6.6.6.6"),
   280  			endpoint.NewEndpoint("domain1.local", endpoint.RecordTypeTXT, "string"),
   281  			endpoint.NewEndpoint("domain3.local", endpoint.RecordTypeA, "7.7.7.7"),
   282  		},
   283  	}
   284  
   285  	err = applyServiceChanges(coredns, changes3)
   286  	require.NoError(t, err)
   287  
   288  	expectedServices3 := map[string][]*Service{
   289  		"/skydns/local/domain2": {{Host: "site.local"}},
   290  	}
   291  	validateServices(client.services, expectedServices3, t, 3)
   292  
   293  	// Test for multiple A records for the same FQDN
   294  	changes4 := &plan.Changes{
   295  		Create: []*endpoint.Endpoint{
   296  			endpoint.NewEndpoint("domain1.local", endpoint.RecordTypeA, "5.5.5.5"),
   297  			endpoint.NewEndpoint("domain1.local", endpoint.RecordTypeA, "6.6.6.6"),
   298  			endpoint.NewEndpoint("domain1.local", endpoint.RecordTypeA, "7.7.7.7"),
   299  		},
   300  	}
   301  	err = coredns.ApplyChanges(context.Background(), changes4)
   302  	require.NoError(t, err)
   303  
   304  	expectedServices4 := map[string][]*Service{
   305  		"/skydns/local/domain2": {{Host: "site.local"}},
   306  		"/skydns/local/domain1": {{Host: "5.5.5.5"}, {Host: "6.6.6.6"}, {Host: "7.7.7.7"}},
   307  	}
   308  	validateServices(client.services, expectedServices4, t, 4)
   309  }
   310  
   311  func applyServiceChanges(provider coreDNSProvider, changes *plan.Changes) error {
   312  	ctx := context.Background()
   313  	records, _ := provider.Records(ctx)
   314  	for _, col := range [][]*endpoint.Endpoint{changes.Create, changes.UpdateNew, changes.Delete} {
   315  		for _, record := range col {
   316  			for _, existingRecord := range records {
   317  				if existingRecord.DNSName == record.DNSName && existingRecord.RecordType == record.RecordType {
   318  					mergeLabels(record, existingRecord.Labels)
   319  				}
   320  			}
   321  		}
   322  	}
   323  	return provider.ApplyChanges(ctx, changes)
   324  }
   325  
   326  func validateServices(services map[string]Service, expectedServices map[string][]*Service, t *testing.T, step int) {
   327  	t.Helper()
   328  	for key, value := range services {
   329  		keyParts := strings.Split(key, "/")
   330  		expectedKey := strings.Join(keyParts[:len(keyParts)-value.TargetStrip], "/")
   331  		expectedServiceEntries := expectedServices[expectedKey]
   332  		if expectedServiceEntries == nil {
   333  			t.Errorf("unexpected service %s", key)
   334  			continue
   335  		}
   336  		found := false
   337  		for i, expectedServiceEntry := range expectedServiceEntries {
   338  			if value.Host == expectedServiceEntry.Host && value.Text == expectedServiceEntry.Text {
   339  				expectedServiceEntries = append(expectedServiceEntries[:i], expectedServiceEntries[i+1:]...)
   340  				found = true
   341  				break
   342  			}
   343  		}
   344  		if !found {
   345  			t.Errorf("unexpected service %s: %s on step %d", key, value.Host, step)
   346  		}
   347  		if len(expectedServiceEntries) == 0 {
   348  			delete(expectedServices, expectedKey)
   349  		} else {
   350  			expectedServices[expectedKey] = expectedServiceEntries
   351  		}
   352  	}
   353  	if len(expectedServices) != 0 {
   354  		t.Errorf("unmatched expected services: %+v on step %d", expectedServices, step)
   355  	}
   356  }
   357  
   358  // mergeLabels adds keys to labels if not defined for the endpoint
   359  func mergeLabels(e *endpoint.Endpoint, labels map[string]string) {
   360  	for k, v := range labels {
   361  		if e.Labels[k] == "" {
   362  			e.Labels[k] = v
   363  		}
   364  	}
   365  }