sigs.k8s.io/external-dns@v0.14.1/provider/azure/azure.go (about)

     1  /*
     2  Copyright 2017 The Kubernetes Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  //nolint:staticcheck // Required due to the current dependency on a deprecated version of azure-sdk-for-go
    18  package azure
    19  
    20  import (
    21  	"context"
    22  	"fmt"
    23  	"strings"
    24  
    25  	log "github.com/sirupsen/logrus"
    26  
    27  	azcoreruntime "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
    28  	"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
    29  	dns "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/dns/armdns"
    30  
    31  	"sigs.k8s.io/external-dns/endpoint"
    32  	"sigs.k8s.io/external-dns/plan"
    33  	"sigs.k8s.io/external-dns/provider"
    34  )
    35  
    36  const (
    37  	azureRecordTTL = 300
    38  )
    39  
    40  // ZonesClient is an interface of dns.ZoneClient that can be stubbed for testing.
    41  type ZonesClient interface {
    42  	NewListByResourceGroupPager(resourceGroupName string, options *dns.ZonesClientListByResourceGroupOptions) *azcoreruntime.Pager[dns.ZonesClientListByResourceGroupResponse]
    43  }
    44  
    45  // RecordSetsClient is an interface of dns.RecordSetsClient that can be stubbed for testing.
    46  type RecordSetsClient interface {
    47  	NewListAllByDNSZonePager(resourceGroupName string, zoneName string, options *dns.RecordSetsClientListAllByDNSZoneOptions) *azcoreruntime.Pager[dns.RecordSetsClientListAllByDNSZoneResponse]
    48  	Delete(ctx context.Context, resourceGroupName string, zoneName string, relativeRecordSetName string, recordType dns.RecordType, options *dns.RecordSetsClientDeleteOptions) (dns.RecordSetsClientDeleteResponse, error)
    49  	CreateOrUpdate(ctx context.Context, resourceGroupName string, zoneName string, relativeRecordSetName string, recordType dns.RecordType, parameters dns.RecordSet, options *dns.RecordSetsClientCreateOrUpdateOptions) (dns.RecordSetsClientCreateOrUpdateResponse, error)
    50  }
    51  
    52  // AzureProvider implements the DNS provider for Microsoft's Azure cloud platform.
    53  type AzureProvider struct {
    54  	provider.BaseProvider
    55  	domainFilter                 endpoint.DomainFilter
    56  	zoneNameFilter               endpoint.DomainFilter
    57  	zoneIDFilter                 provider.ZoneIDFilter
    58  	dryRun                       bool
    59  	resourceGroup                string
    60  	userAssignedIdentityClientID string
    61  	zonesClient                  ZonesClient
    62  	recordSetsClient             RecordSetsClient
    63  }
    64  
    65  // NewAzureProvider creates a new Azure provider.
    66  //
    67  // Returns the provider or an error if a provider could not be created.
    68  func NewAzureProvider(configFile string, domainFilter endpoint.DomainFilter, zoneNameFilter endpoint.DomainFilter, zoneIDFilter provider.ZoneIDFilter, subscriptionID string, resourceGroup string, userAssignedIdentityClientID string, dryRun bool) (*AzureProvider, error) {
    69  	cfg, err := getConfig(configFile, subscriptionID, resourceGroup, userAssignedIdentityClientID)
    70  	if err != nil {
    71  		return nil, fmt.Errorf("failed to read Azure config file '%s': %v", configFile, err)
    72  	}
    73  	cred, clientOpts, err := getCredentials(*cfg)
    74  	if err != nil {
    75  		return nil, fmt.Errorf("failed to get credentials: %w", err)
    76  	}
    77  
    78  	zonesClient, err := dns.NewZonesClient(cfg.SubscriptionID, cred, clientOpts)
    79  	if err != nil {
    80  		return nil, err
    81  	}
    82  	recordSetsClient, err := dns.NewRecordSetsClient(cfg.SubscriptionID, cred, clientOpts)
    83  	if err != nil {
    84  		return nil, err
    85  	}
    86  	return &AzureProvider{
    87  		domainFilter:                 domainFilter,
    88  		zoneNameFilter:               zoneNameFilter,
    89  		zoneIDFilter:                 zoneIDFilter,
    90  		dryRun:                       dryRun,
    91  		resourceGroup:                cfg.ResourceGroup,
    92  		userAssignedIdentityClientID: cfg.UserAssignedIdentityID,
    93  		zonesClient:                  zonesClient,
    94  		recordSetsClient:             recordSetsClient,
    95  	}, nil
    96  }
    97  
    98  // Records gets the current records.
    99  //
   100  // Returns the current records or an error if the operation failed.
   101  func (p *AzureProvider) Records(ctx context.Context) (endpoints []*endpoint.Endpoint, _ error) {
   102  	zones, err := p.zones(ctx)
   103  	if err != nil {
   104  		return nil, err
   105  	}
   106  
   107  	for _, zone := range zones {
   108  		pager := p.recordSetsClient.NewListAllByDNSZonePager(p.resourceGroup, *zone.Name, &dns.RecordSetsClientListAllByDNSZoneOptions{Top: nil})
   109  		for pager.More() {
   110  			nextResult, err := pager.NextPage(ctx)
   111  			if err != nil {
   112  				return nil, provider.NewSoftError(fmt.Errorf("failed to fetch dns records: %w", err))
   113  			}
   114  			for _, recordSet := range nextResult.Value {
   115  				if recordSet.Name == nil || recordSet.Type == nil {
   116  					log.Error("Skipping invalid record set with nil name or type.")
   117  					continue
   118  				}
   119  				recordType := strings.TrimPrefix(*recordSet.Type, "Microsoft.Network/dnszones/")
   120  				if !p.SupportedRecordType(recordType) {
   121  					continue
   122  				}
   123  				name := formatAzureDNSName(*recordSet.Name, *zone.Name)
   124  				if len(p.zoneNameFilter.Filters) > 0 && !p.domainFilter.Match(name) {
   125  					log.Debugf("Skipping return of record %s because it was filtered out by the specified --domain-filter", name)
   126  					continue
   127  				}
   128  				targets := extractAzureTargets(recordSet)
   129  				if len(targets) == 0 {
   130  					log.Debugf("Failed to extract targets for '%s' with type '%s'.", name, recordType)
   131  					continue
   132  				}
   133  				var ttl endpoint.TTL
   134  				if recordSet.Properties.TTL != nil {
   135  					ttl = endpoint.TTL(*recordSet.Properties.TTL)
   136  				}
   137  				ep := endpoint.NewEndpointWithTTL(name, recordType, ttl, targets...)
   138  				log.Debugf(
   139  					"Found %s record for '%s' with target '%s'.",
   140  					ep.RecordType,
   141  					ep.DNSName,
   142  					ep.Targets,
   143  				)
   144  				endpoints = append(endpoints, ep)
   145  			}
   146  		}
   147  	}
   148  	return endpoints, nil
   149  }
   150  
   151  // ApplyChanges applies the given changes.
   152  //
   153  // Returns nil if the operation was successful or an error if the operation failed.
   154  func (p *AzureProvider) ApplyChanges(ctx context.Context, changes *plan.Changes) error {
   155  	zones, err := p.zones(ctx)
   156  	if err != nil {
   157  		return err
   158  	}
   159  
   160  	deleted, updated := p.mapChanges(zones, changes)
   161  	p.deleteRecords(ctx, deleted)
   162  	p.updateRecords(ctx, updated)
   163  	return nil
   164  }
   165  
   166  func (p *AzureProvider) zones(ctx context.Context) ([]dns.Zone, error) {
   167  	log.Debugf("Retrieving Azure DNS zones for resource group: %s.", p.resourceGroup)
   168  	var zones []dns.Zone
   169  	pager := p.zonesClient.NewListByResourceGroupPager(p.resourceGroup, &dns.ZonesClientListByResourceGroupOptions{Top: nil})
   170  	for pager.More() {
   171  		nextResult, err := pager.NextPage(ctx)
   172  		if err != nil {
   173  			return nil, err
   174  		}
   175  		for _, zone := range nextResult.Value {
   176  			if zone.Name != nil && p.domainFilter.Match(*zone.Name) && p.zoneIDFilter.Match(*zone.ID) {
   177  				zones = append(zones, *zone)
   178  			} else if zone.Name != nil && len(p.zoneNameFilter.Filters) > 0 && p.zoneNameFilter.Match(*zone.Name) {
   179  				// Handle zoneNameFilter
   180  				zones = append(zones, *zone)
   181  			}
   182  		}
   183  	}
   184  	log.Debugf("Found %d Azure DNS zone(s).", len(zones))
   185  	return zones, nil
   186  }
   187  
   188  func (p *AzureProvider) SupportedRecordType(recordType string) bool {
   189  	switch recordType {
   190  	case "MX":
   191  		return true
   192  	default:
   193  		return provider.SupportedRecordType(recordType)
   194  	}
   195  }
   196  
   197  type azureChangeMap map[string][]*endpoint.Endpoint
   198  
   199  func (p *AzureProvider) mapChanges(zones []dns.Zone, changes *plan.Changes) (azureChangeMap, azureChangeMap) {
   200  	ignored := map[string]bool{}
   201  	deleted := azureChangeMap{}
   202  	updated := azureChangeMap{}
   203  	zoneNameIDMapper := provider.ZoneIDName{}
   204  	for _, z := range zones {
   205  		if z.Name != nil {
   206  			zoneNameIDMapper.Add(*z.Name, *z.Name)
   207  		}
   208  	}
   209  	mapChange := func(changeMap azureChangeMap, change *endpoint.Endpoint) {
   210  		zone, _ := zoneNameIDMapper.FindZone(change.DNSName)
   211  		if zone == "" {
   212  			if _, ok := ignored[change.DNSName]; !ok {
   213  				ignored[change.DNSName] = true
   214  				log.Infof("Ignoring changes to '%s' because a suitable Azure DNS zone was not found.", change.DNSName)
   215  			}
   216  			return
   217  		}
   218  		// Ensure the record type is suitable
   219  		changeMap[zone] = append(changeMap[zone], change)
   220  	}
   221  
   222  	for _, change := range changes.Delete {
   223  		mapChange(deleted, change)
   224  	}
   225  
   226  	for _, change := range changes.Create {
   227  		mapChange(updated, change)
   228  	}
   229  
   230  	for _, change := range changes.UpdateNew {
   231  		mapChange(updated, change)
   232  	}
   233  	return deleted, updated
   234  }
   235  
   236  func (p *AzureProvider) deleteRecords(ctx context.Context, deleted azureChangeMap) {
   237  	// Delete records first
   238  	for zone, endpoints := range deleted {
   239  		for _, ep := range endpoints {
   240  			name := p.recordSetNameForZone(zone, ep)
   241  			if !p.domainFilter.Match(ep.DNSName) {
   242  				log.Debugf("Skipping deletion of record %s because it was filtered out by the specified --domain-filter", ep.DNSName)
   243  				continue
   244  			}
   245  			if p.dryRun {
   246  				log.Infof("Would delete %s record named '%s' for Azure DNS zone '%s'.", ep.RecordType, name, zone)
   247  			} else {
   248  				log.Infof("Deleting %s record named '%s' for Azure DNS zone '%s'.", ep.RecordType, name, zone)
   249  				if _, err := p.recordSetsClient.Delete(ctx, p.resourceGroup, zone, name, dns.RecordType(ep.RecordType), nil); err != nil {
   250  					log.Errorf(
   251  						"Failed to delete %s record named '%s' for Azure DNS zone '%s': %v",
   252  						ep.RecordType,
   253  						name,
   254  						zone,
   255  						err,
   256  					)
   257  				}
   258  			}
   259  		}
   260  	}
   261  }
   262  
   263  func (p *AzureProvider) updateRecords(ctx context.Context, updated azureChangeMap) {
   264  	for zone, endpoints := range updated {
   265  		for _, ep := range endpoints {
   266  			name := p.recordSetNameForZone(zone, ep)
   267  			if !p.domainFilter.Match(ep.DNSName) {
   268  				log.Debugf("Skipping update of record %s because it was filtered out by the specified --domain-filter", ep.DNSName)
   269  				continue
   270  			}
   271  			if p.dryRun {
   272  				log.Infof(
   273  					"Would update %s record named '%s' to '%s' for Azure DNS zone '%s'.",
   274  					ep.RecordType,
   275  					name,
   276  					ep.Targets,
   277  					zone,
   278  				)
   279  				continue
   280  			}
   281  
   282  			log.Infof(
   283  				"Updating %s record named '%s' to '%s' for Azure DNS zone '%s'.",
   284  				ep.RecordType,
   285  				name,
   286  				ep.Targets,
   287  				zone,
   288  			)
   289  
   290  			recordSet, err := p.newRecordSet(ep)
   291  			if err == nil {
   292  				_, err = p.recordSetsClient.CreateOrUpdate(
   293  					ctx,
   294  					p.resourceGroup,
   295  					zone,
   296  					name,
   297  					dns.RecordType(ep.RecordType),
   298  					recordSet,
   299  					nil,
   300  				)
   301  			}
   302  			if err != nil {
   303  				log.Errorf(
   304  					"Failed to update %s record named '%s' to '%s' for DNS zone '%s': %v",
   305  					ep.RecordType,
   306  					name,
   307  					ep.Targets,
   308  					zone,
   309  					err,
   310  				)
   311  			}
   312  		}
   313  	}
   314  }
   315  
   316  func (p *AzureProvider) recordSetNameForZone(zone string, endpoint *endpoint.Endpoint) string {
   317  	// Remove the zone from the record set
   318  	name := endpoint.DNSName
   319  	name = name[:len(name)-len(zone)]
   320  	name = strings.TrimSuffix(name, ".")
   321  
   322  	// For root, use @
   323  	if name == "" {
   324  		return "@"
   325  	}
   326  	return name
   327  }
   328  
   329  func (p *AzureProvider) newRecordSet(endpoint *endpoint.Endpoint) (dns.RecordSet, error) {
   330  	var ttl int64 = azureRecordTTL
   331  	if endpoint.RecordTTL.IsConfigured() {
   332  		ttl = int64(endpoint.RecordTTL)
   333  	}
   334  	switch dns.RecordType(endpoint.RecordType) {
   335  	case dns.RecordTypeA:
   336  		aRecords := make([]*dns.ARecord, len(endpoint.Targets))
   337  		for i, target := range endpoint.Targets {
   338  			aRecords[i] = &dns.ARecord{
   339  				IPv4Address: to.Ptr(target),
   340  			}
   341  		}
   342  		return dns.RecordSet{
   343  			Properties: &dns.RecordSetProperties{
   344  				TTL:      to.Ptr(ttl),
   345  				ARecords: aRecords,
   346  			},
   347  		}, nil
   348  	case dns.RecordTypeAAAA:
   349  		aaaaRecords := make([]*dns.AaaaRecord, len(endpoint.Targets))
   350  		for i, target := range endpoint.Targets {
   351  			aaaaRecords[i] = &dns.AaaaRecord{
   352  				IPv6Address: to.Ptr(target),
   353  			}
   354  		}
   355  		return dns.RecordSet{
   356  			Properties: &dns.RecordSetProperties{
   357  				TTL:         to.Ptr(ttl),
   358  				AaaaRecords: aaaaRecords,
   359  			},
   360  		}, nil
   361  	case dns.RecordTypeCNAME:
   362  		return dns.RecordSet{
   363  			Properties: &dns.RecordSetProperties{
   364  				TTL: to.Ptr(ttl),
   365  				CnameRecord: &dns.CnameRecord{
   366  					Cname: to.Ptr(endpoint.Targets[0]),
   367  				},
   368  			},
   369  		}, nil
   370  	case dns.RecordTypeMX:
   371  		mxRecords := make([]*dns.MxRecord, len(endpoint.Targets))
   372  		for i, target := range endpoint.Targets {
   373  			mxRecord, err := parseMxTarget[dns.MxRecord](target)
   374  			if err != nil {
   375  				return dns.RecordSet{}, err
   376  			}
   377  			mxRecords[i] = &mxRecord
   378  		}
   379  		return dns.RecordSet{
   380  			Properties: &dns.RecordSetProperties{
   381  				TTL:       to.Ptr(ttl),
   382  				MxRecords: mxRecords,
   383  			},
   384  		}, nil
   385  	case dns.RecordTypeTXT:
   386  		return dns.RecordSet{
   387  			Properties: &dns.RecordSetProperties{
   388  				TTL: to.Ptr(ttl),
   389  				TxtRecords: []*dns.TxtRecord{
   390  					{
   391  						Value: []*string{
   392  							&endpoint.Targets[0],
   393  						},
   394  					},
   395  				},
   396  			},
   397  		}, nil
   398  	}
   399  	return dns.RecordSet{}, fmt.Errorf("unsupported record type '%s'", endpoint.RecordType)
   400  }
   401  
   402  // Helper function (shared with test code)
   403  func formatAzureDNSName(recordName, zoneName string) string {
   404  	if recordName == "@" {
   405  		return zoneName
   406  	}
   407  	return fmt.Sprintf("%s.%s", recordName, zoneName)
   408  }
   409  
   410  // Helper function (shared with text code)
   411  func extractAzureTargets(recordSet *dns.RecordSet) []string {
   412  	properties := recordSet.Properties
   413  	if properties == nil {
   414  		return []string{}
   415  	}
   416  
   417  	// Check for A records
   418  	aRecords := properties.ARecords
   419  	if len(aRecords) > 0 && (aRecords)[0].IPv4Address != nil {
   420  		targets := make([]string, len(aRecords))
   421  		for i, aRecord := range aRecords {
   422  			targets[i] = *aRecord.IPv4Address
   423  		}
   424  		return targets
   425  	}
   426  
   427  	// Check for AAAA records
   428  	aaaaRecords := properties.AaaaRecords
   429  	if len(aaaaRecords) > 0 && (aaaaRecords)[0].IPv6Address != nil {
   430  		targets := make([]string, len(aaaaRecords))
   431  		for i, aaaaRecord := range aaaaRecords {
   432  			targets[i] = *aaaaRecord.IPv6Address
   433  		}
   434  		return targets
   435  	}
   436  
   437  	// Check for CNAME records
   438  	cnameRecord := properties.CnameRecord
   439  	if cnameRecord != nil && cnameRecord.Cname != nil {
   440  		return []string{*cnameRecord.Cname}
   441  	}
   442  
   443  	// Check for MX records
   444  	mxRecords := properties.MxRecords
   445  	if len(mxRecords) > 0 && (mxRecords)[0].Exchange != nil {
   446  		targets := make([]string, len(mxRecords))
   447  		for i, mxRecord := range mxRecords {
   448  			targets[i] = fmt.Sprintf("%d %s", *mxRecord.Preference, *mxRecord.Exchange)
   449  		}
   450  		return targets
   451  	}
   452  
   453  	// Check for TXT records
   454  	txtRecords := properties.TxtRecords
   455  	if len(txtRecords) > 0 && (txtRecords)[0].Value != nil {
   456  		values := (txtRecords)[0].Value
   457  		if len(values) > 0 {
   458  			return []string{*(values)[0]}
   459  		}
   460  	}
   461  	return []string{}
   462  }