github.com/kunlun-qilian/sqlx/v3@v3.0.0/builder/utils_.go (about)

     1  package builder
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"reflect"
     7  	"strings"
     8  
     9  	contextx "github.com/go-courier/x/context"
    10  	reflectx "github.com/go-courier/x/reflect"
    11  	typesx "github.com/go-courier/x/types"
    12  )
    13  
    14  type FieldValues map[string]interface{}
    15  
    16  type StructFieldValue struct {
    17  	Field     StructField
    18  	TableName string
    19  	Value     reflect.Value
    20  }
    21  
    22  type contextKeyTableName struct{}
    23  
    24  func WithTableName(tableName string) func(ctx context.Context) context.Context {
    25  	return func(ctx context.Context) context.Context {
    26  		return contextx.WithValue(ctx, contextKeyTableName{}, tableName)
    27  	}
    28  }
    29  
    30  func TableNameFromContext(ctx context.Context) string {
    31  	if tableName, ok := ctx.Value(contextKeyTableName{}).(string); ok {
    32  		return tableName
    33  	}
    34  	return ""
    35  }
    36  
    37  type contextKeyTableAlias int
    38  
    39  func WithTableAlias(tableName string) func(ctx context.Context) context.Context {
    40  	return func(ctx context.Context) context.Context {
    41  		return contextx.WithValue(ctx, contextKeyTableAlias(1), tableName)
    42  	}
    43  }
    44  
    45  func TableAliasFromContext(ctx context.Context) string {
    46  	if tableName, ok := ctx.Value(contextKeyTableAlias(1)).(string); ok {
    47  		return tableName
    48  	}
    49  	return ""
    50  }
    51  
    52  func ColumnsByStruct(v interface{}) *Ex {
    53  	ctx := context.Background()
    54  
    55  	fields := StructFieldsFor(ctx, typesx.FromRType(reflect.TypeOf(v)))
    56  
    57  	e := Expr("")
    58  	e.Grow(len(fields))
    59  
    60  	i := 0
    61  
    62  	ForEachStructFieldValue(context.Background(), reflect.ValueOf(v), func(field *StructFieldValue) {
    63  		if i > 0 {
    64  			e.WriteQuery(", ")
    65  		}
    66  
    67  		if field.TableName != "" {
    68  			e.WriteQuery(field.TableName)
    69  			e.WriteQueryByte('.')
    70  			e.WriteQuery(field.Field.Name)
    71  			e.WriteQuery(" AS ")
    72  			e.WriteQuery(field.TableName)
    73  			e.WriteQuery("__")
    74  			e.WriteQuery(field.Field.Name)
    75  		} else {
    76  			e.WriteQuery(field.Field.Name)
    77  		}
    78  
    79  		i++
    80  	})
    81  
    82  	return e
    83  }
    84  
    85  func ForEachStructFieldValue(ctx context.Context, v interface{}, fn func(*StructFieldValue)) {
    86  	rv, ok := v.(reflect.Value)
    87  	if ok {
    88  		if rv.Kind() == reflect.Ptr && rv.IsNil() {
    89  			rv.Set(reflectx.New(rv.Type()))
    90  		}
    91  		v = rv.Interface()
    92  	}
    93  
    94  	if m, ok := v.(Model); ok {
    95  		ctx = WithTableName(m.TableName())(ctx)
    96  	}
    97  
    98  	fields := StructFieldsFor(ctx, typesx.FromRType(reflect.TypeOf(v)))
    99  
   100  	rv = reflectx.Indirect(reflect.ValueOf(v))
   101  
   102  	for i := range fields {
   103  		f := fields[i]
   104  
   105  		tagDB := f.Tags["db"]
   106  
   107  		if tagDB.HasFlag("deprecated") {
   108  			continue
   109  		}
   110  
   111  		if tableAlias, ok := f.Tags["alias"]; ok {
   112  			ctx = WithTableAlias(tableAlias.Name())(ctx)
   113  		} else {
   114  			if len(f.ModelLoc) > 0 {
   115  				fpv := f.FieldModelValue(rv)
   116  				if fpv.IsValid() {
   117  					if m, ok := fpv.Interface().(Model); ok {
   118  						ctx = WithTableName(m.TableName())(ctx)
   119  					}
   120  				}
   121  			}
   122  		}
   123  
   124  		sf := &StructFieldValue{}
   125  
   126  		sf.Field = *f
   127  		sf.Value = f.FieldValue(rv)
   128  
   129  		sf.TableName = TableNameFromContext(ctx)
   130  
   131  		if tableAlias := TableAliasFromContext(ctx); tableAlias != "" {
   132  			sf.TableName = tableAlias
   133  		}
   134  
   135  		fn(sf)
   136  	}
   137  
   138  }
   139  
   140  func GetColumnName(fieldName, tagValue string) string {
   141  	i := strings.Index(tagValue, ",")
   142  	if tagValue != "" && (i > 0 || i == -1) {
   143  		if i == -1 {
   144  			return strings.ToLower(tagValue)
   145  		}
   146  		return strings.ToLower(tagValue[0:i])
   147  	}
   148  	return "f_" + strings.ToLower(fieldName)
   149  }
   150  
   151  func ToMap(list []string) map[string]bool {
   152  	m := make(map[string]bool, len(list))
   153  	for _, fieldName := range list {
   154  		m[fieldName] = true
   155  	}
   156  	return m
   157  }
   158  
   159  func FieldValuesFromStructBy(structValue interface{}, fieldNames []string) (fieldValues FieldValues) {
   160  	fieldValues = FieldValues{}
   161  	rv := reflect.Indirect(reflect.ValueOf(structValue))
   162  	fieldMap := ToMap(fieldNames)
   163  	ForEachStructFieldValue(context.Background(), rv, func(sf *StructFieldValue) {
   164  		if fieldMap != nil && fieldMap[sf.Field.FieldName] {
   165  			fieldValues[sf.Field.FieldName] = sf.Value.Interface()
   166  		}
   167  	})
   168  	return fieldValues
   169  }
   170  
   171  func FieldValuesFromStructByNonZero(structValue interface{}, excludes ...string) (fieldValues FieldValues) {
   172  	fieldValues = FieldValues{}
   173  	rv := reflect.Indirect(reflect.ValueOf(structValue))
   174  	fieldMap := ToMap(excludes)
   175  	ForEachStructFieldValue(context.Background(), rv, func(sf *StructFieldValue) {
   176  		if !reflectx.IsEmptyValue(sf.Value) || (fieldMap != nil && fieldMap[sf.Field.FieldName]) {
   177  			fieldValues[sf.Field.FieldName] = sf.Value.Interface()
   178  		}
   179  	})
   180  	return
   181  }
   182  
   183  func TableFromModel(model Model) *Table {
   184  	tpe := reflect.TypeOf(model)
   185  	if tpe.Kind() != reflect.Ptr {
   186  		panic(fmt.Errorf("model %s must be a pointer", tpe.Name()))
   187  	}
   188  	tpe = tpe.Elem()
   189  	if tpe.Kind() != reflect.Struct {
   190  		panic(fmt.Errorf("model %s must be a struct", tpe.Name()))
   191  	}
   192  
   193  	table := T(model.TableName())
   194  	table.Model = model
   195  
   196  	ScanDefToTable(table, model)
   197  
   198  	return table
   199  }
   200  
   201  func ScanDefToTable(table *Table, i interface{}) {
   202  	tpe := typesx.Deref(typesx.FromRType(reflect.TypeOf(i)))
   203  
   204  	EachStructField(context.Background(), tpe, func(f *StructField) bool {
   205  		table.AddCol(&Column{
   206  			FieldName:  f.FieldName,
   207  			Name:       f.Name,
   208  			ColumnType: &f.ColumnType,
   209  		})
   210  		return true
   211  	})
   212  
   213  	if withTableDescription, ok := i.(WithTableDescription); ok {
   214  		desc := withTableDescription.TableDescription()
   215  		table.Description = desc
   216  	}
   217  
   218  	if withComments, ok := i.(WithComments); ok {
   219  		for fieldName, comment := range withComments.Comments() {
   220  			field := table.F(fieldName)
   221  			if field != nil {
   222  				field.Comment = comment
   223  			}
   224  		}
   225  	}
   226  
   227  	if withColDescriptions, ok := i.(WithColDescriptions); ok {
   228  		for fieldName, desc := range withColDescriptions.ColDescriptions() {
   229  			field := table.F(fieldName)
   230  			if field != nil {
   231  				field.Description = desc
   232  			}
   233  		}
   234  	}
   235  
   236  	if withRelations, ok := i.(WithRelations); ok {
   237  		for fieldName, rel := range withRelations.ColRelations() {
   238  			field := table.F(fieldName)
   239  			if field != nil {
   240  				field.Relation = rel
   241  			}
   242  		}
   243  	}
   244  
   245  	if primaryKeyHook, ok := i.(WithPrimaryKey); ok {
   246  		table.AddKey(&Key{
   247  			Name:     "primary",
   248  			IsUnique: true,
   249  			Def:      *ParseIndexDef(primaryKeyHook.PrimaryKey()...),
   250  		})
   251  	}
   252  
   253  	if uniqueIndexesHook, ok := i.(WithUniqueIndexes); ok {
   254  		for indexNameAndMethod, fieldNames := range uniqueIndexesHook.UniqueIndexes() {
   255  			indexName, method := ResolveIndexNameAndMethod(indexNameAndMethod)
   256  
   257  			table.AddKey(&Key{
   258  				Name:     indexName,
   259  				Method:   method,
   260  				IsUnique: true,
   261  				Def:      *ParseIndexDef(fieldNames...),
   262  			})
   263  		}
   264  	}
   265  
   266  	if indexesHook, ok := i.(WithIndexes); ok {
   267  		for indexNameAndMethod, fieldNames := range indexesHook.Indexes() {
   268  			indexName, method := ResolveIndexNameAndMethod(indexNameAndMethod)
   269  
   270  			table.AddKey(&Key{
   271  				Name:   indexName,
   272  				Method: method,
   273  				Def:    *ParseIndexDef(fieldNames...),
   274  			})
   275  		}
   276  	}
   277  
   278  	if partitionHook, ok := i.(WithPartition); ok {
   279  		args := partitionHook.Partition()
   280  		table.AddKey(&Key{
   281  			Name:   "partition",
   282  			Method: strings.ToUpper(args[0]),
   283  			Def:    *ParseIndexDef(args[1:]...),
   284  		})
   285  	}
   286  }
   287  
   288  func ResolveIndexNameAndMethod(n string) (name string, method string) {
   289  	nameAndMethod := strings.Split(n, "/")
   290  	name = strings.ToLower(nameAndMethod[0])
   291  	if len(nameAndMethod) > 1 {
   292  		method = nameAndMethod[1]
   293  	}
   294  	return
   295  }
   296  
   297  // ParseIndexDefine
   298  // @def index i_xxx/BTREE Name
   299  // @def index i_xxx USING GIST (#TEST gist_trgm_ops)
   300  func ParseIndexDefine(def string) *IndexDefine {
   301  	d := IndexDefine{}
   302  
   303  	for i := strings.Index(def, " "); i != -1; i = strings.Index(def, " ") {
   304  		part := def[0:i]
   305  
   306  		if part != "" {
   307  			if d.Kind == "" {
   308  				d.Kind = part
   309  			} else if d.Name == "" && d.Kind != "primary" {
   310  				d.Name, d.Method = ResolveIndexNameAndMethod(part)
   311  			} else {
   312  				break
   313  			}
   314  		}
   315  
   316  		def = def[i+1:]
   317  	}
   318  
   319  	d.IndexDef = *ParseIndexDef(strings.TrimSpace(def))
   320  
   321  	return &d
   322  }
   323  
   324  type IndexDefine struct {
   325  	Kind   string
   326  	Name   string
   327  	Method string
   328  	IndexDef
   329  }
   330  
   331  func (i IndexDefine) ID() string {
   332  	if i.Method != "" {
   333  		return i.Name + "/" + i.Method
   334  	}
   335  	return i.Name
   336  }