github.com/mmatczuk/gohan@v0.0.0-20170206152520-30e45d9bdb69/db/sql/sql.go (about) 1 // Copyright (C) 2015 NTT Innovation Institute, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 12 // implied. 13 // See the License for the specific language governing permissions and 14 // limitations under the License. 15 16 package sql 17 18 import ( 19 "encoding/json" 20 "fmt" 21 "strconv" 22 "strings" 23 "time" 24 25 "github.com/cloudwan/gohan/db/pagination" 26 "github.com/cloudwan/gohan/db/transaction" 27 "github.com/cloudwan/gohan/util" 28 29 "database/sql" 30 31 "github.com/cloudwan/gohan/schema" 32 "github.com/jmoiron/sqlx" 33 sq "github.com/lann/squirrel" 34 // DB import 35 _ "github.com/go-sql-driver/mysql" 36 _ "github.com/mattn/go-sqlite3" 37 _ "github.com/nati/go-fakedb" 38 ) 39 40 const retryDB = 50 41 const retryDBWait = 10 42 43 const ( 44 configVersionColumnName = "config_version" 45 stateVersionColumnName = "state_version" 46 stateErrorColumnName = "state_error" 47 stateColumnName = "state" 48 stateMonitoringColumnName = "state_monitoring" 49 ) 50 51 //DB is sql implementation of DB 52 type DB struct { 53 sqlType, connectionString string 54 handlers map[string]propertyHandler 55 DB *sqlx.DB 56 } 57 58 //Transaction is sql implementation of Transaction 59 type Transaction struct { 60 transaction *sqlx.Tx 61 db *DB 62 closed bool 63 } 64 65 //NewDB constructor 66 func NewDB() *DB { 67 handlers := make(map[string]propertyHandler) 68 //TODO(nati) dynamic configuration 69 handlers["string"] = &stringHandler{} 70 handlers["number"] = &numberHandler{} 71 handlers["integer"] = &numberHandler{} 72 handlers["object"] = &jsonHandler{} 73 handlers["array"] = &jsonHandler{} 74 handlers["boolean"] = &boolHandler{} 75 return &DB{handlers: handlers} 76 } 77 78 //propertyHandler for each propertys 79 type propertyHandler interface { 80 encode(*schema.Property, interface{}) (interface{}, error) 81 decode(*schema.Property, interface{}) (interface{}, error) 82 dataType(*schema.Property) string 83 } 84 85 type defaultHandler struct { 86 } 87 88 func (handler *defaultHandler) encode(property *schema.Property, data interface{}) (interface{}, error) { 89 return data, nil 90 } 91 92 func (handler *defaultHandler) decode(property *schema.Property, data interface{}) (interface{}, error) { 93 return data, nil 94 } 95 96 func (handler *defaultHandler) dataType(property *schema.Property) (res string) { 97 // TODO(marcin) extend types for schema. Here is pretty ugly guessing 98 if property.ID == "id" || property.Relation != "" || property.Unique { 99 res = "varchar(255)" 100 } else { 101 res = "text" 102 } 103 return 104 } 105 106 type stringHandler struct { 107 defaultHandler 108 } 109 110 func (handler *stringHandler) encode(property *schema.Property, data interface{}) (interface{}, error) { 111 return data, nil 112 } 113 114 func (handler *stringHandler) decode(property *schema.Property, data interface{}) (interface{}, error) { 115 if bytes, ok := data.([]byte); ok { 116 return string(bytes), nil 117 } 118 return data, nil 119 } 120 121 type boolHandler struct{} 122 123 func (handler *boolHandler) encode(property *schema.Property, data interface{}) (interface{}, error) { 124 return data, nil 125 } 126 127 func (handler *boolHandler) decode(property *schema.Property, data interface{}) (res interface{}, err error) { 128 // different SQL drivers encode result with different type 129 // so we need to do manual checks 130 if data == nil { 131 return nil, nil 132 } 133 switch t := data.(type) { 134 default: 135 err = fmt.Errorf("unknown type %T", t) 136 return 137 case []uint8: // mysql 138 res, err = strconv.ParseUint(string(data.([]uint8)), 10, 64) 139 res = (res.(uint64) != 0) 140 case int64: //apparently also mysql 141 res = (data.(int64) != 0) 142 case bool: // sqlite3 143 res = data 144 } 145 return 146 } 147 148 func (handler *boolHandler) dataType(property *schema.Property) string { 149 return "boolean" 150 } 151 152 type numberHandler struct{} 153 154 func (handler *numberHandler) encode(property *schema.Property, data interface{}) (interface{}, error) { 155 return data, nil 156 } 157 158 func (handler *numberHandler) decode(property *schema.Property, data interface{}) (res interface{}, err error) { 159 // different SQL drivers encode result with different type 160 // so we need to do manual checks 161 if data == nil { 162 return nil, nil 163 } 164 switch t := data.(type) { 165 default: 166 return data, nil 167 case []uint8: // mysql 168 uintValue, _ := strconv.ParseUint(string(t), 10, 64) 169 res = int(uintValue) 170 case int64: // sqlite3 171 res = int(t) 172 } 173 return 174 } 175 176 func (handler *numberHandler) dataType(property *schema.Property) string { 177 return "numeric" 178 } 179 180 type jsonHandler struct { 181 } 182 183 func (handler *jsonHandler) encode(property *schema.Property, data interface{}) (interface{}, error) { 184 bytes, err := json.Marshal(data) 185 //TODO(nati) should handle encoding err 186 if err != nil { 187 return nil, err 188 } 189 return string(bytes), nil 190 } 191 192 func (handler *jsonHandler) decode(property *schema.Property, data interface{}) (interface{}, error) { 193 if bytes, ok := data.([]byte); ok { 194 var ret interface{} 195 err := json.Unmarshal(bytes, &ret) 196 return ret, err 197 } 198 return data, nil 199 } 200 201 func (handler *jsonHandler) dataType(property *schema.Property) string { 202 return "text" 203 } 204 205 func quote(str string) string { 206 return fmt.Sprintf("`%s`", str) 207 } 208 209 //Connect connec to the db 210 func (db *DB) Connect(sqlType, conn string, maxOpenConn int) (err error) { 211 db.sqlType = sqlType 212 db.connectionString = conn 213 rawDB, err := sql.Open(db.sqlType, db.connectionString) 214 if err != nil { 215 return err 216 } 217 rawDB.SetMaxOpenConns(maxOpenConn) 218 rawDB.SetMaxIdleConns(maxOpenConn) 219 db.DB = sqlx.NewDb(rawDB, db.sqlType) 220 221 if db.sqlType == "sqlite3" { 222 db.DB.Exec("PRAGMA foreign_keys = ON;") 223 } 224 225 for i := 0; i < retryDB; i++ { 226 err = db.DB.Ping() 227 if err == nil { 228 return nil 229 } 230 time.Sleep(retryDBWait * time.Second) 231 log.Info("Retrying db connection... (%s)", err) 232 } 233 234 return fmt.Errorf("Failed to connect db") 235 } 236 237 func (db *DB) Close() { 238 db.DB.Close() 239 } 240 241 //Begin starts new transaction 242 func (db *DB) Begin() (transaction.Transaction, error) { 243 transaction, err := db.DB.Beginx() 244 if err != nil { 245 return nil, err 246 } 247 if db.sqlType == "sqlite3" { 248 transaction.Exec("PRAGMA foreign_keys = ON;") 249 } 250 return &Transaction{ 251 db: db, 252 transaction: transaction, 253 closed: false, 254 }, nil 255 } 256 257 func (db *DB) genTableCols(s *schema.Schema, cascade bool, exclude []string) ([]string, []string, []string) { 258 var cols []string 259 var relations []string 260 var indices []string 261 schemaManager := schema.GetManager() 262 for _, property := range s.Properties { 263 if util.ContainsString(exclude, property.ID) { 264 continue 265 } 266 handler := db.handlers[property.Type] 267 sqlDataType := property.SQLType 268 sqlDataProperties := "" 269 if db.sqlType == "sqlite3" { 270 sqlDataType = strings.Replace(sqlDataType, "auto_increment", "autoincrement", 1) 271 } 272 if sqlDataType == "" { 273 sqlDataType = handler.dataType(&property) 274 if property.ID == "id" { 275 sqlDataProperties = " primary key" 276 } else { 277 if property.Nullable { 278 sqlDataProperties = " null" 279 } else { 280 sqlDataProperties = " not null" 281 } 282 if property.Unique { 283 sqlDataProperties = " unique" 284 } 285 } 286 } 287 sql := "`" + property.ID + "` " + sqlDataType + sqlDataProperties 288 289 cols = append(cols, sql) 290 if property.Relation != "" { 291 foreignSchema, _ := schemaManager.Schema(property.Relation) 292 if foreignSchema != nil { 293 cascadeString := "" 294 if cascade || 295 property.OnDeleteCascade || 296 (property.Relation == s.Parent && s.OnParentDeleteCascade) { 297 cascadeString = "on delete cascade" 298 } 299 300 relationColumn := "id" 301 if property.RelationColumn != "" { 302 relationColumn = property.RelationColumn 303 } 304 305 relations = append(relations, fmt.Sprintf("foreign key(`%s`) REFERENCES `%s`(%s) %s", 306 property.ID, foreignSchema.GetDbTableName(), relationColumn, cascadeString)) 307 } 308 } 309 310 if property.Indexed { 311 prefix := "" 312 if sqlDataType == "text" { 313 prefix = "(255)" 314 } 315 indices = append(indices, fmt.Sprintf("CREATE INDEX %s_%s_idx ON `%s`(`%s`%s);", s.Plural, property.ID, 316 s.Plural, property.ID, prefix)) 317 } 318 } 319 return cols, relations, indices 320 } 321 322 //AlterTableDef generates alter table sql 323 func (db *DB) AlterTableDef(s *schema.Schema, cascade bool) (string, []string, error) { 324 var existing []string 325 rows, err := db.DB.Query(fmt.Sprintf("select * from `%s` limit 1;", s.GetDbTableName())) 326 if err == nil { 327 defer rows.Close() 328 existing, err = rows.Columns() 329 } 330 331 if err != nil { 332 return "", nil, err 333 } 334 335 cols, relations, indices := db.genTableCols(s, cascade, existing) 336 cols = append(cols, relations...) 337 338 if len(cols) == 0 { 339 return "", nil, nil 340 } 341 alterTable := fmt.Sprintf("alter table`%s` add (%s);\n", s.GetDbTableName(), strings.Join(cols, ",")) 342 log.Debug("Altering table: " + alterTable) 343 log.Debug("Altering indices: " + strings.Join(indices, "")) 344 return alterTable, indices, nil 345 } 346 347 //GenTableDef generates create table sql 348 func (db *DB) GenTableDef(s *schema.Schema, cascade bool) (string, []string) { 349 cols, relations, indices := db.genTableCols(s, cascade, nil) 350 351 if s.StateVersioning() { 352 cols = append(cols, quote(configVersionColumnName)+"int not null default 1") 353 cols = append(cols, quote(stateVersionColumnName)+"int not null default 0") 354 cols = append(cols, quote(stateErrorColumnName)+"text not null default ''") 355 cols = append(cols, quote(stateColumnName)+"text not null default ''") 356 cols = append(cols, quote(stateMonitoringColumnName)+"text not null default ''") 357 } 358 359 cols = append(cols, relations...) 360 tableSQL := fmt.Sprintf("create table `%s` (%s);\n", s.GetDbTableName(), strings.Join(cols, ",")) 361 log.Debug("Creating table: " + tableSQL) 362 log.Debug("Creating indices: " + strings.Join(indices, "")) 363 return tableSQL, indices 364 } 365 366 //RegisterTable creates table in the db 367 func (db *DB) RegisterTable(s *schema.Schema, cascade bool) error { 368 if s.IsAbstract() { 369 return nil 370 } 371 tableDef, indices, err := db.AlterTableDef(s, cascade) 372 if err != nil { 373 tableDef, indices = db.GenTableDef(s, cascade) 374 } 375 if tableDef == "" { 376 return nil 377 } 378 _, err = db.DB.Exec(tableDef) 379 if err != nil && indices != nil { 380 for _, indexSql := range indices { 381 _, err = db.DB.Exec(indexSql) 382 if err != nil { 383 return err 384 } 385 } 386 } 387 return err 388 } 389 390 //DropTable drop table definition 391 func (db *DB) DropTable(s *schema.Schema) error { 392 if s.IsAbstract() { 393 return nil 394 } 395 sql := fmt.Sprintf("drop table if exists %s\n", quote(s.GetDbTableName())) 396 _, err := db.DB.Exec(sql) 397 return err 398 } 399 400 func escapeID(ID string) string { 401 return strings.Replace(ID, "-", "_escape_", -1) 402 } 403 404 func logQuery(sql string, args ...interface{}) { 405 sqlFormat := strings.Replace(sql, "?", "%s", -1) 406 query := fmt.Sprintf(sqlFormat, args...) 407 log.Debug("Executing SQL query '%s'", query) 408 } 409 410 // Exec executes sql in transaction 411 func (tx *Transaction) Exec(sql string, args ...interface{}) error { 412 logQuery(sql, args...) 413 _, err := tx.transaction.Exec(sql, args...) 414 return err 415 } 416 417 //Create create resource in the db 418 func (tx *Transaction) Create(resource *schema.Resource) error { 419 var cols []string 420 var values []interface{} 421 db := tx.db 422 s := resource.Schema() 423 data := resource.Data() 424 q := sq.Insert(quote(s.GetDbTableName())) 425 for _, attr := range s.Properties { 426 //TODO(nati) support optional value 427 if _, ok := data[attr.ID]; ok { 428 handler := db.handler(&attr) 429 cols = append(cols, quote(attr.ID)) 430 encoded, err := handler.encode(&attr, data[attr.ID]) 431 if err != nil { 432 return fmt.Errorf("SQL Create encoding error: %s", err) 433 } 434 values = append(values, encoded) 435 } 436 } 437 q = q.Columns(cols...).Values(values...) 438 sql, args, err := q.ToSql() 439 if err != nil { 440 return err 441 } 442 return tx.Exec(sql, args...) 443 } 444 445 func (tx *Transaction) updateQuery(resource *schema.Resource) (sq.UpdateBuilder, error) { 446 s := resource.Schema() 447 db := tx.db 448 data := resource.Data() 449 q := sq.Update(quote(s.GetDbTableName())) 450 for _, attr := range s.Properties { 451 //TODO(nati) support optional value 452 if _, ok := data[attr.ID]; ok { 453 handler := db.handler(&attr) 454 encoded, err := handler.encode(&attr, data[attr.ID]) 455 if err != nil { 456 return q, fmt.Errorf("SQL Update encoding error: %s", err) 457 } 458 q = q.Set(quote(attr.ID), encoded) 459 } 460 } 461 if s.Parent != "" { 462 q = q.Set(s.ParentSchemaPropertyID(), resource.ParentID()) 463 } 464 return q, nil 465 } 466 467 //Update update resource in the db 468 func (tx *Transaction) Update(resource *schema.Resource) error { 469 q, err := tx.updateQuery(resource) 470 if err != nil { 471 return err 472 } 473 sql, args, err := q.ToSql() 474 if err != nil { 475 return err 476 } 477 if resource.Schema().StateVersioning() { 478 sql += ", `" + configVersionColumnName + "` = `" + configVersionColumnName + "` + 1" 479 } 480 sql += " WHERE id = ?" 481 args = append(args, resource.ID()) 482 return tx.Exec(sql, args...) 483 } 484 485 //StateUpdate update resource state 486 func (tx *Transaction) StateUpdate(resource *schema.Resource, state *transaction.ResourceState) error { 487 q, err := tx.updateQuery(resource) 488 if err != nil { 489 return err 490 } 491 if resource.Schema().StateVersioning() && state != nil { 492 q = q.Set(quote(stateVersionColumnName), state.StateVersion) 493 q = q.Set(quote(stateErrorColumnName), state.Error) 494 q = q.Set(quote(stateColumnName), state.State) 495 q = q.Set(quote(stateMonitoringColumnName), state.Monitoring) 496 } 497 q = q.Where(sq.Eq{"id": resource.ID()}) 498 sql, args, err := q.ToSql() 499 if err != nil { 500 return err 501 } 502 return tx.Exec(sql, args...) 503 } 504 505 //Delete delete resource from db 506 func (tx *Transaction) Delete(s *schema.Schema, resourceID interface{}) error { 507 sql, args, err := sq.Delete(quote(s.GetDbTableName())).Where(sq.Eq{"id": resourceID}).ToSql() 508 if err != nil { 509 return err 510 } 511 return tx.Exec(sql, args...) 512 } 513 514 func (db *DB) handler(property *schema.Property) propertyHandler { 515 handler, ok := db.handlers[property.Type] 516 if ok { 517 return handler 518 } 519 return &defaultHandler{} 520 } 521 522 func makeColumnID(tableName string, property schema.Property) string { 523 return fmt.Sprintf("%s__%s", tableName, property.ID) 524 } 525 526 func makeColumn(tableName string, property schema.Property) string { 527 return fmt.Sprintf("%s.%s", tableName, quote(property.ID)) 528 } 529 530 func makeAliasTableName(tableName string, property schema.Property) string { 531 return fmt.Sprintf("%s__%s", tableName, property.RelationProperty) 532 } 533 534 // MakeColumns generates an array that has Gohan style colmun names 535 func MakeColumns(s *schema.Schema, tableName string, join bool) []string { 536 var cols []string 537 manager := schema.GetManager() 538 for _, property := range s.Properties { 539 cols = append(cols, makeColumn(tableName, property)+" as "+quote(makeColumnID(tableName, property))) 540 if property.RelationProperty != "" && join { 541 relatedSchema, _ := manager.Schema(property.Relation) 542 aliasTableName := makeAliasTableName(tableName, property) 543 cols = append(cols, MakeColumns(relatedSchema, aliasTableName, true)...) 544 } 545 } 546 return cols 547 } 548 549 func makeStateColumns(s *schema.Schema) (cols []string) { 550 dbTableName := s.GetDbTableName() 551 cols = append(cols, dbTableName+"."+configVersionColumnName+" as "+quote(configVersionColumnName)) 552 cols = append(cols, dbTableName+"."+stateVersionColumnName+" as "+quote(stateVersionColumnName)) 553 cols = append(cols, dbTableName+"."+stateErrorColumnName+" as "+quote(stateErrorColumnName)) 554 cols = append(cols, dbTableName+"."+stateColumnName+" as "+quote(stateColumnName)) 555 cols = append(cols, dbTableName+"."+stateMonitoringColumnName+" as "+quote(stateMonitoringColumnName)) 556 return cols 557 } 558 559 func makeJoin(s *schema.Schema, tableName string, q sq.SelectBuilder) sq.SelectBuilder { 560 manager := schema.GetManager() 561 for _, property := range s.Properties { 562 if property.RelationProperty == "" { 563 continue 564 } 565 relatedSchema, _ := manager.Schema(property.Relation) 566 aliasTableName := makeAliasTableName(tableName, property) 567 q = q.LeftJoin( 568 fmt.Sprintf("%s as %s on %s.%s = %s.id", quote(relatedSchema.GetDbTableName()), quote(aliasTableName), 569 quote(tableName), quote(property.ID), quote(aliasTableName))) 570 q = makeJoin(relatedSchema, aliasTableName, q) 571 } 572 return q 573 } 574 575 func (tx *Transaction) decode(s *schema.Schema, tableName string, data map[string]interface{}, resource map[string]interface{}) { 576 manager := schema.GetManager() 577 db := tx.db 578 for _, property := range s.Properties { 579 handler := db.handler(&property) 580 value := data[makeColumnID(tableName, property)] 581 if value != nil || property.Nullable { 582 decoded, err := handler.decode(&property, value) 583 if err != nil { 584 log.Error(fmt.Sprintf("SQL List decoding error: %s", err)) 585 } 586 resource[property.ID] = decoded 587 } 588 if property.RelationProperty != "" { 589 relatedSchema, _ := manager.Schema(property.Relation) 590 resourceData := map[string]interface{}{} 591 aliasTableName := makeAliasTableName(tableName, property) 592 tx.decode(relatedSchema, aliasTableName, data, resourceData) 593 resource[property.RelationProperty] = resourceData 594 } 595 } 596 } 597 598 func decodeState(data map[string]interface{}, state *transaction.ResourceState) error { 599 var ok bool 600 state.ConfigVersion, ok = data[configVersionColumnName].(int64) 601 if !ok { 602 return fmt.Errorf("Wrong state column %s returned from query", configVersionColumnName) 603 } 604 state.StateVersion, ok = data[stateVersionColumnName].(int64) 605 if !ok { 606 return fmt.Errorf("Wrong state column %s returned from query", stateVersionColumnName) 607 } 608 stateError, ok := data[stateErrorColumnName].([]byte) 609 if !ok { 610 return fmt.Errorf("Wrong state column %s returned from query", stateErrorColumnName) 611 } 612 state.Error = string(stateError) 613 stateState, ok := data[stateColumnName].([]byte) 614 if !ok { 615 return fmt.Errorf("Wrong state column %s returned from query", stateColumnName) 616 } 617 state.State = string(stateState) 618 stateMonitoring, ok := data[stateMonitoringColumnName].([]byte) 619 if !ok { 620 return fmt.Errorf("Wrong state column %s returned from query", stateMonitoringColumnName) 621 } 622 state.Monitoring = string(stateMonitoring) 623 return nil 624 } 625 626 //List resources in the db 627 func (tx *Transaction) List(s *schema.Schema, filter transaction.Filter, pg *pagination.Paginator) (list []*schema.Resource, total uint64, err error) { 628 cols := MakeColumns(s, s.GetDbTableName(), true) 629 q := sq.Select(cols...).From(quote(s.GetDbTableName())) 630 q, err = addFilterToQuery(s, q, filter, true) 631 if err != nil { 632 return nil, 0, err 633 } 634 if pg != nil { 635 property, err := s.GetPropertyByID(pg.Key) 636 if err == nil { 637 q = q.OrderBy(makeColumn(s.GetDbTableName(), *property) + " " + pg.Order) 638 if pg.Limit > 0 { 639 q = q.Limit(pg.Limit) 640 } 641 if pg.Offset > 0 { 642 q = q.Offset(pg.Offset) 643 } 644 } 645 } 646 q = makeJoin(s, s.GetDbTableName(), q) 647 648 sql, args, err := q.ToSql() 649 if err != nil { 650 return 651 } 652 logQuery(sql, args...) 653 rows, err := tx.transaction.Queryx(sql, args...) 654 if err != nil { 655 return 656 } 657 defer rows.Close() 658 list, err = tx.decodeRows(s, rows, list) 659 if err != nil { 660 return nil, 0, err 661 } 662 total, err = tx.count(s, filter) 663 return 664 } 665 666 // Query with raw sql string 667 func (tx *Transaction) Query(s *schema.Schema, query string, arguments []interface{}) (list []*schema.Resource, err error) { 668 logQuery(query, arguments...) 669 rows, err := tx.transaction.Queryx(query, arguments...) 670 if err != nil { 671 return nil, fmt.Errorf("Failed to run query: %s", query) 672 } 673 674 defer rows.Close() 675 list, err = tx.decodeRows(s, rows, list) 676 if err != nil { 677 return nil, err 678 } 679 680 return 681 } 682 683 func (tx *Transaction) decodeRows(s *schema.Schema, rows *sqlx.Rows, list []*schema.Resource) ([]*schema.Resource, error) { 684 for rows.Next() { 685 resourceData := map[string]interface{}{} 686 data := map[string]interface{}{} 687 rows.MapScan(data) 688 689 var resource *schema.Resource 690 tx.decode(s, s.GetDbTableName(), data, resourceData) 691 resource, err := schema.NewResource(s, resourceData) 692 if err != nil { 693 return nil, fmt.Errorf("Failed to decode rows") 694 } 695 list = append(list, resource) 696 } 697 return list, nil 698 } 699 700 //count count all matching resources in the db 701 func (tx *Transaction) count(s *schema.Schema, filter transaction.Filter) (res uint64, err error) { 702 q := sq.Select("Count(id) as count").From(quote(s.GetDbTableName())) 703 //Filter get already tested 704 q, _ = addFilterToQuery(s, q, filter, false) 705 sql, args, err := q.ToSql() 706 if err != nil { 707 return 708 } 709 result := map[string]interface{}{} 710 err = tx.transaction.QueryRowx(sql, args...).MapScan(result) 711 if err != nil { 712 return 713 } 714 count, _ := result["count"] 715 decoder := &numberHandler{} 716 decoded, decodeErr := decoder.decode(nil, count) 717 if decodeErr != nil { 718 err = fmt.Errorf("SQL List decoding error: %s", decodeErr) 719 return 720 } 721 res = uint64(decoded.(int)) 722 return 723 } 724 725 //Fetch resources by ID in the db 726 func (tx *Transaction) Fetch(s *schema.Schema, filter transaction.Filter) (*schema.Resource, error) { 727 list, _, err := tx.List(s, filter, nil) 728 if len(list) < 1 { 729 return nil, fmt.Errorf("Failed to fetch %s", filter) 730 } 731 return list[0], err 732 } 733 734 //StateFetch fetches the state of the specified resource 735 func (tx *Transaction) StateFetch(s *schema.Schema, filter transaction.Filter) (state transaction.ResourceState, err error) { 736 if !s.StateVersioning() { 737 err = fmt.Errorf("Schema %s does not support state versioning.", s.ID) 738 return 739 } 740 cols := makeStateColumns(s) 741 q := sq.Select(cols...).From(quote(s.GetDbTableName())) 742 q, _ = addFilterToQuery(s, q, filter, true) 743 sql, args, err := q.ToSql() 744 if err != nil { 745 return 746 } 747 logQuery(sql, args...) 748 rows, err := tx.transaction.Queryx(sql, args...) 749 if err != nil { 750 return 751 } 752 defer rows.Close() 753 if !rows.Next() { 754 err = fmt.Errorf("No resource found") 755 return 756 } 757 data := map[string]interface{}{} 758 rows.MapScan(data) 759 err = decodeState(data, &state) 760 return 761 } 762 763 //RawTransaction returns raw transaction 764 func (tx *Transaction) RawTransaction() *sqlx.Tx { 765 return tx.transaction 766 } 767 768 //SetIsolationLevel specify transaction isolation level 769 func (tx *Transaction) SetIsolationLevel(level transaction.Type) error { 770 if tx.db.sqlType == "mysql" { 771 err := tx.Exec(fmt.Sprintf("set session transaction isolation level %s", level)) 772 return err 773 } 774 return nil 775 } 776 777 //Commit commits transaction 778 func (tx *Transaction) Commit() error { 779 err := tx.transaction.Commit() 780 if err != nil { 781 return err 782 } 783 tx.closed = true 784 return nil 785 } 786 787 //Close closes connection 788 func (tx *Transaction) Close() error { 789 //Rollback if it isn't committed yet 790 var err error 791 if !tx.closed { 792 err = tx.transaction.Rollback() 793 if err != nil { 794 return err 795 } 796 tx.closed = true 797 } 798 return nil 799 } 800 801 //Closed returns whether the transaction is closed 802 func (tx *Transaction) Closed() bool { 803 return tx.closed 804 } 805 806 func addFilterToQuery(s *schema.Schema, q sq.SelectBuilder, filter map[string]interface{}, join bool) (sq.SelectBuilder, error) { 807 if filter == nil { 808 return q, nil 809 } 810 for key, value := range filter { 811 property, err := s.GetPropertyByID(key) 812 813 if err != nil { 814 return q, err 815 } 816 817 var column string 818 if join { 819 column = makeColumn(s.GetDbTableName(), *property) 820 } else { 821 column = quote(key) 822 } 823 824 queryValues, ok := value.([]string) 825 if ok && property.Type == "boolean" { 826 v := make([]bool, len(queryValues)) 827 for i, j := range queryValues { 828 v[i], _ = strconv.ParseBool(j) 829 } 830 q = q.Where(sq.Eq{column: v}) 831 } else { 832 q = q.Where(sq.Eq{column: value}) 833 } 834 } 835 return q, nil 836 } 837 838 //SetMaxOpenConns limit maximum connections 839 func (db *DB) SetMaxOpenConns(maxIdleConns int) { 840 // db.DB.SetMaxOpenConns(maxIdleConns) 841 // db.DB.SetMaxIdleConns(maxIdleConns) 842 }