sigs.k8s.io/external-dns@v0.14.1/provider/oci/oci.go (about)

     1  /*
     2  Copyright 2018 The Kubernetes Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package oci
    18  
    19  import (
    20  	"context"
    21  	"fmt"
    22  	"os"
    23  	"strings"
    24  	"time"
    25  
    26  	"github.com/oracle/oci-go-sdk/v65/common"
    27  	"github.com/oracle/oci-go-sdk/v65/common/auth"
    28  	"github.com/oracle/oci-go-sdk/v65/dns"
    29  	"github.com/pkg/errors"
    30  	log "github.com/sirupsen/logrus"
    31  	yaml "gopkg.in/yaml.v2"
    32  
    33  	"sigs.k8s.io/external-dns/endpoint"
    34  	"sigs.k8s.io/external-dns/plan"
    35  	"sigs.k8s.io/external-dns/provider"
    36  )
    37  
    38  const ociRecordTTL = 300
    39  
    40  // OCIAuthConfig holds connection parameters for the OCI API.
    41  type OCIAuthConfig struct {
    42  	Region               string `yaml:"region"`
    43  	TenancyID            string `yaml:"tenancy"`
    44  	UserID               string `yaml:"user"`
    45  	PrivateKey           string `yaml:"key"`
    46  	Fingerprint          string `yaml:"fingerprint"`
    47  	Passphrase           string `yaml:"passphrase"`
    48  	UseInstancePrincipal bool   `yaml:"useInstancePrincipal"`
    49  	UseWorkloadIdentity  bool   `yaml:"useWorkloadIdentity"`
    50  }
    51  
    52  // OCIConfig holds the configuration for the OCI Provider.
    53  type OCIConfig struct {
    54  	Auth              OCIAuthConfig `yaml:"auth"`
    55  	CompartmentID     string        `yaml:"compartment"`
    56  	ZoneCacheDuration time.Duration
    57  }
    58  
    59  // OCIProvider is an implementation of Provider for Oracle Cloud Infrastructure
    60  // (OCI) DNS.
    61  type OCIProvider struct {
    62  	provider.BaseProvider
    63  	client ociDNSClient
    64  	cfg    OCIConfig
    65  
    66  	domainFilter endpoint.DomainFilter
    67  	zoneIDFilter provider.ZoneIDFilter
    68  	zoneScope    string
    69  	zoneCache    *zoneCache
    70  	dryRun       bool
    71  }
    72  
    73  // ociDNSClient is the subset of the OCI DNS API required by the OCI Provider.
    74  type ociDNSClient interface {
    75  	ListZones(ctx context.Context, request dns.ListZonesRequest) (response dns.ListZonesResponse, err error)
    76  	GetZoneRecords(ctx context.Context, request dns.GetZoneRecordsRequest) (response dns.GetZoneRecordsResponse, err error)
    77  	PatchZoneRecords(ctx context.Context, request dns.PatchZoneRecordsRequest) (response dns.PatchZoneRecordsResponse, err error)
    78  }
    79  
    80  // LoadOCIConfig reads and parses the OCI ExternalDNS config file at the given
    81  // path.
    82  func LoadOCIConfig(path string) (*OCIConfig, error) {
    83  	contents, err := os.ReadFile(path)
    84  	if err != nil {
    85  		return nil, fmt.Errorf("reading OCI config file %q: %w", path, err)
    86  	}
    87  
    88  	cfg := OCIConfig{}
    89  	if err := yaml.Unmarshal(contents, &cfg); err != nil {
    90  		return nil, fmt.Errorf("parsing OCI config file %q: %w", path, err)
    91  	}
    92  	return &cfg, nil
    93  }
    94  
    95  // NewOCIProvider initializes a new OCI DNS based Provider.
    96  func NewOCIProvider(cfg OCIConfig, domainFilter endpoint.DomainFilter, zoneIDFilter provider.ZoneIDFilter, zoneScope string, dryRun bool) (*OCIProvider, error) {
    97  	var client ociDNSClient
    98  	var err error
    99  	var configProvider common.ConfigurationProvider
   100  	if cfg.Auth.UseInstancePrincipal && cfg.Auth.UseWorkloadIdentity {
   101  		return nil, errors.New("only one of 'useInstancePrincipal' and 'useWorkloadIdentity' may be enabled for Oracle authentication")
   102  	}
   103  	if cfg.Auth.UseWorkloadIdentity {
   104  		// OCI SDK requires specific, dynamic environment variables for workload identity.
   105  		if err := os.Setenv(auth.ResourcePrincipalVersionEnvVar, auth.ResourcePrincipalVersion2_2); err != nil {
   106  			return nil, fmt.Errorf("unable to set OCI SDK environment variable: %s: %w", auth.ResourcePrincipalVersionEnvVar, err)
   107  		}
   108  		if err := os.Setenv(auth.ResourcePrincipalRegionEnvVar, cfg.Auth.Region); err != nil {
   109  			return nil, fmt.Errorf("unable to set OCI SDK environment variable: %s: %w", auth.ResourcePrincipalRegionEnvVar, err)
   110  		}
   111  		configProvider, err = auth.OkeWorkloadIdentityConfigurationProvider()
   112  		if err != nil {
   113  			return nil, fmt.Errorf("error creating OCI workload identity config provider: %w", err)
   114  		}
   115  	} else if cfg.Auth.UseInstancePrincipal {
   116  		configProvider, err = auth.InstancePrincipalConfigurationProvider()
   117  		if err != nil {
   118  			return nil, fmt.Errorf("error creating OCI instance principal config provider: %w", err)
   119  		}
   120  	} else {
   121  		configProvider = common.NewRawConfigurationProvider(
   122  			cfg.Auth.TenancyID,
   123  			cfg.Auth.UserID,
   124  			cfg.Auth.Region,
   125  			cfg.Auth.Fingerprint,
   126  			cfg.Auth.PrivateKey,
   127  			&cfg.Auth.Passphrase,
   128  		)
   129  	}
   130  
   131  	client, err = dns.NewDnsClientWithConfigurationProvider(configProvider)
   132  	if err != nil {
   133  		return nil, fmt.Errorf("initializing OCI DNS API client: %w", err)
   134  	}
   135  
   136  	return &OCIProvider{
   137  		client:       client,
   138  		cfg:          cfg,
   139  		domainFilter: domainFilter,
   140  		zoneIDFilter: zoneIDFilter,
   141  		zoneScope:    zoneScope,
   142  		zoneCache: &zoneCache{
   143  			duration: cfg.ZoneCacheDuration,
   144  		},
   145  		dryRun: dryRun,
   146  	}, nil
   147  }
   148  
   149  func (p *OCIProvider) zones(ctx context.Context) (map[string]dns.ZoneSummary, error) {
   150  	if !p.zoneCache.Expired() {
   151  		log.Debug("Using cached zones list")
   152  		return p.zoneCache.zones, nil
   153  	}
   154  	zones := make(map[string]dns.ZoneSummary)
   155  	scopes := []dns.GetZoneScopeEnum{dns.GetZoneScopeEnum(p.zoneScope)}
   156  	// If zone scope is empty, list all zones types.
   157  	if p.zoneScope == "" {
   158  		scopes = dns.GetGetZoneScopeEnumValues()
   159  	}
   160  	log.Debugf("Matching zones against domain filters: %v", p.domainFilter.Filters)
   161  	for _, scope := range scopes {
   162  		if err := p.addPaginatedZones(ctx, zones, scope); err != nil {
   163  			return nil, err
   164  		}
   165  	}
   166  	if len(zones) == 0 {
   167  		log.Warnf("No zones in compartment %q match domain filters %v", p.cfg.CompartmentID, p.domainFilter)
   168  	}
   169  	p.zoneCache.Reset(zones)
   170  	return zones, nil
   171  }
   172  
   173  func (p *OCIProvider) addPaginatedZones(ctx context.Context, zones map[string]dns.ZoneSummary, scope dns.GetZoneScopeEnum) error {
   174  	var page *string
   175  	// Loop until we have listed all zones.
   176  	for {
   177  		resp, err := p.client.ListZones(ctx, dns.ListZonesRequest{
   178  			CompartmentId: &p.cfg.CompartmentID,
   179  			ZoneType:      dns.ListZonesZoneTypePrimary,
   180  			Scope:         dns.ListZonesScopeEnum(scope),
   181  			Page:          page,
   182  		})
   183  		if err != nil {
   184  			return provider.NewSoftError(fmt.Errorf("listing zones in %s: %w", p.cfg.CompartmentID, err))
   185  		}
   186  		for _, zone := range resp.Items {
   187  			if p.domainFilter.Match(*zone.Name) && p.zoneIDFilter.Match(*zone.Id) {
   188  				zones[*zone.Id] = zone
   189  				log.Debugf("Matched %q (%q)", *zone.Name, *zone.Id)
   190  			} else {
   191  				log.Debugf("Filtered %q (%q)", *zone.Name, *zone.Id)
   192  			}
   193  		}
   194  		if page = resp.OpcNextPage; resp.OpcNextPage == nil {
   195  			break
   196  		}
   197  	}
   198  	return nil
   199  }
   200  
   201  func (p *OCIProvider) newFilteredRecordOperations(endpoints []*endpoint.Endpoint, opType dns.RecordOperationOperationEnum) []dns.RecordOperation {
   202  	ops := []dns.RecordOperation{}
   203  	for _, endpoint := range endpoints {
   204  		if p.domainFilter.Match(endpoint.DNSName) {
   205  			ops = append(ops, newRecordOperation(endpoint, opType))
   206  		}
   207  	}
   208  	return ops
   209  }
   210  
   211  // Records returns the list of records in a given hosted zone.
   212  func (p *OCIProvider) Records(ctx context.Context) ([]*endpoint.Endpoint, error) {
   213  	zones, err := p.zones(ctx)
   214  	if err != nil {
   215  		return nil, provider.NewSoftError(fmt.Errorf("getting zones: %w", err))
   216  	}
   217  
   218  	endpoints := []*endpoint.Endpoint{}
   219  	for _, zone := range zones {
   220  		var page *string
   221  		for {
   222  			resp, err := p.client.GetZoneRecords(ctx, dns.GetZoneRecordsRequest{
   223  				ZoneNameOrId:  zone.Id,
   224  				Page:          page,
   225  				CompartmentId: &p.cfg.CompartmentID,
   226  			})
   227  			if err != nil {
   228  				return nil, provider.NewSoftError(fmt.Errorf("getting records for zone %q: %w", *zone.Id, err))
   229  			}
   230  
   231  			for _, record := range resp.Items {
   232  				if !provider.SupportedRecordType(*record.Rtype) {
   233  					continue
   234  				}
   235  				endpoints = append(endpoints,
   236  					endpoint.NewEndpointWithTTL(
   237  						*record.Domain,
   238  						*record.Rtype,
   239  						endpoint.TTL(*record.Ttl),
   240  						*record.Rdata,
   241  					),
   242  				)
   243  			}
   244  
   245  			if page = resp.OpcNextPage; resp.OpcNextPage == nil {
   246  				break
   247  			}
   248  		}
   249  	}
   250  
   251  	return endpoints, nil
   252  }
   253  
   254  // ApplyChanges applies a given set of changes to a given zone.
   255  func (p *OCIProvider) ApplyChanges(ctx context.Context, changes *plan.Changes) error {
   256  	log.Debugf("Processing changes: %+v", changes)
   257  
   258  	ops := []dns.RecordOperation{}
   259  	ops = append(ops, p.newFilteredRecordOperations(changes.Create, dns.RecordOperationOperationAdd)...)
   260  
   261  	ops = append(ops, p.newFilteredRecordOperations(changes.UpdateNew, dns.RecordOperationOperationAdd)...)
   262  	ops = append(ops, p.newFilteredRecordOperations(changes.UpdateOld, dns.RecordOperationOperationRemove)...)
   263  
   264  	ops = append(ops, p.newFilteredRecordOperations(changes.Delete, dns.RecordOperationOperationRemove)...)
   265  
   266  	if len(ops) == 0 {
   267  		log.Info("All records are already up to date")
   268  		return nil
   269  	}
   270  
   271  	zones, err := p.zones(ctx)
   272  	if err != nil {
   273  		return provider.NewSoftError(fmt.Errorf("fetching zones: %w", err))
   274  	}
   275  
   276  	// Separate into per-zone change sets to be passed to OCI API.
   277  	opsByZone := operationsByZone(zones, ops)
   278  	for zoneID, ops := range opsByZone {
   279  		log.Infof("Change zone: %q", zoneID)
   280  		for _, op := range ops {
   281  			log.Info(op)
   282  		}
   283  	}
   284  
   285  	if p.dryRun {
   286  		return nil
   287  	}
   288  
   289  	for zoneID, ops := range opsByZone {
   290  		if _, err := p.client.PatchZoneRecords(ctx, dns.PatchZoneRecordsRequest{
   291  			CompartmentId:           &p.cfg.CompartmentID,
   292  			ZoneNameOrId:            &zoneID,
   293  			PatchZoneRecordsDetails: dns.PatchZoneRecordsDetails{Items: ops},
   294  		}); err != nil {
   295  			return provider.NewSoftError(err)
   296  		}
   297  	}
   298  
   299  	return nil
   300  }
   301  
   302  // newRecordOperation returns a RecordOperation based on a given endpoint.
   303  func newRecordOperation(ep *endpoint.Endpoint, opType dns.RecordOperationOperationEnum) dns.RecordOperation {
   304  	targets := make([]string, len(ep.Targets))
   305  	copy(targets, ep.Targets)
   306  	if ep.RecordType == endpoint.RecordTypeCNAME {
   307  		targets[0] = provider.EnsureTrailingDot(targets[0])
   308  	}
   309  	rdata := strings.Join(targets, " ")
   310  
   311  	ttl := ociRecordTTL
   312  	if ep.RecordTTL.IsConfigured() {
   313  		ttl = int(ep.RecordTTL)
   314  	}
   315  
   316  	return dns.RecordOperation{
   317  		Domain:    &ep.DNSName,
   318  		Rdata:     &rdata,
   319  		Ttl:       &ttl,
   320  		Rtype:     &ep.RecordType,
   321  		Operation: opType,
   322  	}
   323  }
   324  
   325  // operationsByZone segments a slice of RecordOperations by their zone.
   326  func operationsByZone(zones map[string]dns.ZoneSummary, ops []dns.RecordOperation) map[string][]dns.RecordOperation {
   327  	changes := make(map[string][]dns.RecordOperation)
   328  
   329  	zoneNameIDMapper := provider.ZoneIDName{}
   330  	for _, z := range zones {
   331  		zoneNameIDMapper.Add(*z.Id, *z.Name)
   332  		changes[*z.Id] = []dns.RecordOperation{}
   333  	}
   334  
   335  	for _, op := range ops {
   336  		if zoneID, _ := zoneNameIDMapper.FindZone(*op.Domain); zoneID != "" {
   337  			changes[zoneID] = append(changes[zoneID], op)
   338  		} else {
   339  			log.Warnf("No matching zone for record operation %s", op)
   340  		}
   341  	}
   342  
   343  	// Remove zones that don't have any changes.
   344  	for zone, ops := range changes {
   345  		if len(ops) == 0 {
   346  			delete(changes, zone)
   347  		}
   348  	}
   349  
   350  	return changes
   351  }