github.com/dolthub/dolt/go@v0.40.5-0.20240520175717-68db7794bea6/libraries/doltcore/dtestutils/sql_server_driver/cmd.go (about) 1 // Copyright 2022 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package sql_server_driver 16 17 import ( 18 "bufio" 19 "bytes" 20 "database/sql" 21 "fmt" 22 "io" 23 "log" 24 "net/url" 25 "os" 26 "os/exec" 27 "path/filepath" 28 "sync" 29 "time" 30 31 _ "github.com/go-sql-driver/mysql" 32 ) 33 34 var DoltPath string 35 var DelvePath string 36 37 const TestUserName = "Bats Tests" 38 const TestEmailAddress = "bats@email.fake" 39 40 const ConnectAttempts = 50 41 const RetrySleepDuration = 50 * time.Millisecond 42 43 const EnvDoltBinPath = "DOLT_BIN_PATH" 44 45 func init() { 46 path := os.Getenv(EnvDoltBinPath) 47 if path == "" { 48 path = "dolt" 49 } 50 path = filepath.Clean(path) 51 var err error 52 53 DoltPath, err = exec.LookPath(path) 54 if err != nil { 55 log.Printf("did not find dolt binary: %v\n", err.Error()) 56 } 57 58 DelvePath, _ = exec.LookPath("dlv") 59 } 60 61 // DoltUser is an abstraction for a user account that calls `dolt` CLI 62 // commands. All of our dolt binary invocations are done through DoltUser. 63 // 64 // For our purposes, it does the following: 65 // * owns a tmpdir, to which it sets DOLT_ROOT_PATH when invoking dolt. 66 // * some initial dolt global config, 67 // - user.name 68 // - user.email 69 // - metrics.disabled = true 70 // 71 // * can create repo stores, which will be a tmpdir to store a repo and/or subrepos. 72 type DoltUser struct { 73 tmpdir string 74 } 75 76 var _ DoltCmdable = DoltUser{} 77 var _ DoltDebuggable = DoltUser{} 78 79 func NewDoltUser() (DoltUser, error) { 80 tmpdir, err := os.MkdirTemp("", "go-sql-server-driver-") 81 if err != nil { 82 return DoltUser{}, err 83 } 84 res := DoltUser{tmpdir} 85 err = res.DoltExec("config", "--global", "--add", "metrics.disabled", "true") 86 if err != nil { 87 return DoltUser{}, err 88 } 89 err = res.DoltExec("config", "--global", "--add", "user.name", TestUserName) 90 if err != nil { 91 return DoltUser{}, err 92 } 93 err = res.DoltExec("config", "--global", "--add", "user.email", TestEmailAddress) 94 if err != nil { 95 return DoltUser{}, err 96 } 97 return res, nil 98 } 99 100 func (u DoltUser) DoltCmd(args ...string) *exec.Cmd { 101 cmd := exec.Command(DoltPath, args...) 102 cmd.Dir = u.tmpdir 103 cmd.Env = append(os.Environ(), "DOLT_ROOT_PATH="+u.tmpdir) 104 ApplyCmdAttributes(cmd) 105 return cmd 106 } 107 108 func (u DoltUser) DoltDebug(debuggerPort int, args ...string) *exec.Cmd { 109 if DelvePath != "" { 110 dlvArgs := []string{ 111 fmt.Sprintf("--listen=:%d", debuggerPort), 112 "--headless", 113 "--api-version=2", 114 "--accept-multiclient", 115 "exec", 116 DoltPath, 117 "--", 118 } 119 cmd := exec.Command(DelvePath, append(dlvArgs, args...)...) 120 cmd.Dir = u.tmpdir 121 cmd.Env = append(os.Environ(), "DOLT_ROOT_PATH="+u.tmpdir) 122 ApplyCmdAttributes(cmd) 123 return cmd 124 } else { 125 panic("dlv not found") 126 } 127 } 128 129 func (u DoltUser) DoltExec(args ...string) error { 130 cmd := u.DoltCmd(args...) 131 return cmd.Run() 132 } 133 134 func (u DoltUser) MakeRepoStore() (RepoStore, error) { 135 tmpdir, err := os.MkdirTemp(u.tmpdir, "repo-store-") 136 if err != nil { 137 return RepoStore{}, err 138 } 139 return RepoStore{u, tmpdir}, nil 140 } 141 142 func (u DoltUser) Cleanup() error { 143 return os.RemoveAll(u.tmpdir) 144 } 145 146 type RepoStore struct { 147 user DoltUser 148 Dir string 149 } 150 151 var _ DoltCmdable = RepoStore{} 152 var _ DoltDebuggable = RepoStore{} 153 154 func (rs RepoStore) MakeRepo(name string) (Repo, error) { 155 path := filepath.Join(rs.Dir, name) 156 err := os.Mkdir(path, 0750) 157 if err != nil { 158 return Repo{}, err 159 } 160 ret := Repo{rs.user, path} 161 err = ret.DoltExec("init") 162 if err != nil { 163 return Repo{}, err 164 } 165 return ret, nil 166 } 167 168 func (rs RepoStore) DoltCmd(args ...string) *exec.Cmd { 169 cmd := rs.user.DoltCmd(args...) 170 cmd.Dir = rs.Dir 171 return cmd 172 } 173 174 func (rs RepoStore) DoltDebug(debuggerPort int, args ...string) *exec.Cmd { 175 cmd := rs.user.DoltDebug(debuggerPort, args...) 176 cmd.Dir = rs.Dir 177 return cmd 178 } 179 180 type Repo struct { 181 user DoltUser 182 Dir string 183 } 184 185 func (r Repo) DoltCmd(args ...string) *exec.Cmd { 186 cmd := r.user.DoltCmd(args...) 187 cmd.Dir = r.Dir 188 return cmd 189 } 190 191 func (r Repo) DoltExec(args ...string) error { 192 cmd := r.DoltCmd(args...) 193 err := cmd.Start() 194 if err != nil { 195 return err 196 } 197 return cmd.Wait() 198 } 199 200 func (r Repo) CreateRemote(name, url string) error { 201 cmd := r.DoltCmd("remote", "add", name, url) 202 return cmd.Run() 203 } 204 205 type SqlServer struct { 206 Name string 207 Done chan struct{} 208 Cmd *exec.Cmd 209 Port int 210 DebugPort int 211 Output *bytes.Buffer 212 DBName string 213 RecreateCmd func(args ...string) *exec.Cmd 214 } 215 216 type SqlServerOpt func(s *SqlServer) 217 218 func WithArgs(args ...string) SqlServerOpt { 219 return func(s *SqlServer) { 220 s.Cmd.Args = append(s.Cmd.Args, args...) 221 } 222 } 223 224 func WithName(name string) SqlServerOpt { 225 return func(s *SqlServer) { 226 s.Name = name 227 } 228 } 229 230 func WithEnvs(envs ...string) SqlServerOpt { 231 return func(s *SqlServer) { 232 s.Cmd.Env = append(s.Cmd.Env, envs...) 233 } 234 } 235 236 func WithPort(port int) SqlServerOpt { 237 return func(s *SqlServer) { 238 s.Port = port 239 } 240 } 241 242 func WithDebugPort(port int) SqlServerOpt { 243 return func(s *SqlServer) { 244 s.DebugPort = port 245 } 246 } 247 248 type DoltCmdable interface { 249 DoltCmd(args ...string) *exec.Cmd 250 } 251 252 type DoltDebuggable interface { 253 DoltDebug(debuggerPort int, args ...string) *exec.Cmd 254 } 255 256 func StartSqlServer(dc DoltCmdable, opts ...SqlServerOpt) (*SqlServer, error) { 257 cmd := dc.DoltCmd("sql-server") 258 return runSqlServerCommand(dc, opts, cmd) 259 } 260 261 func DebugSqlServer(dc DoltCmdable, debuggerPort int, opts ...SqlServerOpt) (*SqlServer, error) { 262 ddb, ok := dc.(DoltDebuggable) 263 if !ok { 264 return nil, fmt.Errorf("%T does not implement DoltDebuggable", dc) 265 } 266 267 cmd := ddb.DoltDebug(debuggerPort, "sql-server") 268 return runSqlServerCommand(dc, append(opts, WithDebugPort(debuggerPort)), cmd) 269 } 270 271 func runSqlServerCommand(dc DoltCmdable, opts []SqlServerOpt, cmd *exec.Cmd) (*SqlServer, error) { 272 stdout, err := cmd.StdoutPipe() 273 if err != nil { 274 return nil, err 275 } 276 cmd.Stderr = cmd.Stdout 277 output := new(bytes.Buffer) 278 var wg sync.WaitGroup 279 wg.Add(1) 280 done := make(chan struct{}) 281 go func() { 282 wg.Wait() 283 close(done) 284 }() 285 286 server := &SqlServer{ 287 Done: done, 288 Cmd: cmd, 289 Port: 3306, 290 Output: output, 291 } 292 for _, o := range opts { 293 o(server) 294 } 295 296 go func() { 297 defer wg.Done() 298 multiCopyWithNamePrefix(os.Stdout, output, stdout, server.Name) 299 }() 300 301 server.RecreateCmd = func(args ...string) *exec.Cmd { 302 if server.DebugPort > 0 { 303 ddb, ok := dc.(DoltDebuggable) 304 if !ok { 305 panic(fmt.Sprintf("%T does not implement DoltDebuggable", dc)) 306 } 307 return ddb.DoltDebug(server.DebugPort, args...) 308 } else { 309 return dc.DoltCmd(args...) 310 } 311 } 312 313 err = server.Cmd.Start() 314 if err != nil { 315 return nil, err 316 } 317 return server, nil 318 } 319 320 func (s *SqlServer) ErrorStop() error { 321 <-s.Done 322 return s.Cmd.Wait() 323 } 324 325 func multiCopyWithNamePrefix(stdout, captured io.Writer, in io.Reader, name string) { 326 reader := bufio.NewReader(in) 327 multiOut := io.MultiWriter(stdout, captured) 328 wantsPrefix := true 329 for { 330 line, isPrefix, err := reader.ReadLine() 331 if err != nil { 332 return 333 } 334 if wantsPrefix && name != "" { 335 stdout.Write([]byte("[")) 336 stdout.Write([]byte(name)) 337 stdout.Write([]byte("] ")) 338 } 339 multiOut.Write(line) 340 if isPrefix { 341 wantsPrefix = false 342 } else { 343 multiOut.Write([]byte("\n")) 344 wantsPrefix = true 345 } 346 } 347 } 348 349 func (s *SqlServer) Restart(newargs *[]string, newenvs *[]string) error { 350 err := s.GracefulStop() 351 if err != nil { 352 return err 353 } 354 args := s.Cmd.Args[1:] 355 if newargs != nil { 356 args = append([]string{"sql-server"}, (*newargs)...) 357 } 358 s.Cmd = s.RecreateCmd(args...) 359 if newenvs != nil { 360 s.Cmd.Env = append(s.Cmd.Env, (*newenvs)...) 361 } 362 stdout, err := s.Cmd.StdoutPipe() 363 if err != nil { 364 return err 365 } 366 s.Cmd.Stderr = s.Cmd.Stdout 367 var wg sync.WaitGroup 368 wg.Add(1) 369 go func() { 370 defer wg.Done() 371 multiCopyWithNamePrefix(os.Stdout, s.Output, stdout, s.Name) 372 }() 373 s.Done = make(chan struct{}) 374 go func() { 375 wg.Wait() 376 close(s.Done) 377 }() 378 return s.Cmd.Start() 379 } 380 381 func (s *SqlServer) DB(c Connection) (*sql.DB, error) { 382 var pass string 383 pass, err := c.Password() 384 if err != nil { 385 return nil, err 386 } 387 return ConnectDB(c.User, pass, s.DBName, "127.0.0.1", s.Port, c.DriverParams) 388 } 389 390 func ConnectDB(user, password, name, host string, port int, driverParams map[string]string) (*sql.DB, error) { 391 params := make(url.Values) 392 params.Set("allowAllFiles", "true") 393 params.Set("tls", "preferred") 394 for k, v := range driverParams { 395 params.Set(k, v) 396 } 397 dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?%s", user, password, host, port, name, params.Encode()) 398 399 db, err := sql.Open("mysql", dsn) 400 if err != nil { 401 return nil, err 402 } 403 for i := 0; i < ConnectAttempts; i++ { 404 err = db.Ping() 405 if err == nil { 406 return db, nil 407 } 408 time.Sleep(RetrySleepDuration) 409 } 410 if err != nil { 411 return nil, err 412 } 413 return db, nil 414 }