github.com/RevenueMonster/sqlike@v1.0.6/sql/dump/dump.go (about)

     1  package sqldump
     2  
     3  import (
     4  	"bufio"
     5  	"context"
     6  	"database/sql"
     7  	"errors"
     8  	"fmt"
     9  	"io"
    10  	"strings"
    11  	"sync"
    12  	"time"
    13  
    14  	"github.com/RevenueMonster/sqlike/sql/driver"
    15  	"github.com/RevenueMonster/sqlike/sqlike/actions"
    16  	"github.com/RevenueMonster/sqlike/types"
    17  	"github.com/RevenueMonster/sqlike/util"
    18  
    19  	"github.com/RevenueMonster/sqlike/sql/dialect"
    20  	sqlstmt "github.com/RevenueMonster/sqlike/sql/stmt"
    21  )
    22  
    23  // Column :
    24  type Column struct {
    25  	// column name
    26  	Name string
    27  
    28  	// column position in sql database
    29  	Position int
    30  
    31  	// column data type with precision or size, eg. VARCHAR(20)
    32  	Type string
    33  
    34  	// column data type without precision and size, eg. VARCHAR
    35  	DataType string
    36  
    37  	// whether column is nullable or not
    38  	IsNullable types.Boolean
    39  
    40  	// default value of the column
    41  	DefaultValue *string
    42  
    43  	// text character set encoding
    44  	Charset *string
    45  
    46  	// text collation for sorting
    47  	Collation *string
    48  
    49  	// column comment
    50  	Comment string
    51  
    52  	// extra information
    53  	Extra string
    54  }
    55  
    56  // Dumper :
    57  type Dumper struct {
    58  	mu      *sync.Mutex
    59  	driver  string
    60  	conn    driver.Queryer
    61  	dialect dialect.Dialect
    62  	mapper  map[string]Parser
    63  }
    64  
    65  // NewDumper :
    66  func NewDumper(driver string, conn driver.Queryer) *Dumper {
    67  	dumper := new(Dumper)
    68  	dumper.mu = new(sync.Mutex)
    69  	dumper.driver = strings.TrimSpace(strings.ToLower(driver))
    70  	dumper.conn = conn
    71  	dumper.dialect = dialect.GetDialectByDriver(driver)
    72  	dumper.mapper = map[string]Parser{
    73  		"VARCHAR":   byteToString,
    74  		"CHAR":      byteToString,
    75  		"ENUM":      byteToString,
    76  		"SET":       setToString,
    77  		"INT":       numToString,
    78  		"TINYINT":   numToString,
    79  		"SMALLINT":  numToString,
    80  		"MEDIUMINT": numToString,
    81  		"BIGINT":    numToString,
    82  		"TIMESTAMP": tsToString,
    83  		"DATETIME":  tsToString,
    84  		"DATE":      dateToString,
    85  		"JSON":      jsonToString,
    86  	}
    87  	return dumper
    88  }
    89  
    90  // RegisterParser :
    91  func (d *Dumper) RegisterParser(dataType string, parser Parser) {
    92  	if parser == nil {
    93  		panic("parser cannot be nil")
    94  	}
    95  	d.mu.Lock()
    96  	defer d.mu.Lock()
    97  	d.mapper[dataType] = parser
    98  }
    99  
   100  // BackupTo :
   101  func (d *Dumper) BackupTo(ctx context.Context, query interface{}, wr io.Writer) (affected int64, err error) {
   102  	w := bufio.NewWriter(wr)
   103  
   104  	var (
   105  		dbName string
   106  		table  string
   107  	)
   108  	switch v := query.(type) {
   109  	case *actions.FindActions:
   110  		dbName = v.Database
   111  		table = v.Table
   112  	case *actions.FindOneActions:
   113  		dbName = v.Database
   114  		table = v.Table
   115  	default:
   116  		return 0, errors.New("unsupported input")
   117  	}
   118  
   119  	columns, err := d.getColumns(ctx, dbName, table)
   120  	if err != nil {
   121  		return 0, err
   122  	}
   123  
   124  	stmt := sqlstmt.AcquireStmt(d.dialect)
   125  	defer sqlstmt.ReleaseStmt(stmt)
   126  
   127  	if err := d.dialect.SelectStmt(stmt, query); err != nil {
   128  		return 0, err
   129  	}
   130  
   131  	rows, err := d.conn.QueryContext(ctx, stmt.String(), stmt.Args()...)
   132  	if err != nil {
   133  		return 0, err
   134  	}
   135  	defer rows.Close()
   136  
   137  	version, err := d.getVersion(ctx)
   138  	if err != nil {
   139  		return 0, err
   140  	}
   141  
   142  	cols, _ := rows.Columns()
   143  	w.WriteString(`
   144  # ************************************************************
   145  # Sqlike Dumper
   146  #
   147  # https://github.com/RevenueMonster/sqlike
   148  #
   149  `)
   150  	w.WriteString("# Driver: " + d.driver + "\n")
   151  	w.WriteString("# Version: " + version + "\n")
   152  	// w.WriteString("# Host: rm-zf86x4n0wvyy6830yyo.mysql.kualalumpur.rds.aliyuncs.com\n")
   153  	w.WriteString("# Database: " + dbName + "\n")
   154  	w.WriteString("# Generation Time: " + time.Now().UTC().Format(time.RFC3339) + "\n")
   155  	w.WriteString("# ************************************************************\n")
   156  
   157  	w.WriteString(`
   158  /*!40101 SET @OLD_CHARACTER_SET_CLIENT=@@CHARACTER_SET_CLIENT */;
   159  /*!40101 SET @OLD_CHARACTER_SET_RESULTS=@@CHARACTER_SET_RESULTS */;
   160  /*!40101 SET @OLD_COLLATION_CONNECTION=@@COLLATION_CONNECTION */;
   161  /*!40101 SET NAMES utf8 */;
   162  SET NAMES utf8mb4;
   163  /*!40014 SET @OLD_FOREIGN_KEY_CHECKS=@@FOREIGN_KEY_CHECKS, FOREIGN_KEY_CHECKS=0 */;
   164  /*!40101 SET @OLD_SQL_MODE=@@SQL_MODE, SQL_MODE='NO_AUTO_VALUE_ON_ZERO' */;
   165  /*!40111 SET @OLD_SQL_NOTES=@@SQL_NOTES, SQL_NOTES=0 */;
   166  
   167  `)
   168  
   169  	table = d.dialect.Quote(table)
   170  
   171  	w.WriteString(fmt.Sprintf(`
   172  LOCK TABLES %s WRITE;
   173  /*!40000 ALTER TABLE %s DISABLE KEYS */;
   174  
   175  `, table, table))
   176  
   177  	defer func() {
   178  		w.WriteString(fmt.Sprintf(`
   179  
   180  /*!40000 ALTER TABLE %s ENABLE KEYS */;
   181  UNLOCK TABLES;
   182  
   183  /*!40111 SET SQL_NOTES=@OLD_SQL_NOTES */;
   184  /*!40101 SET SQL_MODE=@OLD_SQL_MODE */;
   185  /*!40014 SET FOREIGN_KEY_CHECKS=@OLD_FOREIGN_KEY_CHECKS */;
   186  /*!40101 SET CHARACTER_SET_CLIENT=@OLD_CHARACTER_SET_CLIENT */;
   187  /*!40101 SET CHARACTER_SET_RESULTS=@OLD_CHARACTER_SET_RESULTS */;
   188  /*!40101 SET COLLATION_CONNECTION=@OLD_COLLATION_CONNECTION */;
   189  `, table))
   190  		w.Flush()
   191  	}()
   192  
   193  	w.WriteString("INSERT INTO " + table + " ")
   194  	w.WriteByte('(')
   195  
   196  	for i, col := range cols {
   197  		if i > 0 {
   198  			w.WriteByte(',')
   199  		}
   200  		w.WriteString(d.dialect.Quote(col))
   201  	}
   202  	w.WriteByte(')')
   203  	w.WriteByte('\n')
   204  	w.WriteString("VALUES\n")
   205  
   206  	first := true
   207  	for rows.Next() {
   208  		if !first {
   209  			w.WriteByte(',')
   210  			w.WriteByte('\n')
   211  		}
   212  		length := len(cols)
   213  		data := make([]interface{}, length)
   214  		for i := 0; i < length; i++ {
   215  			data[i] = new(sql.RawBytes)
   216  		}
   217  
   218  		if err := rows.Scan(data...); err != nil {
   219  			return 0, err
   220  		}
   221  
   222  		w.WriteByte('(')
   223  		for i, col := range columns {
   224  			if i > 0 {
   225  				w.WriteByte(',')
   226  			}
   227  
   228  			x := (*data[i].(*sql.RawBytes))
   229  			if x == nil {
   230  				w.WriteString("NULL")
   231  				continue
   232  			}
   233  
   234  			parse, ok := d.mapper[col.DataType]
   235  			if !ok {
   236  				w.WriteString(byteToString(x))
   237  				continue
   238  			}
   239  
   240  			if _, err := w.WriteString(parse(x)); err != nil {
   241  				return 0, err
   242  			}
   243  		}
   244  		w.WriteByte(')')
   245  
   246  		first = false
   247  	}
   248  
   249  	w.WriteByte(';')
   250  	w.Flush()
   251  	return
   252  }
   253  
   254  func (d *Dumper) getVersion(ctx context.Context) (string, error) {
   255  	stmt := sqlstmt.AcquireStmt(d.dialect)
   256  	defer sqlstmt.ReleaseStmt(stmt)
   257  
   258  	d.dialect.GetVersion(stmt)
   259  	rows, err := d.conn.QueryContext(ctx, stmt.String(), stmt.Args()...)
   260  	if err != nil {
   261  		return "", err
   262  	}
   263  	defer rows.Close()
   264  
   265  	rows.Next()
   266  
   267  	var version string
   268  	if err := rows.Scan(&version); err != nil {
   269  		return "", err
   270  	}
   271  
   272  	return version, nil
   273  }
   274  
   275  func (d *Dumper) getColumns(ctx context.Context, dbName, table string) ([]Column, error) {
   276  	stmt := sqlstmt.AcquireStmt(d.dialect)
   277  	defer sqlstmt.ReleaseStmt(stmt)
   278  	d.dialect.GetColumns(stmt, dbName, table)
   279  
   280  	rows, err := d.conn.QueryContext(ctx, stmt.String(), stmt.Args()...)
   281  	if err != nil {
   282  		return nil, err
   283  	}
   284  	defer rows.Close()
   285  
   286  	columns := make([]Column, 0)
   287  	for i := 0; rows.Next(); i++ {
   288  		col := Column{}
   289  
   290  		if err := rows.Scan(
   291  			&col.Position,
   292  			&col.Name,
   293  			&col.Type,
   294  			&col.DefaultValue,
   295  			&col.IsNullable,
   296  			&col.DataType,
   297  			&col.Charset,
   298  			&col.Collation,
   299  			&col.Comment,
   300  			&col.Extra,
   301  		); err != nil {
   302  			return nil, err
   303  		}
   304  
   305  		col.Type = strings.ToUpper(col.Type)
   306  		col.DataType = strings.ToUpper(col.DataType)
   307  
   308  		columns = append(columns, col)
   309  	}
   310  
   311  	return columns, nil
   312  }
   313  
   314  func quoteString(str string, width int) string {
   315  	length := len(str)
   316  	blr := util.AcquireString()
   317  	defer util.ReleaseString(blr)
   318  	var lw int
   319  	for i := 0; i < length; i++ {
   320  		char := str[i]
   321  		switch char {
   322  		case '"':
   323  			blr.WriteString(`\"`)
   324  		default:
   325  			blr.WriteByte(char)
   326  		}
   327  
   328  		lw++
   329  		if lw >= width {
   330  			blr.WriteByte('\r')
   331  			lw = 0
   332  		}
   333  	}
   334  	return blr.String()
   335  }