sigs.k8s.io/external-dns@v0.14.1/provider/transip/transip.go (about)

     1  /*
     2  Copyright 2017 The Kubernetes Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package transip
    18  
    19  import (
    20  	"context"
    21  	"errors"
    22  	"fmt"
    23  	"strings"
    24  
    25  	log "github.com/sirupsen/logrus"
    26  	"github.com/transip/gotransip/v6"
    27  	"github.com/transip/gotransip/v6/domain"
    28  
    29  	"sigs.k8s.io/external-dns/endpoint"
    30  	"sigs.k8s.io/external-dns/plan"
    31  	"sigs.k8s.io/external-dns/provider"
    32  )
    33  
    34  const (
    35  	// 60 seconds is the current minimal TTL for TransIP and will replace unconfigured
    36  	// TTL's for Endpoints
    37  	transipMinimalValidTTL = 60
    38  )
    39  
    40  // TransIPProvider is an implementation of Provider for TransIP.
    41  type TransIPProvider struct {
    42  	provider.BaseProvider
    43  	domainRepo   domain.Repository
    44  	domainFilter endpoint.DomainFilter
    45  	dryRun       bool
    46  
    47  	zoneMap provider.ZoneIDName
    48  }
    49  
    50  // NewTransIPProvider initializes a new TransIP Provider.
    51  func NewTransIPProvider(accountName, privateKeyFile string, domainFilter endpoint.DomainFilter, dryRun bool) (*TransIPProvider, error) {
    52  	// check given arguments
    53  	if accountName == "" {
    54  		return nil, errors.New("required --transip-account not set")
    55  	}
    56  
    57  	if privateKeyFile == "" {
    58  		return nil, errors.New("required --transip-keyfile not set")
    59  	}
    60  
    61  	var apiMode gotransip.APIMode
    62  	if dryRun {
    63  		apiMode = gotransip.APIModeReadOnly
    64  	} else {
    65  		apiMode = gotransip.APIModeReadWrite
    66  	}
    67  
    68  	// create new TransIP API client
    69  	client, err := gotransip.NewClient(gotransip.ClientConfiguration{
    70  		AccountName:    accountName,
    71  		PrivateKeyPath: privateKeyFile,
    72  		Mode:           apiMode,
    73  	})
    74  	if err != nil {
    75  		return nil, fmt.Errorf("could not setup TransIP API client: %w", err)
    76  	}
    77  
    78  	// return TransIPProvider struct
    79  	return &TransIPProvider{
    80  		domainRepo:   domain.Repository{Client: client},
    81  		domainFilter: domainFilter,
    82  		dryRun:       dryRun,
    83  		zoneMap:      provider.ZoneIDName{},
    84  	}, nil
    85  }
    86  
    87  // ApplyChanges applies a given set of changes in a given zone.
    88  func (p *TransIPProvider) ApplyChanges(ctx context.Context, changes *plan.Changes) error {
    89  	// fetch all zones we currently have
    90  	// this does NOT include any DNS entries, so we'll have to fetch these for
    91  	// each zone that gets updated
    92  	zones, err := p.domainRepo.GetAll()
    93  	if err != nil {
    94  		return err
    95  	}
    96  
    97  	// refresh zone mapping
    98  	zoneMap := provider.ZoneIDName{}
    99  	for _, zone := range zones {
   100  		// TransIP API doesn't expose a unique identifier for zones, other than than
   101  		// the domain name itself
   102  		zoneMap.Add(zone.Name, zone.Name)
   103  	}
   104  	p.zoneMap = zoneMap
   105  
   106  	// first remove obsolete DNS records
   107  	for _, ep := range changes.Delete {
   108  		epLog := log.WithFields(log.Fields{
   109  			"record": ep.DNSName,
   110  			"type":   ep.RecordType,
   111  		})
   112  		epLog.Info("endpoint has to go")
   113  
   114  		zoneName, entries, err := p.entriesForEndpoint(ep)
   115  		if err != nil {
   116  			epLog.WithError(err).Error("could not get DNS entries")
   117  			return err
   118  		}
   119  
   120  		epLog = epLog.WithField("zone", zoneName)
   121  
   122  		if len(entries) == 0 {
   123  			epLog.Info("no matching entries found")
   124  			continue
   125  		}
   126  
   127  		if p.dryRun {
   128  			epLog.Info("not removing DNS entries in dry-run mode")
   129  			continue
   130  		}
   131  
   132  		for _, entry := range entries {
   133  			log.WithFields(log.Fields{
   134  				"domain":  zoneName,
   135  				"name":    entry.Name,
   136  				"type":    entry.Type,
   137  				"content": entry.Content,
   138  				"ttl":     entry.Expire,
   139  			}).Info("removing DNS entry")
   140  
   141  			err = p.domainRepo.RemoveDNSEntry(zoneName, entry)
   142  			if err != nil {
   143  				epLog.WithError(err).Error("could not remove DNS entry")
   144  				return err
   145  			}
   146  		}
   147  	}
   148  
   149  	// then create new DNS records
   150  	for _, ep := range changes.Create {
   151  		epLog := log.WithFields(log.Fields{
   152  			"record": ep.DNSName,
   153  			"type":   ep.RecordType,
   154  		})
   155  		epLog.Info("endpoint should be created")
   156  
   157  		zoneName, err := p.zoneNameForDNSName(ep.DNSName)
   158  		if err != nil {
   159  			epLog.WithError(err).Warn("could not find zone for endpoint")
   160  			continue
   161  		}
   162  
   163  		epLog = epLog.WithField("zone", zoneName)
   164  
   165  		if p.dryRun {
   166  			epLog.Info("not adding DNS entries in dry-run mode")
   167  			continue
   168  		}
   169  
   170  		for _, entry := range dnsEntriesForEndpoint(ep, zoneName) {
   171  			log.WithFields(log.Fields{
   172  				"domain":  zoneName,
   173  				"name":    entry.Name,
   174  				"type":    entry.Type,
   175  				"content": entry.Content,
   176  				"ttl":     entry.Expire,
   177  			}).Info("creating DNS entry")
   178  
   179  			err = p.domainRepo.AddDNSEntry(zoneName, entry)
   180  			if err != nil {
   181  				epLog.WithError(err).Error("could not add DNS entry")
   182  				return err
   183  			}
   184  		}
   185  	}
   186  
   187  	// then update existing DNS records
   188  	for _, ep := range changes.UpdateNew {
   189  		epLog := log.WithFields(log.Fields{
   190  			"record": ep.DNSName,
   191  			"type":   ep.RecordType,
   192  		})
   193  		epLog.Debug("endpoint needs updating")
   194  
   195  		zoneName, entries, err := p.entriesForEndpoint(ep)
   196  		if err != nil {
   197  			epLog.WithError(err).Error("could not get DNS entries")
   198  			return err
   199  		}
   200  
   201  		epLog = epLog.WithField("zone", zoneName)
   202  
   203  		if len(entries) == 0 {
   204  			epLog.Info("no matching entries found")
   205  			continue
   206  		}
   207  
   208  		newEntries := dnsEntriesForEndpoint(ep, zoneName)
   209  
   210  		// check to see if actually anything changed in the DNSEntry set
   211  		if dnsEntriesAreEqual(newEntries, entries) {
   212  			epLog.Debug("not updating identical DNS entries")
   213  			continue
   214  		}
   215  
   216  		if p.dryRun {
   217  			epLog.Info("not updating DNS entries in dry-run mode")
   218  			continue
   219  		}
   220  
   221  		// TransIP API client does have an UpdateDNSEntry call but that does only
   222  		// allow you to update the content of a DNSEntry, not the TTL
   223  		// to work around this, remove the old entry first and add the new entry
   224  		for _, entry := range entries {
   225  			log.WithFields(log.Fields{
   226  				"domain":  zoneName,
   227  				"name":    entry.Name,
   228  				"type":    entry.Type,
   229  				"content": entry.Content,
   230  				"ttl":     entry.Expire,
   231  			}).Info("removing DNS entry")
   232  
   233  			err = p.domainRepo.RemoveDNSEntry(zoneName, entry)
   234  			if err != nil {
   235  				epLog.WithError(err).Error("could not remove DNS entry")
   236  				return err
   237  			}
   238  		}
   239  
   240  		for _, entry := range newEntries {
   241  			log.WithFields(log.Fields{
   242  				"domain":  zoneName,
   243  				"name":    entry.Name,
   244  				"type":    entry.Type,
   245  				"content": entry.Content,
   246  				"ttl":     entry.Expire,
   247  			}).Info("adding DNS entry")
   248  
   249  			err = p.domainRepo.AddDNSEntry(zoneName, entry)
   250  			if err != nil {
   251  				epLog.WithError(err).Error("could not add DNS entry")
   252  				return err
   253  			}
   254  		}
   255  	}
   256  
   257  	return nil
   258  }
   259  
   260  // Records returns the list of records in all zones
   261  func (p *TransIPProvider) Records(ctx context.Context) ([]*endpoint.Endpoint, error) {
   262  	zones, err := p.domainRepo.GetAll()
   263  	if err != nil {
   264  		return nil, err
   265  	}
   266  
   267  	var endpoints []*endpoint.Endpoint
   268  	// go over all zones and their DNS entries and create endpoints for them
   269  	for _, zone := range zones {
   270  		entries, err := p.domainRepo.GetDNSEntries(zone.Name)
   271  		if err != nil {
   272  			return nil, err
   273  		}
   274  
   275  		for _, r := range entries {
   276  			if !provider.SupportedRecordType(r.Type) {
   277  				continue
   278  			}
   279  
   280  			name := endpointNameForRecord(r, zone.Name)
   281  			endpoints = append(endpoints, endpoint.NewEndpointWithTTL(name, r.Type, endpoint.TTL(r.Expire), r.Content))
   282  		}
   283  	}
   284  
   285  	return endpoints, nil
   286  }
   287  
   288  func (p *TransIPProvider) entriesForEndpoint(ep *endpoint.Endpoint) (string, []domain.DNSEntry, error) {
   289  	zoneName, err := p.zoneNameForDNSName(ep.DNSName)
   290  	if err != nil {
   291  		return "", nil, err
   292  	}
   293  
   294  	epName := recordNameForEndpoint(ep, zoneName)
   295  	dnsEntries, err := p.domainRepo.GetDNSEntries(zoneName)
   296  	if err != nil {
   297  		return zoneName, nil, err
   298  	}
   299  
   300  	matches := []domain.DNSEntry{}
   301  	for _, entry := range dnsEntries {
   302  		if ep.RecordType != entry.Type {
   303  			continue
   304  		}
   305  
   306  		if entry.Name == epName {
   307  			matches = append(matches, entry)
   308  		}
   309  	}
   310  
   311  	return zoneName, matches, nil
   312  }
   313  
   314  // endpointNameForRecord returns "www.example.org" for DNSEntry with Name "www" and
   315  // Domain with Name "example.org"
   316  func endpointNameForRecord(r domain.DNSEntry, zoneName string) string {
   317  	// root name is identified by "@" and should be translated to domain name for
   318  	// the endpoint entry.
   319  	if r.Name == "@" {
   320  		return zoneName
   321  	}
   322  
   323  	return fmt.Sprintf("%s.%s", r.Name, zoneName)
   324  }
   325  
   326  // recordNameForEndpoint returns "www" for Endpoint with DNSName "www.example.org"
   327  // and Domain with Name "example.org"
   328  func recordNameForEndpoint(ep *endpoint.Endpoint, zoneName string) string {
   329  	// root name is identified by "@" and should be translated to domain name for
   330  	// the endpoint entry.
   331  	if ep.DNSName == zoneName {
   332  		return "@"
   333  	}
   334  
   335  	return strings.TrimSuffix(ep.DNSName, "."+zoneName)
   336  }
   337  
   338  // getMinimalValidTTL returns max between given Endpoint's RecordTTL and
   339  // transipMinimalValidTTL
   340  func getMinimalValidTTL(ep *endpoint.Endpoint) int {
   341  	// TTL cannot be lower than transipMinimalValidTTL
   342  	if ep.RecordTTL < transipMinimalValidTTL {
   343  		return transipMinimalValidTTL
   344  	}
   345  
   346  	return int(ep.RecordTTL)
   347  }
   348  
   349  // dnsEntriesAreEqual compares the entries in 2 sets and returns true if the
   350  // content of the entries is equal
   351  func dnsEntriesAreEqual(a, b []domain.DNSEntry) bool {
   352  	if len(a) != len(b) {
   353  		return false
   354  	}
   355  
   356  	match := 0
   357  	for _, aa := range a {
   358  		for _, bb := range b {
   359  			if aa.Content != bb.Content {
   360  				continue
   361  			}
   362  
   363  			if aa.Name != bb.Name {
   364  				continue
   365  			}
   366  
   367  			if aa.Expire != bb.Expire {
   368  				continue
   369  			}
   370  
   371  			if aa.Type != bb.Type {
   372  				continue
   373  			}
   374  
   375  			match++
   376  		}
   377  	}
   378  
   379  	return (len(a) == match)
   380  }
   381  
   382  // dnsEntriesForEndpoint creates DNS entries for given endpoint and returns
   383  // resulting DNS entry set
   384  func dnsEntriesForEndpoint(ep *endpoint.Endpoint, zoneName string) []domain.DNSEntry {
   385  	ttl := getMinimalValidTTL(ep)
   386  
   387  	entries := []domain.DNSEntry{}
   388  	for _, target := range ep.Targets {
   389  		// external hostnames require a trailing dot in TransIP API
   390  		if ep.RecordType == "CNAME" {
   391  			target = provider.EnsureTrailingDot(target)
   392  		}
   393  
   394  		entries = append(entries, domain.DNSEntry{
   395  			Name:    recordNameForEndpoint(ep, zoneName),
   396  			Expire:  ttl,
   397  			Type:    ep.RecordType,
   398  			Content: target,
   399  		})
   400  	}
   401  
   402  	return entries
   403  }
   404  
   405  // zoneForZoneName returns the zone mapped to given name or error if zone could
   406  // not be found
   407  func (p *TransIPProvider) zoneNameForDNSName(name string) (string, error) {
   408  	_, zoneName := p.zoneMap.FindZone(name)
   409  	if zoneName == "" {
   410  		return "", fmt.Errorf("could not find zoneName for %s", name)
   411  	}
   412  
   413  	return zoneName, nil
   414  }