github.com/supabase/cli@v1.168.1/internal/db/branch/switch_/switch_.go (about)

     1  package switch_
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"os"
     7  	"path/filepath"
     8  
     9  	"github.com/go-errors/errors"
    10  	"github.com/jackc/pgconn"
    11  	"github.com/jackc/pgx/v4"
    12  	"github.com/spf13/afero"
    13  	"github.com/supabase/cli/internal/db/reset"
    14  	"github.com/supabase/cli/internal/utils"
    15  )
    16  
    17  func Run(ctx context.Context, target string, fsys afero.Fs, options ...func(*pgx.ConnConfig)) error {
    18  	// 1. Sanity checks
    19  	{
    20  		if err := utils.LoadConfigFS(fsys); err != nil {
    21  			return err
    22  		}
    23  		if err := utils.AssertSupabaseDbIsRunning(); err != nil {
    24  			return err
    25  		}
    26  		if target != "main" && utils.IsBranchNameReserved(target) {
    27  			return errors.New("Cannot switch branch " + utils.Aqua(target) + ": branch name is reserved.")
    28  		}
    29  		branchPath := filepath.Join(filepath.Dir(utils.CurrBranchPath), target)
    30  		if _, err := fsys.Stat(branchPath); errors.Is(err, os.ErrNotExist) {
    31  			return errors.New("Branch " + utils.Aqua(target) + " does not exist.")
    32  		} else if err != nil {
    33  			return err
    34  		}
    35  	}
    36  
    37  	// 2. Check current branch
    38  	currBranch, err := utils.GetCurrentBranchFS(fsys)
    39  	if err != nil {
    40  		// Assume we are on main branch
    41  		currBranch = "main"
    42  	}
    43  
    44  	// 3. Switch Postgres database
    45  	if currBranch == target {
    46  		fmt.Println("Already on branch " + utils.Aqua(target) + ".")
    47  	} else if err := switchDatabase(ctx, currBranch, target, options...); err != nil {
    48  		return errors.New("Error switching to branch " + utils.Aqua(target) + ": " + err.Error())
    49  	} else {
    50  		fmt.Println("Switched to branch " + utils.Aqua(target) + ".")
    51  	}
    52  
    53  	// 4. Update current branch
    54  	if err := afero.WriteFile(fsys, utils.CurrBranchPath, []byte(target), 0644); err != nil {
    55  		return errors.New("Unable to update local branch file. Fix by running: echo '" + target + "' > " + utils.CurrBranchPath)
    56  	}
    57  	return nil
    58  }
    59  
    60  func switchDatabase(ctx context.Context, source, target string, options ...func(*pgx.ConnConfig)) error {
    61  	conn, err := utils.ConnectLocalPostgres(ctx, pgconn.Config{Database: "template1"}, options...)
    62  	if err != nil {
    63  		return err
    64  	}
    65  	defer conn.Close(context.Background())
    66  	if err := reset.DisconnectClients(ctx, conn); err != nil {
    67  		return err
    68  	}
    69  	defer func() {
    70  		if err := reset.RestartDatabase(context.Background(), os.Stderr); err != nil {
    71  			fmt.Fprintln(os.Stderr, "Failed to restart database:", err)
    72  		}
    73  	}()
    74  	backup := "ALTER DATABASE postgres RENAME TO " + source + ";"
    75  	if _, err := conn.Exec(ctx, backup); err != nil {
    76  		return err
    77  	}
    78  	rename := "ALTER DATABASE " + target + " RENAME TO postgres;"
    79  	if _, err := conn.Exec(ctx, rename); err != nil {
    80  		rollback := "ALTER DATABASE " + source + " RENAME TO postgres;"
    81  		if _, err := conn.Exec(ctx, rollback); err != nil {
    82  			fmt.Fprintln(os.Stderr, "Failed to rollback database:", err)
    83  		}
    84  		return err
    85  	}
    86  	return nil
    87  }