github.com/teknogeek/dnscontrol/v2@v2.10.1-0.20200227202244-ae299b55ba42/providers/azuredns/azureDnsProvider.go (about) 1 package azuredns 2 3 import ( 4 "context" 5 "encoding/json" 6 "fmt" 7 "strings" 8 "time" 9 10 adns "github.com/Azure/azure-sdk-for-go/services/dns/mgmt/2018-05-01/dns" 11 aauth "github.com/Azure/go-autorest/autorest/azure/auth" 12 "github.com/Azure/go-autorest/autorest/to" 13 14 "github.com/StackExchange/dnscontrol/v2/models" 15 "github.com/StackExchange/dnscontrol/v2/providers" 16 "github.com/StackExchange/dnscontrol/v2/providers/diff" 17 ) 18 19 type azureDnsProvider struct { 20 zonesClient *adns.ZonesClient 21 recordsClient *adns.RecordSetsClient 22 zones map[string]*adns.Zone 23 resourceGroup *string 24 } 25 26 func newAzureDnsDsp(conf map[string]string, metadata json.RawMessage) (providers.DNSServiceProvider, error) { 27 return newAzureDns(conf, metadata) 28 } 29 30 func newAzureDns(m map[string]string, metadata json.RawMessage) (*azureDnsProvider, error) { 31 subId, rg := m["SubscriptionID"], m["ResourceGroup"] 32 33 zonesClient := adns.NewZonesClient(subId) 34 recordsClient := adns.NewRecordSetsClient(subId) 35 clientCredentialAuthorizer := aauth.NewClientCredentialsConfig(m["ClientID"], m["ClientSecret"], m["TenantID"]) 36 authorizer, authErr := clientCredentialAuthorizer.Authorizer() 37 38 if authErr != nil { 39 return nil, authErr 40 } 41 42 zonesClient.Authorizer = authorizer 43 recordsClient.Authorizer = authorizer 44 api := &azureDnsProvider{zonesClient: &zonesClient, recordsClient: &recordsClient, resourceGroup: to.StringPtr(rg)} 45 err := api.getZones() 46 if err != nil { 47 return nil, err 48 } 49 return api, nil 50 } 51 52 var features = providers.DocumentationNotes{ 53 providers.CanUseAlias: providers.Cannot("Only supported for Azure Resources. Not yet implemented"), 54 providers.DocCreateDomains: providers.Can(), 55 providers.DocDualHost: providers.Can("Azure does not permit modifying the existing NS records, only adding/removing additional records."), 56 providers.DocOfficiallySupported: providers.Can(), 57 providers.CanUsePTR: providers.Can(), 58 providers.CanUseSRV: providers.Can(), 59 providers.CanUseTXTMulti: providers.Can(), 60 providers.CanUseCAA: providers.Can(), 61 providers.CanUseRoute53Alias: providers.Cannot(), 62 providers.CanUseNAPTR: providers.Cannot(), 63 providers.CanUseSSHFP: providers.Cannot(), 64 providers.CanUseTLSA: providers.Cannot(), 65 providers.CanGetZones: providers.Can(), 66 } 67 68 func init() { 69 providers.RegisterDomainServiceProviderType("AZURE_DNS", newAzureDnsDsp, features) 70 } 71 72 func (a *azureDnsProvider) getExistingZones() (*adns.ZoneListResult, error) { 73 ctx, cancel := context.WithTimeout(context.Background(), 6000*time.Second) 74 defer cancel() 75 zonesIterator, zonesErr := a.zonesClient.ListByResourceGroupComplete(ctx, *a.resourceGroup, to.Int32Ptr(100)) 76 if zonesErr != nil { 77 return nil, zonesErr 78 } 79 zonesResult := zonesIterator.Response() 80 return &zonesResult, nil 81 } 82 83 func (a *azureDnsProvider) getZones() error { 84 a.zones = make(map[string]*adns.Zone) 85 86 zonesResult, err := a.getExistingZones() 87 88 if err != nil { 89 return err 90 } 91 92 for _, z := range *zonesResult.Value { 93 zone := z 94 domain := strings.TrimSuffix(*z.Name, ".") 95 a.zones[domain] = &zone 96 } 97 98 return nil 99 } 100 101 type errNoExist struct { 102 domain string 103 } 104 105 func (e errNoExist) Error() string { 106 return fmt.Sprintf("Domain %s not found in you Azure account", e.domain) 107 } 108 109 func (a *azureDnsProvider) GetNameservers(domain string) ([]*models.Nameserver, error) { 110 zone, ok := a.zones[domain] 111 if !ok { 112 return nil, errNoExist{domain} 113 } 114 115 var ns []*models.Nameserver 116 if zone.ZoneProperties != nil { 117 for _, azureNs := range *zone.ZoneProperties.NameServers { 118 ns = append(ns, &models.Nameserver{Name: azureNs}) 119 } 120 } 121 return ns, nil 122 } 123 124 func (a *azureDnsProvider) ListZones() ([]string, error) { 125 zonesResult, err := a.getExistingZones() 126 127 if err != nil { 128 return nil, err 129 } 130 131 var zones []string 132 133 for _, z := range *zonesResult.Value { 134 domain := strings.TrimSuffix(*z.Name, ".") 135 zones = append(zones, domain) 136 } 137 138 return zones, nil 139 } 140 141 // GetZoneRecords gets the records of a zone and returns them in RecordConfig format. 142 func (a *azureDnsProvider) GetZoneRecords(domain string) (models.Records, error) { 143 existingRecords, _, _, err := a.getExistingRecords(domain) 144 if err != nil { 145 return nil, err 146 } 147 return existingRecords, nil 148 } 149 150 func (a *azureDnsProvider) getExistingRecords(domain string) (models.Records, []*adns.RecordSet, string, error) { 151 zone, ok := a.zones[domain] 152 if !ok { 153 return nil, nil, "", errNoExist{domain} 154 } 155 var zoneName string 156 zoneName = *zone.Name 157 records, err := a.fetchRecordSets(zoneName) 158 if err != nil { 159 return nil, nil, "", err 160 } 161 162 var existingRecords models.Records 163 for _, set := range records { 164 existingRecords = append(existingRecords, nativeToRecords(set, zoneName)...) 165 } 166 167 models.PostProcessRecords(existingRecords) 168 return existingRecords, records, zoneName, nil 169 } 170 171 func (a *azureDnsProvider) GetDomainCorrections(dc *models.DomainConfig) ([]*models.Correction, error) { 172 err := dc.Punycode() 173 174 if err != nil { 175 return nil, err 176 } 177 178 var corrections []*models.Correction 179 180 existingRecords, records, zoneName, err := a.getExistingRecords(dc.Name) 181 if err != nil { 182 return nil, err 183 } 184 185 differ := diff.New(dc) 186 namesToUpdate := differ.ChangedGroups(existingRecords) 187 188 if len(namesToUpdate) == 0 { 189 return nil, nil 190 } 191 192 updates := map[models.RecordKey][]*models.RecordConfig{} 193 194 for k := range namesToUpdate { 195 updates[k] = nil 196 for _, rc := range dc.Records { 197 if rc.Key() == k { 198 updates[k] = append(updates[k], rc) 199 } 200 } 201 } 202 203 for k, recs := range updates { 204 if len(recs) == 0 { 205 var rrset *adns.RecordSet 206 for _, r := range records { 207 if strings.TrimSuffix(*r.RecordSetProperties.Fqdn, ".") == k.NameFQDN && nativeToRecordType(r.Type) == nativeToRecordType(to.StringPtr(k.Type)) { 208 rrset = r 209 break 210 } 211 } 212 if rrset != nil { 213 corrections = append(corrections, 214 &models.Correction{ 215 Msg: strings.Join(namesToUpdate[k], "\n"), 216 F: func() error { 217 ctx, cancel := context.WithTimeout(context.Background(), 6000*time.Second) 218 defer cancel() 219 _, err := a.recordsClient.Delete(ctx, *a.resourceGroup, zoneName, *rrset.Name, nativeToRecordType(rrset.Type), "") 220 // Artifically slow things down after a delete, as the API can take time to register it. The tests fail if we delete and then recheck too quickly. 221 time.Sleep(25 * time.Millisecond) 222 if err != nil { 223 return err 224 } 225 return nil 226 }, 227 }) 228 } else { 229 return nil, fmt.Errorf("no record set found to delete. Name: '%s'. Type: '%s'", k.NameFQDN, k.Type) 230 } 231 } else { 232 rrset, recordType := recordToNative(k, recs) 233 var recordName string 234 for _, r := range recs { 235 i := int64(r.TTL) 236 rrset.TTL = &i // TODO: make sure that ttls are consistent within a set 237 recordName = r.Name 238 } 239 240 for _, r := range records { 241 existingRecordType := nativeToRecordType(r.Type) 242 changedRecordType := nativeToRecordType(to.StringPtr(k.Type)) 243 if strings.TrimSuffix(*r.RecordSetProperties.Fqdn, ".") == k.NameFQDN && (changedRecordType == adns.CNAME || existingRecordType == adns.CNAME) { 244 if existingRecordType == adns.A || existingRecordType == adns.AAAA || changedRecordType == adns.A || changedRecordType == adns.AAAA { //CNAME cannot coexist with an A or AA 245 corrections = append(corrections, 246 &models.Correction{ 247 Msg: strings.Join(namesToUpdate[k], "\n"), 248 F: func() error { 249 ctx, cancel := context.WithTimeout(context.Background(), 6000*time.Second) 250 defer cancel() 251 _, err := a.recordsClient.Delete(ctx, *a.resourceGroup, zoneName, recordName, existingRecordType, "") 252 // Artifically slow things down after a delete, as the API can take time to register it. The tests fail if we delete and then recheck too quickly. 253 time.Sleep(25 * time.Millisecond) 254 if err != nil { 255 return err 256 } 257 return nil 258 }, 259 }) 260 } 261 } 262 } 263 264 corrections = append(corrections, 265 &models.Correction{ 266 Msg: strings.Join(namesToUpdate[k], "\n"), 267 F: func() error { 268 ctx, cancel := context.WithTimeout(context.Background(), 6000*time.Second) 269 defer cancel() 270 _, err := a.recordsClient.CreateOrUpdate(ctx, *a.resourceGroup, zoneName, recordName, recordType, *rrset, "", "") 271 // Artifically slow things down after a delete, as the API can take time to register it. The tests fail if we delete and then recheck too quickly. 272 time.Sleep(25 * time.Millisecond) 273 if err != nil { 274 return err 275 } 276 return nil 277 }, 278 }) 279 } 280 } 281 return corrections, nil 282 } 283 284 func nativeToRecordType(recordType *string) adns.RecordType { 285 recordTypeStripped := strings.TrimPrefix(*recordType, "Microsoft.Network/dnszones/") 286 switch recordTypeStripped { 287 case "A": 288 return adns.A 289 case "AAAA": 290 return adns.AAAA 291 case "CAA": 292 return adns.CAA 293 case "CNAME": 294 return adns.CNAME 295 case "MX": 296 return adns.MX 297 case "NS": 298 return adns.NS 299 case "PTR": 300 return adns.PTR 301 case "SRV": 302 return adns.SRV 303 case "TXT": 304 return adns.TXT 305 case "SOA": 306 return adns.SOA 307 default: 308 panic(fmt.Errorf("rc.String rtype %v unimplemented", *recordType)) 309 } 310 } 311 312 func nativeToRecords(set *adns.RecordSet, origin string) []*models.RecordConfig { 313 var results []*models.RecordConfig 314 switch rtype := *set.Type; rtype { 315 case "Microsoft.Network/dnszones/A": 316 if set.ARecords != nil { 317 for _, rec := range *set.ARecords { 318 rc := &models.RecordConfig{TTL: uint32(*set.TTL)} 319 rc.SetLabelFromFQDN(*set.Fqdn, origin) 320 rc.Type = "A" 321 _ = rc.SetTarget(*rec.Ipv4Address) 322 results = append(results, rc) 323 } 324 } 325 case "Microsoft.Network/dnszones/AAAA": 326 if set.AaaaRecords != nil { 327 for _, rec := range *set.AaaaRecords { 328 rc := &models.RecordConfig{TTL: uint32(*set.TTL)} 329 rc.SetLabelFromFQDN(*set.Fqdn, origin) 330 rc.Type = "AAAA" 331 _ = rc.SetTarget(*rec.Ipv6Address) 332 results = append(results, rc) 333 } 334 } 335 case "Microsoft.Network/dnszones/CNAME": 336 rc := &models.RecordConfig{TTL: uint32(*set.TTL)} 337 rc.SetLabelFromFQDN(*set.Fqdn, origin) 338 rc.Type = "CNAME" 339 _ = rc.SetTarget(*set.CnameRecord.Cname) 340 results = append(results, rc) 341 case "Microsoft.Network/dnszones/NS": 342 for _, rec := range *set.NsRecords { 343 rc := &models.RecordConfig{TTL: uint32(*set.TTL)} 344 rc.SetLabelFromFQDN(*set.Fqdn, origin) 345 rc.Type = "NS" 346 _ = rc.SetTarget(*rec.Nsdname) 347 results = append(results, rc) 348 } 349 case "Microsoft.Network/dnszones/PTR": 350 for _, rec := range *set.PtrRecords { 351 rc := &models.RecordConfig{TTL: uint32(*set.TTL)} 352 rc.SetLabelFromFQDN(*set.Fqdn, origin) 353 rc.Type = "PTR" 354 _ = rc.SetTarget(*rec.Ptrdname) 355 results = append(results, rc) 356 } 357 case "Microsoft.Network/dnszones/TXT": 358 if len(*set.TxtRecords) == 0 { // Empty String Record Parsing 359 rc := &models.RecordConfig{TTL: uint32(*set.TTL)} 360 rc.SetLabelFromFQDN(*set.Fqdn, origin) 361 rc.Type = "TXT" 362 _ = rc.SetTargetTXT("") 363 results = append(results, rc) 364 } else { 365 for _, rec := range *set.TxtRecords { 366 rc := &models.RecordConfig{TTL: uint32(*set.TTL)} 367 rc.SetLabelFromFQDN(*set.Fqdn, origin) 368 rc.Type = "TXT" 369 _ = rc.SetTargetTXTs(*rec.Value) 370 results = append(results, rc) 371 } 372 } 373 case "Microsoft.Network/dnszones/MX": 374 for _, rec := range *set.MxRecords { 375 rc := &models.RecordConfig{TTL: uint32(*set.TTL)} 376 rc.SetLabelFromFQDN(*set.Fqdn, origin) 377 rc.Type = "MX" 378 _ = rc.SetTargetMX(uint16(*rec.Preference), *rec.Exchange) 379 results = append(results, rc) 380 } 381 case "Microsoft.Network/dnszones/SRV": 382 for _, rec := range *set.SrvRecords { 383 rc := &models.RecordConfig{TTL: uint32(*set.TTL)} 384 rc.SetLabelFromFQDN(*set.Fqdn, origin) 385 rc.Type = "SRV" 386 _ = rc.SetTargetSRV(uint16(*rec.Priority), uint16(*rec.Weight), uint16(*rec.Port), *rec.Target) 387 results = append(results, rc) 388 } 389 case "Microsoft.Network/dnszones/CAA": 390 for _, rec := range *set.CaaRecords { 391 rc := &models.RecordConfig{TTL: uint32(*set.TTL)} 392 rc.SetLabelFromFQDN(*set.Fqdn, origin) 393 rc.Type = "CAA" 394 _ = rc.SetTargetCAA(uint8(*rec.Flags), *rec.Tag, *rec.Value) 395 results = append(results, rc) 396 } 397 case "Microsoft.Network/dnszones/SOA": 398 default: 399 panic(fmt.Errorf("rc.String rtype %v unimplemented", *set.Type)) 400 } 401 return results 402 } 403 404 func recordToNative(recordKey models.RecordKey, recordConfig []*models.RecordConfig) (*adns.RecordSet, adns.RecordType) { 405 recordSet := &adns.RecordSet{Type: to.StringPtr(recordKey.Type), RecordSetProperties: &adns.RecordSetProperties{}} 406 for _, rec := range recordConfig { 407 switch recordKey.Type { 408 case "A": 409 if recordSet.ARecords == nil { 410 recordSet.ARecords = &[]adns.ARecord{} 411 } 412 *recordSet.ARecords = append(*recordSet.ARecords, adns.ARecord{Ipv4Address: to.StringPtr(rec.Target)}) 413 case "AAAA": 414 if recordSet.AaaaRecords == nil { 415 recordSet.AaaaRecords = &[]adns.AaaaRecord{} 416 } 417 *recordSet.AaaaRecords = append(*recordSet.AaaaRecords, adns.AaaaRecord{Ipv6Address: to.StringPtr(rec.Target)}) 418 case "CNAME": 419 recordSet.CnameRecord = &adns.CnameRecord{Cname: to.StringPtr(rec.Target)} 420 case "NS": 421 if recordSet.NsRecords == nil { 422 recordSet.NsRecords = &[]adns.NsRecord{} 423 } 424 *recordSet.NsRecords = append(*recordSet.NsRecords, adns.NsRecord{Nsdname: to.StringPtr(rec.Target)}) 425 case "PTR": 426 if recordSet.PtrRecords == nil { 427 recordSet.PtrRecords = &[]adns.PtrRecord{} 428 } 429 *recordSet.PtrRecords = append(*recordSet.PtrRecords, adns.PtrRecord{Ptrdname: to.StringPtr(rec.Target)}) 430 case "TXT": 431 if recordSet.TxtRecords == nil { 432 recordSet.TxtRecords = &[]adns.TxtRecord{} 433 } 434 // Empty TXT record needs to have no value set in it's properties 435 if !(len(rec.TxtStrings) == 1 && rec.TxtStrings[0] == "") { 436 *recordSet.TxtRecords = append(*recordSet.TxtRecords, adns.TxtRecord{Value: &rec.TxtStrings}) 437 } 438 case "MX": 439 if recordSet.MxRecords == nil { 440 recordSet.MxRecords = &[]adns.MxRecord{} 441 } 442 *recordSet.MxRecords = append(*recordSet.MxRecords, adns.MxRecord{Exchange: to.StringPtr(rec.Target), Preference: to.Int32Ptr(int32(rec.MxPreference))}) 443 case "SRV": 444 if recordSet.SrvRecords == nil { 445 recordSet.SrvRecords = &[]adns.SrvRecord{} 446 } 447 *recordSet.SrvRecords = append(*recordSet.SrvRecords, adns.SrvRecord{Target: to.StringPtr(rec.Target), Port: to.Int32Ptr(int32(rec.SrvPort)), Weight: to.Int32Ptr(int32(rec.SrvWeight)), Priority: to.Int32Ptr(int32(rec.SrvPriority))}) 448 case "CAA": 449 if recordSet.CaaRecords == nil { 450 recordSet.CaaRecords = &[]adns.CaaRecord{} 451 } 452 *recordSet.CaaRecords = append(*recordSet.CaaRecords, adns.CaaRecord{Value: to.StringPtr(rec.Target), Tag: to.StringPtr(rec.CaaTag), Flags: to.Int32Ptr(int32(rec.CaaFlag))}) 453 default: 454 panic(fmt.Errorf("rc.String rtype %v unimplemented", recordKey.Type)) 455 } 456 } 457 return recordSet, nativeToRecordType(to.StringPtr(recordKey.Type)) 458 } 459 460 func (a *azureDnsProvider) fetchRecordSets(zoneName string) ([]*adns.RecordSet, error) { 461 if zoneName == "" { 462 return nil, nil 463 } 464 var records []*adns.RecordSet 465 ctx, cancel := context.WithTimeout(context.Background(), 6000*time.Second) 466 defer cancel() 467 recordsIterator, recordsErr := a.recordsClient.ListAllByDNSZoneComplete(ctx, *a.resourceGroup, zoneName, to.Int32Ptr(1000), "") 468 if recordsErr != nil { 469 return nil, recordsErr 470 } 471 recordsResult := recordsIterator.Response() 472 473 for _, r := range *recordsResult.Value { 474 record := r 475 records = append(records, &record) 476 } 477 478 return records, nil 479 } 480 481 func (a *azureDnsProvider) EnsureDomainExists(domain string) error { 482 if _, ok := a.zones[domain]; ok { 483 return nil 484 } 485 fmt.Printf("Adding zone for %s to Azure dns account\n", domain) 486 487 ctx, cancel := context.WithTimeout(context.Background(), 6000*time.Second) 488 defer cancel() 489 490 _, err := a.zonesClient.CreateOrUpdate(ctx, *a.resourceGroup, domain, adns.Zone{Location: to.StringPtr("global")}, "", "") 491 if err != nil { 492 return err 493 } 494 return nil 495 }