sigs.k8s.io/external-dns@v0.14.1/provider/azure/azure_private_dns.go (about)

     1  /*
     2  Copyright 2017 The Kubernetes Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  //nolint:staticcheck // Required due to the current dependency on a deprecated version of azure-sdk-for-go
    18  package azure
    19  
    20  import (
    21  	"context"
    22  	"fmt"
    23  	"strings"
    24  
    25  	azcoreruntime "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
    26  	"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
    27  	privatedns "github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/privatedns/armprivatedns"
    28  	log "github.com/sirupsen/logrus"
    29  
    30  	"sigs.k8s.io/external-dns/endpoint"
    31  	"sigs.k8s.io/external-dns/plan"
    32  	"sigs.k8s.io/external-dns/provider"
    33  )
    34  
    35  // PrivateZonesClient is an interface of privatedns.PrivateZoneClient that can be stubbed for testing.
    36  type PrivateZonesClient interface {
    37  	NewListByResourceGroupPager(resourceGroupName string, options *privatedns.PrivateZonesClientListByResourceGroupOptions) *azcoreruntime.Pager[privatedns.PrivateZonesClientListByResourceGroupResponse]
    38  }
    39  
    40  // PrivateRecordSetsClient is an interface of privatedns.RecordSetsClient that can be stubbed for testing.
    41  type PrivateRecordSetsClient interface {
    42  	NewListPager(resourceGroupName string, privateZoneName string, options *privatedns.RecordSetsClientListOptions) *azcoreruntime.Pager[privatedns.RecordSetsClientListResponse]
    43  	Delete(ctx context.Context, resourceGroupName string, privateZoneName string, recordType privatedns.RecordType, relativeRecordSetName string, options *privatedns.RecordSetsClientDeleteOptions) (privatedns.RecordSetsClientDeleteResponse, error)
    44  	CreateOrUpdate(ctx context.Context, resourceGroupName string, privateZoneName string, recordType privatedns.RecordType, relativeRecordSetName string, parameters privatedns.RecordSet, options *privatedns.RecordSetsClientCreateOrUpdateOptions) (privatedns.RecordSetsClientCreateOrUpdateResponse, error)
    45  }
    46  
    47  // AzurePrivateDNSProvider implements the DNS provider for Microsoft's Azure Private DNS service
    48  type AzurePrivateDNSProvider struct {
    49  	provider.BaseProvider
    50  	domainFilter                 endpoint.DomainFilter
    51  	zoneIDFilter                 provider.ZoneIDFilter
    52  	dryRun                       bool
    53  	resourceGroup                string
    54  	userAssignedIdentityClientID string
    55  	zonesClient                  PrivateZonesClient
    56  	recordSetsClient             PrivateRecordSetsClient
    57  }
    58  
    59  // NewAzurePrivateDNSProvider creates a new Azure Private DNS provider.
    60  //
    61  // Returns the provider or an error if a provider could not be created.
    62  func NewAzurePrivateDNSProvider(configFile string, domainFilter endpoint.DomainFilter, zoneIDFilter provider.ZoneIDFilter, subscriptionID string, resourceGroup string, userAssignedIdentityClientID string, dryRun bool) (*AzurePrivateDNSProvider, error) {
    63  	cfg, err := getConfig(configFile, subscriptionID, resourceGroup, userAssignedIdentityClientID)
    64  	if err != nil {
    65  		return nil, fmt.Errorf("failed to read Azure config file '%s': %v", configFile, err)
    66  	}
    67  	cred, clientOpts, err := getCredentials(*cfg)
    68  	if err != nil {
    69  		return nil, fmt.Errorf("failed to get credentials: %w", err)
    70  	}
    71  
    72  	zonesClient, err := privatedns.NewPrivateZonesClient(cfg.SubscriptionID, cred, clientOpts)
    73  	if err != nil {
    74  		return nil, err
    75  	}
    76  	recordSetsClient, err := privatedns.NewRecordSetsClient(cfg.SubscriptionID, cred, clientOpts)
    77  	if err != nil {
    78  		return nil, err
    79  	}
    80  	return &AzurePrivateDNSProvider{
    81  		domainFilter:                 domainFilter,
    82  		zoneIDFilter:                 zoneIDFilter,
    83  		dryRun:                       dryRun,
    84  		resourceGroup:                cfg.ResourceGroup,
    85  		userAssignedIdentityClientID: cfg.UserAssignedIdentityID,
    86  		zonesClient:                  zonesClient,
    87  		recordSetsClient:             recordSetsClient,
    88  	}, nil
    89  }
    90  
    91  // Records gets the current records.
    92  //
    93  // Returns the current records or an error if the operation failed.
    94  func (p *AzurePrivateDNSProvider) Records(ctx context.Context) (endpoints []*endpoint.Endpoint, _ error) {
    95  	zones, err := p.zones(ctx)
    96  	if err != nil {
    97  		return nil, err
    98  	}
    99  
   100  	log.Debugf("Retrieving Azure Private DNS Records for resource group '%s'", p.resourceGroup)
   101  
   102  	for _, zone := range zones {
   103  		pager := p.recordSetsClient.NewListPager(p.resourceGroup, *zone.Name, &privatedns.RecordSetsClientListOptions{Top: nil})
   104  		for pager.More() {
   105  			nextResult, err := pager.NextPage(ctx)
   106  			if err != nil {
   107  				return nil, provider.NewSoftError(fmt.Errorf("failed to fetch dns records: %w", err))
   108  			}
   109  
   110  			for _, recordSet := range nextResult.Value {
   111  				var recordType string
   112  				if recordSet.Type == nil {
   113  					log.Debugf("Skipping invalid record set with missing type.")
   114  					continue
   115  				}
   116  				recordType = strings.TrimPrefix(*recordSet.Type, "Microsoft.Network/privateDnsZones/")
   117  
   118  				var name string
   119  				if recordSet.Name == nil {
   120  					log.Debugf("Skipping invalid record set with missing name.")
   121  					continue
   122  				}
   123  				name = formatAzureDNSName(*recordSet.Name, *zone.Name)
   124  
   125  				targets := extractAzurePrivateDNSTargets(recordSet)
   126  				if len(targets) == 0 {
   127  					log.Debugf("Failed to extract targets for '%s' with type '%s'.", name, recordType)
   128  					continue
   129  				}
   130  
   131  				var ttl endpoint.TTL
   132  				if recordSet.Properties.TTL != nil {
   133  					ttl = endpoint.TTL(*recordSet.Properties.TTL)
   134  				}
   135  
   136  				ep := endpoint.NewEndpointWithTTL(name, recordType, ttl, targets...)
   137  				log.Debugf(
   138  					"Found %s record for '%s' with target '%s'.",
   139  					ep.RecordType,
   140  					ep.DNSName,
   141  					ep.Targets,
   142  				)
   143  				endpoints = append(endpoints, ep)
   144  			}
   145  		}
   146  	}
   147  
   148  	log.Debugf("Returning %d Azure Private DNS Records for resource group '%s'", len(endpoints), p.resourceGroup)
   149  
   150  	return endpoints, nil
   151  }
   152  
   153  // ApplyChanges applies the given changes.
   154  //
   155  // Returns nil if the operation was successful or an error if the operation failed.
   156  func (p *AzurePrivateDNSProvider) ApplyChanges(ctx context.Context, changes *plan.Changes) error {
   157  	log.Debugf("Received %d changes to process", len(changes.Create)+len(changes.Delete)+len(changes.UpdateNew)+len(changes.UpdateOld))
   158  
   159  	zones, err := p.zones(ctx)
   160  	if err != nil {
   161  		return err
   162  	}
   163  
   164  	deleted, updated := p.mapChanges(zones, changes)
   165  	p.deleteRecords(ctx, deleted)
   166  	p.updateRecords(ctx, updated)
   167  	return nil
   168  }
   169  
   170  func (p *AzurePrivateDNSProvider) zones(ctx context.Context) ([]privatedns.PrivateZone, error) {
   171  	log.Debugf("Retrieving Azure Private DNS zones for Resource Group '%s'", p.resourceGroup)
   172  
   173  	var zones []privatedns.PrivateZone
   174  
   175  	pager := p.zonesClient.NewListByResourceGroupPager(p.resourceGroup, &privatedns.PrivateZonesClientListByResourceGroupOptions{Top: nil})
   176  	for pager.More() {
   177  		nextResult, err := pager.NextPage(ctx)
   178  		if err != nil {
   179  			return nil, err
   180  		}
   181  		for _, zone := range nextResult.Value {
   182  			log.Debugf("Validating Zone: %v", *zone.Name)
   183  
   184  			if zone.Name != nil && p.domainFilter.Match(*zone.Name) && p.zoneIDFilter.Match(*zone.ID) {
   185  				zones = append(zones, *zone)
   186  			}
   187  		}
   188  	}
   189  
   190  	log.Debugf("Found %d Azure Private DNS zone(s).", len(zones))
   191  	return zones, nil
   192  }
   193  
   194  type azurePrivateDNSChangeMap map[string][]*endpoint.Endpoint
   195  
   196  func (p *AzurePrivateDNSProvider) mapChanges(zones []privatedns.PrivateZone, changes *plan.Changes) (azurePrivateDNSChangeMap, azurePrivateDNSChangeMap) {
   197  	ignored := map[string]bool{}
   198  	deleted := azurePrivateDNSChangeMap{}
   199  	updated := azurePrivateDNSChangeMap{}
   200  	zoneNameIDMapper := provider.ZoneIDName{}
   201  	for _, z := range zones {
   202  		if z.Name != nil {
   203  			zoneNameIDMapper.Add(*z.Name, *z.Name)
   204  		}
   205  	}
   206  	mapChange := func(changeMap azurePrivateDNSChangeMap, change *endpoint.Endpoint) {
   207  		zone, _ := zoneNameIDMapper.FindZone(change.DNSName)
   208  		if zone == "" {
   209  			if _, ok := ignored[change.DNSName]; !ok {
   210  				ignored[change.DNSName] = true
   211  				log.Infof("Ignoring changes to '%s' because a suitable Azure Private DNS zone was not found.", change.DNSName)
   212  			}
   213  			return
   214  		}
   215  		// Ensure the record type is suitable
   216  		changeMap[zone] = append(changeMap[zone], change)
   217  	}
   218  
   219  	for _, change := range changes.Delete {
   220  		mapChange(deleted, change)
   221  	}
   222  
   223  	for _, change := range changes.Create {
   224  		mapChange(updated, change)
   225  	}
   226  
   227  	for _, change := range changes.UpdateNew {
   228  		mapChange(updated, change)
   229  	}
   230  	return deleted, updated
   231  }
   232  
   233  func (p *AzurePrivateDNSProvider) deleteRecords(ctx context.Context, deleted azurePrivateDNSChangeMap) {
   234  	log.Debugf("Records to be deleted: %d", len(deleted))
   235  	// Delete records first
   236  	for zone, endpoints := range deleted {
   237  		for _, ep := range endpoints {
   238  			name := p.recordSetNameForZone(zone, ep)
   239  			if p.dryRun {
   240  				log.Infof("Would delete %s record named '%s' for Azure Private DNS zone '%s'.", ep.RecordType, name, zone)
   241  			} else {
   242  				log.Infof("Deleting %s record named '%s' for Azure Private DNS zone '%s'.", ep.RecordType, name, zone)
   243  				if _, err := p.recordSetsClient.Delete(ctx, p.resourceGroup, zone, privatedns.RecordType(ep.RecordType), name, nil); err != nil {
   244  					log.Errorf(
   245  						"Failed to delete %s record named '%s' for Azure Private DNS zone '%s': %v",
   246  						ep.RecordType,
   247  						name,
   248  						zone,
   249  						err,
   250  					)
   251  				}
   252  			}
   253  		}
   254  	}
   255  }
   256  
   257  func (p *AzurePrivateDNSProvider) updateRecords(ctx context.Context, updated azurePrivateDNSChangeMap) {
   258  	log.Debugf("Records to be updated: %d", len(updated))
   259  	for zone, endpoints := range updated {
   260  		for _, ep := range endpoints {
   261  			name := p.recordSetNameForZone(zone, ep)
   262  			if p.dryRun {
   263  				log.Infof(
   264  					"Would update %s record named '%s' to '%s' for Azure Private DNS zone '%s'.",
   265  					ep.RecordType,
   266  					name,
   267  					ep.Targets,
   268  					zone,
   269  				)
   270  				continue
   271  			}
   272  
   273  			log.Infof(
   274  				"Updating %s record named '%s' to '%s' for Azure Private DNS zone '%s'.",
   275  				ep.RecordType,
   276  				name,
   277  				ep.Targets,
   278  				zone,
   279  			)
   280  
   281  			recordSet, err := p.newRecordSet(ep)
   282  			if err == nil {
   283  				_, err = p.recordSetsClient.CreateOrUpdate(
   284  					ctx,
   285  					p.resourceGroup,
   286  					zone,
   287  					privatedns.RecordType(ep.RecordType),
   288  					name,
   289  					recordSet,
   290  					nil,
   291  				)
   292  			}
   293  			if err != nil {
   294  				log.Errorf(
   295  					"Failed to update %s record named '%s' to '%s' for Azure Private DNS zone '%s': %v",
   296  					ep.RecordType,
   297  					name,
   298  					ep.Targets,
   299  					zone,
   300  					err,
   301  				)
   302  			}
   303  		}
   304  	}
   305  }
   306  
   307  func (p *AzurePrivateDNSProvider) recordSetNameForZone(zone string, endpoint *endpoint.Endpoint) string {
   308  	// Remove the zone from the record set
   309  	name := endpoint.DNSName
   310  	name = name[:len(name)-len(zone)]
   311  	name = strings.TrimSuffix(name, ".")
   312  
   313  	// For root, use @
   314  	if name == "" {
   315  		return "@"
   316  	}
   317  	return name
   318  }
   319  
   320  func (p *AzurePrivateDNSProvider) newRecordSet(endpoint *endpoint.Endpoint) (privatedns.RecordSet, error) {
   321  	var ttl int64 = azureRecordTTL
   322  	if endpoint.RecordTTL.IsConfigured() {
   323  		ttl = int64(endpoint.RecordTTL)
   324  	}
   325  	switch privatedns.RecordType(endpoint.RecordType) {
   326  	case privatedns.RecordTypeA:
   327  		aRecords := make([]*privatedns.ARecord, len(endpoint.Targets))
   328  		for i, target := range endpoint.Targets {
   329  			aRecords[i] = &privatedns.ARecord{
   330  				IPv4Address: to.Ptr(target),
   331  			}
   332  		}
   333  		return privatedns.RecordSet{
   334  			Properties: &privatedns.RecordSetProperties{
   335  				TTL:      to.Ptr(ttl),
   336  				ARecords: aRecords,
   337  			},
   338  		}, nil
   339  	case privatedns.RecordTypeAAAA:
   340  		aaaaRecords := make([]*privatedns.AaaaRecord, len(endpoint.Targets))
   341  		for i, target := range endpoint.Targets {
   342  			aaaaRecords[i] = &privatedns.AaaaRecord{
   343  				IPv6Address: to.Ptr(target),
   344  			}
   345  		}
   346  		return privatedns.RecordSet{
   347  			Properties: &privatedns.RecordSetProperties{
   348  				TTL:         to.Ptr(ttl),
   349  				AaaaRecords: aaaaRecords,
   350  			},
   351  		}, nil
   352  	case privatedns.RecordTypeCNAME:
   353  		return privatedns.RecordSet{
   354  			Properties: &privatedns.RecordSetProperties{
   355  				TTL: to.Ptr(ttl),
   356  				CnameRecord: &privatedns.CnameRecord{
   357  					Cname: to.Ptr(endpoint.Targets[0]),
   358  				},
   359  			},
   360  		}, nil
   361  	case privatedns.RecordTypeMX:
   362  		mxRecords := make([]*privatedns.MxRecord, len(endpoint.Targets))
   363  		for i, target := range endpoint.Targets {
   364  			mxRecord, err := parseMxTarget[privatedns.MxRecord](target)
   365  			if err != nil {
   366  				return privatedns.RecordSet{}, err
   367  			}
   368  			mxRecords[i] = &mxRecord
   369  		}
   370  		return privatedns.RecordSet{
   371  			Properties: &privatedns.RecordSetProperties{
   372  				TTL:       to.Ptr(ttl),
   373  				MxRecords: mxRecords,
   374  			},
   375  		}, nil
   376  	case privatedns.RecordTypeTXT:
   377  		return privatedns.RecordSet{
   378  			Properties: &privatedns.RecordSetProperties{
   379  				TTL: to.Ptr(ttl),
   380  				TxtRecords: []*privatedns.TxtRecord{
   381  					{
   382  						Value: []*string{
   383  							&endpoint.Targets[0],
   384  						},
   385  					},
   386  				},
   387  			},
   388  		}, nil
   389  	}
   390  	return privatedns.RecordSet{}, fmt.Errorf("unsupported record type '%s'", endpoint.RecordType)
   391  }
   392  
   393  // Helper function (shared with test code)
   394  func extractAzurePrivateDNSTargets(recordSet *privatedns.RecordSet) []string {
   395  	properties := recordSet.Properties
   396  	if properties == nil {
   397  		return []string{}
   398  	}
   399  
   400  	// Check for A records
   401  	aRecords := properties.ARecords
   402  	if len(aRecords) > 0 && (aRecords)[0].IPv4Address != nil {
   403  		targets := make([]string, len(aRecords))
   404  		for i, aRecord := range aRecords {
   405  			targets[i] = *aRecord.IPv4Address
   406  		}
   407  		return targets
   408  	}
   409  
   410  	// Check for AAAA records
   411  	aaaaRecords := properties.AaaaRecords
   412  	if len(aaaaRecords) > 0 && (aaaaRecords)[0].IPv6Address != nil {
   413  		targets := make([]string, len(aaaaRecords))
   414  		for i, aaaaRecord := range aaaaRecords {
   415  			targets[i] = *aaaaRecord.IPv6Address
   416  		}
   417  		return targets
   418  	}
   419  
   420  	// Check for CNAME records
   421  	cnameRecord := properties.CnameRecord
   422  	if cnameRecord != nil && cnameRecord.Cname != nil {
   423  		return []string{*cnameRecord.Cname}
   424  	}
   425  
   426  	// Check for MX records
   427  	mxRecords := properties.MxRecords
   428  	if len(mxRecords) > 0 && (mxRecords)[0].Exchange != nil {
   429  		targets := make([]string, len(mxRecords))
   430  		for i, mxRecord := range mxRecords {
   431  			targets[i] = fmt.Sprintf("%d %s", *mxRecord.Preference, *mxRecord.Exchange)
   432  		}
   433  		return targets
   434  	}
   435  
   436  	// Check for TXT records
   437  	txtRecords := properties.TxtRecords
   438  	if len(txtRecords) > 0 && (txtRecords)[0].Value != nil {
   439  		values := (txtRecords)[0].Value
   440  		if len(values) > 0 {
   441  			return []string{*(values)[0]}
   442  		}
   443  	}
   444  	return []string{}
   445  }