github.com/systematiccaos/gorm@v1.22.6/callbacks/query.go (about) 1 package callbacks 2 3 import ( 4 "fmt" 5 "reflect" 6 "sort" 7 "strings" 8 9 "github.com/systematiccaos/gorm" 10 "github.com/systematiccaos/gorm/clause" 11 ) 12 13 func Query(db *gorm.DB) { 14 if db.Error == nil { 15 BuildQuerySQL(db) 16 17 if !db.DryRun && db.Error == nil { 18 rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) 19 if err != nil { 20 db.AddError(err) 21 return 22 } 23 gorm.Scan(rows, db, 0) 24 db.AddError(rows.Close()) 25 } 26 } 27 } 28 29 func BuildQuerySQL(db *gorm.DB) { 30 if db.Statement.Schema != nil { 31 for _, c := range db.Statement.Schema.QueryClauses { 32 db.Statement.AddClause(c) 33 } 34 } 35 36 if db.Statement.SQL.Len() == 0 { 37 db.Statement.SQL.Grow(100) 38 clauseSelect := clause.Select{Distinct: db.Statement.Distinct} 39 40 if db.Statement.ReflectValue.Kind() == reflect.Struct && db.Statement.ReflectValue.Type() == db.Statement.Schema.ModelType { 41 var conds []clause.Expression 42 for _, primaryField := range db.Statement.Schema.PrimaryFields { 43 if v, isZero := primaryField.ValueOf(db.Statement.ReflectValue); !isZero { 44 conds = append(conds, clause.Eq{Column: clause.Column{Table: db.Statement.Table, Name: primaryField.DBName}, Value: v}) 45 } 46 } 47 48 if len(conds) > 0 { 49 db.Statement.AddClause(clause.Where{Exprs: conds}) 50 } 51 } 52 53 if len(db.Statement.Selects) > 0 { 54 clauseSelect.Columns = make([]clause.Column, len(db.Statement.Selects)) 55 for idx, name := range db.Statement.Selects { 56 if db.Statement.Schema == nil { 57 clauseSelect.Columns[idx] = clause.Column{Name: name, Raw: true} 58 } else if f := db.Statement.Schema.LookUpField(name); f != nil { 59 clauseSelect.Columns[idx] = clause.Column{Name: f.DBName} 60 } else { 61 clauseSelect.Columns[idx] = clause.Column{Name: name, Raw: true} 62 } 63 } 64 } else if db.Statement.Schema != nil && len(db.Statement.Omits) > 0 { 65 selectColumns, _ := db.Statement.SelectAndOmitColumns(false, false) 66 clauseSelect.Columns = make([]clause.Column, 0, len(db.Statement.Schema.DBNames)) 67 for _, dbName := range db.Statement.Schema.DBNames { 68 if v, ok := selectColumns[dbName]; (ok && v) || !ok { 69 clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{Table: db.Statement.Table, Name: dbName}) 70 } 71 } 72 } else if db.Statement.Schema != nil && db.Statement.ReflectValue.IsValid() { 73 queryFields := db.QueryFields 74 if !queryFields { 75 switch db.Statement.ReflectValue.Kind() { 76 case reflect.Struct: 77 queryFields = db.Statement.ReflectValue.Type() != db.Statement.Schema.ModelType 78 case reflect.Slice: 79 queryFields = db.Statement.ReflectValue.Type().Elem() != db.Statement.Schema.ModelType 80 } 81 } 82 83 if queryFields { 84 stmt := gorm.Statement{DB: db} 85 // smaller struct 86 if err := stmt.Parse(db.Statement.Dest); err == nil && (db.QueryFields || stmt.Schema.ModelType != db.Statement.Schema.ModelType) { 87 clauseSelect.Columns = make([]clause.Column, len(stmt.Schema.DBNames)) 88 89 for idx, dbName := range stmt.Schema.DBNames { 90 clauseSelect.Columns[idx] = clause.Column{Table: db.Statement.Table, Name: dbName} 91 } 92 } 93 } 94 } 95 96 // inline joins 97 joins := []clause.Join{} 98 if fromClause, ok := db.Statement.Clauses["FROM"].Expression.(clause.From); ok { 99 joins = fromClause.Joins 100 } 101 102 if len(db.Statement.Joins) != 0 || len(joins) != 0 { 103 if len(db.Statement.Selects) == 0 && db.Statement.Schema != nil { 104 clauseSelect.Columns = make([]clause.Column, len(db.Statement.Schema.DBNames)) 105 for idx, dbName := range db.Statement.Schema.DBNames { 106 clauseSelect.Columns[idx] = clause.Column{Table: db.Statement.Table, Name: dbName} 107 } 108 } 109 110 for _, join := range db.Statement.Joins { 111 if db.Statement.Schema == nil { 112 joins = append(joins, clause.Join{ 113 Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds}, 114 }) 115 } else if relation, ok := db.Statement.Schema.Relationships.Relations[join.Name]; ok { 116 tableAliasName := relation.Name 117 118 for _, s := range relation.FieldSchema.DBNames { 119 clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ 120 Table: tableAliasName, 121 Name: s, 122 Alias: tableAliasName + "__" + s, 123 }) 124 } 125 126 exprs := make([]clause.Expression, len(relation.References)) 127 for idx, ref := range relation.References { 128 if ref.OwnPrimaryKey { 129 exprs[idx] = clause.Eq{ 130 Column: clause.Column{Table: clause.CurrentTable, Name: ref.PrimaryKey.DBName}, 131 Value: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName}, 132 } 133 } else { 134 if ref.PrimaryValue == "" { 135 exprs[idx] = clause.Eq{ 136 Column: clause.Column{Table: clause.CurrentTable, Name: ref.ForeignKey.DBName}, 137 Value: clause.Column{Table: tableAliasName, Name: ref.PrimaryKey.DBName}, 138 } 139 } else { 140 exprs[idx] = clause.Eq{ 141 Column: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName}, 142 Value: ref.PrimaryValue, 143 } 144 } 145 } 146 } 147 148 if join.On != nil { 149 onStmt := gorm.Statement{Table: tableAliasName, DB: db} 150 join.On.Build(&onStmt) 151 onSQL := onStmt.SQL.String() 152 vars := onStmt.Vars 153 for idx, v := range onStmt.Vars { 154 bindvar := strings.Builder{} 155 onStmt.Vars = vars[0 : idx+1] 156 db.Dialector.BindVarTo(&bindvar, &onStmt, v) 157 onSQL = strings.Replace(onSQL, bindvar.String(), "?", 1) 158 } 159 160 exprs = append(exprs, clause.Expr{SQL: onSQL, Vars: vars}) 161 } 162 163 joins = append(joins, clause.Join{ 164 Type: clause.LeftJoin, 165 Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName}, 166 ON: clause.Where{Exprs: exprs}, 167 }) 168 } else { 169 joins = append(joins, clause.Join{ 170 Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds}, 171 }) 172 } 173 } 174 175 db.Statement.Joins = nil 176 db.Statement.AddClause(clause.From{Joins: joins}) 177 } else { 178 db.Statement.AddClauseIfNotExists(clause.From{}) 179 } 180 181 db.Statement.AddClauseIfNotExists(clauseSelect) 182 183 db.Statement.Build(db.Statement.BuildClauses...) 184 } 185 } 186 187 func Preload(db *gorm.DB) { 188 if db.Error == nil && len(db.Statement.Preloads) > 0 { 189 preloadMap := map[string]map[string][]interface{}{} 190 for name := range db.Statement.Preloads { 191 preloadFields := strings.Split(name, ".") 192 if preloadFields[0] == clause.Associations { 193 for _, rel := range db.Statement.Schema.Relationships.Relations { 194 if rel.Schema == db.Statement.Schema { 195 if _, ok := preloadMap[rel.Name]; !ok { 196 preloadMap[rel.Name] = map[string][]interface{}{} 197 } 198 199 if value := strings.TrimPrefix(strings.TrimPrefix(name, preloadFields[0]), "."); value != "" { 200 preloadMap[rel.Name][value] = db.Statement.Preloads[name] 201 } 202 } 203 } 204 } else { 205 if _, ok := preloadMap[preloadFields[0]]; !ok { 206 preloadMap[preloadFields[0]] = map[string][]interface{}{} 207 } 208 209 if value := strings.TrimPrefix(strings.TrimPrefix(name, preloadFields[0]), "."); value != "" { 210 preloadMap[preloadFields[0]][value] = db.Statement.Preloads[name] 211 } 212 } 213 } 214 215 preloadNames := make([]string, 0, len(preloadMap)) 216 for key := range preloadMap { 217 preloadNames = append(preloadNames, key) 218 } 219 sort.Strings(preloadNames) 220 221 for _, name := range preloadNames { 222 if rel := db.Statement.Schema.Relationships.Relations[name]; rel != nil { 223 preload(db, rel, append(db.Statement.Preloads[name], db.Statement.Preloads[clause.Associations]...), preloadMap[name]) 224 } else { 225 db.AddError(fmt.Errorf("%s: %w for schema %s", name, gorm.ErrUnsupportedRelation, db.Statement.Schema.Name)) 226 } 227 } 228 } 229 } 230 231 func AfterQuery(db *gorm.DB) { 232 if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.AfterFind && db.RowsAffected > 0 { 233 callMethod(db, func(value interface{}, tx *gorm.DB) bool { 234 if i, ok := value.(AfterFindInterface); ok { 235 db.AddError(i.AfterFind(tx)) 236 return true 237 } 238 return false 239 }) 240 } 241 }