sigs.k8s.io/external-dns@v0.14.1/provider/pihole/client_test.go (about)

     1  /*
     2  Copyright 2017 The Kubernetes Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package pihole
    18  
    19  import (
    20  	"context"
    21  	"encoding/json"
    22  	"net/http"
    23  	"net/http/httptest"
    24  	"strings"
    25  	"testing"
    26  
    27  	"sigs.k8s.io/external-dns/endpoint"
    28  )
    29  
    30  func newTestServer(t *testing.T, hdlr http.HandlerFunc) *httptest.Server {
    31  	t.Helper()
    32  	svr := httptest.NewServer(hdlr)
    33  	return svr
    34  }
    35  
    36  func TestNewPiholeClient(t *testing.T) {
    37  	// Test correct error on no server provided
    38  	_, err := newPiholeClient(PiholeConfig{})
    39  	if err == nil {
    40  		t.Error("Expected error from config with no server")
    41  	} else if err != ErrNoPiholeServer {
    42  		t.Error("Expected ErrNoPiholeServer, got", err)
    43  	}
    44  
    45  	// Test new client with no password. Should create the
    46  	// client cleanly.
    47  	cl, err := newPiholeClient(PiholeConfig{
    48  		Server: "test",
    49  	})
    50  	if err != nil {
    51  		t.Fatal(err)
    52  	}
    53  	if _, ok := cl.(*piholeClient); !ok {
    54  		t.Error("Did not create a new pihole client")
    55  	}
    56  
    57  	// Create a test server for auth tests
    58  	srvr := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
    59  		r.ParseForm()
    60  		pw := r.Form.Get("pw")
    61  		if pw != "correct" {
    62  			// Pihole actually server side renders the fact that you failed, normal 200
    63  			w.Write([]byte("Invalid"))
    64  			return
    65  		}
    66  		// This is a subset of what happens on successful login
    67  		w.Write([]byte(`
    68  		<!doctype html>
    69  		<html lang="en">
    70  			<body>
    71  				<div id="token" hidden>supersecret</div>
    72  			</body>
    73  		</html>
    74  		`))
    75  	})
    76  	defer srvr.Close()
    77  
    78  	// Test invalid password
    79  	_, err = newPiholeClient(
    80  		PiholeConfig{Server: srvr.URL, Password: "wrong"},
    81  	)
    82  	if err == nil {
    83  		t.Error("Expected error for creating client with invalid password")
    84  	}
    85  
    86  	// Test correct password
    87  	cl, err = newPiholeClient(
    88  		PiholeConfig{Server: srvr.URL, Password: "correct"},
    89  	)
    90  	if err != nil {
    91  		t.Fatal(err)
    92  	}
    93  	if cl.(*piholeClient).token != "supersecret" {
    94  		t.Error("Parsed invalid token from login response:", cl.(*piholeClient).token)
    95  	}
    96  }
    97  
    98  func TestListRecords(t *testing.T) {
    99  	srvr := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
   100  		r.ParseForm()
   101  		if r.Form.Get("action") != "get" {
   102  			t.Error("Expected 'get' action in form from client")
   103  		}
   104  		if strings.Contains(r.URL.Path, "cname") {
   105  			w.Write([]byte(`
   106  			{
   107  				"data": [
   108  					["test4.example.com", "cname.example.com"],
   109  					["test5.example.com", "cname.example.com"],
   110  					["test6.match.com", "cname.example.com"]
   111  				]
   112  			}
   113  			`))
   114  			return
   115  		}
   116  		w.Write([]byte(`
   117  		{
   118  			"data": [
   119  				["test1.example.com", "192.168.1.1"],
   120  				["test2.example.com", "192.168.1.2"],
   121  				["test3.match.com", "192.168.1.3"]
   122  			]
   123  		}
   124  		`))
   125  	})
   126  	defer srvr.Close()
   127  
   128  	// Create a client
   129  	cfg := PiholeConfig{
   130  		Server: srvr.URL,
   131  	}
   132  	cl, err := newPiholeClient(cfg)
   133  	if err != nil {
   134  		t.Fatal(err)
   135  	}
   136  
   137  	// Test retrieve A records unfiltered
   138  	arecs, err := cl.listRecords(context.Background(), endpoint.RecordTypeA)
   139  	if err != nil {
   140  		t.Fatal(err)
   141  	}
   142  	if len(arecs) != 3 {
   143  		t.Fatal("Expected 3 A records returned, got:", len(arecs))
   144  	}
   145  	// Ensure records were parsed correctly
   146  	expected := [][]string{
   147  		{"test1.example.com", "192.168.1.1"},
   148  		{"test2.example.com", "192.168.1.2"},
   149  		{"test3.match.com", "192.168.1.3"},
   150  	}
   151  	for idx, rec := range arecs {
   152  		if rec.DNSName != expected[idx][0] {
   153  			t.Error("Got invalid DNS Name:", rec.DNSName, "expected:", expected[idx][0])
   154  		}
   155  		if rec.Targets[0] != expected[idx][1] {
   156  			t.Error("Got invalid target:", rec.Targets[0], "expected:", expected[idx][1])
   157  		}
   158  	}
   159  
   160  	// Test retrieve CNAME records unfiltered
   161  	cnamerecs, err := cl.listRecords(context.Background(), endpoint.RecordTypeCNAME)
   162  	if err != nil {
   163  		t.Fatal(err)
   164  	}
   165  	if len(cnamerecs) != 3 {
   166  		t.Fatal("Expected 3 CAME records returned, got:", len(cnamerecs))
   167  	}
   168  	// Ensure records were parsed correctly
   169  	expected = [][]string{
   170  		{"test4.example.com", "cname.example.com"},
   171  		{"test5.example.com", "cname.example.com"},
   172  		{"test6.match.com", "cname.example.com"},
   173  	}
   174  	for idx, rec := range cnamerecs {
   175  		if rec.DNSName != expected[idx][0] {
   176  			t.Error("Got invalid DNS Name:", rec.DNSName, "expected:", expected[idx][0])
   177  		}
   178  		if rec.Targets[0] != expected[idx][1] {
   179  			t.Error("Got invalid target:", rec.Targets[0], "expected:", expected[idx][1])
   180  		}
   181  	}
   182  
   183  	// Same tests but with a domain filter
   184  
   185  	cfg.DomainFilter = endpoint.NewDomainFilter([]string{"match.com"})
   186  	cl, err = newPiholeClient(cfg)
   187  	if err != nil {
   188  		t.Fatal(err)
   189  	}
   190  
   191  	// Test retrieve A records filtered
   192  	arecs, err = cl.listRecords(context.Background(), endpoint.RecordTypeA)
   193  	if err != nil {
   194  		t.Fatal(err)
   195  	}
   196  	if len(arecs) != 1 {
   197  		t.Fatal("Expected 1 A record returned, got:", len(arecs))
   198  	}
   199  	// Ensure records were parsed correctly
   200  	expected = [][]string{
   201  		{"test3.match.com", "192.168.1.3"},
   202  	}
   203  	for idx, rec := range arecs {
   204  		if rec.DNSName != expected[idx][0] {
   205  			t.Error("Got invalid DNS Name:", rec.DNSName, "expected:", expected[idx][0])
   206  		}
   207  		if rec.Targets[0] != expected[idx][1] {
   208  			t.Error("Got invalid target:", rec.Targets[0], "expected:", expected[idx][1])
   209  		}
   210  	}
   211  
   212  	// Test retrieve CNAME records filtered
   213  	cnamerecs, err = cl.listRecords(context.Background(), endpoint.RecordTypeCNAME)
   214  	if err != nil {
   215  		t.Fatal(err)
   216  	}
   217  	if len(cnamerecs) != 1 {
   218  		t.Fatal("Expected 1 CNAME record returned, got:", len(cnamerecs))
   219  	}
   220  	// Ensure records were parsed correctly
   221  	expected = [][]string{
   222  		{"test6.match.com", "cname.example.com"},
   223  	}
   224  	for idx, rec := range cnamerecs {
   225  		if rec.DNSName != expected[idx][0] {
   226  			t.Error("Got invalid DNS Name:", rec.DNSName, "expected:", expected[idx][0])
   227  		}
   228  		if rec.Targets[0] != expected[idx][1] {
   229  			t.Error("Got invalid target:", rec.Targets[0], "expected:", expected[idx][1])
   230  		}
   231  	}
   232  }
   233  
   234  func TestCreateRecord(t *testing.T) {
   235  	var ep *endpoint.Endpoint
   236  	srvr := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
   237  		r.ParseForm()
   238  		if r.Form.Get("action") != "add" {
   239  			t.Error("Expected 'add' action in form from client")
   240  		}
   241  		if r.Form.Get("domain") != ep.DNSName {
   242  			t.Error("Invalid domain in form:", r.Form.Get("domain"), "Expected:", ep.DNSName)
   243  		}
   244  		switch ep.RecordType {
   245  		case endpoint.RecordTypeA:
   246  			if r.Form.Get("ip") != ep.Targets[0] {
   247  				t.Error("Invalid ip in form:", r.Form.Get("ip"), "Expected:", ep.Targets[0])
   248  			}
   249  		case endpoint.RecordTypeCNAME:
   250  			if r.Form.Get("target") != ep.Targets[0] {
   251  				t.Error("Invalid target in form:", r.Form.Get("target"), "Expected:", ep.Targets[0])
   252  			}
   253  		}
   254  		out, err := json.Marshal(actionResponse{
   255  			Success: true,
   256  			Message: "",
   257  		})
   258  		if err != nil {
   259  			t.Fatal(err)
   260  		}
   261  		w.Write(out)
   262  	})
   263  	defer srvr.Close()
   264  
   265  	// Create a client
   266  	cfg := PiholeConfig{
   267  		Server: srvr.URL,
   268  	}
   269  	cl, err := newPiholeClient(cfg)
   270  	if err != nil {
   271  		t.Fatal(err)
   272  	}
   273  
   274  	// Test create A record
   275  	ep = &endpoint.Endpoint{
   276  		DNSName:    "test.example.com",
   277  		Targets:    []string{"192.168.1.1"},
   278  		RecordType: endpoint.RecordTypeA,
   279  	}
   280  	if err := cl.createRecord(context.Background(), ep); err != nil {
   281  		t.Fatal(err)
   282  	}
   283  
   284  	// Test create CNAME record
   285  	ep = &endpoint.Endpoint{
   286  		DNSName:    "test.example.com",
   287  		Targets:    []string{"test.cname.com"},
   288  		RecordType: endpoint.RecordTypeCNAME,
   289  	}
   290  	if err := cl.createRecord(context.Background(), ep); err != nil {
   291  		t.Fatal(err)
   292  	}
   293  }
   294  
   295  func TestDeleteRecord(t *testing.T) {
   296  	var ep *endpoint.Endpoint
   297  	srvr := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
   298  		r.ParseForm()
   299  		if r.Form.Get("action") != "delete" {
   300  			t.Error("Expected 'delete' action in form from client")
   301  		}
   302  		if r.Form.Get("domain") != ep.DNSName {
   303  			t.Error("Invalid domain in form:", r.Form.Get("domain"), "Expected:", ep.DNSName)
   304  		}
   305  		switch ep.RecordType {
   306  		case endpoint.RecordTypeA:
   307  			if r.Form.Get("ip") != ep.Targets[0] {
   308  				t.Error("Invalid ip in form:", r.Form.Get("ip"), "Expected:", ep.Targets[0])
   309  			}
   310  		case endpoint.RecordTypeCNAME:
   311  			if r.Form.Get("target") != ep.Targets[0] {
   312  				t.Error("Invalid target in form:", r.Form.Get("target"), "Expected:", ep.Targets[0])
   313  			}
   314  		}
   315  		out, err := json.Marshal(actionResponse{
   316  			Success: true,
   317  			Message: "",
   318  		})
   319  		if err != nil {
   320  			t.Fatal(err)
   321  		}
   322  		w.Write(out)
   323  	})
   324  	defer srvr.Close()
   325  
   326  	// Create a client
   327  	cfg := PiholeConfig{
   328  		Server: srvr.URL,
   329  	}
   330  	cl, err := newPiholeClient(cfg)
   331  	if err != nil {
   332  		t.Fatal(err)
   333  	}
   334  
   335  	// Test delete A record
   336  	ep = &endpoint.Endpoint{
   337  		DNSName:    "test.example.com",
   338  		Targets:    []string{"192.168.1.1"},
   339  		RecordType: endpoint.RecordTypeA,
   340  	}
   341  	if err := cl.deleteRecord(context.Background(), ep); err != nil {
   342  		t.Fatal(err)
   343  	}
   344  
   345  	// Test delete CNAME record
   346  	ep = &endpoint.Endpoint{
   347  		DNSName:    "test.example.com",
   348  		Targets:    []string{"test.cname.com"},
   349  		RecordType: endpoint.RecordTypeCNAME,
   350  	}
   351  	if err := cl.deleteRecord(context.Background(), ep); err != nil {
   352  		t.Fatal(err)
   353  	}
   354  }