github.com/keysonZZZ/kmg@v0.0.0-20151121023212-05317bfd7d39/kmgSql/SetTableData.go (about) 1 package kmgSql 2 3 import ( 4 "database/sql" 5 "fmt" 6 "strings" 7 8 "github.com/bronze1man/kmg/encoding/kmgYaml" 9 ) 10 11 //设置表数据 12 // 注意会删除数据 13 func MustSetTableDataYaml(yaml string) { 14 err := GetDb().SetTablesDataYaml(yaml) 15 if err != nil { 16 panic(err) 17 } 18 } 19 20 // @deprecated 21 func (db DB) MustSetTablesDataYaml(yaml string) { 22 err := db.SetTablesDataYaml(yaml) 23 if err != nil { 24 panic(err) 25 } 26 } 27 28 // @deprecated 29 func (db DB) SetTablesDataYaml(yaml string) (err error) { 30 data := make(map[string][]map[string]string) 31 err = kmgYaml.Unmarshal([]byte(yaml), &data) 32 if err != nil { 33 return err 34 } 35 if len(data) == 0 { 36 return fmt.Errorf("[SetTablesDataYaml] try to set tables with no data,wrong format?") 37 } 38 return db.SetTablesData(data) 39 } 40 41 // @deprecated 42 // Set some tables data in this database. 43 // mostly for test 44 // not guarantee next increment id will be!! 45 //设置表数据 46 // 注意: 47 // * 会删除数据 48 // * 保证 auto_increase 的值是数据里面的最大值+1 49 func (db DB) SetTablesData(data map[string][]map[string]string) (err error) { 50 tx, err := db.Begin() 51 if err != nil { 52 return err 53 } 54 err = setTablesDataTransaction(data, tx) 55 if err != nil { 56 errRoll := tx.Rollback() 57 if errRoll != nil { 58 return fmt.Errorf("error [transaction] %s,[rollback] %s", err, errRoll) 59 } 60 return err 61 } 62 err = tx.Commit() 63 if err != nil { 64 return err 65 } 66 return nil 67 } 68 func setTablesDataTransaction(data map[string][]map[string]string, tx *sql.Tx) error { 69 for tableName, tableData := range data { 70 sql := fmt.Sprintf("truncate `%s`", tableName) 71 _, err := tx.Exec(sql) 72 if err != nil { 73 return err 74 } 75 for _, row := range tableData { 76 colNameList := []string{} 77 placeHolderNum := len(row) 78 valueList := []interface{}{} 79 for name, value := range row { 80 colNameList = append(colNameList, name) 81 valueList = append(valueList, value) 82 } 83 sqlColNamePart := "`" + strings.Join(colNameList, "`, `") + "`" 84 sqlValuePart := strings.Repeat("?, ", placeHolderNum-1) + "?" 85 sql = fmt.Sprintf("INSERT INTO `%s` (%s) VALUES (%s)", tableName, sqlColNamePart, sqlValuePart) 86 _, err := tx.Exec(sql, valueList...) 87 if err != nil { 88 return err 89 } 90 } 91 } 92 return nil 93 }