github.phpd.cn/amacneil/dbmate@v1.4.1/pkg/dbmate/utils.go (about)

     1  package dbmate
     2  
     3  import (
     4  	"bufio"
     5  	"bytes"
     6  	"database/sql"
     7  	"errors"
     8  	"fmt"
     9  	"io"
    10  	"net/url"
    11  	"os"
    12  	"os/exec"
    13  	"strings"
    14  	"unicode"
    15  )
    16  
    17  // databaseName returns the database name from a URL
    18  func databaseName(u *url.URL) string {
    19  	name := u.Path
    20  	if len(name) > 0 && name[:1] == "/" {
    21  		name = name[1:]
    22  	}
    23  
    24  	return name
    25  }
    26  
    27  // mustClose ensures a stream is closed
    28  func mustClose(c io.Closer) {
    29  	if err := c.Close(); err != nil {
    30  		panic(err)
    31  	}
    32  }
    33  
    34  // ensureDir creates a directory if it does not already exist
    35  func ensureDir(dir string) error {
    36  	if err := os.MkdirAll(dir, 0755); err != nil {
    37  		return fmt.Errorf("unable to create directory `%s`", dir)
    38  	}
    39  
    40  	return nil
    41  }
    42  
    43  // runCommand runs a command and returns the stdout if successful
    44  func runCommand(name string, args ...string) ([]byte, error) {
    45  	var stdout, stderr bytes.Buffer
    46  	cmd := exec.Command(name, args...)
    47  	cmd.Stdout = &stdout
    48  	cmd.Stderr = &stderr
    49  
    50  	if err := cmd.Run(); err != nil {
    51  		// return stderr if available
    52  		if s := strings.TrimSpace(stderr.String()); s != "" {
    53  			return nil, errors.New(s)
    54  		}
    55  
    56  		// otherwise return error
    57  		return nil, err
    58  	}
    59  
    60  	// return stdout
    61  	return stdout.Bytes(), nil
    62  }
    63  
    64  // trimLeadingSQLComments removes sql comments and blank lines from the beginning of text
    65  // generally when performing sql dumps these contain host-specific information such as
    66  // client/server version numbers
    67  func trimLeadingSQLComments(data []byte) ([]byte, error) {
    68  	// create decent size buffer
    69  	out := bytes.NewBuffer(make([]byte, 0, len(data)))
    70  
    71  	// iterate over sql lines
    72  	preamble := true
    73  	scanner := bufio.NewScanner(bytes.NewReader(data))
    74  	for scanner.Scan() {
    75  		// we read bytes directly for premature performance optimization
    76  		line := scanner.Bytes()
    77  
    78  		if preamble && (len(line) == 0 || bytes.Equal(line[0:2], []byte("--"))) {
    79  			// header section, skip this line in output buffer
    80  			continue
    81  		}
    82  
    83  		// header section is over
    84  		preamble = false
    85  
    86  		// trim trailing whitespace
    87  		line = bytes.TrimRightFunc(line, unicode.IsSpace)
    88  
    89  		// copy bytes to output buffer
    90  		if _, err := out.Write(line); err != nil {
    91  			return nil, err
    92  		}
    93  		if _, err := out.WriteString("\n"); err != nil {
    94  			return nil, err
    95  		}
    96  	}
    97  	if err := scanner.Err(); err != nil {
    98  		return nil, err
    99  	}
   100  
   101  	return out.Bytes(), nil
   102  }
   103  
   104  // queryColumn runs a SQL statement and returns a slice of strings
   105  // it is assumed that the statement returns only one column
   106  // e.g. schema_migrations table
   107  func queryColumn(db *sql.DB, query string) ([]string, error) {
   108  	rows, err := db.Query(query)
   109  	if err != nil {
   110  		return nil, err
   111  	}
   112  	defer mustClose(rows)
   113  
   114  	// read into slice
   115  	var result []string
   116  	for rows.Next() {
   117  		var v string
   118  		if err := rows.Scan(&v); err != nil {
   119  			return nil, err
   120  		}
   121  
   122  		result = append(result, v)
   123  	}
   124  	if err = rows.Err(); err != nil {
   125  		return nil, err
   126  	}
   127  
   128  	return result, nil
   129  }