github.com/teknogeek/dnscontrol/v2@v2.10.1-0.20200227202244-ae299b55ba42/providers/gcloud/gcloudProvider.go (about)

     1  package google
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"fmt"
     7  	"strings"
     8  
     9  	gauth "golang.org/x/oauth2/google"
    10  	gdns "google.golang.org/api/dns/v1"
    11  
    12  	"github.com/StackExchange/dnscontrol/v2/models"
    13  	"github.com/StackExchange/dnscontrol/v2/providers"
    14  	"github.com/StackExchange/dnscontrol/v2/providers/diff"
    15  )
    16  
    17  var features = providers.DocumentationNotes{
    18  	providers.DocCreateDomains:       providers.Can(),
    19  	providers.DocDualHost:            providers.Can(),
    20  	providers.DocOfficiallySupported: providers.Can(),
    21  	providers.CanUsePTR:              providers.Can(),
    22  	providers.CanUseSRV:              providers.Can(),
    23  	providers.CanUseCAA:              providers.Can(),
    24  	providers.CanUseTXTMulti:         providers.Can(),
    25  	providers.CanGetZones:            providers.Can(),
    26  }
    27  
    28  func sPtr(s string) *string {
    29  	return &s
    30  }
    31  
    32  func init() {
    33  	providers.RegisterDomainServiceProviderType("GCLOUD", New, features)
    34  }
    35  
    36  type gcloud struct {
    37  	client        *gdns.Service
    38  	project       string
    39  	nameServerSet *string
    40  	zones         map[string]*gdns.ManagedZone
    41  }
    42  
    43  type errNoExist struct {
    44  	domain string
    45  }
    46  
    47  func (e errNoExist) Error() string {
    48  	return fmt.Sprintf("Domain '%s' not found in gcloud account", e.domain)
    49  }
    50  
    51  // New creates a new gcloud provider
    52  func New(cfg map[string]string, metadata json.RawMessage) (providers.DNSServiceProvider, error) {
    53  	// the key as downloaded is json encoded with literal "\n" instead of newlines.
    54  	// in some cases (round-tripping through env vars) this tends to get messed up.
    55  	// fix it if we find that.
    56  	if key, ok := cfg["private_key"]; ok {
    57  		cfg["private_key"] = strings.Replace(key, "\\n", "\n", -1)
    58  	}
    59  	raw, err := json.Marshal(cfg)
    60  	if err != nil {
    61  		return nil, err
    62  	}
    63  	config, err := gauth.JWTConfigFromJSON(raw, "https://www.googleapis.com/auth/ndev.clouddns.readwrite")
    64  	if err != nil {
    65  		return nil, err
    66  	}
    67  	ctx := context.Background()
    68  	hc := config.Client(ctx)
    69  	dcli, err := gdns.New(hc)
    70  	if err != nil {
    71  		return nil, err
    72  	}
    73  	var nss *string = nil
    74  	if val, ok := cfg["name_server_set"]; ok {
    75  		fmt.Printf("GCLOUD :name_server_set %s configured\n", val)
    76  		nss = sPtr(val)
    77  	}
    78  
    79  	g := &gcloud{
    80  		client:        dcli,
    81  		nameServerSet: nss,
    82  		project:       cfg["project_id"],
    83  	}
    84  	return g, g.loadZoneInfo()
    85  }
    86  
    87  func (g *gcloud) loadZoneInfo() error {
    88  	if g.zones != nil {
    89  		return nil
    90  	}
    91  	g.zones = map[string]*gdns.ManagedZone{}
    92  	pageToken := ""
    93  	for {
    94  		resp, err := g.client.ManagedZones.List(g.project).PageToken(pageToken).Do()
    95  		if err != nil {
    96  			return err
    97  		}
    98  		for _, z := range resp.ManagedZones {
    99  			g.zones[z.DnsName] = z
   100  		}
   101  		if pageToken = resp.NextPageToken; pageToken == "" {
   102  			break
   103  		}
   104  	}
   105  	return nil
   106  }
   107  
   108  // ListZones returns the list of zones (domains) in this account.
   109  func (g *gcloud) ListZones() ([]string, error) {
   110  	var zones []string
   111  	for i := range g.zones {
   112  		zones = append(zones, strings.TrimSuffix(i, "."))
   113  	}
   114  	return zones, nil
   115  }
   116  
   117  func (g *gcloud) getZone(domain string) (*gdns.ManagedZone, error) {
   118  	return g.zones[domain+"."], nil
   119  }
   120  
   121  func (g *gcloud) GetNameservers(domain string) ([]*models.Nameserver, error) {
   122  	zone, err := g.getZone(domain)
   123  	if err != nil {
   124  		return nil, err
   125  	}
   126  	return models.StringsToNameservers(zone.NameServers), nil
   127  }
   128  
   129  type key struct {
   130  	Type string
   131  	Name string
   132  }
   133  
   134  func keyFor(r *gdns.ResourceRecordSet) key {
   135  	return key{Type: r.Type, Name: r.Name}
   136  }
   137  func keyForRec(r *models.RecordConfig) key {
   138  	return key{Type: r.Type, Name: r.GetLabelFQDN() + "."}
   139  }
   140  
   141  // GetZoneRecords gets the records of a zone and returns them in RecordConfig format.
   142  func (g *gcloud) GetZoneRecords(domain string) (models.Records, error) {
   143  	existingRecords, _, _, err := g.getZoneSets(domain)
   144  	return existingRecords, err
   145  }
   146  
   147  func (g *gcloud) getZoneSets(domain string) (models.Records, map[key]*gdns.ResourceRecordSet, string, error) {
   148  	rrs, zoneName, err := g.getRecords(domain)
   149  	if err != nil {
   150  		return nil, nil, "", err
   151  	}
   152  	// convert to dnscontrol RecordConfig format
   153  	existingRecords := []*models.RecordConfig{}
   154  	oldRRs := map[key]*gdns.ResourceRecordSet{}
   155  	for _, set := range rrs {
   156  		oldRRs[keyFor(set)] = set
   157  		for _, rec := range set.Rrdatas {
   158  			existingRecords = append(existingRecords, nativeToRecord(set, rec, domain))
   159  		}
   160  	}
   161  	return existingRecords, oldRRs, zoneName, err
   162  }
   163  
   164  func (g *gcloud) GetDomainCorrections(dc *models.DomainConfig) ([]*models.Correction, error) {
   165  	if err := dc.Punycode(); err != nil {
   166  		return nil, err
   167  	}
   168  	existingRecords, oldRRs, zoneName, err := g.getZoneSets(dc.Name)
   169  	if err != nil {
   170  		return nil, err
   171  	}
   172  
   173  	// Normalize
   174  	models.PostProcessRecords(existingRecords)
   175  
   176  	// first collect keys that have changed
   177  	differ := diff.New(dc)
   178  	_, create, delete, modify := differ.IncrementalDiff(existingRecords)
   179  	changedKeys := map[key]bool{}
   180  	desc := ""
   181  	for _, c := range create {
   182  		desc += fmt.Sprintln(c)
   183  		changedKeys[keyForRec(c.Desired)] = true
   184  	}
   185  	for _, d := range delete {
   186  		desc += fmt.Sprintln(d)
   187  		changedKeys[keyForRec(d.Existing)] = true
   188  	}
   189  	for _, m := range modify {
   190  		desc += fmt.Sprintln(m)
   191  		changedKeys[keyForRec(m.Existing)] = true
   192  	}
   193  	if len(changedKeys) == 0 {
   194  		return nil, nil
   195  	}
   196  	chg := &gdns.Change{Kind: "dns#change"}
   197  	for ck := range changedKeys {
   198  		// remove old version (if present)
   199  		if old, ok := oldRRs[ck]; ok {
   200  			chg.Deletions = append(chg.Deletions, old)
   201  		}
   202  		// collect records to replace with
   203  		newRRs := &gdns.ResourceRecordSet{
   204  			Name: ck.Name,
   205  			Type: ck.Type,
   206  			Kind: "dns#resourceRecordSet",
   207  		}
   208  		for _, r := range dc.Records {
   209  			if keyForRec(r) == ck {
   210  				newRRs.Rrdatas = append(newRRs.Rrdatas, r.GetTargetCombined())
   211  				newRRs.Ttl = int64(r.TTL)
   212  			}
   213  		}
   214  		if len(newRRs.Rrdatas) > 0 {
   215  			chg.Additions = append(chg.Additions, newRRs)
   216  		}
   217  	}
   218  
   219  	runChange := func() error {
   220  		_, err := g.client.Changes.Create(g.project, zoneName, chg).Do()
   221  		return err
   222  	}
   223  	return []*models.Correction{{
   224  		Msg: desc,
   225  		F:   runChange,
   226  	}}, nil
   227  }
   228  
   229  func nativeToRecord(set *gdns.ResourceRecordSet, rec, origin string) *models.RecordConfig {
   230  	r := &models.RecordConfig{}
   231  	r.SetLabelFromFQDN(set.Name, origin)
   232  	r.TTL = uint32(set.Ttl)
   233  	if err := r.PopulateFromString(set.Type, rec, origin); err != nil {
   234  		panic(fmt.Errorf("unparsable record received from GCLOUD: %w", err))
   235  	}
   236  	return r
   237  }
   238  
   239  func (g *gcloud) getRecords(domain string) ([]*gdns.ResourceRecordSet, string, error) {
   240  	zone, err := g.getZone(domain)
   241  	if err != nil {
   242  		return nil, "", err
   243  	}
   244  	pageToken := ""
   245  	sets := []*gdns.ResourceRecordSet{}
   246  	for {
   247  		call := g.client.ResourceRecordSets.List(g.project, zone.Name)
   248  		if pageToken != "" {
   249  			call = call.PageToken(pageToken)
   250  		}
   251  		resp, err := call.Do()
   252  		if err != nil {
   253  			return nil, "", err
   254  		}
   255  		for _, rrs := range resp.Rrsets {
   256  			if rrs.Type == "SOA" {
   257  				continue
   258  			}
   259  			sets = append(sets, rrs)
   260  		}
   261  		if pageToken = resp.NextPageToken; pageToken == "" {
   262  			break
   263  		}
   264  	}
   265  	return sets, zone.Name, nil
   266  }
   267  
   268  func (g *gcloud) EnsureDomainExists(domain string) error {
   269  	z, err := g.getZone(domain)
   270  	if err != nil {
   271  		if _, ok := err.(errNoExist); !ok {
   272  			return err
   273  		}
   274  	}
   275  	if z != nil {
   276  		return nil
   277  	}
   278  	var mz *gdns.ManagedZone
   279  	if g.nameServerSet != nil {
   280  		fmt.Printf("Adding zone for %s to gcloud account with name_server_set %s\n", domain, *g.nameServerSet)
   281  		mz = &gdns.ManagedZone{
   282  			DnsName:       domain + ".",
   283  			NameServerSet: *g.nameServerSet,
   284  			Name:          "zone-" + strings.Replace(domain, ".", "-", -1),
   285  			Description:   "zone added by dnscontrol",
   286  		}
   287  	} else {
   288  		fmt.Printf("Adding zone for %s to gcloud account \n", domain)
   289  		mz = &gdns.ManagedZone{
   290  			DnsName:     domain + ".",
   291  			Name:        "zone-" + strings.Replace(domain, ".", "-", -1),
   292  			Description: "zone added by dnscontrol",
   293  		}
   294  	}
   295  	g.zones = nil // reset cache
   296  	_, err = g.client.ManagedZones.Create(g.project, mz).Do()
   297  	return err
   298  }