github.com/dolthub/dolt/go@v0.40.5-0.20240520175717-68db7794bea6/libraries/doltcore/sqle/dsess/session_state_adapter.go (about) 1 // Copyright 2021 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package dsess 16 17 import ( 18 "context" 19 "fmt" 20 "strings" 21 22 "github.com/dolthub/go-mysql-server/sql" 23 24 "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" 25 "github.com/dolthub/dolt/go/libraries/doltcore/env" 26 "github.com/dolthub/dolt/go/libraries/doltcore/ref" 27 "github.com/dolthub/dolt/go/libraries/utils/concurrentmap" 28 ) 29 30 // SessionStateAdapter is an adapter for env.RepoStateReader in SQL contexts, getting information about the repo state 31 // from the session. 32 type SessionStateAdapter struct { 33 session *DoltSession 34 dbName string 35 remotes *concurrentmap.Map[string, env.Remote] 36 backups *concurrentmap.Map[string, env.Remote] 37 branches *concurrentmap.Map[string, env.BranchConfig] 38 } 39 40 func (s SessionStateAdapter) SetCWBHeadRef(ctx context.Context, newRef ref.MarshalableRef) error { 41 return fmt.Errorf("Cannot set cwb head ref with a SessionStateAdapter") 42 } 43 44 var _ env.RepoStateReader = SessionStateAdapter{} 45 var _ env.RepoStateWriter = SessionStateAdapter{} 46 var _ env.RootsProvider = SessionStateAdapter{} 47 48 func NewSessionStateAdapter(session *DoltSession, dbName string, remotes *concurrentmap.Map[string, env.Remote], branches *concurrentmap.Map[string, env.BranchConfig], backups *concurrentmap.Map[string, env.Remote]) SessionStateAdapter { 49 if branches == nil { 50 branches = concurrentmap.New[string, env.BranchConfig]() 51 } 52 return SessionStateAdapter{session: session, dbName: dbName, remotes: remotes, branches: branches, backups: backups} 53 } 54 55 func (s SessionStateAdapter) GetRoots(ctx context.Context) (doltdb.Roots, error) { 56 sqlCtx := sql.NewContext(ctx) 57 state, _, err := s.session.lookupDbState(sqlCtx, s.dbName) 58 if err != nil { 59 return doltdb.Roots{}, err 60 } 61 62 return state.roots(), nil 63 } 64 65 func (s SessionStateAdapter) CWBHeadRef() (ref.DoltRef, error) { 66 workingSet, err := s.session.WorkingSet(sql.NewContext(context.Background()), s.dbName) 67 if err != nil { 68 return nil, err 69 } 70 71 headRef, err := workingSet.Ref().ToHeadRef() 72 if err != nil { 73 return nil, err 74 } 75 return headRef, nil 76 } 77 78 func (s SessionStateAdapter) CWBHeadSpec() (*doltdb.CommitSpec, error) { 79 // TODO: get rid of this 80 ref, err := s.CWBHeadRef() 81 if err != nil { 82 return nil, err 83 } 84 spec, err := doltdb.NewCommitSpec(ref.GetPath()) 85 if err != nil { 86 panic(err) 87 } 88 return spec, nil 89 } 90 91 func (s SessionStateAdapter) GetRemotes() (*concurrentmap.Map[string, env.Remote], error) { 92 return s.remotes, nil 93 } 94 95 func (s SessionStateAdapter) GetBackups() (*concurrentmap.Map[string, env.Remote], error) { 96 return s.backups, nil 97 } 98 99 func (s SessionStateAdapter) GetBranches() (*concurrentmap.Map[string, env.BranchConfig], error) { 100 return s.branches, nil 101 } 102 103 func (s SessionStateAdapter) UpdateBranch(name string, new env.BranchConfig) error { 104 s.branches.Set(name, new) 105 106 fs, err := s.session.Provider().FileSystemForDatabase(s.dbName) 107 if err != nil { 108 return err 109 } 110 111 repoState, err := env.LoadRepoState(fs) 112 if err != nil { 113 return err 114 } 115 repoState.Branches.Set(name, new) 116 117 return repoState.Save(fs) 118 } 119 120 func (s SessionStateAdapter) AddRemote(remote env.Remote) error { 121 if _, ok := s.remotes.Get(remote.Name); ok { 122 return env.ErrRemoteAlreadyExists 123 } 124 125 if strings.IndexAny(remote.Name, " \t\n\r./\\!@#$%^&*(){}[],.<>'\"?=+|") != -1 { 126 return env.ErrInvalidBackupName 127 } 128 129 fs, err := s.session.Provider().FileSystemForDatabase(s.dbName) 130 if err != nil { 131 return err 132 } 133 134 repoState, err := env.LoadRepoState(fs) 135 if err != nil { 136 return err 137 } 138 139 // can have multiple remotes with the same address, but no conflicting backups 140 if rem, found := env.CheckRemoteAddressConflict(remote.Url, nil, repoState.Backups); found { 141 return fmt.Errorf("%w: '%s' -> %s", env.ErrRemoteAddressConflict, rem.Name, rem.Url) 142 } 143 144 s.remotes.Set(remote.Name, remote) 145 repoState.AddRemote(remote) 146 return repoState.Save(fs) 147 } 148 149 func (s SessionStateAdapter) AddBackup(backup env.Remote) error { 150 if _, ok := s.backups.Get(backup.Name); ok { 151 return env.ErrBackupAlreadyExists 152 } 153 154 if strings.IndexAny(backup.Name, " \t\n\r./\\!@#$%^&*(){}[],.<>'\"?=+|") != -1 { 155 return env.ErrInvalidBackupName 156 } 157 158 fs, err := s.session.Provider().FileSystemForDatabase(s.dbName) 159 if err != nil { 160 return err 161 } 162 163 repoState, err := env.LoadRepoState(fs) 164 if err != nil { 165 return err 166 } 167 168 // no conflicting remote or backup addresses 169 if bac, found := env.CheckRemoteAddressConflict(backup.Url, repoState.Remotes, repoState.Backups); found { 170 return fmt.Errorf("%w: '%s' -> %s", env.ErrRemoteAddressConflict, bac.Name, bac.Url) 171 } 172 173 s.backups.Set(backup.Name, backup) 174 repoState.AddBackup(backup) 175 return repoState.Save(fs) 176 } 177 178 func (s SessionStateAdapter) RemoveRemote(_ context.Context, name string) error { 179 remote, ok := s.remotes.Get(name) 180 if !ok { 181 return env.ErrRemoteNotFound 182 } 183 s.remotes.Delete(remote.Name) 184 185 fs, err := s.session.Provider().FileSystemForDatabase(s.dbName) 186 if err != nil { 187 return err 188 } 189 190 repoState, err := env.LoadRepoState(fs) 191 if err != nil { 192 return err 193 } 194 195 remote, ok = repoState.Remotes.Get(name) 196 if !ok { 197 // sanity check 198 return env.ErrRemoteNotFound 199 } 200 repoState.Remotes.Delete(name) 201 return repoState.Save(fs) 202 } 203 204 func (s SessionStateAdapter) RemoveBackup(_ context.Context, name string) error { 205 backup, ok := s.backups.Get(name) 206 if !ok { 207 return env.ErrBackupNotFound 208 } 209 s.backups.Delete(backup.Name) 210 211 fs, err := s.session.Provider().FileSystemForDatabase(s.dbName) 212 if err != nil { 213 return err 214 } 215 216 repoState, err := env.LoadRepoState(fs) 217 if err != nil { 218 return err 219 } 220 221 backup, ok = repoState.Backups.Get(name) 222 if !ok { 223 // sanity check 224 return env.ErrBackupNotFound 225 } 226 repoState.Backups.Delete(name) 227 return repoState.Save(fs) 228 } 229 230 func (s SessionStateAdapter) TempTableFilesDir() (string, error) { 231 branchState, _, err := s.session.lookupDbState(sql.NewContext(context.Background()), s.dbName) 232 if err != nil { 233 return "", err 234 } 235 236 return branchState.dbState.tmpFileDir, nil 237 }