sigs.k8s.io/external-dns@v0.14.1/provider/webhook/api/httpapi_test.go (about)

     1  /*
     2  Copyright 2023 The Kubernetes Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package api
    18  
    19  import (
    20  	"bytes"
    21  	"context"
    22  	"encoding/json"
    23  	"fmt"
    24  	"io"
    25  	"net/http"
    26  	"net/http/httptest"
    27  	"testing"
    28  	"time"
    29  
    30  	"github.com/stretchr/testify/require"
    31  	"sigs.k8s.io/external-dns/endpoint"
    32  	"sigs.k8s.io/external-dns/plan"
    33  )
    34  
    35  var records []*endpoint.Endpoint
    36  
    37  type FakeWebhookProvider struct {
    38  	err          error
    39  	domainFilter endpoint.DomainFilter
    40  }
    41  
    42  func (p FakeWebhookProvider) Records(ctx context.Context) ([]*endpoint.Endpoint, error) {
    43  	if p.err != nil {
    44  		return nil, p.err
    45  	}
    46  	return records, nil
    47  }
    48  
    49  func (p FakeWebhookProvider) ApplyChanges(ctx context.Context, changes *plan.Changes) error {
    50  	if p.err != nil {
    51  		return p.err
    52  	}
    53  	records = append(records, changes.Create...)
    54  	return nil
    55  }
    56  
    57  func (p FakeWebhookProvider) AdjustEndpoints(endpoints []*endpoint.Endpoint) ([]*endpoint.Endpoint, error) {
    58  	// for simplicity, we do not adjust endpoints in this test
    59  	if p.err != nil {
    60  		return nil, p.err
    61  	}
    62  	return endpoints, nil
    63  }
    64  
    65  func (p FakeWebhookProvider) GetDomainFilter() endpoint.DomainFilter {
    66  	return p.domainFilter
    67  }
    68  
    69  func TestMain(m *testing.M) {
    70  	records = []*endpoint.Endpoint{
    71  		{
    72  			DNSName:    "foo.bar.com",
    73  			RecordType: "A",
    74  		},
    75  	}
    76  	m.Run()
    77  }
    78  
    79  func TestRecordsHandlerRecords(t *testing.T) {
    80  	req := httptest.NewRequest(http.MethodGet, "/records", nil)
    81  	w := httptest.NewRecorder()
    82  
    83  	providerAPIServer := &WebhookServer{
    84  		Provider: &FakeWebhookProvider{
    85  			domainFilter: endpoint.NewDomainFilter([]string{"foo.bar.com"}),
    86  		},
    87  	}
    88  	providerAPIServer.RecordsHandler(w, req)
    89  	res := w.Result()
    90  	require.Equal(t, http.StatusOK, res.StatusCode)
    91  	// require that the res has the same endpoints as the records slice
    92  	defer res.Body.Close()
    93  	require.NotNil(t, res.Body)
    94  	endpoints := []*endpoint.Endpoint{}
    95  	if err := json.NewDecoder(res.Body).Decode(&endpoints); err != nil {
    96  		t.Errorf("Failed to decode response body: %s", err.Error())
    97  	}
    98  	require.Equal(t, records, endpoints)
    99  }
   100  
   101  func TestRecordsHandlerRecordsWithErrors(t *testing.T) {
   102  	req := httptest.NewRequest(http.MethodGet, "/records", nil)
   103  	w := httptest.NewRecorder()
   104  
   105  	providerAPIServer := &WebhookServer{
   106  		Provider: &FakeWebhookProvider{
   107  			err: fmt.Errorf("error"),
   108  		},
   109  	}
   110  	providerAPIServer.RecordsHandler(w, req)
   111  	res := w.Result()
   112  	require.Equal(t, http.StatusInternalServerError, res.StatusCode)
   113  }
   114  
   115  func TestRecordsHandlerApplyChangesWithBadRequest(t *testing.T) {
   116  	req := httptest.NewRequest(http.MethodPost, "/applychanges", nil)
   117  	w := httptest.NewRecorder()
   118  
   119  	providerAPIServer := &WebhookServer{
   120  		Provider: &FakeWebhookProvider{},
   121  	}
   122  	providerAPIServer.RecordsHandler(w, req)
   123  	res := w.Result()
   124  	require.Equal(t, http.StatusBadRequest, res.StatusCode)
   125  }
   126  
   127  func TestRecordsHandlerApplyChangesWithValidRequest(t *testing.T) {
   128  	changes := &plan.Changes{
   129  		Create: []*endpoint.Endpoint{
   130  			{
   131  				DNSName:    "foo.bar.com",
   132  				RecordType: "A",
   133  				Targets:    endpoint.Targets{},
   134  			},
   135  		},
   136  	}
   137  	j, err := json.Marshal(changes)
   138  	require.NoError(t, err)
   139  
   140  	reader := bytes.NewReader(j)
   141  
   142  	req := httptest.NewRequest(http.MethodPost, "/applychanges", reader)
   143  	w := httptest.NewRecorder()
   144  
   145  	providerAPIServer := &WebhookServer{
   146  		Provider: &FakeWebhookProvider{},
   147  	}
   148  	providerAPIServer.RecordsHandler(w, req)
   149  	res := w.Result()
   150  	require.Equal(t, http.StatusNoContent, res.StatusCode)
   151  }
   152  
   153  func TestRecordsHandlerApplyChangesWithErrors(t *testing.T) {
   154  	changes := &plan.Changes{
   155  		Create: []*endpoint.Endpoint{
   156  			{
   157  				DNSName:    "foo.bar.com",
   158  				RecordType: "A",
   159  				Targets:    endpoint.Targets{},
   160  			},
   161  		},
   162  	}
   163  	j, err := json.Marshal(changes)
   164  	require.NoError(t, err)
   165  
   166  	reader := bytes.NewReader(j)
   167  
   168  	req := httptest.NewRequest(http.MethodPost, "/applychanges", reader)
   169  	w := httptest.NewRecorder()
   170  
   171  	providerAPIServer := &WebhookServer{
   172  		Provider: &FakeWebhookProvider{
   173  			err: fmt.Errorf("error"),
   174  		},
   175  	}
   176  	providerAPIServer.RecordsHandler(w, req)
   177  	res := w.Result()
   178  	require.Equal(t, http.StatusInternalServerError, res.StatusCode)
   179  }
   180  
   181  func TestRecordsHandlerWithWrongHTTPMethod(t *testing.T) {
   182  	req := httptest.NewRequest(http.MethodPut, "/records", nil)
   183  	w := httptest.NewRecorder()
   184  
   185  	providerAPIServer := &WebhookServer{
   186  		Provider: &FakeWebhookProvider{},
   187  	}
   188  	providerAPIServer.RecordsHandler(w, req)
   189  	res := w.Result()
   190  	require.Equal(t, http.StatusBadRequest, res.StatusCode)
   191  }
   192  
   193  func TestAdjustEndpointsHandlerWithInvalidRequest(t *testing.T) {
   194  	req := httptest.NewRequest(http.MethodPost, "/adjustendpoints", nil)
   195  	w := httptest.NewRecorder()
   196  
   197  	providerAPIServer := &WebhookServer{
   198  		Provider: &FakeWebhookProvider{},
   199  	}
   200  	providerAPIServer.AdjustEndpointsHandler(w, req)
   201  	res := w.Result()
   202  	require.Equal(t, http.StatusBadRequest, res.StatusCode)
   203  
   204  	req = httptest.NewRequest(http.MethodGet, "/adjustendpoints", nil)
   205  
   206  	providerAPIServer.AdjustEndpointsHandler(w, req)
   207  	res = w.Result()
   208  	require.Equal(t, http.StatusBadRequest, res.StatusCode)
   209  }
   210  
   211  func TestAdjustEndpointsHandlerWithValidRequest(t *testing.T) {
   212  	pve := []*endpoint.Endpoint{
   213  		{
   214  			DNSName:    "foo.bar.com",
   215  			RecordType: "A",
   216  			Targets:    endpoint.Targets{},
   217  			RecordTTL:  0,
   218  		},
   219  	}
   220  
   221  	j, err := json.Marshal(pve)
   222  	require.NoError(t, err)
   223  
   224  	reader := bytes.NewReader(j)
   225  	req := httptest.NewRequest(http.MethodPost, "/adjustendpoints", reader)
   226  	w := httptest.NewRecorder()
   227  
   228  	providerAPIServer := &WebhookServer{
   229  		Provider: &FakeWebhookProvider{},
   230  	}
   231  	providerAPIServer.AdjustEndpointsHandler(w, req)
   232  	res := w.Result()
   233  	require.Equal(t, http.StatusOK, res.StatusCode)
   234  	require.NotNil(t, res.Body)
   235  }
   236  
   237  func TestAdjustEndpointsHandlerWithError(t *testing.T) {
   238  	pve := []*endpoint.Endpoint{
   239  		{
   240  			DNSName:    "foo.bar.com",
   241  			RecordType: "A",
   242  			Targets:    endpoint.Targets{},
   243  			RecordTTL:  0,
   244  		},
   245  	}
   246  
   247  	j, err := json.Marshal(pve)
   248  	require.NoError(t, err)
   249  
   250  	reader := bytes.NewReader(j)
   251  	req := httptest.NewRequest(http.MethodPost, "/adjustendpoints", reader)
   252  	w := httptest.NewRecorder()
   253  
   254  	providerAPIServer := &WebhookServer{
   255  		Provider: &FakeWebhookProvider{
   256  			err: fmt.Errorf("error"),
   257  		},
   258  	}
   259  	providerAPIServer.AdjustEndpointsHandler(w, req)
   260  	res := w.Result()
   261  	require.Equal(t, http.StatusInternalServerError, res.StatusCode)
   262  	require.NotNil(t, res.Body)
   263  }
   264  
   265  func TestStartHTTPApi(t *testing.T) {
   266  	startedChan := make(chan struct{})
   267  	go StartHTTPApi(FakeWebhookProvider{}, startedChan, 5*time.Second, 10*time.Second, "127.0.0.1:8887")
   268  	<-startedChan
   269  	resp, err := http.Get("http://127.0.0.1:8887")
   270  	require.NoError(t, err)
   271  	// check that resp has a valid domain filter
   272  	defer resp.Body.Close()
   273  
   274  	df := endpoint.DomainFilter{}
   275  	b, err := io.ReadAll(resp.Body)
   276  	require.NoError(t, err)
   277  	require.NoError(t, df.UnmarshalJSON(b))
   278  }