github.com/supabase/cli@v1.168.1/internal/restrictions/update/update_test.go (about)

     1  package update
     2  
     3  import (
     4  	"context"
     5  	"net/http"
     6  	"testing"
     7  
     8  	"github.com/go-errors/errors"
     9  	"github.com/stretchr/testify/assert"
    10  	"github.com/supabase/cli/internal/testing/apitest"
    11  	"github.com/supabase/cli/internal/utils"
    12  	"github.com/supabase/cli/pkg/api"
    13  	"gopkg.in/h2non/gock.v1"
    14  )
    15  
    16  func TestUpdateRestrictionsCommand(t *testing.T) {
    17  	projectRef := apitest.RandomProjectRef()
    18  	// Setup valid access token
    19  	token := apitest.RandomAccessToken(t)
    20  	t.Setenv("SUPABASE_ACCESS_TOKEN", string(token))
    21  
    22  	t.Run("updates v4 and v6 CIDR", func(t *testing.T) {
    23  		// Setup mock api
    24  		defer gock.OffAll()
    25  		gock.New(utils.DefaultApiHost).
    26  			Post("/v1/projects/" + projectRef + "/network-restrictions/apply").
    27  			MatchType("json").
    28  			JSON(api.NetworkRestrictionsRequest{
    29  				DbAllowedCidrs:   &[]string{"12.3.4.5/32", "1.2.3.1/24"},
    30  				DbAllowedCidrsV6: &[]string{"2001:db8:abcd:0012::0/64"},
    31  			}).
    32  			Reply(http.StatusCreated).
    33  			JSON(api.NetworkRestrictionsResponse{
    34  				Status: api.NetworkRestrictionsResponseStatus("applied"),
    35  			})
    36  		// Run test
    37  		err := Run(context.Background(), projectRef, []string{"12.3.4.5/32", "2001:db8:abcd:0012::0/64", "1.2.3.1/24"}, false)
    38  		// Check error
    39  		assert.NoError(t, err)
    40  		assert.Empty(t, apitest.ListUnmatchedRequests())
    41  	})
    42  
    43  	t.Run("throws error on network failure", func(t *testing.T) {
    44  		errNetwork := errors.New("network error")
    45  		// Setup mock api
    46  		defer gock.OffAll()
    47  		gock.New(utils.DefaultApiHost).
    48  			Post("/v1/projects/" + projectRef + "/network-restrictions/apply").
    49  			MatchType("json").
    50  			JSON(api.NetworkRestrictionsRequest{
    51  				DbAllowedCidrs:   &[]string{},
    52  				DbAllowedCidrsV6: &[]string{},
    53  			}).
    54  			ReplyError(errNetwork)
    55  		// Run test
    56  		err := Run(context.Background(), projectRef, []string{}, true)
    57  		// Check error
    58  		assert.ErrorIs(t, err, errNetwork)
    59  		assert.Empty(t, apitest.ListUnmatchedRequests())
    60  	})
    61  
    62  	t.Run("throws error on server unavailable", func(t *testing.T) {
    63  		// Setup mock api
    64  		defer gock.OffAll()
    65  		gock.New(utils.DefaultApiHost).
    66  			Post("/v1/projects/" + projectRef + "/network-restrictions/apply").
    67  			MatchType("json").
    68  			JSON(api.NetworkRestrictionsRequest{
    69  				DbAllowedCidrs:   &[]string{},
    70  				DbAllowedCidrsV6: &[]string{},
    71  			}).
    72  			Reply(http.StatusServiceUnavailable)
    73  		// Run test
    74  		err := Run(context.Background(), projectRef, []string{}, true)
    75  		// Check error
    76  		assert.ErrorContains(t, err, "failed to apply network restrictions:")
    77  		assert.Empty(t, apitest.ListUnmatchedRequests())
    78  	})
    79  }
    80  
    81  func TestValidateCIDR(t *testing.T) {
    82  	projectRef := apitest.RandomProjectRef()
    83  	// Setup valid access token
    84  	token := apitest.RandomAccessToken(t)
    85  	t.Setenv("SUPABASE_ACCESS_TOKEN", string(token))
    86  
    87  	t.Run("bypasses private subnet checks", func(t *testing.T) {
    88  		// Setup mock api
    89  		defer gock.OffAll()
    90  		gock.New(utils.DefaultApiHost).
    91  			Post("/v1/projects/" + projectRef + "/network-restrictions/apply").
    92  			MatchType("json").
    93  			JSON(api.NetworkRestrictionsRequest{
    94  				DbAllowedCidrs:   &[]string{"10.0.0.0/8"},
    95  				DbAllowedCidrsV6: &[]string{},
    96  			}).
    97  			Reply(http.StatusCreated).
    98  			JSON(api.NetworkRestrictionsResponse{
    99  				Status: api.NetworkRestrictionsResponseStatus("applied"),
   100  			})
   101  		// Run test
   102  		err := Run(context.Background(), projectRef, []string{"10.0.0.0/8"}, true)
   103  		// Check error
   104  		assert.NoError(t, err)
   105  		assert.Empty(t, apitest.ListUnmatchedRequests())
   106  	})
   107  
   108  	t.Run("throws error on private subnet", func(t *testing.T) {
   109  		// Run test
   110  		err := Run(context.Background(), projectRef, []string{"12.3.4.5/32", "10.0.0.0/8", "1.2.3.1/24"}, false)
   111  		// Check error
   112  		assert.ErrorContains(t, err, "private IP provided: 10.0.0.0/8")
   113  	})
   114  
   115  	t.Run("throws error on invalid subnet", func(t *testing.T) {
   116  		// Run test
   117  		err := Run(context.Background(), projectRef, []string{"12.3.4.5", "10.0.0.0/8", "1.2.3.1/24"}, false)
   118  		// Check error
   119  		assert.ErrorContains(t, err, "failed to parse IP: 12.3.4.5")
   120  	})
   121  }