github.com/systematiccaos/gorm@v1.22.6/chainable_api.go (about)

     1  package gorm
     2  
     3  import (
     4  	"fmt"
     5  	"regexp"
     6  	"strings"
     7  
     8  	"github.com/systematiccaos/gorm/clause"
     9  	"github.com/systematiccaos/gorm/utils"
    10  )
    11  
    12  // Model specify the model you would like to run db operations
    13  //    // update all users's name to `hello`
    14  //    db.Model(&User{}).Update("name", "hello")
    15  //    // if user's primary key is non-blank, will use it as condition, then will only update the user's name to `hello`
    16  //    db.Model(&user).Update("name", "hello")
    17  func (db *DB) Model(value interface{}) (tx *DB) {
    18  	tx = db.getInstance()
    19  	tx.Statement.Model = value
    20  	return
    21  }
    22  
    23  // Clauses Add clauses
    24  func (db *DB) Clauses(conds ...clause.Expression) (tx *DB) {
    25  	tx = db.getInstance()
    26  	var whereConds []interface{}
    27  
    28  	for _, cond := range conds {
    29  		if c, ok := cond.(clause.Interface); ok {
    30  			tx.Statement.AddClause(c)
    31  		} else if optimizer, ok := cond.(StatementModifier); ok {
    32  			optimizer.ModifyStatement(tx.Statement)
    33  		} else {
    34  			whereConds = append(whereConds, cond)
    35  		}
    36  	}
    37  
    38  	if len(whereConds) > 0 {
    39  		tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondition(whereConds[0], whereConds[1:]...)})
    40  	}
    41  	return
    42  }
    43  
    44  var tableRegexp = regexp.MustCompile(`(?i).+? AS (\w+)\s*(?:$|,)`)
    45  
    46  // Table specify the table you would like to run db operations
    47  func (db *DB) Table(name string, args ...interface{}) (tx *DB) {
    48  	tx = db.getInstance()
    49  	if strings.Contains(name, " ") || strings.Contains(name, "`") || len(args) > 0 {
    50  		tx.Statement.TableExpr = &clause.Expr{SQL: name, Vars: args}
    51  		if results := tableRegexp.FindStringSubmatch(name); len(results) == 2 {
    52  			tx.Statement.Table = results[1]
    53  		}
    54  	} else if tables := strings.Split(name, "."); len(tables) == 2 {
    55  		tx.Statement.TableExpr = &clause.Expr{SQL: tx.Statement.Quote(name)}
    56  		tx.Statement.Table = tables[1]
    57  	} else {
    58  		tx.Statement.TableExpr = &clause.Expr{SQL: tx.Statement.Quote(name)}
    59  		tx.Statement.Table = name
    60  	}
    61  	return
    62  }
    63  
    64  // Distinct specify distinct fields that you want querying
    65  func (db *DB) Distinct(args ...interface{}) (tx *DB) {
    66  	tx = db.getInstance()
    67  	tx.Statement.Distinct = true
    68  	if len(args) > 0 {
    69  		tx = tx.Select(args[0], args[1:]...)
    70  	}
    71  	return
    72  }
    73  
    74  // Select specify fields that you want when querying, creating, updating
    75  func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) {
    76  	tx = db.getInstance()
    77  
    78  	switch v := query.(type) {
    79  	case []string:
    80  		tx.Statement.Selects = v
    81  
    82  		for _, arg := range args {
    83  			switch arg := arg.(type) {
    84  			case string:
    85  				tx.Statement.Selects = append(tx.Statement.Selects, arg)
    86  			case []string:
    87  				tx.Statement.Selects = append(tx.Statement.Selects, arg...)
    88  			default:
    89  				tx.AddError(fmt.Errorf("unsupported select args %v %v", query, args))
    90  				return
    91  			}
    92  		}
    93  		delete(tx.Statement.Clauses, "SELECT")
    94  	case string:
    95  		if strings.Count(v, "?") >= len(args) && len(args) > 0 {
    96  			tx.Statement.AddClause(clause.Select{
    97  				Distinct:   db.Statement.Distinct,
    98  				Expression: clause.Expr{SQL: v, Vars: args},
    99  			})
   100  		} else if strings.Count(v, "@") > 0 && len(args) > 0 {
   101  			tx.Statement.AddClause(clause.Select{
   102  				Distinct:   db.Statement.Distinct,
   103  				Expression: clause.NamedExpr{SQL: v, Vars: args},
   104  			})
   105  		} else {
   106  			tx.Statement.Selects = []string{v}
   107  
   108  			for _, arg := range args {
   109  				switch arg := arg.(type) {
   110  				case string:
   111  					tx.Statement.Selects = append(tx.Statement.Selects, arg)
   112  				case []string:
   113  					tx.Statement.Selects = append(tx.Statement.Selects, arg...)
   114  				default:
   115  					tx.Statement.AddClause(clause.Select{
   116  						Distinct:   db.Statement.Distinct,
   117  						Expression: clause.Expr{SQL: v, Vars: args},
   118  					})
   119  					return
   120  				}
   121  			}
   122  
   123  			delete(tx.Statement.Clauses, "SELECT")
   124  		}
   125  	default:
   126  		tx.AddError(fmt.Errorf("unsupported select args %v %v", query, args))
   127  	}
   128  
   129  	return
   130  }
   131  
   132  // Omit specify fields that you want to ignore when creating, updating and querying
   133  func (db *DB) Omit(columns ...string) (tx *DB) {
   134  	tx = db.getInstance()
   135  
   136  	if len(columns) == 1 && strings.ContainsRune(columns[0], ',') {
   137  		tx.Statement.Omits = strings.FieldsFunc(columns[0], utils.IsValidDBNameChar)
   138  	} else {
   139  		tx.Statement.Omits = columns
   140  	}
   141  	return
   142  }
   143  
   144  // Where add conditions
   145  func (db *DB) Where(query interface{}, args ...interface{}) (tx *DB) {
   146  	tx = db.getInstance()
   147  	if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 {
   148  		tx.Statement.AddClause(clause.Where{Exprs: conds})
   149  	}
   150  	return
   151  }
   152  
   153  // Not add NOT conditions
   154  func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) {
   155  	tx = db.getInstance()
   156  	if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 {
   157  		tx.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.Not(conds...)}})
   158  	}
   159  	return
   160  }
   161  
   162  // Or add OR conditions
   163  func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) {
   164  	tx = db.getInstance()
   165  	if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 {
   166  		tx.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(clause.And(conds...))}})
   167  	}
   168  	return
   169  }
   170  
   171  // Joins specify Joins conditions
   172  //     db.Joins("Account").Find(&user)
   173  //     db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user)
   174  //     db.Joins("Account", DB.Select("id").Where("user_id = users.id AND name = ?", "someName").Model(&Account{}))
   175  func (db *DB) Joins(query string, args ...interface{}) (tx *DB) {
   176  	tx = db.getInstance()
   177  
   178  	if len(args) == 1 {
   179  		if db, ok := args[0].(*DB); ok {
   180  			if where, ok := db.Statement.Clauses["WHERE"].Expression.(clause.Where); ok {
   181  				tx.Statement.Joins = append(tx.Statement.Joins, join{Name: query, Conds: args, On: &where})
   182  				return
   183  			}
   184  		}
   185  	}
   186  
   187  	tx.Statement.Joins = append(tx.Statement.Joins, join{Name: query, Conds: args})
   188  	return
   189  }
   190  
   191  // Group specify the group method on the find
   192  func (db *DB) Group(name string) (tx *DB) {
   193  	tx = db.getInstance()
   194  
   195  	fields := strings.FieldsFunc(name, utils.IsValidDBNameChar)
   196  	tx.Statement.AddClause(clause.GroupBy{
   197  		Columns: []clause.Column{{Name: name, Raw: len(fields) != 1}},
   198  	})
   199  	return
   200  }
   201  
   202  // Having specify HAVING conditions for GROUP BY
   203  func (db *DB) Having(query interface{}, args ...interface{}) (tx *DB) {
   204  	tx = db.getInstance()
   205  	tx.Statement.AddClause(clause.GroupBy{
   206  		Having: tx.Statement.BuildCondition(query, args...),
   207  	})
   208  	return
   209  }
   210  
   211  // Order specify order when retrieve records from database
   212  //     db.Order("name DESC")
   213  //     db.Order(clause.OrderByColumn{Column: clause.Column{Name: "name"}, Desc: true})
   214  func (db *DB) Order(value interface{}) (tx *DB) {
   215  	tx = db.getInstance()
   216  
   217  	switch v := value.(type) {
   218  	case clause.OrderByColumn:
   219  		tx.Statement.AddClause(clause.OrderBy{
   220  			Columns: []clause.OrderByColumn{v},
   221  		})
   222  	case string:
   223  		if v != "" {
   224  			tx.Statement.AddClause(clause.OrderBy{
   225  				Columns: []clause.OrderByColumn{{
   226  					Column: clause.Column{Name: v, Raw: true},
   227  				}},
   228  			})
   229  		}
   230  	}
   231  	return
   232  }
   233  
   234  // Limit specify the number of records to be retrieved
   235  func (db *DB) Limit(limit int) (tx *DB) {
   236  	tx = db.getInstance()
   237  	tx.Statement.AddClause(clause.Limit{Limit: limit})
   238  	return
   239  }
   240  
   241  // Offset specify the number of records to skip before starting to return the records
   242  func (db *DB) Offset(offset int) (tx *DB) {
   243  	tx = db.getInstance()
   244  	tx.Statement.AddClause(clause.Limit{Offset: offset})
   245  	return
   246  }
   247  
   248  // Scopes pass current database connection to arguments `func(DB) DB`, which could be used to add conditions dynamically
   249  //     func AmountGreaterThan1000(db *gorm.DB) *gorm.DB {
   250  //         return db.Where("amount > ?", 1000)
   251  //     }
   252  //
   253  //     func OrderStatus(status []string) func (db *gorm.DB) *gorm.DB {
   254  //         return func (db *gorm.DB) *gorm.DB {
   255  //             return db.Scopes(AmountGreaterThan1000).Where("status in (?)", status)
   256  //         }
   257  //     }
   258  //
   259  //     db.Scopes(AmountGreaterThan1000, OrderStatus([]string{"paid", "shipped"})).Find(&orders)
   260  func (db *DB) Scopes(funcs ...func(*DB) *DB) (tx *DB) {
   261  	tx = db.getInstance()
   262  	tx.Statement.scopes = append(tx.Statement.scopes, funcs...)
   263  	return tx
   264  }
   265  
   266  // Preload preload associations with given conditions
   267  //    db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users)
   268  func (db *DB) Preload(query string, args ...interface{}) (tx *DB) {
   269  	tx = db.getInstance()
   270  	if tx.Statement.Preloads == nil {
   271  		tx.Statement.Preloads = map[string][]interface{}{}
   272  	}
   273  	tx.Statement.Preloads[query] = args
   274  	return
   275  }
   276  
   277  func (db *DB) Attrs(attrs ...interface{}) (tx *DB) {
   278  	tx = db.getInstance()
   279  	tx.Statement.attrs = attrs
   280  	return
   281  }
   282  
   283  func (db *DB) Assign(attrs ...interface{}) (tx *DB) {
   284  	tx = db.getInstance()
   285  	tx.Statement.assigns = attrs
   286  	return
   287  }
   288  
   289  func (db *DB) Unscoped() (tx *DB) {
   290  	tx = db.getInstance()
   291  	tx.Statement.Unscoped = true
   292  	return
   293  }
   294  
   295  func (db *DB) Raw(sql string, values ...interface{}) (tx *DB) {
   296  	tx = db.getInstance()
   297  	tx.Statement.SQL = strings.Builder{}
   298  
   299  	if strings.Contains(sql, "@") {
   300  		clause.NamedExpr{SQL: sql, Vars: values}.Build(tx.Statement)
   301  	} else {
   302  		clause.Expr{SQL: sql, Vars: values}.Build(tx.Statement)
   303  	}
   304  	return
   305  }