sigs.k8s.io/external-dns@v0.14.1/provider/pihole/client.go (about) 1 /* 2 Copyright 2017 The Kubernetes Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package pihole 18 19 import ( 20 "context" 21 "crypto/tls" 22 "encoding/json" 23 "errors" 24 "fmt" 25 "io" 26 "net/http" 27 "net/http/cookiejar" 28 "net/url" 29 "strings" 30 31 "github.com/linki/instrumented_http" 32 log "github.com/sirupsen/logrus" 33 "golang.org/x/net/html" 34 35 "sigs.k8s.io/external-dns/endpoint" 36 ) 37 38 // piholeAPI declares the "API" actions performed against the Pihole server. 39 type piholeAPI interface { 40 // listRecords returns endpoints for the given record type (A or CNAME). 41 listRecords(ctx context.Context, rtype string) ([]*endpoint.Endpoint, error) 42 // createRecord will create a new record for the given endpoint. 43 createRecord(ctx context.Context, ep *endpoint.Endpoint) error 44 // deleteRecord will delete the given record. 45 deleteRecord(ctx context.Context, ep *endpoint.Endpoint) error 46 } 47 48 // piholeClient implements the piholeAPI. 49 type piholeClient struct { 50 cfg PiholeConfig 51 httpClient *http.Client 52 token string 53 } 54 55 // newPiholeClient creates a new Pihole API client. 56 func newPiholeClient(cfg PiholeConfig) (piholeAPI, error) { 57 if cfg.Server == "" { 58 return nil, ErrNoPiholeServer 59 } 60 61 // Setup a persistent cookiejar for storing PHP session information 62 jar, err := cookiejar.New(&cookiejar.Options{}) 63 if err != nil { 64 return nil, err 65 } 66 // Setup an HTTP client using the cookiejar 67 httpClient := &http.Client{ 68 Jar: jar, 69 Transport: &http.Transport{ 70 TLSClientConfig: &tls.Config{ 71 InsecureSkipVerify: cfg.TLSInsecureSkipVerify, 72 }, 73 }, 74 } 75 cl := instrumented_http.NewClient(httpClient, &instrumented_http.Callbacks{}) 76 77 p := &piholeClient{ 78 cfg: cfg, 79 httpClient: cl, 80 } 81 82 if cfg.Password != "" { 83 if err := p.retrieveNewToken(context.Background()); err != nil { 84 return nil, err 85 } 86 } 87 88 return p, nil 89 } 90 91 func (p *piholeClient) listRecords(ctx context.Context, rtype string) ([]*endpoint.Endpoint, error) { 92 form := &url.Values{} 93 form.Add("action", "get") 94 if p.token != "" { 95 form.Add("token", p.token) 96 } 97 98 url, err := p.urlForRecordType(rtype) 99 if err != nil { 100 return nil, err 101 } 102 103 log.Debugf("Listing %s records from %s", rtype, url) 104 105 req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, strings.NewReader(form.Encode())) 106 if err != nil { 107 return nil, err 108 } 109 req.Header.Add("content-type", "application/x-www-form-urlencoded") 110 111 body, err := p.do(req) 112 if err != nil { 113 return nil, err 114 } 115 defer body.Close() 116 raw, err := io.ReadAll(body) 117 if err != nil { 118 return nil, err 119 } 120 121 // Response is a map of "data" to a list of lists where the first element in each 122 // list is the dns name and the second is the target. 123 // Pi-Hole does not allow for a record to have multiple targets. 124 var res map[string][][]string 125 if err := json.Unmarshal(raw, &res); err != nil { 126 // Unfortunately this could also just mean we needed to authenticate (still returns a 200). 127 // Thankfully the body is a short and concise error. 128 err = errors.New(string(raw)) 129 if strings.Contains(err.Error(), "expired") && p.cfg.Password != "" { 130 // Try to fetch a new token and redo the request. 131 // Full error message at time of writing: 132 // "Not allowed (login session invalid or expired, please relogin on the Pi-hole dashboard)!" 133 log.Info("Pihole token has expired, fetching a new one") 134 if err := p.retrieveNewToken(ctx); err != nil { 135 return nil, err 136 } 137 return p.listRecords(ctx, rtype) 138 } 139 // Return raw body as error. 140 return nil, err 141 } 142 143 out := make([]*endpoint.Endpoint, 0) 144 data, ok := res["data"] 145 if !ok { 146 return out, nil 147 } 148 for _, rec := range data { 149 name := rec[0] 150 target := rec[1] 151 if !p.cfg.DomainFilter.Match(name) { 152 log.Debugf("Skipping %s that does not match domain filter", name) 153 continue 154 } 155 out = append(out, &endpoint.Endpoint{ 156 DNSName: name, 157 Targets: []string{target}, 158 RecordType: rtype, 159 }) 160 } 161 162 return out, nil 163 } 164 165 func (p *piholeClient) createRecord(ctx context.Context, ep *endpoint.Endpoint) error { 166 return p.apply(ctx, "add", ep) 167 } 168 169 func (p *piholeClient) deleteRecord(ctx context.Context, ep *endpoint.Endpoint) error { 170 return p.apply(ctx, "delete", ep) 171 } 172 173 func (p *piholeClient) aRecordsScript() string { 174 return fmt.Sprintf("%s/admin/scripts/pi-hole/php/customdns.php", p.cfg.Server) 175 } 176 177 func (p *piholeClient) cnameRecordsScript() string { 178 return fmt.Sprintf("%s/admin/scripts/pi-hole/php/customcname.php", p.cfg.Server) 179 } 180 181 func (p *piholeClient) urlForRecordType(rtype string) (string, error) { 182 switch rtype { 183 case endpoint.RecordTypeA: 184 return p.aRecordsScript(), nil 185 case endpoint.RecordTypeCNAME: 186 return p.cnameRecordsScript(), nil 187 default: 188 return "", fmt.Errorf("unsupported record type: %s", rtype) 189 } 190 } 191 192 type actionResponse struct { 193 Success bool `json:"success"` 194 Message string `json:"message"` 195 } 196 197 func (p *piholeClient) apply(ctx context.Context, action string, ep *endpoint.Endpoint) error { 198 if !p.cfg.DomainFilter.Match(ep.DNSName) { 199 log.Debugf("Skipping %s %s that does not match domain filter", action, ep.DNSName) 200 return nil 201 } 202 url, err := p.urlForRecordType(ep.RecordType) 203 if err != nil { 204 log.Warnf("Skipping unsupported endpoint %s %s %v", ep.DNSName, ep.RecordType, ep.Targets) 205 return nil 206 } 207 208 if p.cfg.DryRun { 209 log.Infof("DRY RUN: %s %s IN %s -> %s", action, ep.DNSName, ep.RecordType, ep.Targets[0]) 210 return nil 211 } 212 213 log.Infof("%s %s IN %s -> %s", action, ep.DNSName, ep.RecordType, ep.Targets[0]) 214 215 form := p.newDNSActionForm(action, ep) 216 req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, strings.NewReader(form.Encode())) 217 if err != nil { 218 return err 219 } 220 req.Header.Add("content-type", "application/x-www-form-urlencoded") 221 222 body, err := p.do(req) 223 if err != nil { 224 return err 225 } 226 defer body.Close() 227 228 raw, err := io.ReadAll(body) 229 if err != nil { 230 return nil 231 } 232 233 var res actionResponse 234 if err := json.Unmarshal(raw, &res); err != nil { 235 // Unfortunately this could also be a generic server or auth error. 236 err = errors.New(string(raw)) 237 if strings.Contains(err.Error(), "expired") && p.cfg.Password != "" { 238 // Try to fetch a new token and redo the request. 239 log.Info("Pihole token has expired, fetching a new one") 240 if err := p.retrieveNewToken(ctx); err != nil { 241 return err 242 } 243 return p.apply(ctx, action, ep) 244 } 245 // Return raw body as error. 246 return err 247 } 248 249 if !res.Success { 250 return errors.New(res.Message) 251 } 252 253 return nil 254 } 255 256 func (p *piholeClient) retrieveNewToken(ctx context.Context) error { 257 if p.cfg.Password == "" { 258 return nil 259 } 260 261 form := &url.Values{} 262 form.Add("pw", p.cfg.Password) 263 url := fmt.Sprintf("%s/admin/index.php?login", p.cfg.Server) 264 log.Debugf("Fetching new token from %s", url) 265 266 req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, strings.NewReader(form.Encode())) 267 if err != nil { 268 return err 269 } 270 req.Header.Add("content-type", "application/x-www-form-urlencoded") 271 272 body, err := p.do(req) 273 if err != nil { 274 return err 275 } 276 defer body.Close() 277 278 // If successful the request will redirect us to an HTML page with a hidden 279 // div containing the token...The token gives us access to other PHP 280 // endpoints via a form value. 281 p.token, err = parseTokenFromLogin(body) 282 return err 283 } 284 285 func (p *piholeClient) newDNSActionForm(action string, ep *endpoint.Endpoint) *url.Values { 286 form := &url.Values{} 287 form.Add("action", action) 288 form.Add("domain", ep.DNSName) 289 switch ep.RecordType { 290 case endpoint.RecordTypeA: 291 form.Add("ip", ep.Targets[0]) 292 case endpoint.RecordTypeCNAME: 293 form.Add("target", ep.Targets[0]) 294 } 295 if p.token != "" { 296 form.Add("token", p.token) 297 } 298 return form 299 } 300 301 func (p *piholeClient) do(req *http.Request) (io.ReadCloser, error) { 302 res, err := p.httpClient.Do(req) 303 if err != nil { 304 return nil, err 305 } 306 if res.StatusCode != http.StatusOK { 307 defer res.Body.Close() 308 return nil, fmt.Errorf("received non-200 status code from request: %s", res.Status) 309 } 310 return res.Body, nil 311 } 312 313 func parseTokenFromLogin(body io.ReadCloser) (string, error) { 314 doc, err := html.Parse(body) 315 if err != nil { 316 return "", err 317 } 318 319 tokenNode := getElementById(doc, "token") 320 if tokenNode == nil { 321 return "", errors.New("could not parse token from login response") 322 } 323 324 return tokenNode.FirstChild.Data, nil 325 } 326 327 func getAttribute(n *html.Node, key string) (string, bool) { 328 for _, attr := range n.Attr { 329 if attr.Key == key { 330 return attr.Val, true 331 } 332 } 333 return "", false 334 } 335 336 func hasID(n *html.Node, id string) bool { 337 if n.Type == html.ElementNode { 338 s, ok := getAttribute(n, "id") 339 if ok && s == id { 340 return true 341 } 342 } 343 return false 344 } 345 346 func traverse(n *html.Node, id string) *html.Node { 347 if hasID(n, id) { 348 return n 349 } 350 351 for c := n.FirstChild; c != nil; c = c.NextSibling { 352 result := traverse(c, id) 353 if result != nil { 354 return result 355 } 356 } 357 358 return nil 359 } 360 361 func getElementById(n *html.Node, id string) *html.Node { 362 return traverse(n, id) 363 }