github.com/supabase/cli@v1.168.1/internal/sso/update/update_test.go (about)

     1  package update
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"fmt"
     7  	"net/http"
     8  	"testing"
     9  
    10  	"github.com/google/uuid"
    11  	"github.com/stretchr/testify/assert"
    12  	"github.com/supabase/cli/internal/testing/apitest"
    13  	"github.com/supabase/cli/internal/utils"
    14  	"github.com/supabase/cli/pkg/api"
    15  	"gopkg.in/h2non/gock.v1"
    16  )
    17  
    18  func response(providerId string, domains []string) map[string]any {
    19  	resp := map[string]any{
    20  		"id":         providerId,
    21  		"created_at": "2023-03-28T13:50:14.464Z",
    22  		"updated_at": "2023-03-28T13:50:14.464Z",
    23  		"saml": map[string]any{
    24  			"id":           "8682fcf4-4056-455c-bd93-f33295604929",
    25  			"metadata_url": "https://example.com",
    26  			"metadata_xml": "<?xml version=\"2.0\"?>",
    27  			"entity_id":    "https://example.com",
    28  			"attribute_mapping": map[string]any{
    29  				"keys": map[string]any{
    30  					"a": map[string]any{
    31  						"name": "xyz",
    32  						"names": []string{
    33  							"x",
    34  							"y",
    35  							"z",
    36  						},
    37  						"default": 3,
    38  					},
    39  				},
    40  			},
    41  			"created_at": "2023-03-28T13:50:14.464Z",
    42  			"updated_at": "2023-03-28T13:50:14.464Z",
    43  		},
    44  		"domains": []map[string]any{},
    45  	}
    46  
    47  	for _, domain := range domains {
    48  		respDomains := resp["domains"].([]map[string]any)
    49  		resp["domains"] = append(respDomains, map[string]any{
    50  			"id":         "9484591c-a203-4500-bea7-d0aaa845e2f5",
    51  			"domain":     domain,
    52  			"created_at": "2023-03-28T13:50:14.464Z",
    53  			"updated_at": "2023-03-28T13:50:14.464Z",
    54  		})
    55  	}
    56  
    57  	return resp
    58  }
    59  
    60  func TestSSOProvidersUpdateCommand(t *testing.T) {
    61  	t.Run("update provider", func(t *testing.T) {
    62  		// Setup valid access token
    63  		token := apitest.RandomAccessToken(t)
    64  		t.Setenv("SUPABASE_ACCESS_TOKEN", string(token))
    65  
    66  		// Flush pending mocks after test execution
    67  		defer gock.OffAll()
    68  
    69  		projectRef := apitest.RandomProjectRef()
    70  		providerId := uuid.New().String()
    71  
    72  		gock.New(utils.DefaultApiHost).
    73  			Get("/v1/projects/" + projectRef + "/config/auth/sso/providers/" + providerId).
    74  			Reply(200).
    75  			JSON(response(providerId, []string{"example.com"}))
    76  
    77  		gock.New(utils.DefaultApiHost).
    78  			Put("/v1/projects/" + projectRef + "/config/auth/sso/providers/" + providerId).
    79  			Reply(200).
    80  			JSON(response(providerId, []string{"new-domain.com"}))
    81  
    82  		observed := 0
    83  		gock.Observe(func(r *http.Request, mock gock.Mock) {
    84  			if r.Method != http.MethodPut {
    85  				return
    86  			}
    87  			observed += 1
    88  
    89  			var body api.UpdateProviderByIdJSONRequestBody
    90  			assert.NoError(t, json.NewDecoder(r.Body).Decode(&body))
    91  
    92  			assert.NotNil(t, body.Domains)
    93  			assert.Equal(t, 1, len(*body.Domains))
    94  			assert.Equal(t, "new-domain.com", (*body.Domains)[0])
    95  		})
    96  
    97  		// Run test
    98  		assert.NoError(t, Run(context.Background(), RunParams{
    99  			ProjectRef: projectRef,
   100  			ProviderID: providerId,
   101  			Format:     utils.OutputPretty,
   102  
   103  			Domains: []string{
   104  				"new-domain.com",
   105  			},
   106  		}))
   107  		// Validate api
   108  		assert.Empty(t, apitest.ListUnmatchedRequests())
   109  		assert.Equal(t, 1, observed)
   110  	})
   111  
   112  	t.Run("update provider with --add-domains and --remove-domains", func(t *testing.T) {
   113  		// Setup valid access token
   114  		token := apitest.RandomAccessToken(t)
   115  		t.Setenv("SUPABASE_ACCESS_TOKEN", string(token))
   116  
   117  		// Flush pending mocks after test execution
   118  		defer gock.OffAll()
   119  
   120  		projectRef := apitest.RandomProjectRef()
   121  		providerId := uuid.New().String()
   122  
   123  		gock.New(utils.DefaultApiHost).
   124  			Get("/v1/projects/" + projectRef + "/config/auth/sso/providers/" + providerId).
   125  			Reply(200).
   126  			JSON(response(providerId, []string{"example.com"}))
   127  
   128  		gock.New(utils.DefaultApiHost).
   129  			Put("/v1/projects/" + projectRef + "/config/auth/sso/providers/" + providerId).
   130  			Reply(200).
   131  			JSON(response(providerId, []string{"new-domain.com"}))
   132  
   133  		observed := 0
   134  		gock.Observe(func(r *http.Request, mock gock.Mock) {
   135  			if r.Method != http.MethodPut {
   136  				return
   137  			}
   138  			observed += 1
   139  
   140  			var body api.UpdateProviderByIdJSONRequestBody
   141  			assert.NoError(t, json.NewDecoder(r.Body).Decode(&body))
   142  
   143  			assert.NotNil(t, body.Domains)
   144  			assert.Equal(t, 1, len(*body.Domains))
   145  			assert.Equal(t, "new-domain.com", (*body.Domains)[0])
   146  		})
   147  
   148  		// Run test
   149  		assert.NoError(t, Run(context.Background(), RunParams{
   150  			ProjectRef: projectRef,
   151  			ProviderID: providerId,
   152  			Format:     utils.OutputPretty,
   153  
   154  			AddDomains: []string{
   155  				"new-domain.com",
   156  			},
   157  			RemoveDomains: []string{
   158  				"example.com",
   159  			},
   160  		}))
   161  		// Validate api
   162  		assert.Empty(t, apitest.ListUnmatchedRequests())
   163  		assert.Equal(t, 1, observed)
   164  	})
   165  
   166  	t.Run("update provider that does not exist", func(t *testing.T) {
   167  		// Setup valid access token
   168  		token := apitest.RandomAccessToken(t)
   169  		t.Setenv("SUPABASE_ACCESS_TOKEN", string(token))
   170  
   171  		// Flush pending mocks after test execution
   172  		defer gock.OffAll()
   173  
   174  		projectRef := apitest.RandomProjectRef()
   175  		providerId := uuid.New().String()
   176  
   177  		gock.New(utils.DefaultApiHost).
   178  			Get("/v1/projects/" + projectRef + "/config/auth/sso/providers/" + providerId).
   179  			Reply(404).
   180  			JSON(map[string]string{})
   181  
   182  		err := Run(context.Background(), RunParams{
   183  			ProjectRef: projectRef,
   184  			ProviderID: providerId,
   185  			Format:     utils.OutputPretty,
   186  		})
   187  
   188  		// Run test
   189  		assert.Error(t, err)
   190  		assert.Equal(t, err.Error(), fmt.Sprintf("An identity provider with ID %q could not be found.", providerId))
   191  
   192  		// Validate api
   193  		assert.Empty(t, apitest.ListUnmatchedRequests())
   194  	})
   195  }