github.com/systematiccaos/gorm@v1.22.6/callbacks/query.go (about)

     1  package callbacks
     2  
     3  import (
     4  	"fmt"
     5  	"reflect"
     6  	"sort"
     7  	"strings"
     8  
     9  	"github.com/systematiccaos/gorm"
    10  	"github.com/systematiccaos/gorm/clause"
    11  )
    12  
    13  func Query(db *gorm.DB) {
    14  	if db.Error == nil {
    15  		BuildQuerySQL(db)
    16  
    17  		if !db.DryRun && db.Error == nil {
    18  			rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)
    19  			if err != nil {
    20  				db.AddError(err)
    21  				return
    22  			}
    23  			gorm.Scan(rows, db, 0)
    24  			db.AddError(rows.Close())
    25  		}
    26  	}
    27  }
    28  
    29  func BuildQuerySQL(db *gorm.DB) {
    30  	if db.Statement.Schema != nil {
    31  		for _, c := range db.Statement.Schema.QueryClauses {
    32  			db.Statement.AddClause(c)
    33  		}
    34  	}
    35  
    36  	if db.Statement.SQL.Len() == 0 {
    37  		db.Statement.SQL.Grow(100)
    38  		clauseSelect := clause.Select{Distinct: db.Statement.Distinct}
    39  
    40  		if db.Statement.ReflectValue.Kind() == reflect.Struct && db.Statement.ReflectValue.Type() == db.Statement.Schema.ModelType {
    41  			var conds []clause.Expression
    42  			for _, primaryField := range db.Statement.Schema.PrimaryFields {
    43  				if v, isZero := primaryField.ValueOf(db.Statement.ReflectValue); !isZero {
    44  					conds = append(conds, clause.Eq{Column: clause.Column{Table: db.Statement.Table, Name: primaryField.DBName}, Value: v})
    45  				}
    46  			}
    47  
    48  			if len(conds) > 0 {
    49  				db.Statement.AddClause(clause.Where{Exprs: conds})
    50  			}
    51  		}
    52  
    53  		if len(db.Statement.Selects) > 0 {
    54  			clauseSelect.Columns = make([]clause.Column, len(db.Statement.Selects))
    55  			for idx, name := range db.Statement.Selects {
    56  				if db.Statement.Schema == nil {
    57  					clauseSelect.Columns[idx] = clause.Column{Name: name, Raw: true}
    58  				} else if f := db.Statement.Schema.LookUpField(name); f != nil {
    59  					clauseSelect.Columns[idx] = clause.Column{Name: f.DBName}
    60  				} else {
    61  					clauseSelect.Columns[idx] = clause.Column{Name: name, Raw: true}
    62  				}
    63  			}
    64  		} else if db.Statement.Schema != nil && len(db.Statement.Omits) > 0 {
    65  			selectColumns, _ := db.Statement.SelectAndOmitColumns(false, false)
    66  			clauseSelect.Columns = make([]clause.Column, 0, len(db.Statement.Schema.DBNames))
    67  			for _, dbName := range db.Statement.Schema.DBNames {
    68  				if v, ok := selectColumns[dbName]; (ok && v) || !ok {
    69  					clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{Table: db.Statement.Table, Name: dbName})
    70  				}
    71  			}
    72  		} else if db.Statement.Schema != nil && db.Statement.ReflectValue.IsValid() {
    73  			queryFields := db.QueryFields
    74  			if !queryFields {
    75  				switch db.Statement.ReflectValue.Kind() {
    76  				case reflect.Struct:
    77  					queryFields = db.Statement.ReflectValue.Type() != db.Statement.Schema.ModelType
    78  				case reflect.Slice:
    79  					queryFields = db.Statement.ReflectValue.Type().Elem() != db.Statement.Schema.ModelType
    80  				}
    81  			}
    82  
    83  			if queryFields {
    84  				stmt := gorm.Statement{DB: db}
    85  				// smaller struct
    86  				if err := stmt.Parse(db.Statement.Dest); err == nil && (db.QueryFields || stmt.Schema.ModelType != db.Statement.Schema.ModelType) {
    87  					clauseSelect.Columns = make([]clause.Column, len(stmt.Schema.DBNames))
    88  
    89  					for idx, dbName := range stmt.Schema.DBNames {
    90  						clauseSelect.Columns[idx] = clause.Column{Table: db.Statement.Table, Name: dbName}
    91  					}
    92  				}
    93  			}
    94  		}
    95  
    96  		// inline joins
    97  		joins := []clause.Join{}
    98  		if fromClause, ok := db.Statement.Clauses["FROM"].Expression.(clause.From); ok {
    99  			joins = fromClause.Joins
   100  		}
   101  
   102  		if len(db.Statement.Joins) != 0 || len(joins) != 0 {
   103  			if len(db.Statement.Selects) == 0 && db.Statement.Schema != nil {
   104  				clauseSelect.Columns = make([]clause.Column, len(db.Statement.Schema.DBNames))
   105  				for idx, dbName := range db.Statement.Schema.DBNames {
   106  					clauseSelect.Columns[idx] = clause.Column{Table: db.Statement.Table, Name: dbName}
   107  				}
   108  			}
   109  
   110  			for _, join := range db.Statement.Joins {
   111  				if db.Statement.Schema == nil {
   112  					joins = append(joins, clause.Join{
   113  						Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds},
   114  					})
   115  				} else if relation, ok := db.Statement.Schema.Relationships.Relations[join.Name]; ok {
   116  					tableAliasName := relation.Name
   117  
   118  					for _, s := range relation.FieldSchema.DBNames {
   119  						clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{
   120  							Table: tableAliasName,
   121  							Name:  s,
   122  							Alias: tableAliasName + "__" + s,
   123  						})
   124  					}
   125  
   126  					exprs := make([]clause.Expression, len(relation.References))
   127  					for idx, ref := range relation.References {
   128  						if ref.OwnPrimaryKey {
   129  							exprs[idx] = clause.Eq{
   130  								Column: clause.Column{Table: clause.CurrentTable, Name: ref.PrimaryKey.DBName},
   131  								Value:  clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName},
   132  							}
   133  						} else {
   134  							if ref.PrimaryValue == "" {
   135  								exprs[idx] = clause.Eq{
   136  									Column: clause.Column{Table: clause.CurrentTable, Name: ref.ForeignKey.DBName},
   137  									Value:  clause.Column{Table: tableAliasName, Name: ref.PrimaryKey.DBName},
   138  								}
   139  							} else {
   140  								exprs[idx] = clause.Eq{
   141  									Column: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName},
   142  									Value:  ref.PrimaryValue,
   143  								}
   144  							}
   145  						}
   146  					}
   147  
   148  					if join.On != nil {
   149  						onStmt := gorm.Statement{Table: tableAliasName, DB: db}
   150  						join.On.Build(&onStmt)
   151  						onSQL := onStmt.SQL.String()
   152  						vars := onStmt.Vars
   153  						for idx, v := range onStmt.Vars {
   154  							bindvar := strings.Builder{}
   155  							onStmt.Vars = vars[0 : idx+1]
   156  							db.Dialector.BindVarTo(&bindvar, &onStmt, v)
   157  							onSQL = strings.Replace(onSQL, bindvar.String(), "?", 1)
   158  						}
   159  
   160  						exprs = append(exprs, clause.Expr{SQL: onSQL, Vars: vars})
   161  					}
   162  
   163  					joins = append(joins, clause.Join{
   164  						Type:  clause.LeftJoin,
   165  						Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName},
   166  						ON:    clause.Where{Exprs: exprs},
   167  					})
   168  				} else {
   169  					joins = append(joins, clause.Join{
   170  						Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds},
   171  					})
   172  				}
   173  			}
   174  
   175  			db.Statement.Joins = nil
   176  			db.Statement.AddClause(clause.From{Joins: joins})
   177  		} else {
   178  			db.Statement.AddClauseIfNotExists(clause.From{})
   179  		}
   180  
   181  		db.Statement.AddClauseIfNotExists(clauseSelect)
   182  
   183  		db.Statement.Build(db.Statement.BuildClauses...)
   184  	}
   185  }
   186  
   187  func Preload(db *gorm.DB) {
   188  	if db.Error == nil && len(db.Statement.Preloads) > 0 {
   189  		preloadMap := map[string]map[string][]interface{}{}
   190  		for name := range db.Statement.Preloads {
   191  			preloadFields := strings.Split(name, ".")
   192  			if preloadFields[0] == clause.Associations {
   193  				for _, rel := range db.Statement.Schema.Relationships.Relations {
   194  					if rel.Schema == db.Statement.Schema {
   195  						if _, ok := preloadMap[rel.Name]; !ok {
   196  							preloadMap[rel.Name] = map[string][]interface{}{}
   197  						}
   198  
   199  						if value := strings.TrimPrefix(strings.TrimPrefix(name, preloadFields[0]), "."); value != "" {
   200  							preloadMap[rel.Name][value] = db.Statement.Preloads[name]
   201  						}
   202  					}
   203  				}
   204  			} else {
   205  				if _, ok := preloadMap[preloadFields[0]]; !ok {
   206  					preloadMap[preloadFields[0]] = map[string][]interface{}{}
   207  				}
   208  
   209  				if value := strings.TrimPrefix(strings.TrimPrefix(name, preloadFields[0]), "."); value != "" {
   210  					preloadMap[preloadFields[0]][value] = db.Statement.Preloads[name]
   211  				}
   212  			}
   213  		}
   214  
   215  		preloadNames := make([]string, 0, len(preloadMap))
   216  		for key := range preloadMap {
   217  			preloadNames = append(preloadNames, key)
   218  		}
   219  		sort.Strings(preloadNames)
   220  
   221  		for _, name := range preloadNames {
   222  			if rel := db.Statement.Schema.Relationships.Relations[name]; rel != nil {
   223  				preload(db, rel, append(db.Statement.Preloads[name], db.Statement.Preloads[clause.Associations]...), preloadMap[name])
   224  			} else {
   225  				db.AddError(fmt.Errorf("%s: %w for schema %s", name, gorm.ErrUnsupportedRelation, db.Statement.Schema.Name))
   226  			}
   227  		}
   228  	}
   229  }
   230  
   231  func AfterQuery(db *gorm.DB) {
   232  	if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.AfterFind && db.RowsAffected > 0 {
   233  		callMethod(db, func(value interface{}, tx *gorm.DB) bool {
   234  			if i, ok := value.(AfterFindInterface); ok {
   235  				db.AddError(i.AfterFind(tx))
   236  				return true
   237  			}
   238  			return false
   239  		})
   240  	}
   241  }