github.com/supabase/cli@v1.168.1/internal/link/link.go (about)

     1  package link
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"io"
     7  	"os"
     8  	"strconv"
     9  	"strings"
    10  	"sync"
    11  
    12  	"github.com/BurntSushi/toml"
    13  	"github.com/go-errors/errors"
    14  	"github.com/jackc/pgconn"
    15  	"github.com/jackc/pgx/v4"
    16  	"github.com/spf13/afero"
    17  	"github.com/spf13/viper"
    18  	"github.com/supabase/cli/internal/migration/history"
    19  	"github.com/supabase/cli/internal/utils"
    20  	"github.com/supabase/cli/internal/utils/credentials"
    21  	"github.com/supabase/cli/internal/utils/flags"
    22  	"github.com/supabase/cli/internal/utils/tenant"
    23  	"github.com/supabase/cli/pkg/api"
    24  )
    25  
    26  var updatedConfig ConfigCopy
    27  
    28  type ConfigCopy struct {
    29  	Api    interface{} `toml:"api"`
    30  	Db     interface{} `toml:"db"`
    31  	Pooler interface{} `toml:"db.pooler"`
    32  }
    33  
    34  func (c ConfigCopy) IsEmpty() bool {
    35  	return c.Api == nil && c.Db == nil && c.Pooler == nil
    36  }
    37  
    38  func Run(ctx context.Context, projectRef string, fsys afero.Fs, options ...func(*pgx.ConnConfig)) error {
    39  	// 1. Check service config
    40  	keys, err := tenant.GetApiKeys(ctx, projectRef)
    41  	if err != nil {
    42  		return err
    43  	}
    44  	LinkServices(ctx, projectRef, keys.Anon, fsys)
    45  
    46  	// 2. Check database connection
    47  	config := flags.GetDbConfigOptionalPassword(projectRef)
    48  	if len(config.Password) > 0 {
    49  		if err := linkDatabase(ctx, config, options...); err != nil {
    50  			return err
    51  		}
    52  		// Save database password
    53  		if err := credentials.Set(projectRef, config.Password); err != nil {
    54  			fmt.Fprintln(os.Stderr, "Failed to save database password:", err)
    55  		}
    56  	}
    57  
    58  	// 3. Save project ref
    59  	return utils.WriteFile(utils.ProjectRefPath, []byte(projectRef), fsys)
    60  }
    61  
    62  func PostRun(projectRef string, stdout io.Writer, fsys afero.Fs) error {
    63  	fmt.Fprintln(stdout, "Finished "+utils.Aqua("supabase link")+".")
    64  	if updatedConfig.IsEmpty() {
    65  		return nil
    66  	}
    67  	fmt.Fprintln(os.Stderr, utils.Yellow("Warning:"), "Local config differs from linked project. Try updating", utils.Bold(utils.ConfigPath))
    68  	enc := toml.NewEncoder(stdout)
    69  	enc.Indent = ""
    70  	if err := enc.Encode(updatedConfig); err != nil {
    71  		return errors.Errorf("failed to marshal toml config: %w", err)
    72  	}
    73  	return nil
    74  }
    75  
    76  func LinkServices(ctx context.Context, projectRef, anonKey string, fsys afero.Fs) {
    77  	// Ignore non-fatal errors linking services
    78  	var wg sync.WaitGroup
    79  	wg.Add(6)
    80  	go func() {
    81  		defer wg.Done()
    82  		if err := linkDatabaseVersion(ctx, projectRef, fsys); err != nil && viper.GetBool("DEBUG") {
    83  			fmt.Fprintln(os.Stderr, err)
    84  		}
    85  	}()
    86  	go func() {
    87  		defer wg.Done()
    88  		if err := linkPostgrest(ctx, projectRef); err != nil && viper.GetBool("DEBUG") {
    89  			fmt.Fprintln(os.Stderr, err)
    90  		}
    91  	}()
    92  	go func() {
    93  		defer wg.Done()
    94  		if err := linkPooler(ctx, projectRef, fsys); err != nil && viper.GetBool("DEBUG") {
    95  			fmt.Fprintln(os.Stderr, err)
    96  		}
    97  	}()
    98  	api := tenant.NewTenantAPI(ctx, projectRef, anonKey)
    99  	go func() {
   100  		defer wg.Done()
   101  		if err := linkPostgrestVersion(ctx, api, fsys); err != nil && viper.GetBool("DEBUG") {
   102  			fmt.Fprintln(os.Stderr, err)
   103  		}
   104  	}()
   105  	go func() {
   106  		defer wg.Done()
   107  		if err := linkGotrueVersion(ctx, api, fsys); err != nil && viper.GetBool("DEBUG") {
   108  			fmt.Fprintln(os.Stderr, err)
   109  		}
   110  	}()
   111  	go func() {
   112  		defer wg.Done()
   113  		if err := linkStorageVersion(ctx, api, fsys); err != nil && viper.GetBool("DEBUG") {
   114  			fmt.Fprintln(os.Stderr, err)
   115  		}
   116  	}()
   117  	wg.Wait()
   118  }
   119  
   120  func linkPostgrest(ctx context.Context, projectRef string) error {
   121  	resp, err := utils.GetSupabase().GetPostgRESTConfigWithResponse(ctx, projectRef)
   122  	if err != nil {
   123  		return errors.Errorf("failed to get postgrest config: %w", err)
   124  	}
   125  	if resp.JSON200 == nil {
   126  		return errors.Errorf("%w: %s", tenant.ErrAuthToken, string(resp.Body))
   127  	}
   128  	updateApiConfig(*resp.JSON200)
   129  	return nil
   130  }
   131  
   132  func linkPostgrestVersion(ctx context.Context, api tenant.TenantAPI, fsys afero.Fs) error {
   133  	version, err := api.GetPostgrestVersion(ctx)
   134  	if err != nil {
   135  		return err
   136  	}
   137  	return utils.WriteFile(utils.RestVersionPath, []byte(version), fsys)
   138  }
   139  
   140  func updateApiConfig(config api.PostgrestConfigWithJWTSecretResponse) {
   141  	copy := utils.Config.Api
   142  	copy.MaxRows = uint(config.MaxRows)
   143  	copy.ExtraSearchPath = readCsv(config.DbExtraSearchPath)
   144  	copy.Schemas = readCsv(config.DbSchema)
   145  	changed := utils.Config.Api.MaxRows != copy.MaxRows ||
   146  		!utils.SliceEqual(utils.Config.Api.ExtraSearchPath, copy.ExtraSearchPath) ||
   147  		!utils.SliceEqual(utils.Config.Api.Schemas, copy.Schemas)
   148  	if changed {
   149  		updatedConfig.Api = copy
   150  	}
   151  }
   152  
   153  func readCsv(line string) []string {
   154  	var result []string
   155  	tokens := strings.Split(line, ",")
   156  	for _, t := range tokens {
   157  		trimmed := strings.TrimSpace(t)
   158  		if len(trimmed) > 0 {
   159  			result = append(result, trimmed)
   160  		}
   161  	}
   162  	return result
   163  }
   164  
   165  func linkGotrueVersion(ctx context.Context, api tenant.TenantAPI, fsys afero.Fs) error {
   166  	version, err := api.GetGotrueVersion(ctx)
   167  	if err != nil {
   168  		return err
   169  	}
   170  	return utils.WriteFile(utils.GotrueVersionPath, []byte(version), fsys)
   171  }
   172  
   173  func linkStorageVersion(ctx context.Context, api tenant.TenantAPI, fsys afero.Fs) error {
   174  	version, err := api.GetStorageVersion(ctx)
   175  	if err != nil {
   176  		return err
   177  	}
   178  	return utils.WriteFile(utils.StorageVersionPath, []byte(version), fsys)
   179  }
   180  
   181  func linkDatabase(ctx context.Context, config pgconn.Config, options ...func(*pgx.ConnConfig)) error {
   182  	conn, err := utils.ConnectByConfig(ctx, config, options...)
   183  	if err != nil {
   184  		return err
   185  	}
   186  	defer conn.Close(context.Background())
   187  	updatePostgresConfig(conn)
   188  	// If `schema_migrations` doesn't exist on the remote database, create it.
   189  	return history.CreateMigrationTable(ctx, conn)
   190  }
   191  
   192  func linkDatabaseVersion(ctx context.Context, projectRef string, fsys afero.Fs) error {
   193  	version, err := tenant.GetDatabaseVersion(ctx, projectRef)
   194  	if err != nil {
   195  		return err
   196  	}
   197  	return utils.WriteFile(utils.PostgresVersionPath, []byte(version), fsys)
   198  }
   199  
   200  func updatePostgresConfig(conn *pgx.Conn) {
   201  	serverVersion := conn.PgConn().ParameterStatus("server_version")
   202  	// Safe to assume that supported Postgres version is 10.0 <= n < 100.0
   203  	majorDigits := len(serverVersion)
   204  	if majorDigits > 2 {
   205  		majorDigits = 2
   206  	}
   207  	dbMajorVersion, err := strconv.ParseUint(serverVersion[:majorDigits], 10, 7)
   208  	// Treat error as unchanged
   209  	if err == nil && uint64(utils.Config.Db.MajorVersion) != dbMajorVersion {
   210  		copy := utils.Config.Db
   211  		copy.MajorVersion = uint(dbMajorVersion)
   212  		updatedConfig.Db = copy
   213  	}
   214  }
   215  
   216  func linkPooler(ctx context.Context, projectRef string, fsys afero.Fs) error {
   217  	resp, err := utils.GetSupabase().V1GetPgbouncerConfigWithResponse(ctx, projectRef)
   218  	if err != nil {
   219  		return errors.Errorf("failed to get pooler config: %w", err)
   220  	}
   221  	if resp.JSON200 == nil {
   222  		return errors.Errorf("%w: %s", tenant.ErrAuthToken, string(resp.Body))
   223  	}
   224  	updatePoolerConfig(*resp.JSON200)
   225  	if resp.JSON200.ConnectionString != nil {
   226  		utils.Config.Db.Pooler.ConnectionString = *resp.JSON200.ConnectionString
   227  		return utils.WriteFile(utils.PoolerUrlPath, []byte(utils.Config.Db.Pooler.ConnectionString), fsys)
   228  	}
   229  	return nil
   230  }
   231  
   232  func updatePoolerConfig(config api.V1PgbouncerConfigResponse) {
   233  	copy := utils.Config.Db.Pooler
   234  	if config.PoolMode != nil {
   235  		copy.PoolMode = utils.PoolMode(*config.PoolMode)
   236  	}
   237  	if config.DefaultPoolSize != nil {
   238  		copy.DefaultPoolSize = uint(*config.DefaultPoolSize)
   239  	}
   240  	if config.MaxClientConn != nil {
   241  		copy.MaxClientConn = uint(*config.MaxClientConn)
   242  	}
   243  	changed := utils.Config.Db.Pooler.PoolMode != copy.PoolMode ||
   244  		utils.Config.Db.Pooler.DefaultPoolSize != copy.DefaultPoolSize ||
   245  		utils.Config.Db.Pooler.MaxClientConn != copy.MaxClientConn
   246  	if changed {
   247  		updatedConfig.Pooler = copy
   248  	}
   249  }