github.com/Redstoneguy129/cli@v0.0.0-20230211220159-15dca4e91917/internal/link/link.go (about) 1 package link 2 3 import ( 4 "context" 5 "errors" 6 "fmt" 7 "io" 8 "os" 9 "path/filepath" 10 "strconv" 11 "strings" 12 13 "github.com/BurntSushi/toml" 14 "github.com/Redstoneguy129/cli/internal/migration/repair" 15 "github.com/Redstoneguy129/cli/internal/utils" 16 "github.com/Redstoneguy129/cli/internal/utils/credentials" 17 "github.com/Redstoneguy129/cli/pkg/api" 18 "github.com/jackc/pgconn" 19 "github.com/jackc/pgx/v4" 20 "github.com/spf13/afero" 21 "golang.org/x/term" 22 ) 23 24 var updatedConfig map[string]interface{} = make(map[string]interface{}) 25 26 func PreRun(projectRef string, fsys afero.Fs) error { 27 // Sanity checks 28 if !utils.ProjectRefPattern.MatchString(projectRef) { 29 return errors.New("Invalid project ref format. Must be like `abcdefghijklmnopqrst`.") 30 } 31 return utils.LoadConfigFS(fsys) 32 } 33 34 func Run(ctx context.Context, projectRef, username, password, database string, fsys afero.Fs, options ...func(*pgx.ConnConfig)) error { 35 // 1. Check postgrest config 36 if err := linkPostgrest(ctx, projectRef); err != nil { 37 return err 38 } 39 40 // 2. Check database connection 41 if len(password) > 0 { 42 host := utils.GetSupabaseDbHost(projectRef) 43 if err := linkDatabase(ctx, username, password, database, host, options...); err != nil { 44 return err 45 } 46 // Save database password 47 if err := credentials.Set(projectRef, password); err != nil { 48 fmt.Fprintln(os.Stderr, "Failed to save database password:", err) 49 } 50 } 51 52 // 3. Save project ref 53 if err := utils.MkdirIfNotExistFS(fsys, filepath.Dir(utils.ProjectRefPath)); err != nil { 54 return err 55 } 56 return afero.WriteFile(fsys, utils.ProjectRefPath, []byte(projectRef), 0644) 57 } 58 59 func PostRun(projectRef string, stdout io.Writer, fsys afero.Fs) error { 60 fmt.Fprintln(stdout, "Finished "+utils.Aqua("supabase link")+".") 61 if len(updatedConfig) == 0 { 62 return nil 63 } 64 fmt.Fprintln(os.Stderr, "Local config differs from linked project. Try updating", utils.Bold(utils.ConfigPath)) 65 enc := toml.NewEncoder(stdout) 66 enc.Indent = "" 67 return enc.Encode(updatedConfig) 68 } 69 70 func linkPostgrest(ctx context.Context, projectRef string) error { 71 resp, err := utils.GetSupabase().GetPostgRESTConfigWithResponse(ctx, projectRef) 72 if err != nil { 73 return err 74 } 75 if resp.JSON200 == nil { 76 return errors.New("Authorization failed for the access token and project ref pair: " + string(resp.Body)) 77 } 78 updateApiConfig(*resp.JSON200) 79 return nil 80 } 81 82 func updateApiConfig(config api.PostgrestConfigResponse) { 83 maxRows := uint(config.MaxRows) 84 searchPath := readCsv(config.DbExtraSearchPath) 85 dbSchema := readCsv(config.DbSchema) 86 changed := utils.Config.Api.MaxRows != maxRows || 87 !sliceEqual(utils.Config.Api.ExtraSearchPath, searchPath) || 88 !sliceEqual(utils.Config.Api.Schemas, dbSchema) 89 if changed { 90 copy := utils.Config.Api 91 copy.MaxRows = maxRows 92 copy.ExtraSearchPath = searchPath 93 copy.Schemas = dbSchema 94 updatedConfig["api"] = copy 95 } 96 } 97 98 func readCsv(line string) []string { 99 var result []string 100 tokens := strings.Split(line, ",") 101 for _, t := range tokens { 102 trimmed := strings.TrimSpace(t) 103 if len(trimmed) > 0 { 104 result = append(result, trimmed) 105 } 106 } 107 return result 108 } 109 110 func sliceEqual(a, b []string) bool { 111 if len(a) != len(b) { 112 return false 113 } 114 for i, v := range a { 115 if v != b[i] { 116 return false 117 } 118 } 119 return true 120 } 121 122 func linkDatabase(ctx context.Context, username, password, database, host string, options ...func(*pgx.ConnConfig)) error { 123 conn, err := utils.ConnectRemotePostgres(ctx, username, password, database, host, options...) 124 if err != nil { 125 return err 126 } 127 defer conn.Close(context.Background()) 128 updatePostgresConfig(conn) 129 // If `schema_migrations` doesn't exist on the remote database, create it. 130 batch := pgconn.Batch{} 131 batch.ExecParams(repair.CREATE_VERSION_SCHEMA, nil, nil, nil, nil) 132 batch.ExecParams(repair.CREATE_VERSION_TABLE, nil, nil, nil, nil) 133 _, err = conn.PgConn().ExecBatch(ctx, &batch).ReadAll() 134 return err 135 } 136 137 func updatePostgresConfig(conn *pgx.Conn) { 138 serverVersion := conn.PgConn().ParameterStatus("server_version") 139 // Safe to assume that supported Postgres version is 10.0 <= n < 100.0 140 majorDigits := len(serverVersion) 141 if majorDigits > 2 { 142 majorDigits = 2 143 } 144 dbMajorVersion, err := strconv.ParseUint(serverVersion[:majorDigits], 10, 7) 145 // Treat error as unchanged 146 if err == nil && uint64(utils.Config.Db.MajorVersion) != dbMajorVersion { 147 copy := utils.Config.Db 148 copy.MajorVersion = uint(dbMajorVersion) 149 updatedConfig["db"] = copy 150 } 151 } 152 153 func PromptPassword(stdin *os.File) string { 154 fmt.Fprint(os.Stderr, "Enter your database password: ") 155 bytepw, err := term.ReadPassword(int(stdin.Fd())) 156 fmt.Println() 157 if err != nil { 158 return "" 159 } 160 return string(bytepw) 161 }