github.com/MetalBlockchain/metalgo@v1.11.9/api/server/allowed_hosts_test.go (about)

     1  // Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved.
     2  // See the file LICENSE for licensing terms.
     3  
     4  package server
     5  
     6  import (
     7  	"net/http"
     8  	"net/http/httptest"
     9  	"testing"
    10  
    11  	"github.com/stretchr/testify/require"
    12  )
    13  
    14  func TestAllowedHostsHandler_ServeHTTP(t *testing.T) {
    15  	tests := []struct {
    16  		name    string
    17  		allowed []string
    18  		host    string
    19  		serve   bool
    20  	}{
    21  		{
    22  			name:    "no host header",
    23  			allowed: []string{"www.foobar.com"},
    24  			host:    "",
    25  			serve:   true,
    26  		},
    27  		{
    28  			name:    "ip",
    29  			allowed: []string{"www.foobar.com"},
    30  			host:    "192.168.1.1",
    31  			serve:   true,
    32  		},
    33  		{
    34  			name:    "hostname not allowed",
    35  			allowed: []string{"www.foobar.com"},
    36  			host:    "www.evil.com",
    37  		},
    38  		{
    39  			name:    "hostname allowed",
    40  			allowed: []string{"www.foobar.com"},
    41  			host:    "www.foobar.com",
    42  			serve:   true,
    43  		},
    44  		{
    45  			name:    "wildcard",
    46  			allowed: []string{"*"},
    47  			host:    "www.foobar.com",
    48  			serve:   true,
    49  		},
    50  	}
    51  
    52  	for _, test := range tests {
    53  		t.Run(test.name, func(t *testing.T) {
    54  			require := require.New(t)
    55  
    56  			baseHandler := &testHandler{}
    57  
    58  			httpAllowedHostsHandler := filterInvalidHosts(
    59  				baseHandler,
    60  				test.allowed,
    61  			)
    62  
    63  			w := &httptest.ResponseRecorder{}
    64  			r := httptest.NewRequest("", "/", nil)
    65  			r.Host = test.host
    66  
    67  			httpAllowedHostsHandler.ServeHTTP(w, r)
    68  
    69  			if test.serve {
    70  				require.True(baseHandler.called)
    71  				return
    72  			}
    73  
    74  			require.Equal(http.StatusForbidden, w.Code)
    75  		})
    76  	}
    77  }