sigs.k8s.io/external-dns@v0.14.1/provider/ovh/ovh.go (about)

     1  /*
     2  Copyright 2020 The Kubernetes Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package ovh
    18  
    19  import (
    20  	"context"
    21  	"errors"
    22  	"fmt"
    23  	"strings"
    24  	"time"
    25  
    26  	"github.com/miekg/dns"
    27  	"github.com/ovh/go-ovh/ovh"
    28  	"github.com/patrickmn/go-cache"
    29  	log "github.com/sirupsen/logrus"
    30  	"golang.org/x/sync/errgroup"
    31  
    32  	"sigs.k8s.io/external-dns/endpoint"
    33  	"sigs.k8s.io/external-dns/pkg/apis/externaldns"
    34  	"sigs.k8s.io/external-dns/plan"
    35  	"sigs.k8s.io/external-dns/provider"
    36  
    37  	"go.uber.org/ratelimit"
    38  )
    39  
    40  const (
    41  	ovhDefaultTTL = 0
    42  	ovhCreate     = iota
    43  	ovhDelete
    44  )
    45  
    46  var (
    47  	// ErrRecordToMutateNotFound when ApplyChange has to update/delete and didn't found the record in the existing zone (Change with no record ID)
    48  	ErrRecordToMutateNotFound = errors.New("record to mutate not found in current zone")
    49  	// ErrNoDryRun No dry run support for the moment
    50  	ErrNoDryRun = errors.New("dry run not supported")
    51  )
    52  
    53  // OVHProvider is an implementation of Provider for OVH DNS.
    54  type OVHProvider struct {
    55  	provider.BaseProvider
    56  
    57  	client ovhClient
    58  
    59  	apiRateLimiter ratelimit.Limiter
    60  
    61  	domainFilter endpoint.DomainFilter
    62  	DryRun       bool
    63  
    64  	// UseCache controls if the OVHProvider will cache records in memory, and serve them
    65  	// without recontacting the OVHcloud API if the SOA of the domain zone hasn't changed.
    66  	// Note that, when disabling cache, OVHcloud API has rate-limiting that will hit if
    67  	// your refresh rate/number of records is too big, which might cause issue with the
    68  	// provider.
    69  	// Default value: true
    70  	UseCache bool
    71  
    72  	cacheInstance *cache.Cache
    73  	dnsClient     dnsClient
    74  }
    75  
    76  type ovhClient interface {
    77  	Post(string, interface{}, interface{}) error
    78  	Get(string, interface{}) error
    79  	Delete(string, interface{}) error
    80  }
    81  
    82  type dnsClient interface {
    83  	ExchangeContext(ctx context.Context, m *dns.Msg, a string) (*dns.Msg, time.Duration, error)
    84  }
    85  
    86  type ovhRecordFields struct {
    87  	FieldType string `json:"fieldType"`
    88  	SubDomain string `json:"subDomain"`
    89  	TTL       int64  `json:"ttl"`
    90  	Target    string `json:"target"`
    91  }
    92  
    93  type ovhRecord struct {
    94  	ovhRecordFields
    95  	ID   uint64 `json:"id"`
    96  	Zone string `json:"zone"`
    97  }
    98  
    99  type ovhChange struct {
   100  	ovhRecord
   101  	Action int
   102  }
   103  
   104  // NewOVHProvider initializes a new OVH DNS based Provider.
   105  func NewOVHProvider(ctx context.Context, domainFilter endpoint.DomainFilter, endpoint string, apiRateLimit int, dryRun bool) (*OVHProvider, error) {
   106  	client, err := ovh.NewEndpointClient(endpoint)
   107  	if err != nil {
   108  		return nil, err
   109  	}
   110  
   111  	client.UserAgent = externaldns.Version
   112  
   113  	// TODO: Add Dry Run support
   114  	if dryRun {
   115  		return nil, ErrNoDryRun
   116  	}
   117  	return &OVHProvider{
   118  		client:         client,
   119  		domainFilter:   domainFilter,
   120  		apiRateLimiter: ratelimit.New(apiRateLimit),
   121  		DryRun:         dryRun,
   122  		cacheInstance:  cache.New(cache.NoExpiration, cache.NoExpiration),
   123  		dnsClient:      new(dns.Client),
   124  		UseCache:       true,
   125  	}, nil
   126  }
   127  
   128  // Records returns the list of records in all relevant zones.
   129  func (p *OVHProvider) Records(ctx context.Context) ([]*endpoint.Endpoint, error) {
   130  	_, records, err := p.zonesRecords(ctx)
   131  	if err != nil {
   132  		return nil, err
   133  	}
   134  	endpoints := ovhGroupByNameAndType(records)
   135  	log.Infof("OVH: %d endpoints have been found", len(endpoints))
   136  	return endpoints, nil
   137  }
   138  
   139  // ApplyChanges applies a given set of changes in a given zone.
   140  func (p *OVHProvider) ApplyChanges(ctx context.Context, changes *plan.Changes) error {
   141  	zones, records, err := p.zonesRecords(ctx)
   142  	zonesChangeUniques := map[string]bool{}
   143  	if err != nil {
   144  		return err
   145  	}
   146  
   147  	allChanges := make([]ovhChange, 0, countTargets(changes.Create, changes.UpdateNew, changes.UpdateOld, changes.Delete))
   148  
   149  	allChanges = append(allChanges, newOvhChange(ovhCreate, changes.Create, zones, records)...)
   150  	allChanges = append(allChanges, newOvhChange(ovhCreate, changes.UpdateNew, zones, records)...)
   151  
   152  	allChanges = append(allChanges, newOvhChange(ovhDelete, changes.UpdateOld, zones, records)...)
   153  	allChanges = append(allChanges, newOvhChange(ovhDelete, changes.Delete, zones, records)...)
   154  
   155  	log.Infof("OVH: %d changes will be done", len(allChanges))
   156  
   157  	eg, _ := errgroup.WithContext(ctx)
   158  	for _, change := range allChanges {
   159  		change := change
   160  		zonesChangeUniques[change.Zone] = true
   161  		eg.Go(func() error { return p.change(change) })
   162  	}
   163  	if err := eg.Wait(); err != nil {
   164  		return err
   165  	}
   166  
   167  	log.Infof("OVH: %d zones will be refreshed", len(zonesChangeUniques))
   168  
   169  	eg, _ = errgroup.WithContext(ctx)
   170  	for zone := range zonesChangeUniques {
   171  		zone := zone
   172  		eg.Go(func() error { return p.refresh(zone) })
   173  	}
   174  	if err := eg.Wait(); err != nil {
   175  		return err
   176  	}
   177  	return nil
   178  }
   179  
   180  func (p *OVHProvider) refresh(zone string) error {
   181  	log.Debugf("OVH: Refresh %s zone", zone)
   182  
   183  	p.apiRateLimiter.Take()
   184  	return p.client.Post(fmt.Sprintf("/domain/zone/%s/refresh", zone), nil, nil)
   185  }
   186  
   187  func (p *OVHProvider) change(change ovhChange) error {
   188  	p.apiRateLimiter.Take()
   189  
   190  	switch change.Action {
   191  	case ovhCreate:
   192  		log.Debugf("OVH: Add an entry to %s", change.String())
   193  		return p.client.Post(fmt.Sprintf("/domain/zone/%s/record", change.Zone), change.ovhRecordFields, nil)
   194  	case ovhDelete:
   195  		if change.ID == 0 {
   196  			return ErrRecordToMutateNotFound
   197  		}
   198  		log.Debugf("OVH: Delete an entry to %s", change.String())
   199  		return p.client.Delete(fmt.Sprintf("/domain/zone/%s/record/%d", change.Zone, change.ID), nil)
   200  	}
   201  	return nil
   202  }
   203  
   204  func (p *OVHProvider) zonesRecords(ctx context.Context) ([]string, []ovhRecord, error) {
   205  	var allRecords []ovhRecord
   206  	zones, err := p.zones()
   207  	if err != nil {
   208  		return nil, nil, err
   209  	}
   210  
   211  	chRecords := make(chan []ovhRecord, len(zones))
   212  	eg, ctx := errgroup.WithContext(ctx)
   213  	for _, zone := range zones {
   214  		zone := zone
   215  		eg.Go(func() error { return p.records(&ctx, &zone, chRecords) })
   216  	}
   217  	if err := eg.Wait(); err != nil {
   218  		return nil, nil, err
   219  	}
   220  	close(chRecords)
   221  	for records := range chRecords {
   222  		allRecords = append(allRecords, records...)
   223  	}
   224  	return zones, allRecords, nil
   225  }
   226  
   227  func (p *OVHProvider) zones() ([]string, error) {
   228  	zones := []string{}
   229  	filteredZones := []string{}
   230  
   231  	p.apiRateLimiter.Take()
   232  	if err := p.client.Get("/domain/zone", &zones); err != nil {
   233  		return nil, err
   234  	}
   235  
   236  	for _, zoneName := range zones {
   237  		if p.domainFilter.Match(zoneName) {
   238  			filteredZones = append(filteredZones, zoneName)
   239  		}
   240  	}
   241  	log.Infof("OVH: %d zones found", len(filteredZones))
   242  	return filteredZones, nil
   243  }
   244  
   245  type ovhSoa struct {
   246  	Server  string `json:"server"`
   247  	Serial  uint32 `json:"serial"`
   248  	records []ovhRecord
   249  }
   250  
   251  func (p *OVHProvider) records(ctx *context.Context, zone *string, records chan<- []ovhRecord) error {
   252  	var recordsIds []uint64
   253  	ovhRecords := make([]ovhRecord, len(recordsIds))
   254  	eg, _ := errgroup.WithContext(*ctx)
   255  
   256  	if p.UseCache {
   257  		if cachedSoaItf, ok := p.cacheInstance.Get(*zone + "#soa"); ok {
   258  			cachedSoa := cachedSoaItf.(ovhSoa)
   259  
   260  			m := new(dns.Msg)
   261  			m.SetQuestion(dns.Fqdn(*zone), dns.TypeSOA)
   262  			in, _, err := p.dnsClient.ExchangeContext(*ctx, m, strings.TrimSuffix(cachedSoa.Server, ".")+":53")
   263  			if err == nil {
   264  				if s, ok := in.Answer[0].(*dns.SOA); ok {
   265  					// do something with t.Txt
   266  					if s.Serial == cachedSoa.Serial {
   267  						records <- cachedSoa.records
   268  						return nil
   269  					}
   270  				}
   271  			}
   272  
   273  			p.cacheInstance.Delete(*zone + "#soa")
   274  		}
   275  	}
   276  
   277  	log.Debugf("OVH: Getting records for %s", *zone)
   278  
   279  	p.apiRateLimiter.Take()
   280  	var soa ovhSoa
   281  	if p.UseCache {
   282  		if err := p.client.Get("/domain/zone/"+*zone+"/soa", &soa); err != nil {
   283  			return err
   284  		}
   285  	}
   286  
   287  	if err := p.client.Get(fmt.Sprintf("/domain/zone/%s/record", *zone), &recordsIds); err != nil {
   288  		return err
   289  	}
   290  	chRecords := make(chan ovhRecord, len(recordsIds))
   291  	for _, id := range recordsIds {
   292  		id := id
   293  		eg.Go(func() error { return p.record(zone, id, chRecords) })
   294  	}
   295  	if err := eg.Wait(); err != nil {
   296  		return err
   297  	}
   298  	close(chRecords)
   299  	for record := range chRecords {
   300  		ovhRecords = append(ovhRecords, record)
   301  	}
   302  
   303  	if p.UseCache {
   304  		soa.records = ovhRecords
   305  		_ = p.cacheInstance.Add(*zone+"#soa", soa, time.Hour)
   306  	}
   307  
   308  	records <- ovhRecords
   309  	return nil
   310  }
   311  
   312  func (p *OVHProvider) record(zone *string, id uint64, records chan<- ovhRecord) error {
   313  	record := ovhRecord{}
   314  
   315  	log.Debugf("OVH: Getting record %d for %s", id, *zone)
   316  
   317  	p.apiRateLimiter.Take()
   318  	if err := p.client.Get(fmt.Sprintf("/domain/zone/%s/record/%d", *zone, id), &record); err != nil {
   319  		return err
   320  	}
   321  	if provider.SupportedRecordType(record.FieldType) {
   322  		log.Debugf("OVH: Record %d for %s is %+v", id, *zone, record)
   323  		records <- record
   324  	}
   325  	return nil
   326  }
   327  
   328  func ovhGroupByNameAndType(records []ovhRecord) []*endpoint.Endpoint {
   329  	endpoints := []*endpoint.Endpoint{}
   330  
   331  	// group supported records by name and type
   332  	groups := map[string][]ovhRecord{}
   333  
   334  	for _, r := range records {
   335  		groupBy := r.Zone + r.SubDomain + r.FieldType
   336  		if _, ok := groups[groupBy]; !ok {
   337  			groups[groupBy] = []ovhRecord{}
   338  		}
   339  
   340  		groups[groupBy] = append(groups[groupBy], r)
   341  	}
   342  
   343  	// create single endpoint with all the targets for each name/type
   344  	for _, records := range groups {
   345  		targets := []string{}
   346  		for _, record := range records {
   347  			targets = append(targets, record.Target)
   348  		}
   349  		endpoint := endpoint.NewEndpointWithTTL(
   350  			strings.TrimPrefix(records[0].SubDomain+"."+records[0].Zone, "."),
   351  			records[0].FieldType,
   352  			endpoint.TTL(records[0].TTL),
   353  			targets...,
   354  		)
   355  		endpoints = append(endpoints, endpoint)
   356  	}
   357  
   358  	return endpoints
   359  }
   360  
   361  func newOvhChange(action int, endpoints []*endpoint.Endpoint, zones []string, records []ovhRecord) []ovhChange {
   362  	zoneNameIDMapper := provider.ZoneIDName{}
   363  	ovhChanges := make([]ovhChange, 0, countTargets(endpoints))
   364  	for _, zone := range zones {
   365  		zoneNameIDMapper.Add(zone, zone)
   366  	}
   367  
   368  	for _, e := range endpoints {
   369  		zone, _ := zoneNameIDMapper.FindZone(e.DNSName)
   370  		if zone == "" {
   371  			log.Debugf("Skipping record %s because no hosted zone matching record DNS Name was detected", e.DNSName)
   372  			continue
   373  		}
   374  		for _, target := range e.Targets {
   375  			if e.RecordType == endpoint.RecordTypeCNAME {
   376  				target = target + "."
   377  			}
   378  			change := ovhChange{
   379  				Action: action,
   380  				ovhRecord: ovhRecord{
   381  					Zone: zone,
   382  					ovhRecordFields: ovhRecordFields{
   383  						FieldType: e.RecordType,
   384  						SubDomain: strings.TrimSuffix(e.DNSName, "."+zone),
   385  						TTL:       ovhDefaultTTL,
   386  						Target:    target,
   387  					},
   388  				},
   389  			}
   390  			if e.RecordTTL.IsConfigured() {
   391  				change.TTL = int64(e.RecordTTL)
   392  			}
   393  			for _, record := range records {
   394  				if record.Zone == change.Zone && record.SubDomain == change.SubDomain && record.FieldType == change.FieldType && record.Target == change.Target {
   395  					change.ID = record.ID
   396  				}
   397  			}
   398  			ovhChanges = append(ovhChanges, change)
   399  		}
   400  	}
   401  
   402  	return ovhChanges
   403  }
   404  
   405  func countTargets(allEndpoints ...[]*endpoint.Endpoint) int {
   406  	count := 0
   407  	for _, endpoints := range allEndpoints {
   408  		for _, endpoint := range endpoints {
   409  			count += len(endpoint.Targets)
   410  		}
   411  	}
   412  	return count
   413  }
   414  
   415  func (c *ovhChange) String() string {
   416  	if c.ID != 0 {
   417  		return fmt.Sprintf("%s zone (ID : %d) : %s %d IN %s %s", c.Zone, c.ID, c.SubDomain, c.TTL, c.FieldType, c.Target)
   418  	}
   419  	return fmt.Sprintf("%s zone : %s %d IN %s %s", c.Zone, c.SubDomain, c.TTL, c.FieldType, c.Target)
   420  }