github.com/pelicanplatform/pelican@v1.0.5/director/director_test.go (about)

     1  package director
     2  
     3  import (
     4  	"encoding/json"
     5  	"net/http"
     6  	"net/http/httptest"
     7  	"testing"
     8  
     9  	"github.com/gin-gonic/gin"
    10  	"github.com/jellydator/ttlcache/v3"
    11  	"github.com/stretchr/testify/assert"
    12  	"github.com/stretchr/testify/require"
    13  )
    14  
    15  func TestListServers(t *testing.T) {
    16  	router := gin.Default()
    17  
    18  	router.GET("/servers", listServers)
    19  
    20  	func() {
    21  		serverAdMutex.Lock()
    22  		defer serverAdMutex.Unlock()
    23  		serverAds.Set(mockOriginServerAd, mockNamespaceAds(5, "origin1"), ttlcache.DefaultTTL)
    24  		serverAds.Set(mockCacheServerAd, mockNamespaceAds(4, "cache1"), ttlcache.DefaultTTL)
    25  		require.True(t, serverAds.Has(mockOriginServerAd))
    26  		require.True(t, serverAds.Has(mockCacheServerAd))
    27  	}()
    28  
    29  	mocklistOriginRes := listServerResponse{
    30  		Name:      mockOriginServerAd.Name,
    31  		AuthURL:   mockOriginServerAd.AuthURL.String(),
    32  		URL:       mockOriginServerAd.URL.String(),
    33  		WebURL:    mockOriginServerAd.WebURL.String(),
    34  		Type:      mockOriginServerAd.Type,
    35  		Latitude:  mockOriginServerAd.Latitude,
    36  		Longitude: mockOriginServerAd.Longitude,
    37  	}
    38  	mocklistCacheRes := listServerResponse{
    39  		Name:      mockCacheServerAd.Name,
    40  		AuthURL:   mockCacheServerAd.AuthURL.String(),
    41  		URL:       mockCacheServerAd.URL.String(),
    42  		WebURL:    mockCacheServerAd.WebURL.String(),
    43  		Type:      mockCacheServerAd.Type,
    44  		Latitude:  mockCacheServerAd.Latitude,
    45  		Longitude: mockCacheServerAd.Longitude,
    46  	}
    47  
    48  	t.Run("query-origin", func(t *testing.T) {
    49  		// Create a request to the endpoint
    50  		w := httptest.NewRecorder()
    51  		req, _ := http.NewRequest("GET", "/servers?server_type=origin", nil)
    52  		router.ServeHTTP(w, req)
    53  
    54  		// Check the response
    55  		require.Equal(t, 200, w.Code)
    56  
    57  		var got []listServerResponse
    58  		err := json.Unmarshal(w.Body.Bytes(), &got)
    59  		if err != nil {
    60  			t.Fatalf("Failed to unmarshal response body: %v", err)
    61  		}
    62  		require.Equal(t, 1, len(got))
    63  		assert.Equal(t, mocklistOriginRes, got[0], "Response data does not match expected")
    64  	})
    65  
    66  	t.Run("query-cache", func(t *testing.T) {
    67  		// Create a request to the endpoint
    68  		w := httptest.NewRecorder()
    69  		req, _ := http.NewRequest("GET", "/servers?server_type=cache", nil)
    70  		router.ServeHTTP(w, req)
    71  
    72  		// Check the response
    73  		require.Equal(t, 200, w.Code)
    74  
    75  		var got []listServerResponse
    76  		err := json.Unmarshal(w.Body.Bytes(), &got)
    77  		if err != nil {
    78  			t.Fatalf("Failed to unmarshal response body: %v", err)
    79  		}
    80  		require.Equal(t, 1, len(got))
    81  		assert.Equal(t, mocklistCacheRes, got[0], "Response data does not match expected")
    82  	})
    83  
    84  	t.Run("query-all-with-empty-server-type", func(t *testing.T) {
    85  		// Create a request to the endpoint
    86  		w := httptest.NewRecorder()
    87  		req, _ := http.NewRequest("GET", "/servers?server_type=", nil)
    88  		router.ServeHTTP(w, req)
    89  
    90  		// Check the response
    91  		require.Equal(t, 200, w.Code)
    92  
    93  		var got []listServerResponse
    94  		err := json.Unmarshal(w.Body.Bytes(), &got)
    95  		if err != nil {
    96  			t.Fatalf("Failed to unmarshal response body: %v", err)
    97  		}
    98  		require.Equal(t, 2, len(got))
    99  	})
   100  
   101  	t.Run("query-all-without-query-param", func(t *testing.T) {
   102  		// Create a request to the endpoint
   103  		w := httptest.NewRecorder()
   104  		req, _ := http.NewRequest("GET", "/servers", nil)
   105  		router.ServeHTTP(w, req)
   106  
   107  		// Check the response
   108  		require.Equal(t, 200, w.Code)
   109  
   110  		var got []listServerResponse
   111  		err := json.Unmarshal(w.Body.Bytes(), &got)
   112  		if err != nil {
   113  			t.Fatalf("Failed to unmarshal response body: %v", err)
   114  		}
   115  		require.Equal(t, 2, len(got))
   116  	})
   117  
   118  	t.Run("query-with-invalid-param", func(t *testing.T) {
   119  		// Create a request to the endpoint
   120  		w := httptest.NewRecorder()
   121  		req, _ := http.NewRequest("GET", "/servers?server_type=staging", nil)
   122  		router.ServeHTTP(w, req)
   123  
   124  		// Check the response
   125  		require.Equal(t, 400, w.Code)
   126  	})
   127  }