sigs.k8s.io/external-dns@v0.14.1/provider/pihole/client.go (about)

     1  /*
     2  Copyright 2017 The Kubernetes Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package pihole
    18  
    19  import (
    20  	"context"
    21  	"crypto/tls"
    22  	"encoding/json"
    23  	"errors"
    24  	"fmt"
    25  	"io"
    26  	"net/http"
    27  	"net/http/cookiejar"
    28  	"net/url"
    29  	"strings"
    30  
    31  	"github.com/linki/instrumented_http"
    32  	log "github.com/sirupsen/logrus"
    33  	"golang.org/x/net/html"
    34  
    35  	"sigs.k8s.io/external-dns/endpoint"
    36  )
    37  
    38  // piholeAPI declares the "API" actions performed against the Pihole server.
    39  type piholeAPI interface {
    40  	// listRecords returns endpoints for the given record type (A or CNAME).
    41  	listRecords(ctx context.Context, rtype string) ([]*endpoint.Endpoint, error)
    42  	// createRecord will create a new record for the given endpoint.
    43  	createRecord(ctx context.Context, ep *endpoint.Endpoint) error
    44  	// deleteRecord will delete the given record.
    45  	deleteRecord(ctx context.Context, ep *endpoint.Endpoint) error
    46  }
    47  
    48  // piholeClient implements the piholeAPI.
    49  type piholeClient struct {
    50  	cfg        PiholeConfig
    51  	httpClient *http.Client
    52  	token      string
    53  }
    54  
    55  // newPiholeClient creates a new Pihole API client.
    56  func newPiholeClient(cfg PiholeConfig) (piholeAPI, error) {
    57  	if cfg.Server == "" {
    58  		return nil, ErrNoPiholeServer
    59  	}
    60  
    61  	// Setup a persistent cookiejar for storing PHP session information
    62  	jar, err := cookiejar.New(&cookiejar.Options{})
    63  	if err != nil {
    64  		return nil, err
    65  	}
    66  	// Setup an HTTP client using the cookiejar
    67  	httpClient := &http.Client{
    68  		Jar: jar,
    69  		Transport: &http.Transport{
    70  			TLSClientConfig: &tls.Config{
    71  				InsecureSkipVerify: cfg.TLSInsecureSkipVerify,
    72  			},
    73  		},
    74  	}
    75  	cl := instrumented_http.NewClient(httpClient, &instrumented_http.Callbacks{})
    76  
    77  	p := &piholeClient{
    78  		cfg:        cfg,
    79  		httpClient: cl,
    80  	}
    81  
    82  	if cfg.Password != "" {
    83  		if err := p.retrieveNewToken(context.Background()); err != nil {
    84  			return nil, err
    85  		}
    86  	}
    87  
    88  	return p, nil
    89  }
    90  
    91  func (p *piholeClient) listRecords(ctx context.Context, rtype string) ([]*endpoint.Endpoint, error) {
    92  	form := &url.Values{}
    93  	form.Add("action", "get")
    94  	if p.token != "" {
    95  		form.Add("token", p.token)
    96  	}
    97  
    98  	url, err := p.urlForRecordType(rtype)
    99  	if err != nil {
   100  		return nil, err
   101  	}
   102  
   103  	log.Debugf("Listing %s records from %s", rtype, url)
   104  
   105  	req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, strings.NewReader(form.Encode()))
   106  	if err != nil {
   107  		return nil, err
   108  	}
   109  	req.Header.Add("content-type", "application/x-www-form-urlencoded")
   110  
   111  	body, err := p.do(req)
   112  	if err != nil {
   113  		return nil, err
   114  	}
   115  	defer body.Close()
   116  	raw, err := io.ReadAll(body)
   117  	if err != nil {
   118  		return nil, err
   119  	}
   120  
   121  	// Response is a map of "data" to a list of lists where the first element in each
   122  	// list is the dns name and the second is the target.
   123  	// Pi-Hole does not allow for a record to have multiple targets.
   124  	var res map[string][][]string
   125  	if err := json.Unmarshal(raw, &res); err != nil {
   126  		// Unfortunately this could also just mean we needed to authenticate (still returns a 200).
   127  		// Thankfully the body is a short and concise error.
   128  		err = errors.New(string(raw))
   129  		if strings.Contains(err.Error(), "expired") && p.cfg.Password != "" {
   130  			// Try to fetch a new token and redo the request.
   131  			// Full error message at time of writing:
   132  			// "Not allowed (login session invalid or expired, please relogin on the Pi-hole dashboard)!"
   133  			log.Info("Pihole token has expired, fetching a new one")
   134  			if err := p.retrieveNewToken(ctx); err != nil {
   135  				return nil, err
   136  			}
   137  			return p.listRecords(ctx, rtype)
   138  		}
   139  		// Return raw body as error.
   140  		return nil, err
   141  	}
   142  
   143  	out := make([]*endpoint.Endpoint, 0)
   144  	data, ok := res["data"]
   145  	if !ok {
   146  		return out, nil
   147  	}
   148  	for _, rec := range data {
   149  		name := rec[0]
   150  		target := rec[1]
   151  		if !p.cfg.DomainFilter.Match(name) {
   152  			log.Debugf("Skipping %s that does not match domain filter", name)
   153  			continue
   154  		}
   155  		out = append(out, &endpoint.Endpoint{
   156  			DNSName:    name,
   157  			Targets:    []string{target},
   158  			RecordType: rtype,
   159  		})
   160  	}
   161  
   162  	return out, nil
   163  }
   164  
   165  func (p *piholeClient) createRecord(ctx context.Context, ep *endpoint.Endpoint) error {
   166  	return p.apply(ctx, "add", ep)
   167  }
   168  
   169  func (p *piholeClient) deleteRecord(ctx context.Context, ep *endpoint.Endpoint) error {
   170  	return p.apply(ctx, "delete", ep)
   171  }
   172  
   173  func (p *piholeClient) aRecordsScript() string {
   174  	return fmt.Sprintf("%s/admin/scripts/pi-hole/php/customdns.php", p.cfg.Server)
   175  }
   176  
   177  func (p *piholeClient) cnameRecordsScript() string {
   178  	return fmt.Sprintf("%s/admin/scripts/pi-hole/php/customcname.php", p.cfg.Server)
   179  }
   180  
   181  func (p *piholeClient) urlForRecordType(rtype string) (string, error) {
   182  	switch rtype {
   183  	case endpoint.RecordTypeA:
   184  		return p.aRecordsScript(), nil
   185  	case endpoint.RecordTypeCNAME:
   186  		return p.cnameRecordsScript(), nil
   187  	default:
   188  		return "", fmt.Errorf("unsupported record type: %s", rtype)
   189  	}
   190  }
   191  
   192  type actionResponse struct {
   193  	Success bool   `json:"success"`
   194  	Message string `json:"message"`
   195  }
   196  
   197  func (p *piholeClient) apply(ctx context.Context, action string, ep *endpoint.Endpoint) error {
   198  	if !p.cfg.DomainFilter.Match(ep.DNSName) {
   199  		log.Debugf("Skipping %s %s that does not match domain filter", action, ep.DNSName)
   200  		return nil
   201  	}
   202  	url, err := p.urlForRecordType(ep.RecordType)
   203  	if err != nil {
   204  		log.Warnf("Skipping unsupported endpoint %s %s %v", ep.DNSName, ep.RecordType, ep.Targets)
   205  		return nil
   206  	}
   207  
   208  	if p.cfg.DryRun {
   209  		log.Infof("DRY RUN: %s %s IN %s -> %s", action, ep.DNSName, ep.RecordType, ep.Targets[0])
   210  		return nil
   211  	}
   212  
   213  	log.Infof("%s %s IN %s -> %s", action, ep.DNSName, ep.RecordType, ep.Targets[0])
   214  
   215  	form := p.newDNSActionForm(action, ep)
   216  	req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, strings.NewReader(form.Encode()))
   217  	if err != nil {
   218  		return err
   219  	}
   220  	req.Header.Add("content-type", "application/x-www-form-urlencoded")
   221  
   222  	body, err := p.do(req)
   223  	if err != nil {
   224  		return err
   225  	}
   226  	defer body.Close()
   227  
   228  	raw, err := io.ReadAll(body)
   229  	if err != nil {
   230  		return nil
   231  	}
   232  
   233  	var res actionResponse
   234  	if err := json.Unmarshal(raw, &res); err != nil {
   235  		// Unfortunately this could also be a generic server or auth error.
   236  		err = errors.New(string(raw))
   237  		if strings.Contains(err.Error(), "expired") && p.cfg.Password != "" {
   238  			// Try to fetch a new token and redo the request.
   239  			log.Info("Pihole token has expired, fetching a new one")
   240  			if err := p.retrieveNewToken(ctx); err != nil {
   241  				return err
   242  			}
   243  			return p.apply(ctx, action, ep)
   244  		}
   245  		// Return raw body as error.
   246  		return err
   247  	}
   248  
   249  	if !res.Success {
   250  		return errors.New(res.Message)
   251  	}
   252  
   253  	return nil
   254  }
   255  
   256  func (p *piholeClient) retrieveNewToken(ctx context.Context) error {
   257  	if p.cfg.Password == "" {
   258  		return nil
   259  	}
   260  
   261  	form := &url.Values{}
   262  	form.Add("pw", p.cfg.Password)
   263  	url := fmt.Sprintf("%s/admin/index.php?login", p.cfg.Server)
   264  	log.Debugf("Fetching new token from %s", url)
   265  
   266  	req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, strings.NewReader(form.Encode()))
   267  	if err != nil {
   268  		return err
   269  	}
   270  	req.Header.Add("content-type", "application/x-www-form-urlencoded")
   271  
   272  	body, err := p.do(req)
   273  	if err != nil {
   274  		return err
   275  	}
   276  	defer body.Close()
   277  
   278  	// If successful the request will redirect us to an HTML page with a hidden
   279  	// div containing the token...The token gives us access to other PHP
   280  	// endpoints via a form value.
   281  	p.token, err = parseTokenFromLogin(body)
   282  	return err
   283  }
   284  
   285  func (p *piholeClient) newDNSActionForm(action string, ep *endpoint.Endpoint) *url.Values {
   286  	form := &url.Values{}
   287  	form.Add("action", action)
   288  	form.Add("domain", ep.DNSName)
   289  	switch ep.RecordType {
   290  	case endpoint.RecordTypeA:
   291  		form.Add("ip", ep.Targets[0])
   292  	case endpoint.RecordTypeCNAME:
   293  		form.Add("target", ep.Targets[0])
   294  	}
   295  	if p.token != "" {
   296  		form.Add("token", p.token)
   297  	}
   298  	return form
   299  }
   300  
   301  func (p *piholeClient) do(req *http.Request) (io.ReadCloser, error) {
   302  	res, err := p.httpClient.Do(req)
   303  	if err != nil {
   304  		return nil, err
   305  	}
   306  	if res.StatusCode != http.StatusOK {
   307  		defer res.Body.Close()
   308  		return nil, fmt.Errorf("received non-200 status code from request: %s", res.Status)
   309  	}
   310  	return res.Body, nil
   311  }
   312  
   313  func parseTokenFromLogin(body io.ReadCloser) (string, error) {
   314  	doc, err := html.Parse(body)
   315  	if err != nil {
   316  		return "", err
   317  	}
   318  
   319  	tokenNode := getElementById(doc, "token")
   320  	if tokenNode == nil {
   321  		return "", errors.New("could not parse token from login response")
   322  	}
   323  
   324  	return tokenNode.FirstChild.Data, nil
   325  }
   326  
   327  func getAttribute(n *html.Node, key string) (string, bool) {
   328  	for _, attr := range n.Attr {
   329  		if attr.Key == key {
   330  			return attr.Val, true
   331  		}
   332  	}
   333  	return "", false
   334  }
   335  
   336  func hasID(n *html.Node, id string) bool {
   337  	if n.Type == html.ElementNode {
   338  		s, ok := getAttribute(n, "id")
   339  		if ok && s == id {
   340  			return true
   341  		}
   342  	}
   343  	return false
   344  }
   345  
   346  func traverse(n *html.Node, id string) *html.Node {
   347  	if hasID(n, id) {
   348  		return n
   349  	}
   350  
   351  	for c := n.FirstChild; c != nil; c = c.NextSibling {
   352  		result := traverse(c, id)
   353  		if result != nil {
   354  			return result
   355  		}
   356  	}
   357  
   358  	return nil
   359  }
   360  
   361  func getElementById(n *html.Node, id string) *html.Node {
   362  	return traverse(n, id)
   363  }