github.com/supabase/cli@v1.168.1/internal/utils/api_test.go (about)

     1  package utils
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"net"
     7  	"net/http"
     8  	"testing"
     9  
    10  	"github.com/stretchr/testify/assert"
    11  	"github.com/stretchr/testify/mock"
    12  	"github.com/supabase/cli/internal/testing/apitest"
    13  	"github.com/supabase/cli/internal/utils/cloudflare"
    14  	"gopkg.in/h2non/gock.v1"
    15  )
    16  
    17  const host = "api.supabase.io"
    18  
    19  func TestLookupIP(t *testing.T) {
    20  	t.Run("resolves IPv4 with CloudFlare", func(t *testing.T) {
    21  		// Setup http mock
    22  		defer gock.OffAll()
    23  		gock.New("https://1.1.1.1").
    24  			Get("/dns-query").
    25  			MatchParam("name", host).
    26  			MatchHeader("accept", "application/dns-json").
    27  			Reply(http.StatusOK).
    28  			JSON(&cloudflare.DNSResponse{Answer: []cloudflare.DNSAnswer{
    29  				{Type: cloudflare.TypeA, Data: "127.0.0.1"},
    30  			}})
    31  		// Run test
    32  		ip, err := FallbackLookupIP(context.Background(), host)
    33  		// Validate output
    34  		assert.NoError(t, err)
    35  		assert.ElementsMatch(t, []string{"127.0.0.1"}, ip)
    36  		assert.Empty(t, apitest.ListUnmatchedRequests())
    37  	})
    38  
    39  	t.Run("resolves IPv6 recursively", func(t *testing.T) {
    40  		// Setup http mock
    41  		defer gock.OffAll()
    42  		gock.New("https://1.1.1.1").
    43  			Get("/dns-query").
    44  			MatchParam("name", "api.supabase.com").
    45  			MatchHeader("accept", "application/dns-json").
    46  			Reply(http.StatusOK).
    47  			JSON(&cloudflare.DNSResponse{Answer: []cloudflare.DNSAnswer{
    48  				{Type: cloudflare.TypeCNAME, Data: "supabase-api.fly.dev."},
    49  				{Type: cloudflare.TypeAAAA, Data: "2606:2800:220:1:248:1893:25c8:1946"},
    50  			}})
    51  		// Run test
    52  		ip, err := FallbackLookupIP(context.Background(), "api.supabase.com")
    53  		// Validate output
    54  		assert.NoError(t, err)
    55  		assert.ElementsMatch(t, []string{"2606:2800:220:1:248:1893:25c8:1946"}, ip)
    56  		assert.Empty(t, apitest.ListUnmatchedRequests())
    57  	})
    58  
    59  	t.Run("returns immediately if already resolved", func(t *testing.T) {
    60  		// Run test
    61  		ip, err := FallbackLookupIP(context.Background(), "127.0.0.1")
    62  		// Validate output
    63  		assert.NoError(t, err)
    64  		assert.ElementsMatch(t, []string{"127.0.0.1"}, ip)
    65  		assert.Empty(t, apitest.ListUnmatchedRequests())
    66  	})
    67  
    68  	t.Run("empty on network failure", func(t *testing.T) {
    69  		// Setup http mock
    70  		defer gock.OffAll()
    71  		gock.New("https://1.1.1.1").
    72  			Get("/dns-query").
    73  			MatchParam("name", host).
    74  			MatchHeader("accept", "application/dns-json").
    75  			ReplyError(errors.New("network error"))
    76  		// Run test
    77  		ip, err := FallbackLookupIP(context.Background(), host)
    78  		// Validate output
    79  		assert.ErrorContains(t, err, "network error")
    80  		assert.Empty(t, ip)
    81  		assert.Empty(t, apitest.ListUnmatchedRequests())
    82  	})
    83  
    84  	t.Run("empty on service unavailable", func(t *testing.T) {
    85  		// Setup http mock
    86  		defer gock.OffAll()
    87  		gock.New("https://1.1.1.1").
    88  			Get("/dns-query").
    89  			MatchParam("name", host).
    90  			MatchHeader("accept", "application/dns-json").
    91  			Reply(http.StatusServiceUnavailable)
    92  		// Run test
    93  		ip, err := FallbackLookupIP(context.Background(), host)
    94  		// Validate output
    95  		assert.ErrorContains(t, err, "status 503")
    96  		assert.Empty(t, ip)
    97  		assert.Empty(t, apitest.ListUnmatchedRequests())
    98  	})
    99  
   100  	t.Run("empty on malformed json", func(t *testing.T) {
   101  		// Setup http mock
   102  		defer gock.OffAll()
   103  		gock.New("https://1.1.1.1").
   104  			Get("/dns-query").
   105  			MatchParam("name", host).
   106  			MatchHeader("accept", "application/dns-json").
   107  			Reply(http.StatusOK).
   108  			JSON("malformed")
   109  		// Run test
   110  		ip, err := FallbackLookupIP(context.Background(), host)
   111  		// Validate output
   112  		assert.ErrorContains(t, err, "invalid character 'm' looking for beginning of value")
   113  		assert.Empty(t, ip)
   114  		assert.Empty(t, apitest.ListUnmatchedRequests())
   115  	})
   116  
   117  	t.Run("empty on no answer", func(t *testing.T) {
   118  		// Setup http mock
   119  		defer gock.OffAll()
   120  		gock.New("https://1.1.1.1").
   121  			Get("/dns-query").
   122  			MatchParam("name", host).
   123  			MatchHeader("accept", "application/dns-json").
   124  			Reply(http.StatusOK).
   125  			JSON(&cloudflare.DNSResponse{})
   126  		// Run test
   127  		ip, err := FallbackLookupIP(context.Background(), host)
   128  		// Validate output
   129  		assert.ErrorContains(t, err, "failed to locate valid IP for api.supabase.io; resolves to []cloudflare.DNSAnswer(nil)")
   130  		assert.Empty(t, ip)
   131  		assert.Empty(t, apitest.ListUnmatchedRequests())
   132  	})
   133  }
   134  
   135  func TestResolveCNAME(t *testing.T) {
   136  	t.Run("resolves CNAMEs with CloudFlare", func(t *testing.T) {
   137  		defer gock.OffAll()
   138  		gock.New("https://1.1.1.1").
   139  			Get("/dns-query").
   140  			MatchParam("name", host).
   141  			MatchParam("type", "5").
   142  			MatchHeader("accept", "application/dns-json").
   143  			Reply(http.StatusOK).
   144  			JSON(&cloudflare.DNSResponse{Answer: []cloudflare.DNSAnswer{
   145  				{Type: cloudflare.TypeCNAME, Data: "foobarbaz.supabase.co"},
   146  			}})
   147  		// Run test
   148  		cname, err := ResolveCNAME(context.Background(), host)
   149  		// Validate output
   150  		assert.Equal(t, "foobarbaz.supabase.co", cname)
   151  		assert.Nil(t, err)
   152  		assert.Empty(t, apitest.ListUnmatchedRequests())
   153  	})
   154  
   155  	t.Run("missing CNAMEs return an error", func(t *testing.T) {
   156  		defer gock.OffAll()
   157  		gock.New("https://1.1.1.1").
   158  			Get("/dns-query").
   159  			MatchParam("name", host).
   160  			MatchParam("type", "5").
   161  			MatchHeader("accept", "application/dns-json").
   162  			Reply(http.StatusOK).
   163  			JSON(&cloudflare.DNSResponse{Answer: []cloudflare.DNSAnswer{}})
   164  		// Run test
   165  		cname, err := ResolveCNAME(context.Background(), host)
   166  		// Validate output
   167  		assert.Empty(t, cname)
   168  		assert.ErrorContains(t, err, "failed to locate appropriate CNAME record for api.supabase.io")
   169  		assert.Empty(t, apitest.ListUnmatchedRequests())
   170  	})
   171  
   172  	t.Run("missing CNAMEs return an error", func(t *testing.T) {
   173  		defer gock.OffAll()
   174  		gock.New("https://1.1.1.1").
   175  			Get("/dns-query").
   176  			MatchParam("name", host).
   177  			MatchParam("type", "5").
   178  			MatchHeader("accept", "application/dns-json").
   179  			Reply(http.StatusOK).
   180  			JSON(&cloudflare.DNSResponse{Answer: []cloudflare.DNSAnswer{
   181  				{Type: cloudflare.TypeA, Data: "127.0.0.1"},
   182  			}})
   183  		// Run test
   184  		cname, err := ResolveCNAME(context.Background(), host)
   185  		// Validate output
   186  		assert.Empty(t, cname)
   187  		assert.ErrorContains(t, err, "failed to locate appropriate CNAME record for api.supabase.io")
   188  		assert.Empty(t, apitest.ListUnmatchedRequests())
   189  	})
   190  }
   191  
   192  type MockDialer struct {
   193  	mock.Mock
   194  }
   195  
   196  func (m *MockDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
   197  	args := m.Called(ctx, network, address)
   198  	if conn, ok := args.Get(0).(net.Conn); ok {
   199  		return conn, args.Error(1)
   200  	}
   201  	return nil, args.Error(1)
   202  }
   203  
   204  func TestFallbackDNS(t *testing.T) {
   205  	errNetwork := errors.New("network error")
   206  	errDNS := &net.DNSError{
   207  		IsTimeout: true,
   208  	}
   209  
   210  	t.Run("overrides DialContext with DoH", func(t *testing.T) {
   211  		DNSResolver.Value = DNS_OVER_HTTPS
   212  		// Setup mock dialer
   213  		dialer := MockDialer{}
   214  		dialer.On("DialContext", mock.Anything, mock.Anything, "127.0.0.1:80").
   215  			Return(nil, errNetwork)
   216  		wrapped := withFallbackDNS(dialer.DialContext)
   217  		// Setup http mock
   218  		defer gock.OffAll()
   219  		gock.New("https://1.1.1.1").
   220  			Get("/dns-query").
   221  			MatchParam("name", host).
   222  			MatchHeader("accept", "application/dns-json").
   223  			Reply(http.StatusOK).
   224  			JSON(&cloudflare.DNSResponse{Answer: []cloudflare.DNSAnswer{
   225  				{Type: cloudflare.TypeA, Data: "127.0.0.1"},
   226  			}})
   227  		// Run test
   228  		conn, err := wrapped(context.Background(), "udp", host+":80")
   229  		// Check error
   230  		assert.ErrorIs(t, err, errNetwork)
   231  		assert.Nil(t, conn)
   232  		dialer.AssertExpectations(t)
   233  		assert.Empty(t, apitest.ListUnmatchedRequests())
   234  	})
   235  
   236  	t.Run("native with DoH fallback", func(t *testing.T) {
   237  		DNSResolver.Value = DNS_GO_NATIVE
   238  		// Setup mock dialer
   239  		dialer := MockDialer{}
   240  		dialer.On("DialContext", mock.Anything, mock.Anything, host+":80").
   241  			Return(nil, errDNS)
   242  		dialer.On("DialContext", mock.Anything, mock.Anything, "127.0.0.1:80").
   243  			Return(nil, nil)
   244  		wrapped := withFallbackDNS(dialer.DialContext)
   245  		// Setup http mock
   246  		defer gock.OffAll()
   247  		gock.New("https://1.1.1.1").
   248  			Get("/dns-query").
   249  			MatchParam("name", host).
   250  			MatchHeader("accept", "application/dns-json").
   251  			Reply(http.StatusOK).
   252  			JSON(&cloudflare.DNSResponse{Answer: []cloudflare.DNSAnswer{
   253  				{Type: cloudflare.TypeA, Data: "127.0.0.1"},
   254  			}})
   255  		// Run test
   256  		conn, err := wrapped(context.Background(), "udp", host+":80")
   257  		// Check error
   258  		assert.NoError(t, err)
   259  		assert.Nil(t, conn)
   260  		dialer.AssertExpectations(t)
   261  		assert.Empty(t, apitest.ListUnmatchedRequests())
   262  	})
   263  
   264  	t.Run("throws error on malformed address", func(t *testing.T) {
   265  		DNSResolver.Value = DNS_OVER_HTTPS
   266  		// Setup mock dialer
   267  		dialer := MockDialer{}
   268  		wrapped := withFallbackDNS(dialer.DialContext)
   269  		// Run test
   270  		conn, err := wrapped(context.Background(), "udp", "bad?url")
   271  		// Check error
   272  		assert.ErrorContains(t, err, "missing port in address")
   273  		assert.Nil(t, conn)
   274  		assert.Empty(t, apitest.ListUnmatchedRequests())
   275  	})
   276  
   277  	t.Run("throws error on fallback failure", func(t *testing.T) {
   278  		DNSResolver.Value = DNS_GO_NATIVE
   279  		// Setup mock dialer
   280  		dialer := MockDialer{}
   281  		dialer.On("DialContext", mock.Anything, mock.Anything, host+":80").
   282  			Return(nil, errDNS)
   283  		wrapped := withFallbackDNS(dialer.DialContext)
   284  		// Setup http mock
   285  		defer gock.OffAll()
   286  		gock.New("https://1.1.1.1").
   287  			Get("/dns-query").
   288  			MatchParam("name", host).
   289  			MatchHeader("accept", "application/dns-json").
   290  			ReplyError(errNetwork)
   291  		// Run test
   292  		conn, err := wrapped(context.Background(), "udp", host+":80")
   293  		// Check error
   294  		assert.ErrorIs(t, err, errDNS)
   295  		assert.Nil(t, conn)
   296  		dialer.AssertExpectations(t)
   297  		assert.Empty(t, apitest.ListUnmatchedRequests())
   298  	})
   299  }