github.com/dolthub/dolt/go@v0.40.5-0.20240520175717-68db7794bea6/libraries/doltcore/sqle/dprocedures/dolt_remote.go (about) 1 // Copyright 2022 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package dprocedures 16 17 import ( 18 "fmt" 19 "strings" 20 21 "github.com/dolthub/go-mysql-server/sql" 22 23 "github.com/dolthub/dolt/go/cmd/dolt/cli" 24 "github.com/dolthub/dolt/go/libraries/doltcore/branch_control" 25 "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" 26 "github.com/dolthub/dolt/go/libraries/doltcore/env" 27 "github.com/dolthub/dolt/go/libraries/doltcore/ref" 28 "github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess" 29 "github.com/dolthub/dolt/go/libraries/utils/argparser" 30 "github.com/dolthub/dolt/go/libraries/utils/config" 31 ) 32 33 // doltRemote is the stored procedure version for the CLI command `dolt remote`. 34 func doltRemote(ctx *sql.Context, args ...string) (sql.RowIter, error) { 35 res, err := doDoltRemote(ctx, args) 36 if err != nil { 37 return nil, err 38 } 39 return rowToIter(res), nil 40 } 41 42 // doDoltRemote is used as sql dolt_remote command for only creating or deleting remotes, not listing. 43 // To list remotes, dolt_remotes system table is used. 44 func doDoltRemote(ctx *sql.Context, args []string) (int, error) { 45 dbName := ctx.GetCurrentDatabase() 46 if len(dbName) == 0 { 47 return 1, fmt.Errorf("Empty database name.") 48 } 49 if err := branch_control.CheckAccess(ctx, branch_control.Permissions_Write); err != nil { 50 return 1, err 51 } 52 dSess := dsess.DSessFromSess(ctx.Session) 53 dbData, ok := dSess.GetDbData(ctx, dbName) 54 if !ok { 55 return 1, fmt.Errorf("Could not load database %s", dbName) 56 } 57 58 apr, err := cli.CreateRemoteArgParser().Parse(args) 59 if err != nil { 60 return 1, err 61 } 62 63 if apr.NArg() == 0 { 64 return 1, fmt.Errorf("error: invalid argument, use 'dolt_remotes' system table to list remotes") 65 } 66 67 var rsc doltdb.ReplicationStatusController 68 69 switch apr.Arg(0) { 70 case "add": 71 err = addRemote(ctx, dbName, dbData, apr, dSess) 72 case "remove", "rm": 73 err = removeRemote(ctx, dbData, apr, &rsc) 74 default: 75 err = fmt.Errorf("error: invalid argument") 76 } 77 78 if err != nil { 79 return 1, err 80 } 81 82 dsess.WaitForReplicationController(ctx, rsc) 83 84 return 0, nil 85 } 86 87 func addRemote(_ *sql.Context, dbName string, dbd env.DbData, apr *argparser.ArgParseResults, sess *dsess.DoltSession) error { 88 if apr.NArg() != 3 { 89 return fmt.Errorf("error: invalid argument") 90 } 91 92 remoteName := strings.TrimSpace(apr.Arg(1)) 93 remoteUrl := apr.Arg(2) 94 95 dbFs, err := sess.Provider().FileSystemForDatabase(dbName) 96 if err != nil { 97 return err 98 } 99 100 _, absRemoteUrl, err := env.GetAbsRemoteUrl(dbFs, &config.MapConfig{}, remoteUrl) 101 if err != nil { 102 return err 103 } 104 105 r := env.NewRemote(remoteName, absRemoteUrl, map[string]string{}) 106 return dbd.Rsw.AddRemote(r) 107 } 108 109 func removeRemote(ctx *sql.Context, dbd env.DbData, apr *argparser.ArgParseResults, rsc *doltdb.ReplicationStatusController) error { 110 if apr.NArg() != 2 { 111 return fmt.Errorf("error: invalid argument") 112 } 113 114 old := strings.TrimSpace(apr.Arg(1)) 115 116 remotes, err := dbd.Rsr.GetRemotes() 117 if err != nil { 118 return err 119 } 120 121 remote, ok := remotes.Get(old) 122 if !ok { 123 return fmt.Errorf("error: unknown remote: '%s'", old) 124 } 125 126 ddb := dbd.Ddb 127 refs, err := ddb.GetRemoteRefs(ctx) 128 if err != nil { 129 return fmt.Errorf("error: %w, cause: %s", env.ErrFailedToReadFromDb, err.Error()) 130 } 131 132 for _, r := range refs { 133 rr := r.(ref.RemoteRef) 134 135 if rr.GetRemote() == remote.Name { 136 err = ddb.DeleteBranch(ctx, rr, rsc) 137 138 if err != nil { 139 return fmt.Errorf("%w; failed to delete remote tracking ref '%s'; %s", env.ErrFailedToDeleteRemote, rr.String(), err.Error()) 140 } 141 } 142 } 143 144 return dbd.Rsw.RemoveRemote(ctx, remote.Name) 145 }