github.com/teknogeek/dnscontrol/v2@v2.10.1-0.20200227202244-ae299b55ba42/providers/azuredns/azureDnsProvider.go (about)

     1  package azuredns
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"fmt"
     7  	"strings"
     8  	"time"
     9  
    10  	adns "github.com/Azure/azure-sdk-for-go/services/dns/mgmt/2018-05-01/dns"
    11  	aauth "github.com/Azure/go-autorest/autorest/azure/auth"
    12  	"github.com/Azure/go-autorest/autorest/to"
    13  
    14  	"github.com/StackExchange/dnscontrol/v2/models"
    15  	"github.com/StackExchange/dnscontrol/v2/providers"
    16  	"github.com/StackExchange/dnscontrol/v2/providers/diff"
    17  )
    18  
    19  type azureDnsProvider struct {
    20  	zonesClient   *adns.ZonesClient
    21  	recordsClient *adns.RecordSetsClient
    22  	zones         map[string]*adns.Zone
    23  	resourceGroup *string
    24  }
    25  
    26  func newAzureDnsDsp(conf map[string]string, metadata json.RawMessage) (providers.DNSServiceProvider, error) {
    27  	return newAzureDns(conf, metadata)
    28  }
    29  
    30  func newAzureDns(m map[string]string, metadata json.RawMessage) (*azureDnsProvider, error) {
    31  	subId, rg := m["SubscriptionID"], m["ResourceGroup"]
    32  
    33  	zonesClient := adns.NewZonesClient(subId)
    34  	recordsClient := adns.NewRecordSetsClient(subId)
    35  	clientCredentialAuthorizer := aauth.NewClientCredentialsConfig(m["ClientID"], m["ClientSecret"], m["TenantID"])
    36  	authorizer, authErr := clientCredentialAuthorizer.Authorizer()
    37  
    38  	if authErr != nil {
    39  		return nil, authErr
    40  	}
    41  
    42  	zonesClient.Authorizer = authorizer
    43  	recordsClient.Authorizer = authorizer
    44  	api := &azureDnsProvider{zonesClient: &zonesClient, recordsClient: &recordsClient, resourceGroup: to.StringPtr(rg)}
    45  	err := api.getZones()
    46  	if err != nil {
    47  		return nil, err
    48  	}
    49  	return api, nil
    50  }
    51  
    52  var features = providers.DocumentationNotes{
    53  	providers.CanUseAlias:            providers.Cannot("Only supported for Azure Resources. Not yet implemented"),
    54  	providers.DocCreateDomains:       providers.Can(),
    55  	providers.DocDualHost:            providers.Can("Azure does not permit modifying the existing NS records, only adding/removing additional records."),
    56  	providers.DocOfficiallySupported: providers.Can(),
    57  	providers.CanUsePTR:              providers.Can(),
    58  	providers.CanUseSRV:              providers.Can(),
    59  	providers.CanUseTXTMulti:         providers.Can(),
    60  	providers.CanUseCAA:              providers.Can(),
    61  	providers.CanUseRoute53Alias:     providers.Cannot(),
    62  	providers.CanUseNAPTR:            providers.Cannot(),
    63  	providers.CanUseSSHFP:            providers.Cannot(),
    64  	providers.CanUseTLSA:             providers.Cannot(),
    65  	providers.CanGetZones:            providers.Can(),
    66  }
    67  
    68  func init() {
    69  	providers.RegisterDomainServiceProviderType("AZURE_DNS", newAzureDnsDsp, features)
    70  }
    71  
    72  func (a *azureDnsProvider) getExistingZones() (*adns.ZoneListResult, error) {
    73  	ctx, cancel := context.WithTimeout(context.Background(), 6000*time.Second)
    74  	defer cancel()
    75  	zonesIterator, zonesErr := a.zonesClient.ListByResourceGroupComplete(ctx, *a.resourceGroup, to.Int32Ptr(100))
    76  	if zonesErr != nil {
    77  		return nil, zonesErr
    78  	}
    79  	zonesResult := zonesIterator.Response()
    80  	return &zonesResult, nil
    81  }
    82  
    83  func (a *azureDnsProvider) getZones() error {
    84  	a.zones = make(map[string]*adns.Zone)
    85  
    86  	zonesResult, err := a.getExistingZones()
    87  
    88  	if err != nil {
    89  		return err
    90  	}
    91  
    92  	for _, z := range *zonesResult.Value {
    93  		zone := z
    94  		domain := strings.TrimSuffix(*z.Name, ".")
    95  		a.zones[domain] = &zone
    96  	}
    97  
    98  	return nil
    99  }
   100  
   101  type errNoExist struct {
   102  	domain string
   103  }
   104  
   105  func (e errNoExist) Error() string {
   106  	return fmt.Sprintf("Domain %s not found in you Azure account", e.domain)
   107  }
   108  
   109  func (a *azureDnsProvider) GetNameservers(domain string) ([]*models.Nameserver, error) {
   110  	zone, ok := a.zones[domain]
   111  	if !ok {
   112  		return nil, errNoExist{domain}
   113  	}
   114  
   115  	var ns []*models.Nameserver
   116  	if zone.ZoneProperties != nil {
   117  		for _, azureNs := range *zone.ZoneProperties.NameServers {
   118  			ns = append(ns, &models.Nameserver{Name: azureNs})
   119  		}
   120  	}
   121  	return ns, nil
   122  }
   123  
   124  func (a *azureDnsProvider) ListZones() ([]string, error) {
   125  	zonesResult, err := a.getExistingZones()
   126  
   127  	if err != nil {
   128  		return nil, err
   129  	}
   130  
   131  	var zones []string
   132  
   133  	for _, z := range *zonesResult.Value {
   134  		domain := strings.TrimSuffix(*z.Name, ".")
   135  		zones = append(zones, domain)
   136  	}
   137  
   138  	return zones, nil
   139  }
   140  
   141  // GetZoneRecords gets the records of a zone and returns them in RecordConfig format.
   142  func (a *azureDnsProvider) GetZoneRecords(domain string) (models.Records, error) {
   143  	existingRecords, _, _, err := a.getExistingRecords(domain)
   144  	if err != nil {
   145  		return nil, err
   146  	}
   147  	return existingRecords, nil
   148  }
   149  
   150  func (a *azureDnsProvider) getExistingRecords(domain string) (models.Records, []*adns.RecordSet, string, error) {
   151  	zone, ok := a.zones[domain]
   152  	if !ok {
   153  		return nil, nil, "", errNoExist{domain}
   154  	}
   155  	var zoneName string
   156  	zoneName = *zone.Name
   157  	records, err := a.fetchRecordSets(zoneName)
   158  	if err != nil {
   159  		return nil, nil, "", err
   160  	}
   161  
   162  	var existingRecords models.Records
   163  	for _, set := range records {
   164  		existingRecords = append(existingRecords, nativeToRecords(set, zoneName)...)
   165  	}
   166  
   167  	models.PostProcessRecords(existingRecords)
   168  	return existingRecords, records, zoneName, nil
   169  }
   170  
   171  func (a *azureDnsProvider) GetDomainCorrections(dc *models.DomainConfig) ([]*models.Correction, error) {
   172  	err := dc.Punycode()
   173  
   174  	if err != nil {
   175  		return nil, err
   176  	}
   177  
   178  	var corrections []*models.Correction
   179  
   180  	existingRecords, records, zoneName, err := a.getExistingRecords(dc.Name)
   181  	if err != nil {
   182  		return nil, err
   183  	}
   184  
   185  	differ := diff.New(dc)
   186  	namesToUpdate := differ.ChangedGroups(existingRecords)
   187  
   188  	if len(namesToUpdate) == 0 {
   189  		return nil, nil
   190  	}
   191  
   192  	updates := map[models.RecordKey][]*models.RecordConfig{}
   193  
   194  	for k := range namesToUpdate {
   195  		updates[k] = nil
   196  		for _, rc := range dc.Records {
   197  			if rc.Key() == k {
   198  				updates[k] = append(updates[k], rc)
   199  			}
   200  		}
   201  	}
   202  
   203  	for k, recs := range updates {
   204  		if len(recs) == 0 {
   205  			var rrset *adns.RecordSet
   206  			for _, r := range records {
   207  				if strings.TrimSuffix(*r.RecordSetProperties.Fqdn, ".") == k.NameFQDN && nativeToRecordType(r.Type) == nativeToRecordType(to.StringPtr(k.Type)) {
   208  					rrset = r
   209  					break
   210  				}
   211  			}
   212  			if rrset != nil {
   213  				corrections = append(corrections,
   214  					&models.Correction{
   215  						Msg: strings.Join(namesToUpdate[k], "\n"),
   216  						F: func() error {
   217  							ctx, cancel := context.WithTimeout(context.Background(), 6000*time.Second)
   218  							defer cancel()
   219  							_, err := a.recordsClient.Delete(ctx, *a.resourceGroup, zoneName, *rrset.Name, nativeToRecordType(rrset.Type), "")
   220  							// Artifically slow things down after a delete, as the API can take time to register it. The tests fail if we delete and then recheck too quickly.
   221  							time.Sleep(25 * time.Millisecond)
   222  							if err != nil {
   223  								return err
   224  							}
   225  							return nil
   226  						},
   227  					})
   228  			} else {
   229  				return nil, fmt.Errorf("no record set found to delete. Name: '%s'. Type: '%s'", k.NameFQDN, k.Type)
   230  			}
   231  		} else {
   232  			rrset, recordType := recordToNative(k, recs)
   233  			var recordName string
   234  			for _, r := range recs {
   235  				i := int64(r.TTL)
   236  				rrset.TTL = &i // TODO: make sure that ttls are consistent within a set
   237  				recordName = r.Name
   238  			}
   239  
   240  			for _, r := range records {
   241  				existingRecordType := nativeToRecordType(r.Type)
   242  				changedRecordType := nativeToRecordType(to.StringPtr(k.Type))
   243  				if strings.TrimSuffix(*r.RecordSetProperties.Fqdn, ".") == k.NameFQDN && (changedRecordType == adns.CNAME || existingRecordType == adns.CNAME) {
   244  					if existingRecordType == adns.A || existingRecordType == adns.AAAA || changedRecordType == adns.A || changedRecordType == adns.AAAA { //CNAME cannot coexist with an A or AA
   245  						corrections = append(corrections,
   246  							&models.Correction{
   247  								Msg: strings.Join(namesToUpdate[k], "\n"),
   248  								F: func() error {
   249  									ctx, cancel := context.WithTimeout(context.Background(), 6000*time.Second)
   250  									defer cancel()
   251  									_, err := a.recordsClient.Delete(ctx, *a.resourceGroup, zoneName, recordName, existingRecordType, "")
   252  									// Artifically slow things down after a delete, as the API can take time to register it. The tests fail if we delete and then recheck too quickly.
   253  									time.Sleep(25 * time.Millisecond)
   254  									if err != nil {
   255  										return err
   256  									}
   257  									return nil
   258  								},
   259  							})
   260  					}
   261  				}
   262  			}
   263  
   264  			corrections = append(corrections,
   265  				&models.Correction{
   266  					Msg: strings.Join(namesToUpdate[k], "\n"),
   267  					F: func() error {
   268  						ctx, cancel := context.WithTimeout(context.Background(), 6000*time.Second)
   269  						defer cancel()
   270  						_, err := a.recordsClient.CreateOrUpdate(ctx, *a.resourceGroup, zoneName, recordName, recordType, *rrset, "", "")
   271  						// Artifically slow things down after a delete, as the API can take time to register it. The tests fail if we delete and then recheck too quickly.
   272  						time.Sleep(25 * time.Millisecond)
   273  						if err != nil {
   274  							return err
   275  						}
   276  						return nil
   277  					},
   278  				})
   279  		}
   280  	}
   281  	return corrections, nil
   282  }
   283  
   284  func nativeToRecordType(recordType *string) adns.RecordType {
   285  	recordTypeStripped := strings.TrimPrefix(*recordType, "Microsoft.Network/dnszones/")
   286  	switch recordTypeStripped {
   287  	case "A":
   288  		return adns.A
   289  	case "AAAA":
   290  		return adns.AAAA
   291  	case "CAA":
   292  		return adns.CAA
   293  	case "CNAME":
   294  		return adns.CNAME
   295  	case "MX":
   296  		return adns.MX
   297  	case "NS":
   298  		return adns.NS
   299  	case "PTR":
   300  		return adns.PTR
   301  	case "SRV":
   302  		return adns.SRV
   303  	case "TXT":
   304  		return adns.TXT
   305  	case "SOA":
   306  		return adns.SOA
   307  	default:
   308  		panic(fmt.Errorf("rc.String rtype %v unimplemented", *recordType))
   309  	}
   310  }
   311  
   312  func nativeToRecords(set *adns.RecordSet, origin string) []*models.RecordConfig {
   313  	var results []*models.RecordConfig
   314  	switch rtype := *set.Type; rtype {
   315  	case "Microsoft.Network/dnszones/A":
   316  		if set.ARecords != nil {
   317  			for _, rec := range *set.ARecords {
   318  				rc := &models.RecordConfig{TTL: uint32(*set.TTL)}
   319  				rc.SetLabelFromFQDN(*set.Fqdn, origin)
   320  				rc.Type = "A"
   321  				_ = rc.SetTarget(*rec.Ipv4Address)
   322  				results = append(results, rc)
   323  			}
   324  		}
   325  	case "Microsoft.Network/dnszones/AAAA":
   326  		if set.AaaaRecords != nil {
   327  			for _, rec := range *set.AaaaRecords {
   328  				rc := &models.RecordConfig{TTL: uint32(*set.TTL)}
   329  				rc.SetLabelFromFQDN(*set.Fqdn, origin)
   330  				rc.Type = "AAAA"
   331  				_ = rc.SetTarget(*rec.Ipv6Address)
   332  				results = append(results, rc)
   333  			}
   334  		}
   335  	case "Microsoft.Network/dnszones/CNAME":
   336  		rc := &models.RecordConfig{TTL: uint32(*set.TTL)}
   337  		rc.SetLabelFromFQDN(*set.Fqdn, origin)
   338  		rc.Type = "CNAME"
   339  		_ = rc.SetTarget(*set.CnameRecord.Cname)
   340  		results = append(results, rc)
   341  	case "Microsoft.Network/dnszones/NS":
   342  		for _, rec := range *set.NsRecords {
   343  			rc := &models.RecordConfig{TTL: uint32(*set.TTL)}
   344  			rc.SetLabelFromFQDN(*set.Fqdn, origin)
   345  			rc.Type = "NS"
   346  			_ = rc.SetTarget(*rec.Nsdname)
   347  			results = append(results, rc)
   348  		}
   349  	case "Microsoft.Network/dnszones/PTR":
   350  		for _, rec := range *set.PtrRecords {
   351  			rc := &models.RecordConfig{TTL: uint32(*set.TTL)}
   352  			rc.SetLabelFromFQDN(*set.Fqdn, origin)
   353  			rc.Type = "PTR"
   354  			_ = rc.SetTarget(*rec.Ptrdname)
   355  			results = append(results, rc)
   356  		}
   357  	case "Microsoft.Network/dnszones/TXT":
   358  		if len(*set.TxtRecords) == 0 { // Empty String Record Parsing
   359  			rc := &models.RecordConfig{TTL: uint32(*set.TTL)}
   360  			rc.SetLabelFromFQDN(*set.Fqdn, origin)
   361  			rc.Type = "TXT"
   362  			_ = rc.SetTargetTXT("")
   363  			results = append(results, rc)
   364  		} else {
   365  			for _, rec := range *set.TxtRecords {
   366  				rc := &models.RecordConfig{TTL: uint32(*set.TTL)}
   367  				rc.SetLabelFromFQDN(*set.Fqdn, origin)
   368  				rc.Type = "TXT"
   369  				_ = rc.SetTargetTXTs(*rec.Value)
   370  				results = append(results, rc)
   371  			}
   372  		}
   373  	case "Microsoft.Network/dnszones/MX":
   374  		for _, rec := range *set.MxRecords {
   375  			rc := &models.RecordConfig{TTL: uint32(*set.TTL)}
   376  			rc.SetLabelFromFQDN(*set.Fqdn, origin)
   377  			rc.Type = "MX"
   378  			_ = rc.SetTargetMX(uint16(*rec.Preference), *rec.Exchange)
   379  			results = append(results, rc)
   380  		}
   381  	case "Microsoft.Network/dnszones/SRV":
   382  		for _, rec := range *set.SrvRecords {
   383  			rc := &models.RecordConfig{TTL: uint32(*set.TTL)}
   384  			rc.SetLabelFromFQDN(*set.Fqdn, origin)
   385  			rc.Type = "SRV"
   386  			_ = rc.SetTargetSRV(uint16(*rec.Priority), uint16(*rec.Weight), uint16(*rec.Port), *rec.Target)
   387  			results = append(results, rc)
   388  		}
   389  	case "Microsoft.Network/dnszones/CAA":
   390  		for _, rec := range *set.CaaRecords {
   391  			rc := &models.RecordConfig{TTL: uint32(*set.TTL)}
   392  			rc.SetLabelFromFQDN(*set.Fqdn, origin)
   393  			rc.Type = "CAA"
   394  			_ = rc.SetTargetCAA(uint8(*rec.Flags), *rec.Tag, *rec.Value)
   395  			results = append(results, rc)
   396  		}
   397  	case "Microsoft.Network/dnszones/SOA":
   398  	default:
   399  		panic(fmt.Errorf("rc.String rtype %v unimplemented", *set.Type))
   400  	}
   401  	return results
   402  }
   403  
   404  func recordToNative(recordKey models.RecordKey, recordConfig []*models.RecordConfig) (*adns.RecordSet, adns.RecordType) {
   405  	recordSet := &adns.RecordSet{Type: to.StringPtr(recordKey.Type), RecordSetProperties: &adns.RecordSetProperties{}}
   406  	for _, rec := range recordConfig {
   407  		switch recordKey.Type {
   408  		case "A":
   409  			if recordSet.ARecords == nil {
   410  				recordSet.ARecords = &[]adns.ARecord{}
   411  			}
   412  			*recordSet.ARecords = append(*recordSet.ARecords, adns.ARecord{Ipv4Address: to.StringPtr(rec.Target)})
   413  		case "AAAA":
   414  			if recordSet.AaaaRecords == nil {
   415  				recordSet.AaaaRecords = &[]adns.AaaaRecord{}
   416  			}
   417  			*recordSet.AaaaRecords = append(*recordSet.AaaaRecords, adns.AaaaRecord{Ipv6Address: to.StringPtr(rec.Target)})
   418  		case "CNAME":
   419  			recordSet.CnameRecord = &adns.CnameRecord{Cname: to.StringPtr(rec.Target)}
   420  		case "NS":
   421  			if recordSet.NsRecords == nil {
   422  				recordSet.NsRecords = &[]adns.NsRecord{}
   423  			}
   424  			*recordSet.NsRecords = append(*recordSet.NsRecords, adns.NsRecord{Nsdname: to.StringPtr(rec.Target)})
   425  		case "PTR":
   426  			if recordSet.PtrRecords == nil {
   427  				recordSet.PtrRecords = &[]adns.PtrRecord{}
   428  			}
   429  			*recordSet.PtrRecords = append(*recordSet.PtrRecords, adns.PtrRecord{Ptrdname: to.StringPtr(rec.Target)})
   430  		case "TXT":
   431  			if recordSet.TxtRecords == nil {
   432  				recordSet.TxtRecords = &[]adns.TxtRecord{}
   433  			}
   434  			// Empty TXT record needs to have no value set in it's properties
   435  			if !(len(rec.TxtStrings) == 1 && rec.TxtStrings[0] == "") {
   436  				*recordSet.TxtRecords = append(*recordSet.TxtRecords, adns.TxtRecord{Value: &rec.TxtStrings})
   437  			}
   438  		case "MX":
   439  			if recordSet.MxRecords == nil {
   440  				recordSet.MxRecords = &[]adns.MxRecord{}
   441  			}
   442  			*recordSet.MxRecords = append(*recordSet.MxRecords, adns.MxRecord{Exchange: to.StringPtr(rec.Target), Preference: to.Int32Ptr(int32(rec.MxPreference))})
   443  		case "SRV":
   444  			if recordSet.SrvRecords == nil {
   445  				recordSet.SrvRecords = &[]adns.SrvRecord{}
   446  			}
   447  			*recordSet.SrvRecords = append(*recordSet.SrvRecords, adns.SrvRecord{Target: to.StringPtr(rec.Target), Port: to.Int32Ptr(int32(rec.SrvPort)), Weight: to.Int32Ptr(int32(rec.SrvWeight)), Priority: to.Int32Ptr(int32(rec.SrvPriority))})
   448  		case "CAA":
   449  			if recordSet.CaaRecords == nil {
   450  				recordSet.CaaRecords = &[]adns.CaaRecord{}
   451  			}
   452  			*recordSet.CaaRecords = append(*recordSet.CaaRecords, adns.CaaRecord{Value: to.StringPtr(rec.Target), Tag: to.StringPtr(rec.CaaTag), Flags: to.Int32Ptr(int32(rec.CaaFlag))})
   453  		default:
   454  			panic(fmt.Errorf("rc.String rtype %v unimplemented", recordKey.Type))
   455  		}
   456  	}
   457  	return recordSet, nativeToRecordType(to.StringPtr(recordKey.Type))
   458  }
   459  
   460  func (a *azureDnsProvider) fetchRecordSets(zoneName string) ([]*adns.RecordSet, error) {
   461  	if zoneName == "" {
   462  		return nil, nil
   463  	}
   464  	var records []*adns.RecordSet
   465  	ctx, cancel := context.WithTimeout(context.Background(), 6000*time.Second)
   466  	defer cancel()
   467  	recordsIterator, recordsErr := a.recordsClient.ListAllByDNSZoneComplete(ctx, *a.resourceGroup, zoneName, to.Int32Ptr(1000), "")
   468  	if recordsErr != nil {
   469  		return nil, recordsErr
   470  	}
   471  	recordsResult := recordsIterator.Response()
   472  
   473  	for _, r := range *recordsResult.Value {
   474  		record := r
   475  		records = append(records, &record)
   476  	}
   477  
   478  	return records, nil
   479  }
   480  
   481  func (a *azureDnsProvider) EnsureDomainExists(domain string) error {
   482  	if _, ok := a.zones[domain]; ok {
   483  		return nil
   484  	}
   485  	fmt.Printf("Adding zone for %s to Azure dns account\n", domain)
   486  
   487  	ctx, cancel := context.WithTimeout(context.Background(), 6000*time.Second)
   488  	defer cancel()
   489  
   490  	_, err := a.zonesClient.CreateOrUpdate(ctx, *a.resourceGroup, domain, adns.Zone{Location: to.StringPtr("global")}, "", "")
   491  	if err != nil {
   492  		return err
   493  	}
   494  	return nil
   495  }