github.com/dolthub/dolt/go@v0.40.5-0.20240520175717-68db7794bea6/libraries/doltcore/dtestutils/sql_server_driver/cmd.go (about)

     1  // Copyright 2022 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package sql_server_driver
    16  
    17  import (
    18  	"bufio"
    19  	"bytes"
    20  	"database/sql"
    21  	"fmt"
    22  	"io"
    23  	"log"
    24  	"net/url"
    25  	"os"
    26  	"os/exec"
    27  	"path/filepath"
    28  	"sync"
    29  	"time"
    30  
    31  	_ "github.com/go-sql-driver/mysql"
    32  )
    33  
    34  var DoltPath string
    35  var DelvePath string
    36  
    37  const TestUserName = "Bats Tests"
    38  const TestEmailAddress = "bats@email.fake"
    39  
    40  const ConnectAttempts = 50
    41  const RetrySleepDuration = 50 * time.Millisecond
    42  
    43  const EnvDoltBinPath = "DOLT_BIN_PATH"
    44  
    45  func init() {
    46  	path := os.Getenv(EnvDoltBinPath)
    47  	if path == "" {
    48  		path = "dolt"
    49  	}
    50  	path = filepath.Clean(path)
    51  	var err error
    52  
    53  	DoltPath, err = exec.LookPath(path)
    54  	if err != nil {
    55  		log.Printf("did not find dolt binary: %v\n", err.Error())
    56  	}
    57  
    58  	DelvePath, _ = exec.LookPath("dlv")
    59  }
    60  
    61  // DoltUser is an abstraction for a user account that calls `dolt` CLI
    62  // commands. All of our dolt binary invocations are done through DoltUser.
    63  //
    64  // For our purposes, it does the following:
    65  // * owns a tmpdir, to which it sets DOLT_ROOT_PATH when invoking dolt.
    66  // * some initial dolt global config,
    67  //   - user.name
    68  //   - user.email
    69  //   - metrics.disabled = true
    70  //
    71  // * can create repo stores, which will be a tmpdir to store a repo and/or subrepos.
    72  type DoltUser struct {
    73  	tmpdir string
    74  }
    75  
    76  var _ DoltCmdable = DoltUser{}
    77  var _ DoltDebuggable = DoltUser{}
    78  
    79  func NewDoltUser() (DoltUser, error) {
    80  	tmpdir, err := os.MkdirTemp("", "go-sql-server-driver-")
    81  	if err != nil {
    82  		return DoltUser{}, err
    83  	}
    84  	res := DoltUser{tmpdir}
    85  	err = res.DoltExec("config", "--global", "--add", "metrics.disabled", "true")
    86  	if err != nil {
    87  		return DoltUser{}, err
    88  	}
    89  	err = res.DoltExec("config", "--global", "--add", "user.name", TestUserName)
    90  	if err != nil {
    91  		return DoltUser{}, err
    92  	}
    93  	err = res.DoltExec("config", "--global", "--add", "user.email", TestEmailAddress)
    94  	if err != nil {
    95  		return DoltUser{}, err
    96  	}
    97  	return res, nil
    98  }
    99  
   100  func (u DoltUser) DoltCmd(args ...string) *exec.Cmd {
   101  	cmd := exec.Command(DoltPath, args...)
   102  	cmd.Dir = u.tmpdir
   103  	cmd.Env = append(os.Environ(), "DOLT_ROOT_PATH="+u.tmpdir)
   104  	ApplyCmdAttributes(cmd)
   105  	return cmd
   106  }
   107  
   108  func (u DoltUser) DoltDebug(debuggerPort int, args ...string) *exec.Cmd {
   109  	if DelvePath != "" {
   110  		dlvArgs := []string{
   111  			fmt.Sprintf("--listen=:%d", debuggerPort),
   112  			"--headless",
   113  			"--api-version=2",
   114  			"--accept-multiclient",
   115  			"exec",
   116  			DoltPath,
   117  			"--",
   118  		}
   119  		cmd := exec.Command(DelvePath, append(dlvArgs, args...)...)
   120  		cmd.Dir = u.tmpdir
   121  		cmd.Env = append(os.Environ(), "DOLT_ROOT_PATH="+u.tmpdir)
   122  		ApplyCmdAttributes(cmd)
   123  		return cmd
   124  	} else {
   125  		panic("dlv not found")
   126  	}
   127  }
   128  
   129  func (u DoltUser) DoltExec(args ...string) error {
   130  	cmd := u.DoltCmd(args...)
   131  	return cmd.Run()
   132  }
   133  
   134  func (u DoltUser) MakeRepoStore() (RepoStore, error) {
   135  	tmpdir, err := os.MkdirTemp(u.tmpdir, "repo-store-")
   136  	if err != nil {
   137  		return RepoStore{}, err
   138  	}
   139  	return RepoStore{u, tmpdir}, nil
   140  }
   141  
   142  func (u DoltUser) Cleanup() error {
   143  	return os.RemoveAll(u.tmpdir)
   144  }
   145  
   146  type RepoStore struct {
   147  	user DoltUser
   148  	Dir  string
   149  }
   150  
   151  var _ DoltCmdable = RepoStore{}
   152  var _ DoltDebuggable = RepoStore{}
   153  
   154  func (rs RepoStore) MakeRepo(name string) (Repo, error) {
   155  	path := filepath.Join(rs.Dir, name)
   156  	err := os.Mkdir(path, 0750)
   157  	if err != nil {
   158  		return Repo{}, err
   159  	}
   160  	ret := Repo{rs.user, path}
   161  	err = ret.DoltExec("init")
   162  	if err != nil {
   163  		return Repo{}, err
   164  	}
   165  	return ret, nil
   166  }
   167  
   168  func (rs RepoStore) DoltCmd(args ...string) *exec.Cmd {
   169  	cmd := rs.user.DoltCmd(args...)
   170  	cmd.Dir = rs.Dir
   171  	return cmd
   172  }
   173  
   174  func (rs RepoStore) DoltDebug(debuggerPort int, args ...string) *exec.Cmd {
   175  	cmd := rs.user.DoltDebug(debuggerPort, args...)
   176  	cmd.Dir = rs.Dir
   177  	return cmd
   178  }
   179  
   180  type Repo struct {
   181  	user DoltUser
   182  	Dir  string
   183  }
   184  
   185  func (r Repo) DoltCmd(args ...string) *exec.Cmd {
   186  	cmd := r.user.DoltCmd(args...)
   187  	cmd.Dir = r.Dir
   188  	return cmd
   189  }
   190  
   191  func (r Repo) DoltExec(args ...string) error {
   192  	cmd := r.DoltCmd(args...)
   193  	err := cmd.Start()
   194  	if err != nil {
   195  		return err
   196  	}
   197  	return cmd.Wait()
   198  }
   199  
   200  func (r Repo) CreateRemote(name, url string) error {
   201  	cmd := r.DoltCmd("remote", "add", name, url)
   202  	return cmd.Run()
   203  }
   204  
   205  type SqlServer struct {
   206  	Name        string
   207  	Done        chan struct{}
   208  	Cmd         *exec.Cmd
   209  	Port        int
   210  	DebugPort   int
   211  	Output      *bytes.Buffer
   212  	DBName      string
   213  	RecreateCmd func(args ...string) *exec.Cmd
   214  }
   215  
   216  type SqlServerOpt func(s *SqlServer)
   217  
   218  func WithArgs(args ...string) SqlServerOpt {
   219  	return func(s *SqlServer) {
   220  		s.Cmd.Args = append(s.Cmd.Args, args...)
   221  	}
   222  }
   223  
   224  func WithName(name string) SqlServerOpt {
   225  	return func(s *SqlServer) {
   226  		s.Name = name
   227  	}
   228  }
   229  
   230  func WithEnvs(envs ...string) SqlServerOpt {
   231  	return func(s *SqlServer) {
   232  		s.Cmd.Env = append(s.Cmd.Env, envs...)
   233  	}
   234  }
   235  
   236  func WithPort(port int) SqlServerOpt {
   237  	return func(s *SqlServer) {
   238  		s.Port = port
   239  	}
   240  }
   241  
   242  func WithDebugPort(port int) SqlServerOpt {
   243  	return func(s *SqlServer) {
   244  		s.DebugPort = port
   245  	}
   246  }
   247  
   248  type DoltCmdable interface {
   249  	DoltCmd(args ...string) *exec.Cmd
   250  }
   251  
   252  type DoltDebuggable interface {
   253  	DoltDebug(debuggerPort int, args ...string) *exec.Cmd
   254  }
   255  
   256  func StartSqlServer(dc DoltCmdable, opts ...SqlServerOpt) (*SqlServer, error) {
   257  	cmd := dc.DoltCmd("sql-server")
   258  	return runSqlServerCommand(dc, opts, cmd)
   259  }
   260  
   261  func DebugSqlServer(dc DoltCmdable, debuggerPort int, opts ...SqlServerOpt) (*SqlServer, error) {
   262  	ddb, ok := dc.(DoltDebuggable)
   263  	if !ok {
   264  		return nil, fmt.Errorf("%T does not implement DoltDebuggable", dc)
   265  	}
   266  
   267  	cmd := ddb.DoltDebug(debuggerPort, "sql-server")
   268  	return runSqlServerCommand(dc, append(opts, WithDebugPort(debuggerPort)), cmd)
   269  }
   270  
   271  func runSqlServerCommand(dc DoltCmdable, opts []SqlServerOpt, cmd *exec.Cmd) (*SqlServer, error) {
   272  	stdout, err := cmd.StdoutPipe()
   273  	if err != nil {
   274  		return nil, err
   275  	}
   276  	cmd.Stderr = cmd.Stdout
   277  	output := new(bytes.Buffer)
   278  	var wg sync.WaitGroup
   279  	wg.Add(1)
   280  	done := make(chan struct{})
   281  	go func() {
   282  		wg.Wait()
   283  		close(done)
   284  	}()
   285  
   286  	server := &SqlServer{
   287  		Done:   done,
   288  		Cmd:    cmd,
   289  		Port:   3306,
   290  		Output: output,
   291  	}
   292  	for _, o := range opts {
   293  		o(server)
   294  	}
   295  
   296  	go func() {
   297  		defer wg.Done()
   298  		multiCopyWithNamePrefix(os.Stdout, output, stdout, server.Name)
   299  	}()
   300  
   301  	server.RecreateCmd = func(args ...string) *exec.Cmd {
   302  		if server.DebugPort > 0 {
   303  			ddb, ok := dc.(DoltDebuggable)
   304  			if !ok {
   305  				panic(fmt.Sprintf("%T does not implement DoltDebuggable", dc))
   306  			}
   307  			return ddb.DoltDebug(server.DebugPort, args...)
   308  		} else {
   309  			return dc.DoltCmd(args...)
   310  		}
   311  	}
   312  
   313  	err = server.Cmd.Start()
   314  	if err != nil {
   315  		return nil, err
   316  	}
   317  	return server, nil
   318  }
   319  
   320  func (s *SqlServer) ErrorStop() error {
   321  	<-s.Done
   322  	return s.Cmd.Wait()
   323  }
   324  
   325  func multiCopyWithNamePrefix(stdout, captured io.Writer, in io.Reader, name string) {
   326  	reader := bufio.NewReader(in)
   327  	multiOut := io.MultiWriter(stdout, captured)
   328  	wantsPrefix := true
   329  	for {
   330  		line, isPrefix, err := reader.ReadLine()
   331  		if err != nil {
   332  			return
   333  		}
   334  		if wantsPrefix && name != "" {
   335  			stdout.Write([]byte("["))
   336  			stdout.Write([]byte(name))
   337  			stdout.Write([]byte("] "))
   338  		}
   339  		multiOut.Write(line)
   340  		if isPrefix {
   341  			wantsPrefix = false
   342  		} else {
   343  			multiOut.Write([]byte("\n"))
   344  			wantsPrefix = true
   345  		}
   346  	}
   347  }
   348  
   349  func (s *SqlServer) Restart(newargs *[]string, newenvs *[]string) error {
   350  	err := s.GracefulStop()
   351  	if err != nil {
   352  		return err
   353  	}
   354  	args := s.Cmd.Args[1:]
   355  	if newargs != nil {
   356  		args = append([]string{"sql-server"}, (*newargs)...)
   357  	}
   358  	s.Cmd = s.RecreateCmd(args...)
   359  	if newenvs != nil {
   360  		s.Cmd.Env = append(s.Cmd.Env, (*newenvs)...)
   361  	}
   362  	stdout, err := s.Cmd.StdoutPipe()
   363  	if err != nil {
   364  		return err
   365  	}
   366  	s.Cmd.Stderr = s.Cmd.Stdout
   367  	var wg sync.WaitGroup
   368  	wg.Add(1)
   369  	go func() {
   370  		defer wg.Done()
   371  		multiCopyWithNamePrefix(os.Stdout, s.Output, stdout, s.Name)
   372  	}()
   373  	s.Done = make(chan struct{})
   374  	go func() {
   375  		wg.Wait()
   376  		close(s.Done)
   377  	}()
   378  	return s.Cmd.Start()
   379  }
   380  
   381  func (s *SqlServer) DB(c Connection) (*sql.DB, error) {
   382  	var pass string
   383  	pass, err := c.Password()
   384  	if err != nil {
   385  		return nil, err
   386  	}
   387  	return ConnectDB(c.User, pass, s.DBName, "127.0.0.1", s.Port, c.DriverParams)
   388  }
   389  
   390  func ConnectDB(user, password, name, host string, port int, driverParams map[string]string) (*sql.DB, error) {
   391  	params := make(url.Values)
   392  	params.Set("allowAllFiles", "true")
   393  	params.Set("tls", "preferred")
   394  	for k, v := range driverParams {
   395  		params.Set(k, v)
   396  	}
   397  	dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?%s", user, password, host, port, name, params.Encode())
   398  
   399  	db, err := sql.Open("mysql", dsn)
   400  	if err != nil {
   401  		return nil, err
   402  	}
   403  	for i := 0; i < ConnectAttempts; i++ {
   404  		err = db.Ping()
   405  		if err == nil {
   406  			return db, nil
   407  		}
   408  		time.Sleep(RetrySleepDuration)
   409  	}
   410  	if err != nil {
   411  		return nil, err
   412  	}
   413  	return db, nil
   414  }