github.com/systematiccaos/gorm@v1.22.6/schema/schema.go (about) 1 package schema 2 3 import ( 4 "context" 5 "errors" 6 "fmt" 7 "go/ast" 8 "reflect" 9 "sync" 10 11 "github.com/systematiccaos/gorm/clause" 12 "github.com/systematiccaos/gorm/logger" 13 ) 14 15 // ErrUnsupportedDataType unsupported data type 16 var ErrUnsupportedDataType = errors.New("unsupported data type") 17 18 type Schema struct { 19 Name string 20 ModelType reflect.Type 21 Table string 22 PrioritizedPrimaryField *Field 23 DBNames []string 24 PrimaryFields []*Field 25 PrimaryFieldDBNames []string 26 Fields []*Field 27 FieldsByName map[string]*Field 28 FieldsByDBName map[string]*Field 29 FieldsWithDefaultDBValue []*Field // fields with default value assigned by database 30 Relationships Relationships 31 CreateClauses []clause.Interface 32 QueryClauses []clause.Interface 33 UpdateClauses []clause.Interface 34 DeleteClauses []clause.Interface 35 BeforeCreate, AfterCreate bool 36 BeforeUpdate, AfterUpdate bool 37 BeforeDelete, AfterDelete bool 38 BeforeSave, AfterSave bool 39 AfterFind bool 40 err error 41 initialized chan struct{} 42 namer Namer 43 cacheStore *sync.Map 44 } 45 46 func (schema Schema) String() string { 47 if schema.ModelType.Name() == "" { 48 return fmt.Sprintf("%s(%s)", schema.Name, schema.Table) 49 } 50 return fmt.Sprintf("%s.%s", schema.ModelType.PkgPath(), schema.ModelType.Name()) 51 } 52 53 func (schema Schema) MakeSlice() reflect.Value { 54 slice := reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(schema.ModelType)), 0, 20) 55 results := reflect.New(slice.Type()) 56 results.Elem().Set(slice) 57 return results 58 } 59 60 func (schema Schema) LookUpField(name string) *Field { 61 if field, ok := schema.FieldsByDBName[name]; ok { 62 return field 63 } 64 if field, ok := schema.FieldsByName[name]; ok { 65 return field 66 } 67 return nil 68 } 69 70 type Tabler interface { 71 TableName() string 72 } 73 74 // Parse get data type from dialector 75 func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) { 76 return ParseWithSpecialTableName(dest, cacheStore, namer, "") 77 } 78 79 // ParseWithSpecialTableName get data type from dialector with extra schema table 80 func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Namer, specialTableName string) (*Schema, error) { 81 if dest == nil { 82 return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) 83 } 84 85 value := reflect.ValueOf(dest) 86 if value.Kind() == reflect.Ptr && value.IsNil() { 87 value = reflect.New(value.Type().Elem()) 88 } 89 modelType := reflect.Indirect(value).Type() 90 91 if modelType.Kind() == reflect.Interface { 92 modelType = reflect.Indirect(reflect.ValueOf(dest)).Elem().Type() 93 } 94 95 for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr { 96 modelType = modelType.Elem() 97 } 98 99 if modelType.Kind() != reflect.Struct { 100 if modelType.PkgPath() == "" { 101 return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) 102 } 103 return nil, fmt.Errorf("%w: %s.%s", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name()) 104 } 105 106 // Cache the Schema for performance, 107 // Use the modelType or modelType + schemaTable (if it present) as cache key. 108 var schemaCacheKey interface{} 109 if specialTableName != "" { 110 schemaCacheKey = fmt.Sprintf("%p-%s", modelType, specialTableName) 111 } else { 112 schemaCacheKey = modelType 113 } 114 115 // Load exist schmema cache, return if exists 116 if v, ok := cacheStore.Load(schemaCacheKey); ok { 117 s := v.(*Schema) 118 // Wait for the initialization of other goroutines to complete 119 <-s.initialized 120 return s, s.err 121 } 122 123 modelValue := reflect.New(modelType) 124 tableName := namer.TableName(modelType.Name()) 125 if tabler, ok := modelValue.Interface().(Tabler); ok { 126 tableName = tabler.TableName() 127 } 128 if en, ok := namer.(embeddedNamer); ok { 129 tableName = en.Table 130 } 131 if specialTableName != "" && specialTableName != tableName { 132 tableName = specialTableName 133 } 134 135 schema := &Schema{ 136 Name: modelType.Name(), 137 ModelType: modelType, 138 Table: tableName, 139 FieldsByName: map[string]*Field{}, 140 FieldsByDBName: map[string]*Field{}, 141 Relationships: Relationships{Relations: map[string]*Relationship{}}, 142 cacheStore: cacheStore, 143 namer: namer, 144 initialized: make(chan struct{}), 145 } 146 // When the schema initialization is completed, the channel will be closed 147 defer close(schema.initialized) 148 149 // Load exist schmema cache, return if exists 150 if v, ok := cacheStore.Load(schemaCacheKey); ok { 151 s := v.(*Schema) 152 // Wait for the initialization of other goroutines to complete 153 <-s.initialized 154 return s, s.err 155 } 156 157 for i := 0; i < modelType.NumField(); i++ { 158 if fieldStruct := modelType.Field(i); ast.IsExported(fieldStruct.Name) { 159 if field := schema.ParseField(fieldStruct); field.EmbeddedSchema != nil { 160 schema.Fields = append(schema.Fields, field.EmbeddedSchema.Fields...) 161 } else { 162 schema.Fields = append(schema.Fields, field) 163 } 164 } 165 } 166 167 for _, field := range schema.Fields { 168 if field.DBName == "" && field.DataType != "" { 169 field.DBName = namer.ColumnName(schema.Table, field.Name) 170 } 171 172 if field.DBName != "" { 173 // nonexistence or shortest path or first appear prioritized if has permission 174 if v, ok := schema.FieldsByDBName[field.DBName]; !ok || ((field.Creatable || field.Updatable || field.Readable) && len(field.BindNames) < len(v.BindNames)) { 175 if _, ok := schema.FieldsByDBName[field.DBName]; !ok { 176 schema.DBNames = append(schema.DBNames, field.DBName) 177 } 178 schema.FieldsByDBName[field.DBName] = field 179 schema.FieldsByName[field.Name] = field 180 181 if v != nil && v.PrimaryKey { 182 for idx, f := range schema.PrimaryFields { 183 if f == v { 184 schema.PrimaryFields = append(schema.PrimaryFields[0:idx], schema.PrimaryFields[idx+1:]...) 185 } 186 } 187 } 188 189 if field.PrimaryKey { 190 schema.PrimaryFields = append(schema.PrimaryFields, field) 191 } 192 } 193 } 194 195 if of, ok := schema.FieldsByName[field.Name]; !ok || of.TagSettings["-"] == "-" { 196 schema.FieldsByName[field.Name] = field 197 } 198 199 field.setupValuerAndSetter() 200 } 201 202 prioritizedPrimaryField := schema.LookUpField("id") 203 if prioritizedPrimaryField == nil { 204 prioritizedPrimaryField = schema.LookUpField("ID") 205 } 206 207 if prioritizedPrimaryField != nil { 208 if prioritizedPrimaryField.PrimaryKey { 209 schema.PrioritizedPrimaryField = prioritizedPrimaryField 210 } else if len(schema.PrimaryFields) == 0 { 211 prioritizedPrimaryField.PrimaryKey = true 212 schema.PrioritizedPrimaryField = prioritizedPrimaryField 213 schema.PrimaryFields = append(schema.PrimaryFields, prioritizedPrimaryField) 214 } 215 } 216 217 if schema.PrioritizedPrimaryField == nil && len(schema.PrimaryFields) == 1 { 218 schema.PrioritizedPrimaryField = schema.PrimaryFields[0] 219 } 220 221 for _, field := range schema.PrimaryFields { 222 schema.PrimaryFieldDBNames = append(schema.PrimaryFieldDBNames, field.DBName) 223 } 224 225 for _, field := range schema.Fields { 226 if field.HasDefaultValue && field.DefaultValueInterface == nil { 227 schema.FieldsWithDefaultDBValue = append(schema.FieldsWithDefaultDBValue, field) 228 } 229 } 230 231 if field := schema.PrioritizedPrimaryField; field != nil { 232 switch field.GORMDataType { 233 case Int, Uint: 234 if _, ok := field.TagSettings["AUTOINCREMENT"]; !ok { 235 if !field.HasDefaultValue || field.DefaultValueInterface != nil { 236 schema.FieldsWithDefaultDBValue = append(schema.FieldsWithDefaultDBValue, field) 237 } 238 239 field.HasDefaultValue = true 240 field.AutoIncrement = true 241 } 242 } 243 } 244 245 callbacks := []string{"BeforeCreate", "AfterCreate", "BeforeUpdate", "AfterUpdate", "BeforeSave", "AfterSave", "BeforeDelete", "AfterDelete", "AfterFind"} 246 for _, name := range callbacks { 247 if methodValue := modelValue.MethodByName(name); methodValue.IsValid() { 248 switch methodValue.Type().String() { 249 case "func(*gorm.DB) error": // TODO hack 250 reflect.Indirect(reflect.ValueOf(schema)).FieldByName(name).SetBool(true) 251 default: 252 logger.Default.Warn(context.Background(), "Model %v don't match %vInterface, should be `%v(*gorm.DB) error`. Please see https://gorm.io/docs/hooks.html", schema, name, name) 253 } 254 } 255 } 256 257 // Cache the schema 258 if v, loaded := cacheStore.LoadOrStore(schemaCacheKey, schema); loaded { 259 s := v.(*Schema) 260 // Wait for the initialization of other goroutines to complete 261 <-s.initialized 262 return s, s.err 263 } 264 265 defer func() { 266 if schema.err != nil { 267 logger.Default.Error(context.Background(), schema.err.Error()) 268 cacheStore.Delete(modelType) 269 } 270 }() 271 272 if _, embedded := schema.cacheStore.Load(embeddedCacheKey); !embedded { 273 for _, field := range schema.Fields { 274 if field.DataType == "" && (field.Creatable || field.Updatable || field.Readable) { 275 if schema.parseRelation(field); schema.err != nil { 276 return schema, schema.err 277 } else { 278 schema.FieldsByName[field.Name] = field 279 } 280 } 281 282 fieldValue := reflect.New(field.IndirectFieldType) 283 fieldInterface := fieldValue.Interface() 284 if fc, ok := fieldInterface.(CreateClausesInterface); ok { 285 field.Schema.CreateClauses = append(field.Schema.CreateClauses, fc.CreateClauses(field)...) 286 } 287 288 if fc, ok := fieldInterface.(QueryClausesInterface); ok { 289 field.Schema.QueryClauses = append(field.Schema.QueryClauses, fc.QueryClauses(field)...) 290 } 291 292 if fc, ok := fieldInterface.(UpdateClausesInterface); ok { 293 field.Schema.UpdateClauses = append(field.Schema.UpdateClauses, fc.UpdateClauses(field)...) 294 } 295 296 if fc, ok := fieldInterface.(DeleteClausesInterface); ok { 297 field.Schema.DeleteClauses = append(field.Schema.DeleteClauses, fc.DeleteClauses(field)...) 298 } 299 } 300 } 301 302 return schema, schema.err 303 } 304 305 func getOrParse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) { 306 modelType := reflect.ValueOf(dest).Type() 307 for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr { 308 modelType = modelType.Elem() 309 } 310 311 if modelType.Kind() != reflect.Struct { 312 if modelType.PkgPath() == "" { 313 return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) 314 } 315 return nil, fmt.Errorf("%w: %s.%s", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name()) 316 } 317 318 if v, ok := cacheStore.Load(modelType); ok { 319 return v.(*Schema), nil 320 } 321 322 return Parse(dest, cacheStore, namer) 323 }