github.com/systematiccaos/gorm@v1.22.6/schema/schema.go (about)

     1  package schema
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"go/ast"
     8  	"reflect"
     9  	"sync"
    10  
    11  	"github.com/systematiccaos/gorm/clause"
    12  	"github.com/systematiccaos/gorm/logger"
    13  )
    14  
    15  // ErrUnsupportedDataType unsupported data type
    16  var ErrUnsupportedDataType = errors.New("unsupported data type")
    17  
    18  type Schema struct {
    19  	Name                      string
    20  	ModelType                 reflect.Type
    21  	Table                     string
    22  	PrioritizedPrimaryField   *Field
    23  	DBNames                   []string
    24  	PrimaryFields             []*Field
    25  	PrimaryFieldDBNames       []string
    26  	Fields                    []*Field
    27  	FieldsByName              map[string]*Field
    28  	FieldsByDBName            map[string]*Field
    29  	FieldsWithDefaultDBValue  []*Field // fields with default value assigned by database
    30  	Relationships             Relationships
    31  	CreateClauses             []clause.Interface
    32  	QueryClauses              []clause.Interface
    33  	UpdateClauses             []clause.Interface
    34  	DeleteClauses             []clause.Interface
    35  	BeforeCreate, AfterCreate bool
    36  	BeforeUpdate, AfterUpdate bool
    37  	BeforeDelete, AfterDelete bool
    38  	BeforeSave, AfterSave     bool
    39  	AfterFind                 bool
    40  	err                       error
    41  	initialized               chan struct{}
    42  	namer                     Namer
    43  	cacheStore                *sync.Map
    44  }
    45  
    46  func (schema Schema) String() string {
    47  	if schema.ModelType.Name() == "" {
    48  		return fmt.Sprintf("%s(%s)", schema.Name, schema.Table)
    49  	}
    50  	return fmt.Sprintf("%s.%s", schema.ModelType.PkgPath(), schema.ModelType.Name())
    51  }
    52  
    53  func (schema Schema) MakeSlice() reflect.Value {
    54  	slice := reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(schema.ModelType)), 0, 20)
    55  	results := reflect.New(slice.Type())
    56  	results.Elem().Set(slice)
    57  	return results
    58  }
    59  
    60  func (schema Schema) LookUpField(name string) *Field {
    61  	if field, ok := schema.FieldsByDBName[name]; ok {
    62  		return field
    63  	}
    64  	if field, ok := schema.FieldsByName[name]; ok {
    65  		return field
    66  	}
    67  	return nil
    68  }
    69  
    70  type Tabler interface {
    71  	TableName() string
    72  }
    73  
    74  // Parse get data type from dialector
    75  func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) {
    76  	return ParseWithSpecialTableName(dest, cacheStore, namer, "")
    77  }
    78  
    79  // ParseWithSpecialTableName get data type from dialector with extra schema table
    80  func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Namer, specialTableName string) (*Schema, error) {
    81  	if dest == nil {
    82  		return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest)
    83  	}
    84  
    85  	value := reflect.ValueOf(dest)
    86  	if value.Kind() == reflect.Ptr && value.IsNil() {
    87  		value = reflect.New(value.Type().Elem())
    88  	}
    89  	modelType := reflect.Indirect(value).Type()
    90  
    91  	if modelType.Kind() == reflect.Interface {
    92  		modelType = reflect.Indirect(reflect.ValueOf(dest)).Elem().Type()
    93  	}
    94  
    95  	for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr {
    96  		modelType = modelType.Elem()
    97  	}
    98  
    99  	if modelType.Kind() != reflect.Struct {
   100  		if modelType.PkgPath() == "" {
   101  			return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest)
   102  		}
   103  		return nil, fmt.Errorf("%w: %s.%s", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name())
   104  	}
   105  
   106  	// Cache the Schema for performance,
   107  	// Use the modelType or modelType + schemaTable (if it present) as cache key.
   108  	var schemaCacheKey interface{}
   109  	if specialTableName != "" {
   110  		schemaCacheKey = fmt.Sprintf("%p-%s", modelType, specialTableName)
   111  	} else {
   112  		schemaCacheKey = modelType
   113  	}
   114  
   115  	// Load exist schmema cache, return if exists
   116  	if v, ok := cacheStore.Load(schemaCacheKey); ok {
   117  		s := v.(*Schema)
   118  		// Wait for the initialization of other goroutines to complete
   119  		<-s.initialized
   120  		return s, s.err
   121  	}
   122  
   123  	modelValue := reflect.New(modelType)
   124  	tableName := namer.TableName(modelType.Name())
   125  	if tabler, ok := modelValue.Interface().(Tabler); ok {
   126  		tableName = tabler.TableName()
   127  	}
   128  	if en, ok := namer.(embeddedNamer); ok {
   129  		tableName = en.Table
   130  	}
   131  	if specialTableName != "" && specialTableName != tableName {
   132  		tableName = specialTableName
   133  	}
   134  
   135  	schema := &Schema{
   136  		Name:           modelType.Name(),
   137  		ModelType:      modelType,
   138  		Table:          tableName,
   139  		FieldsByName:   map[string]*Field{},
   140  		FieldsByDBName: map[string]*Field{},
   141  		Relationships:  Relationships{Relations: map[string]*Relationship{}},
   142  		cacheStore:     cacheStore,
   143  		namer:          namer,
   144  		initialized:    make(chan struct{}),
   145  	}
   146  	// When the schema initialization is completed, the channel will be closed
   147  	defer close(schema.initialized)
   148  
   149  	// Load exist schmema cache, return if exists
   150  	if v, ok := cacheStore.Load(schemaCacheKey); ok {
   151  		s := v.(*Schema)
   152  		// Wait for the initialization of other goroutines to complete
   153  		<-s.initialized
   154  		return s, s.err
   155  	}
   156  
   157  	for i := 0; i < modelType.NumField(); i++ {
   158  		if fieldStruct := modelType.Field(i); ast.IsExported(fieldStruct.Name) {
   159  			if field := schema.ParseField(fieldStruct); field.EmbeddedSchema != nil {
   160  				schema.Fields = append(schema.Fields, field.EmbeddedSchema.Fields...)
   161  			} else {
   162  				schema.Fields = append(schema.Fields, field)
   163  			}
   164  		}
   165  	}
   166  
   167  	for _, field := range schema.Fields {
   168  		if field.DBName == "" && field.DataType != "" {
   169  			field.DBName = namer.ColumnName(schema.Table, field.Name)
   170  		}
   171  
   172  		if field.DBName != "" {
   173  			// nonexistence or shortest path or first appear prioritized if has permission
   174  			if v, ok := schema.FieldsByDBName[field.DBName]; !ok || ((field.Creatable || field.Updatable || field.Readable) && len(field.BindNames) < len(v.BindNames)) {
   175  				if _, ok := schema.FieldsByDBName[field.DBName]; !ok {
   176  					schema.DBNames = append(schema.DBNames, field.DBName)
   177  				}
   178  				schema.FieldsByDBName[field.DBName] = field
   179  				schema.FieldsByName[field.Name] = field
   180  
   181  				if v != nil && v.PrimaryKey {
   182  					for idx, f := range schema.PrimaryFields {
   183  						if f == v {
   184  							schema.PrimaryFields = append(schema.PrimaryFields[0:idx], schema.PrimaryFields[idx+1:]...)
   185  						}
   186  					}
   187  				}
   188  
   189  				if field.PrimaryKey {
   190  					schema.PrimaryFields = append(schema.PrimaryFields, field)
   191  				}
   192  			}
   193  		}
   194  
   195  		if of, ok := schema.FieldsByName[field.Name]; !ok || of.TagSettings["-"] == "-" {
   196  			schema.FieldsByName[field.Name] = field
   197  		}
   198  
   199  		field.setupValuerAndSetter()
   200  	}
   201  
   202  	prioritizedPrimaryField := schema.LookUpField("id")
   203  	if prioritizedPrimaryField == nil {
   204  		prioritizedPrimaryField = schema.LookUpField("ID")
   205  	}
   206  
   207  	if prioritizedPrimaryField != nil {
   208  		if prioritizedPrimaryField.PrimaryKey {
   209  			schema.PrioritizedPrimaryField = prioritizedPrimaryField
   210  		} else if len(schema.PrimaryFields) == 0 {
   211  			prioritizedPrimaryField.PrimaryKey = true
   212  			schema.PrioritizedPrimaryField = prioritizedPrimaryField
   213  			schema.PrimaryFields = append(schema.PrimaryFields, prioritizedPrimaryField)
   214  		}
   215  	}
   216  
   217  	if schema.PrioritizedPrimaryField == nil && len(schema.PrimaryFields) == 1 {
   218  		schema.PrioritizedPrimaryField = schema.PrimaryFields[0]
   219  	}
   220  
   221  	for _, field := range schema.PrimaryFields {
   222  		schema.PrimaryFieldDBNames = append(schema.PrimaryFieldDBNames, field.DBName)
   223  	}
   224  
   225  	for _, field := range schema.Fields {
   226  		if field.HasDefaultValue && field.DefaultValueInterface == nil {
   227  			schema.FieldsWithDefaultDBValue = append(schema.FieldsWithDefaultDBValue, field)
   228  		}
   229  	}
   230  
   231  	if field := schema.PrioritizedPrimaryField; field != nil {
   232  		switch field.GORMDataType {
   233  		case Int, Uint:
   234  			if _, ok := field.TagSettings["AUTOINCREMENT"]; !ok {
   235  				if !field.HasDefaultValue || field.DefaultValueInterface != nil {
   236  					schema.FieldsWithDefaultDBValue = append(schema.FieldsWithDefaultDBValue, field)
   237  				}
   238  
   239  				field.HasDefaultValue = true
   240  				field.AutoIncrement = true
   241  			}
   242  		}
   243  	}
   244  
   245  	callbacks := []string{"BeforeCreate", "AfterCreate", "BeforeUpdate", "AfterUpdate", "BeforeSave", "AfterSave", "BeforeDelete", "AfterDelete", "AfterFind"}
   246  	for _, name := range callbacks {
   247  		if methodValue := modelValue.MethodByName(name); methodValue.IsValid() {
   248  			switch methodValue.Type().String() {
   249  			case "func(*gorm.DB) error": // TODO hack
   250  				reflect.Indirect(reflect.ValueOf(schema)).FieldByName(name).SetBool(true)
   251  			default:
   252  				logger.Default.Warn(context.Background(), "Model %v don't match %vInterface, should be `%v(*gorm.DB) error`. Please see https://gorm.io/docs/hooks.html", schema, name, name)
   253  			}
   254  		}
   255  	}
   256  
   257  	// Cache the schema
   258  	if v, loaded := cacheStore.LoadOrStore(schemaCacheKey, schema); loaded {
   259  		s := v.(*Schema)
   260  		// Wait for the initialization of other goroutines to complete
   261  		<-s.initialized
   262  		return s, s.err
   263  	}
   264  
   265  	defer func() {
   266  		if schema.err != nil {
   267  			logger.Default.Error(context.Background(), schema.err.Error())
   268  			cacheStore.Delete(modelType)
   269  		}
   270  	}()
   271  
   272  	if _, embedded := schema.cacheStore.Load(embeddedCacheKey); !embedded {
   273  		for _, field := range schema.Fields {
   274  			if field.DataType == "" && (field.Creatable || field.Updatable || field.Readable) {
   275  				if schema.parseRelation(field); schema.err != nil {
   276  					return schema, schema.err
   277  				} else {
   278  					schema.FieldsByName[field.Name] = field
   279  				}
   280  			}
   281  
   282  			fieldValue := reflect.New(field.IndirectFieldType)
   283  			fieldInterface := fieldValue.Interface()
   284  			if fc, ok := fieldInterface.(CreateClausesInterface); ok {
   285  				field.Schema.CreateClauses = append(field.Schema.CreateClauses, fc.CreateClauses(field)...)
   286  			}
   287  
   288  			if fc, ok := fieldInterface.(QueryClausesInterface); ok {
   289  				field.Schema.QueryClauses = append(field.Schema.QueryClauses, fc.QueryClauses(field)...)
   290  			}
   291  
   292  			if fc, ok := fieldInterface.(UpdateClausesInterface); ok {
   293  				field.Schema.UpdateClauses = append(field.Schema.UpdateClauses, fc.UpdateClauses(field)...)
   294  			}
   295  
   296  			if fc, ok := fieldInterface.(DeleteClausesInterface); ok {
   297  				field.Schema.DeleteClauses = append(field.Schema.DeleteClauses, fc.DeleteClauses(field)...)
   298  			}
   299  		}
   300  	}
   301  
   302  	return schema, schema.err
   303  }
   304  
   305  func getOrParse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) {
   306  	modelType := reflect.ValueOf(dest).Type()
   307  	for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr {
   308  		modelType = modelType.Elem()
   309  	}
   310  
   311  	if modelType.Kind() != reflect.Struct {
   312  		if modelType.PkgPath() == "" {
   313  			return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest)
   314  		}
   315  		return nil, fmt.Errorf("%w: %s.%s", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name())
   316  	}
   317  
   318  	if v, ok := cacheStore.Load(modelType); ok {
   319  		return v.(*Schema), nil
   320  	}
   321  
   322  	return Parse(dest, cacheStore, namer)
   323  }