sigs.k8s.io/external-dns@v0.14.1/provider/pihole/client_test.go (about) 1 /* 2 Copyright 2017 The Kubernetes Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package pihole 18 19 import ( 20 "context" 21 "encoding/json" 22 "net/http" 23 "net/http/httptest" 24 "strings" 25 "testing" 26 27 "sigs.k8s.io/external-dns/endpoint" 28 ) 29 30 func newTestServer(t *testing.T, hdlr http.HandlerFunc) *httptest.Server { 31 t.Helper() 32 svr := httptest.NewServer(hdlr) 33 return svr 34 } 35 36 func TestNewPiholeClient(t *testing.T) { 37 // Test correct error on no server provided 38 _, err := newPiholeClient(PiholeConfig{}) 39 if err == nil { 40 t.Error("Expected error from config with no server") 41 } else if err != ErrNoPiholeServer { 42 t.Error("Expected ErrNoPiholeServer, got", err) 43 } 44 45 // Test new client with no password. Should create the 46 // client cleanly. 47 cl, err := newPiholeClient(PiholeConfig{ 48 Server: "test", 49 }) 50 if err != nil { 51 t.Fatal(err) 52 } 53 if _, ok := cl.(*piholeClient); !ok { 54 t.Error("Did not create a new pihole client") 55 } 56 57 // Create a test server for auth tests 58 srvr := newTestServer(t, func(w http.ResponseWriter, r *http.Request) { 59 r.ParseForm() 60 pw := r.Form.Get("pw") 61 if pw != "correct" { 62 // Pihole actually server side renders the fact that you failed, normal 200 63 w.Write([]byte("Invalid")) 64 return 65 } 66 // This is a subset of what happens on successful login 67 w.Write([]byte(` 68 <!doctype html> 69 <html lang="en"> 70 <body> 71 <div id="token" hidden>supersecret</div> 72 </body> 73 </html> 74 `)) 75 }) 76 defer srvr.Close() 77 78 // Test invalid password 79 _, err = newPiholeClient( 80 PiholeConfig{Server: srvr.URL, Password: "wrong"}, 81 ) 82 if err == nil { 83 t.Error("Expected error for creating client with invalid password") 84 } 85 86 // Test correct password 87 cl, err = newPiholeClient( 88 PiholeConfig{Server: srvr.URL, Password: "correct"}, 89 ) 90 if err != nil { 91 t.Fatal(err) 92 } 93 if cl.(*piholeClient).token != "supersecret" { 94 t.Error("Parsed invalid token from login response:", cl.(*piholeClient).token) 95 } 96 } 97 98 func TestListRecords(t *testing.T) { 99 srvr := newTestServer(t, func(w http.ResponseWriter, r *http.Request) { 100 r.ParseForm() 101 if r.Form.Get("action") != "get" { 102 t.Error("Expected 'get' action in form from client") 103 } 104 if strings.Contains(r.URL.Path, "cname") { 105 w.Write([]byte(` 106 { 107 "data": [ 108 ["test4.example.com", "cname.example.com"], 109 ["test5.example.com", "cname.example.com"], 110 ["test6.match.com", "cname.example.com"] 111 ] 112 } 113 `)) 114 return 115 } 116 w.Write([]byte(` 117 { 118 "data": [ 119 ["test1.example.com", "192.168.1.1"], 120 ["test2.example.com", "192.168.1.2"], 121 ["test3.match.com", "192.168.1.3"] 122 ] 123 } 124 `)) 125 }) 126 defer srvr.Close() 127 128 // Create a client 129 cfg := PiholeConfig{ 130 Server: srvr.URL, 131 } 132 cl, err := newPiholeClient(cfg) 133 if err != nil { 134 t.Fatal(err) 135 } 136 137 // Test retrieve A records unfiltered 138 arecs, err := cl.listRecords(context.Background(), endpoint.RecordTypeA) 139 if err != nil { 140 t.Fatal(err) 141 } 142 if len(arecs) != 3 { 143 t.Fatal("Expected 3 A records returned, got:", len(arecs)) 144 } 145 // Ensure records were parsed correctly 146 expected := [][]string{ 147 {"test1.example.com", "192.168.1.1"}, 148 {"test2.example.com", "192.168.1.2"}, 149 {"test3.match.com", "192.168.1.3"}, 150 } 151 for idx, rec := range arecs { 152 if rec.DNSName != expected[idx][0] { 153 t.Error("Got invalid DNS Name:", rec.DNSName, "expected:", expected[idx][0]) 154 } 155 if rec.Targets[0] != expected[idx][1] { 156 t.Error("Got invalid target:", rec.Targets[0], "expected:", expected[idx][1]) 157 } 158 } 159 160 // Test retrieve CNAME records unfiltered 161 cnamerecs, err := cl.listRecords(context.Background(), endpoint.RecordTypeCNAME) 162 if err != nil { 163 t.Fatal(err) 164 } 165 if len(cnamerecs) != 3 { 166 t.Fatal("Expected 3 CAME records returned, got:", len(cnamerecs)) 167 } 168 // Ensure records were parsed correctly 169 expected = [][]string{ 170 {"test4.example.com", "cname.example.com"}, 171 {"test5.example.com", "cname.example.com"}, 172 {"test6.match.com", "cname.example.com"}, 173 } 174 for idx, rec := range cnamerecs { 175 if rec.DNSName != expected[idx][0] { 176 t.Error("Got invalid DNS Name:", rec.DNSName, "expected:", expected[idx][0]) 177 } 178 if rec.Targets[0] != expected[idx][1] { 179 t.Error("Got invalid target:", rec.Targets[0], "expected:", expected[idx][1]) 180 } 181 } 182 183 // Same tests but with a domain filter 184 185 cfg.DomainFilter = endpoint.NewDomainFilter([]string{"match.com"}) 186 cl, err = newPiholeClient(cfg) 187 if err != nil { 188 t.Fatal(err) 189 } 190 191 // Test retrieve A records filtered 192 arecs, err = cl.listRecords(context.Background(), endpoint.RecordTypeA) 193 if err != nil { 194 t.Fatal(err) 195 } 196 if len(arecs) != 1 { 197 t.Fatal("Expected 1 A record returned, got:", len(arecs)) 198 } 199 // Ensure records were parsed correctly 200 expected = [][]string{ 201 {"test3.match.com", "192.168.1.3"}, 202 } 203 for idx, rec := range arecs { 204 if rec.DNSName != expected[idx][0] { 205 t.Error("Got invalid DNS Name:", rec.DNSName, "expected:", expected[idx][0]) 206 } 207 if rec.Targets[0] != expected[idx][1] { 208 t.Error("Got invalid target:", rec.Targets[0], "expected:", expected[idx][1]) 209 } 210 } 211 212 // Test retrieve CNAME records filtered 213 cnamerecs, err = cl.listRecords(context.Background(), endpoint.RecordTypeCNAME) 214 if err != nil { 215 t.Fatal(err) 216 } 217 if len(cnamerecs) != 1 { 218 t.Fatal("Expected 1 CNAME record returned, got:", len(cnamerecs)) 219 } 220 // Ensure records were parsed correctly 221 expected = [][]string{ 222 {"test6.match.com", "cname.example.com"}, 223 } 224 for idx, rec := range cnamerecs { 225 if rec.DNSName != expected[idx][0] { 226 t.Error("Got invalid DNS Name:", rec.DNSName, "expected:", expected[idx][0]) 227 } 228 if rec.Targets[0] != expected[idx][1] { 229 t.Error("Got invalid target:", rec.Targets[0], "expected:", expected[idx][1]) 230 } 231 } 232 } 233 234 func TestCreateRecord(t *testing.T) { 235 var ep *endpoint.Endpoint 236 srvr := newTestServer(t, func(w http.ResponseWriter, r *http.Request) { 237 r.ParseForm() 238 if r.Form.Get("action") != "add" { 239 t.Error("Expected 'add' action in form from client") 240 } 241 if r.Form.Get("domain") != ep.DNSName { 242 t.Error("Invalid domain in form:", r.Form.Get("domain"), "Expected:", ep.DNSName) 243 } 244 switch ep.RecordType { 245 case endpoint.RecordTypeA: 246 if r.Form.Get("ip") != ep.Targets[0] { 247 t.Error("Invalid ip in form:", r.Form.Get("ip"), "Expected:", ep.Targets[0]) 248 } 249 case endpoint.RecordTypeCNAME: 250 if r.Form.Get("target") != ep.Targets[0] { 251 t.Error("Invalid target in form:", r.Form.Get("target"), "Expected:", ep.Targets[0]) 252 } 253 } 254 out, err := json.Marshal(actionResponse{ 255 Success: true, 256 Message: "", 257 }) 258 if err != nil { 259 t.Fatal(err) 260 } 261 w.Write(out) 262 }) 263 defer srvr.Close() 264 265 // Create a client 266 cfg := PiholeConfig{ 267 Server: srvr.URL, 268 } 269 cl, err := newPiholeClient(cfg) 270 if err != nil { 271 t.Fatal(err) 272 } 273 274 // Test create A record 275 ep = &endpoint.Endpoint{ 276 DNSName: "test.example.com", 277 Targets: []string{"192.168.1.1"}, 278 RecordType: endpoint.RecordTypeA, 279 } 280 if err := cl.createRecord(context.Background(), ep); err != nil { 281 t.Fatal(err) 282 } 283 284 // Test create CNAME record 285 ep = &endpoint.Endpoint{ 286 DNSName: "test.example.com", 287 Targets: []string{"test.cname.com"}, 288 RecordType: endpoint.RecordTypeCNAME, 289 } 290 if err := cl.createRecord(context.Background(), ep); err != nil { 291 t.Fatal(err) 292 } 293 } 294 295 func TestDeleteRecord(t *testing.T) { 296 var ep *endpoint.Endpoint 297 srvr := newTestServer(t, func(w http.ResponseWriter, r *http.Request) { 298 r.ParseForm() 299 if r.Form.Get("action") != "delete" { 300 t.Error("Expected 'delete' action in form from client") 301 } 302 if r.Form.Get("domain") != ep.DNSName { 303 t.Error("Invalid domain in form:", r.Form.Get("domain"), "Expected:", ep.DNSName) 304 } 305 switch ep.RecordType { 306 case endpoint.RecordTypeA: 307 if r.Form.Get("ip") != ep.Targets[0] { 308 t.Error("Invalid ip in form:", r.Form.Get("ip"), "Expected:", ep.Targets[0]) 309 } 310 case endpoint.RecordTypeCNAME: 311 if r.Form.Get("target") != ep.Targets[0] { 312 t.Error("Invalid target in form:", r.Form.Get("target"), "Expected:", ep.Targets[0]) 313 } 314 } 315 out, err := json.Marshal(actionResponse{ 316 Success: true, 317 Message: "", 318 }) 319 if err != nil { 320 t.Fatal(err) 321 } 322 w.Write(out) 323 }) 324 defer srvr.Close() 325 326 // Create a client 327 cfg := PiholeConfig{ 328 Server: srvr.URL, 329 } 330 cl, err := newPiholeClient(cfg) 331 if err != nil { 332 t.Fatal(err) 333 } 334 335 // Test delete A record 336 ep = &endpoint.Endpoint{ 337 DNSName: "test.example.com", 338 Targets: []string{"192.168.1.1"}, 339 RecordType: endpoint.RecordTypeA, 340 } 341 if err := cl.deleteRecord(context.Background(), ep); err != nil { 342 t.Fatal(err) 343 } 344 345 // Test delete CNAME record 346 ep = &endpoint.Endpoint{ 347 DNSName: "test.example.com", 348 Targets: []string{"test.cname.com"}, 349 RecordType: endpoint.RecordTypeCNAME, 350 } 351 if err := cl.deleteRecord(context.Background(), ep); err != nil { 352 t.Fatal(err) 353 } 354 }