github.com/keysonZZZ/kmg@v0.0.0-20151121023212-05317bfd7d39/kmgSql/SetTableData.go (about)

     1  package kmgSql
     2  
     3  import (
     4  	"database/sql"
     5  	"fmt"
     6  	"strings"
     7  
     8  	"github.com/bronze1man/kmg/encoding/kmgYaml"
     9  )
    10  
    11  //设置表数据
    12  // 注意会删除数据
    13  func MustSetTableDataYaml(yaml string) {
    14  	err := GetDb().SetTablesDataYaml(yaml)
    15  	if err != nil {
    16  		panic(err)
    17  	}
    18  }
    19  
    20  // @deprecated
    21  func (db DB) MustSetTablesDataYaml(yaml string) {
    22  	err := db.SetTablesDataYaml(yaml)
    23  	if err != nil {
    24  		panic(err)
    25  	}
    26  }
    27  
    28  // @deprecated
    29  func (db DB) SetTablesDataYaml(yaml string) (err error) {
    30  	data := make(map[string][]map[string]string)
    31  	err = kmgYaml.Unmarshal([]byte(yaml), &data)
    32  	if err != nil {
    33  		return err
    34  	}
    35  	if len(data) == 0 {
    36  		return fmt.Errorf("[SetTablesDataYaml] try to set tables with no data,wrong format?")
    37  	}
    38  	return db.SetTablesData(data)
    39  }
    40  
    41  // @deprecated
    42  // Set some tables data in this database.
    43  // mostly for test
    44  // not guarantee next increment id will be!!
    45  //设置表数据
    46  // 注意:
    47  //   * 会删除数据
    48  //   * 保证 auto_increase 的值是数据里面的最大值+1
    49  func (db DB) SetTablesData(data map[string][]map[string]string) (err error) {
    50  	tx, err := db.Begin()
    51  	if err != nil {
    52  		return err
    53  	}
    54  	err = setTablesDataTransaction(data, tx)
    55  	if err != nil {
    56  		errRoll := tx.Rollback()
    57  		if errRoll != nil {
    58  			return fmt.Errorf("error [transaction] %s,[rollback] %s", err, errRoll)
    59  		}
    60  		return err
    61  	}
    62  	err = tx.Commit()
    63  	if err != nil {
    64  		return err
    65  	}
    66  	return nil
    67  }
    68  func setTablesDataTransaction(data map[string][]map[string]string, tx *sql.Tx) error {
    69  	for tableName, tableData := range data {
    70  		sql := fmt.Sprintf("truncate `%s`", tableName)
    71  		_, err := tx.Exec(sql)
    72  		if err != nil {
    73  			return err
    74  		}
    75  		for _, row := range tableData {
    76  			colNameList := []string{}
    77  			placeHolderNum := len(row)
    78  			valueList := []interface{}{}
    79  			for name, value := range row {
    80  				colNameList = append(colNameList, name)
    81  				valueList = append(valueList, value)
    82  			}
    83  			sqlColNamePart := "`" + strings.Join(colNameList, "`, `") + "`"
    84  			sqlValuePart := strings.Repeat("?, ", placeHolderNum-1) + "?"
    85  			sql = fmt.Sprintf("INSERT INTO `%s` (%s) VALUES (%s)", tableName, sqlColNamePart, sqlValuePart)
    86  			_, err := tx.Exec(sql, valueList...)
    87  			if err != nil {
    88  				return err
    89  			}
    90  		}
    91  	}
    92  	return nil
    93  }