github.com/matrixorigin/matrixone@v0.7.0/cmd/mo-dump/main.go (about)

     1  // Copyright 2022 Matrix Origin
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package main
    16  
    17  import (
    18  	"bytes"
    19  	"context"
    20  	"database/sql"
    21  	"flag"
    22  	"fmt"
    23  	_ "github.com/go-sql-driver/mysql"
    24  	"github.com/matrixorigin/matrixone/pkg/catalog"
    25  	"github.com/matrixorigin/matrixone/pkg/common/moerr"
    26  	"github.com/matrixorigin/matrixone/pkg/common/mpool"
    27  	"os"
    28  	"strings"
    29  	"sync"
    30  	"time"
    31  )
    32  
    33  const (
    34  	defaultUsername        = "dump"
    35  	defaultPassword        = "111"
    36  	defaultHost            = "127.0.0.1"
    37  	defaultPort            = 6001
    38  	defaultNetBufferLength = mpool.MB
    39  	minNetBufferLength     = mpool.KB * 16
    40  	maxNetBufferLength     = mpool.MB * 16
    41  	timeout                = 10 * time.Second
    42  )
    43  
    44  var (
    45  	conn *sql.DB
    46  )
    47  
    48  type Column struct {
    49  	Name string
    50  	Type string
    51  }
    52  
    53  type Table struct {
    54  	Name string
    55  	Kind string
    56  }
    57  
    58  type Tables []Table
    59  
    60  func (t *Tables) String() string {
    61  	return fmt.Sprint(*t)
    62  }
    63  
    64  func (t *Tables) Set(value string) error {
    65  	*t = append(*t, Table{value, ""})
    66  	return nil
    67  }
    68  
    69  func main() {
    70  	var (
    71  		username, password, host, database string
    72  		tables                             Tables
    73  		port, netBufferLength              int
    74  		createDb                           string
    75  		createTable                        []string
    76  		err                                error
    77  	)
    78  	dumpStart := time.Now()
    79  	defer func() {
    80  		if err != nil {
    81  			fmt.Fprintf(os.Stderr, "modump error: %v\n", err)
    82  		}
    83  		if conn != nil {
    84  			err := conn.Close()
    85  			if err != nil {
    86  				fmt.Fprintf(os.Stderr, "modump error while close connection: %v\n", err)
    87  			}
    88  		}
    89  		if err == nil {
    90  			fmt.Fprintf(os.Stdout, "/* MODUMP SUCCESS, COST %v */\n", time.Since(dumpStart))
    91  		}
    92  	}()
    93  
    94  	ctx := context.Background()
    95  	flag.StringVar(&username, "u", defaultUsername, "username")
    96  	flag.StringVar(&password, "p", defaultPassword, "password")
    97  	flag.StringVar(&host, "h", defaultHost, "hostname")
    98  	flag.IntVar(&port, "P", defaultPort, "portNumber")
    99  	flag.IntVar(&netBufferLength, "net-buffer-length", defaultNetBufferLength, "net_buffer_length")
   100  	flag.StringVar(&database, "db", "", "databaseName, must be specified")
   101  	flag.Var(&tables, "tbl", "tableNameList, default all")
   102  	flag.Parse()
   103  	if netBufferLength < minNetBufferLength {
   104  		fmt.Fprintf(os.Stderr, "net_buffer_length must be greater than %d, set to %d\n", minNetBufferLength, minNetBufferLength)
   105  		netBufferLength = minNetBufferLength
   106  	}
   107  	if netBufferLength > maxNetBufferLength {
   108  		fmt.Fprintf(os.Stderr, "net_buffer_length must be less than %d, set to %d\n", maxNetBufferLength, maxNetBufferLength)
   109  		netBufferLength = maxNetBufferLength
   110  	}
   111  	if len(database) == 0 {
   112  		err = moerr.NewInvalidInput(ctx, "database must be specified")
   113  		return
   114  	}
   115  	dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s", username, password, host, port, database)
   116  	conn, err = sql.Open("mysql", dsn) // Open doesn't open a connection. Validate DSN data:
   117  	if err != nil {
   118  		return
   119  	}
   120  	ch := make(chan error)
   121  	go func() {
   122  		err := conn.Ping() // Before use, we must ping to validate DSN data:
   123  		ch <- err
   124  	}()
   125  
   126  	select {
   127  	case err = <-ch:
   128  	case <-time.After(timeout):
   129  		err = moerr.NewInternalError(ctx, "connect to %s timeout", dsn)
   130  	}
   131  	if err != nil {
   132  		return
   133  	}
   134  	if len(tables) == 0 { //dump all tables
   135  		createDb, err = getCreateDB(database)
   136  		if err != nil {
   137  			return
   138  		}
   139  		fmt.Printf("DROP DATABASE IF EXISTS `%s`;\n", database)
   140  		fmt.Println(createDb, ";")
   141  		fmt.Printf("USE `%s`;\n\n\n", database)
   142  	}
   143  	tables, err = getTables(database, tables)
   144  	if err != nil {
   145  		return
   146  	}
   147  	createTable = make([]string, len(tables))
   148  	for i, tbl := range tables {
   149  		createTable[i], err = getCreateTable(database, tbl.Name)
   150  		if err != nil {
   151  			return
   152  		}
   153  	}
   154  	bufPool := &sync.Pool{
   155  		New: func() any {
   156  			return &bytes.Buffer{}
   157  		},
   158  	}
   159  	for i, create := range createTable {
   160  		tbl := tables[i]
   161  		switch tbl.Kind {
   162  		case catalog.SystemOrdinaryRel:
   163  			fmt.Printf("DROP TABLE IF EXISTS `%s`;\n", tbl.Name)
   164  			showCreateTable(create, false)
   165  			err = showInsert(database, tbl.Name, bufPool, netBufferLength)
   166  			if err != nil {
   167  				return
   168  			}
   169  		case catalog.SystemExternalRel:
   170  			fmt.Printf("/*!EXTERNAL TABLE `%s`*/\n", tbl.Name)
   171  			fmt.Printf("DROP TABLE IF EXISTS `%s`;\n", tbl.Name)
   172  			showCreateTable(create, true)
   173  		case catalog.SystemViewRel:
   174  			fmt.Printf("DROP VIEW IF EXISTS `%s`;\n", tbl.Name)
   175  			showCreateTable(create, true)
   176  		default:
   177  			err = moerr.NewNotSupported(ctx, "table type %s", tbl.Kind)
   178  		}
   179  	}
   180  }
   181  
   182  func showCreateTable(createSql string, withNextLine bool) {
   183  	var suffix string
   184  	if !strings.HasSuffix(createSql, ";") {
   185  		suffix = ";"
   186  	}
   187  	if withNextLine {
   188  		suffix += "\n\n"
   189  	}
   190  	fmt.Printf("%s%s\n", createSql, suffix)
   191  }
   192  
   193  func getTables(db string, tables Tables) (Tables, error) {
   194  	sql := "select relname,relkind from mo_catalog.mo_tables where reldatabase = '" + db + "'"
   195  	if len(tables) > 0 {
   196  		sql += " and relname in ("
   197  		for i, tbl := range tables {
   198  			if i != 0 {
   199  				sql += ","
   200  			}
   201  			sql += "'" + tbl.Name + "'"
   202  		}
   203  		sql += ")"
   204  	}
   205  	r, err := conn.Query(sql) //TODO: after unified sys table prefix, add condition in where clause
   206  	if err != nil {
   207  		return nil, err
   208  	}
   209  	defer r.Close()
   210  
   211  	if tables == nil {
   212  		tables = Tables{}
   213  	}
   214  	tables = tables[:0]
   215  	for r.Next() {
   216  		var table string
   217  		var kind string
   218  		err = r.Scan(&table, &kind)
   219  		if err != nil {
   220  			return nil, err
   221  		}
   222  		if strings.HasPrefix(table, "__mo_") || strings.HasPrefix(table, "%!%") { //TODO: after adding condition in where clause, remove this
   223  			continue
   224  		}
   225  		tables = append(tables, Table{table, kind})
   226  	}
   227  	return tables, nil
   228  }
   229  
   230  func getCreateDB(db string) (string, error) {
   231  	r := conn.QueryRow("show create database `" + db + "`")
   232  	var create string
   233  	err := r.Scan(&db, &create)
   234  	if err != nil {
   235  		return "", err
   236  	}
   237  	return create, nil
   238  }
   239  
   240  func getCreateTable(db, tbl string) (string, error) {
   241  	r := conn.QueryRow("show create table `" + db + "`.`" + tbl + "`")
   242  	var create string
   243  	err := r.Scan(&tbl, &create)
   244  	if err != nil {
   245  		return "", err
   246  	}
   247  	return create, nil
   248  }
   249  
   250  func showInsert(db string, tbl string, bufPool *sync.Pool, netBufferLength int) error {
   251  	r, err := conn.Query("select * from `" + db + "`.`" + tbl + "`")
   252  	if err != nil {
   253  		return err
   254  	}
   255  	colTypes, err := r.ColumnTypes()
   256  	if err != nil {
   257  		return err
   258  	}
   259  	cols := make([]*Column, 0, len(colTypes))
   260  	for _, col := range colTypes {
   261  		var c Column
   262  		c.Name = col.Name()
   263  		c.Type = col.DatabaseTypeName()
   264  		cols = append(cols, &c)
   265  	}
   266  	args := make([]any, 0, len(cols))
   267  	for range cols {
   268  		var v sql.RawBytes
   269  		args = append(args, &v)
   270  	}
   271  	buf := bufPool.Get().(*bytes.Buffer)
   272  	curBuf := bufPool.Get().(*bytes.Buffer)
   273  	buf.Grow(netBufferLength)
   274  	initInert := "INSERT INTO `" + tbl + "` VALUES "
   275  	for {
   276  		buf.WriteString(initInert)
   277  		preLen := buf.Len()
   278  		first := true
   279  		if curBuf.Len() > 0 {
   280  			bts := curBuf.Bytes()
   281  			if bts[0] == ',' {
   282  				bts = bts[1:]
   283  			}
   284  			buf.Write(bts)
   285  			curBuf.Reset()
   286  			first = false
   287  		}
   288  		for r.Next() {
   289  			err = r.Scan(args...)
   290  			if err != nil {
   291  				return err
   292  			}
   293  			if !first {
   294  				curBuf.WriteString(",(")
   295  			} else {
   296  				curBuf.WriteString("(")
   297  				first = false
   298  			}
   299  
   300  			for i, v := range args {
   301  				if i > 0 {
   302  					curBuf.WriteString(",")
   303  				}
   304  				curBuf.WriteString(convertValue(v, cols[i].Type))
   305  			}
   306  			curBuf.WriteString(")")
   307  			if buf.Len()+curBuf.Len() >= netBufferLength {
   308  				break
   309  			}
   310  			buf.Write(curBuf.Bytes())
   311  			curBuf.Reset()
   312  		}
   313  		if buf.Len() > preLen {
   314  			buf.WriteString(";\n")
   315  			_, err = buf.WriteTo(os.Stdout)
   316  			if err != nil {
   317  				return err
   318  			}
   319  			continue
   320  		}
   321  		if curBuf.Len() > 0 {
   322  			continue
   323  		}
   324  		buf.Reset()
   325  		curBuf.Reset()
   326  		break
   327  	}
   328  	bufPool.Put(buf)
   329  	bufPool.Put(curBuf)
   330  	fmt.Printf("\n\n\n")
   331  	return nil
   332  }
   333  
   334  func convertValue(v any, typ string) string {
   335  	ret := *(v.(*sql.RawBytes))
   336  	if ret == nil {
   337  		return "NULL"
   338  	}
   339  	typ = strings.ToLower(typ)
   340  	switch typ {
   341  	case "int", "tinyint", "smallint", "bigint", "unsigned bigint", "unsigned int", "unsigned tinyint", "unsigned smallint", "float", "double", "bool", "boolean", "":
   342  		// why empty string in column type?
   343  		// see https://github.com/matrixorigin/matrixone/issues/8050#issuecomment-1431251524
   344  		return string(ret)
   345  	default:
   346  		return "'" + strings.Replace(string(ret), "'", "\\'", -1) + "'"
   347  	}
   348  }