github.com/Redstoneguy129/cli@v0.0.0-20230211220159-15dca4e91917/internal/link/link.go (about)

     1  package link
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"io"
     8  	"os"
     9  	"path/filepath"
    10  	"strconv"
    11  	"strings"
    12  
    13  	"github.com/BurntSushi/toml"
    14  	"github.com/Redstoneguy129/cli/internal/migration/repair"
    15  	"github.com/Redstoneguy129/cli/internal/utils"
    16  	"github.com/Redstoneguy129/cli/internal/utils/credentials"
    17  	"github.com/Redstoneguy129/cli/pkg/api"
    18  	"github.com/jackc/pgconn"
    19  	"github.com/jackc/pgx/v4"
    20  	"github.com/spf13/afero"
    21  	"golang.org/x/term"
    22  )
    23  
    24  var updatedConfig map[string]interface{} = make(map[string]interface{})
    25  
    26  func PreRun(projectRef string, fsys afero.Fs) error {
    27  	// Sanity checks
    28  	if !utils.ProjectRefPattern.MatchString(projectRef) {
    29  		return errors.New("Invalid project ref format. Must be like `abcdefghijklmnopqrst`.")
    30  	}
    31  	return utils.LoadConfigFS(fsys)
    32  }
    33  
    34  func Run(ctx context.Context, projectRef, username, password, database string, fsys afero.Fs, options ...func(*pgx.ConnConfig)) error {
    35  	// 1. Check postgrest config
    36  	if err := linkPostgrest(ctx, projectRef); err != nil {
    37  		return err
    38  	}
    39  
    40  	// 2. Check database connection
    41  	if len(password) > 0 {
    42  		host := utils.GetSupabaseDbHost(projectRef)
    43  		if err := linkDatabase(ctx, username, password, database, host, options...); err != nil {
    44  			return err
    45  		}
    46  		// Save database password
    47  		if err := credentials.Set(projectRef, password); err != nil {
    48  			fmt.Fprintln(os.Stderr, "Failed to save database password:", err)
    49  		}
    50  	}
    51  
    52  	// 3. Save project ref
    53  	if err := utils.MkdirIfNotExistFS(fsys, filepath.Dir(utils.ProjectRefPath)); err != nil {
    54  		return err
    55  	}
    56  	return afero.WriteFile(fsys, utils.ProjectRefPath, []byte(projectRef), 0644)
    57  }
    58  
    59  func PostRun(projectRef string, stdout io.Writer, fsys afero.Fs) error {
    60  	fmt.Fprintln(stdout, "Finished "+utils.Aqua("supabase link")+".")
    61  	if len(updatedConfig) == 0 {
    62  		return nil
    63  	}
    64  	fmt.Fprintln(os.Stderr, "Local config differs from linked project. Try updating", utils.Bold(utils.ConfigPath))
    65  	enc := toml.NewEncoder(stdout)
    66  	enc.Indent = ""
    67  	return enc.Encode(updatedConfig)
    68  }
    69  
    70  func linkPostgrest(ctx context.Context, projectRef string) error {
    71  	resp, err := utils.GetSupabase().GetPostgRESTConfigWithResponse(ctx, projectRef)
    72  	if err != nil {
    73  		return err
    74  	}
    75  	if resp.JSON200 == nil {
    76  		return errors.New("Authorization failed for the access token and project ref pair: " + string(resp.Body))
    77  	}
    78  	updateApiConfig(*resp.JSON200)
    79  	return nil
    80  }
    81  
    82  func updateApiConfig(config api.PostgrestConfigResponse) {
    83  	maxRows := uint(config.MaxRows)
    84  	searchPath := readCsv(config.DbExtraSearchPath)
    85  	dbSchema := readCsv(config.DbSchema)
    86  	changed := utils.Config.Api.MaxRows != maxRows ||
    87  		!sliceEqual(utils.Config.Api.ExtraSearchPath, searchPath) ||
    88  		!sliceEqual(utils.Config.Api.Schemas, dbSchema)
    89  	if changed {
    90  		copy := utils.Config.Api
    91  		copy.MaxRows = maxRows
    92  		copy.ExtraSearchPath = searchPath
    93  		copy.Schemas = dbSchema
    94  		updatedConfig["api"] = copy
    95  	}
    96  }
    97  
    98  func readCsv(line string) []string {
    99  	var result []string
   100  	tokens := strings.Split(line, ",")
   101  	for _, t := range tokens {
   102  		trimmed := strings.TrimSpace(t)
   103  		if len(trimmed) > 0 {
   104  			result = append(result, trimmed)
   105  		}
   106  	}
   107  	return result
   108  }
   109  
   110  func sliceEqual(a, b []string) bool {
   111  	if len(a) != len(b) {
   112  		return false
   113  	}
   114  	for i, v := range a {
   115  		if v != b[i] {
   116  			return false
   117  		}
   118  	}
   119  	return true
   120  }
   121  
   122  func linkDatabase(ctx context.Context, username, password, database, host string, options ...func(*pgx.ConnConfig)) error {
   123  	conn, err := utils.ConnectRemotePostgres(ctx, username, password, database, host, options...)
   124  	if err != nil {
   125  		return err
   126  	}
   127  	defer conn.Close(context.Background())
   128  	updatePostgresConfig(conn)
   129  	// If `schema_migrations` doesn't exist on the remote database, create it.
   130  	batch := pgconn.Batch{}
   131  	batch.ExecParams(repair.CREATE_VERSION_SCHEMA, nil, nil, nil, nil)
   132  	batch.ExecParams(repair.CREATE_VERSION_TABLE, nil, nil, nil, nil)
   133  	_, err = conn.PgConn().ExecBatch(ctx, &batch).ReadAll()
   134  	return err
   135  }
   136  
   137  func updatePostgresConfig(conn *pgx.Conn) {
   138  	serverVersion := conn.PgConn().ParameterStatus("server_version")
   139  	// Safe to assume that supported Postgres version is 10.0 <= n < 100.0
   140  	majorDigits := len(serverVersion)
   141  	if majorDigits > 2 {
   142  		majorDigits = 2
   143  	}
   144  	dbMajorVersion, err := strconv.ParseUint(serverVersion[:majorDigits], 10, 7)
   145  	// Treat error as unchanged
   146  	if err == nil && uint64(utils.Config.Db.MajorVersion) != dbMajorVersion {
   147  		copy := utils.Config.Db
   148  		copy.MajorVersion = uint(dbMajorVersion)
   149  		updatedConfig["db"] = copy
   150  	}
   151  }
   152  
   153  func PromptPassword(stdin *os.File) string {
   154  	fmt.Fprint(os.Stderr, "Enter your database password: ")
   155  	bytepw, err := term.ReadPassword(int(stdin.Fd()))
   156  	fmt.Println()
   157  	if err != nil {
   158  		return ""
   159  	}
   160  	return string(bytepw)
   161  }