github.com/machinefi/w3bstream@v1.6.5-rc9.0.20240426031326-b8c7c4876e72/pkg/depends/kit/sqlx/builder/builder_util.go (about)

     1  package builder
     2  
     3  import (
     4  	"context"
     5  	"go/ast"
     6  	"reflect"
     7  	"strings"
     8  	"sync"
     9  
    10  	"github.com/machinefi/w3bstream/pkg/depends/x/contextx"
    11  	"github.com/machinefi/w3bstream/pkg/depends/x/misc/clone"
    12  	"github.com/machinefi/w3bstream/pkg/depends/x/misc/must"
    13  	"github.com/machinefi/w3bstream/pkg/depends/x/reflectx"
    14  	"github.com/machinefi/w3bstream/pkg/depends/x/typesx"
    15  )
    16  
    17  type FieldValues map[string]interface{}
    18  
    19  type ctxKeyTableName struct{}
    20  
    21  func WithTableName(tbl string) func(ctx context.Context) context.Context {
    22  	return func(ctx context.Context) context.Context {
    23  		return contextx.WithValue(ctx, ctxKeyTableName{}, tbl)
    24  	}
    25  }
    26  
    27  func TableNameFromContext(ctx context.Context) (string, bool) {
    28  	tbl, ok := ctx.Value(ctxKeyTableName{}).(string)
    29  	return tbl, ok
    30  }
    31  
    32  type ctxKeyTableAlias struct{}
    33  
    34  func WithTabelAlias(tbl string) func(ctx context.Context) context.Context {
    35  	return func(ctx context.Context) context.Context {
    36  		return contextx.WithValue(ctx, ctxKeyTableAlias{}, tbl)
    37  	}
    38  }
    39  
    40  func TableAliasFromContext(ctx context.Context) (string, bool) {
    41  	tbl, ok := ctx.Value(ctxKeyTableAlias{}).(string)
    42  	return tbl, ok
    43  }
    44  
    45  func ColumnsByStruct(v interface{}) *Ex {
    46  	ctx := context.Background()
    47  	fields := StructFieldsFor(ctx, typesx.FromReflectType(reflect.TypeOf(v)))
    48  
    49  	e := Expr("")
    50  	e.Grow(len(fields))
    51  
    52  	i := 0
    53  	ForEachFieldValue(ctx, reflect.ValueOf(v), func(f *FieldValue) {
    54  		if i > 0 {
    55  			e.WriteQuery(", ")
    56  		}
    57  		if f.TableName != "" {
    58  			e.WriteQuery(f.TableName)
    59  			e.WriteQueryByte('.')
    60  			e.WriteQuery(f.Field.Name)
    61  			e.WriteQuery(" AS ")
    62  			e.WriteQuery(f.TableName)
    63  			e.WriteQuery("__")
    64  			e.WriteQuery(f.Field.Name)
    65  		} else {
    66  			e.WriteQuery(f.Field.Name)
    67  		}
    68  		i++
    69  	})
    70  
    71  	return e
    72  }
    73  
    74  func ForEachFieldValue(ctx context.Context, v interface{}, fn func(*FieldValue)) {
    75  	rv, ok := v.(reflect.Value)
    76  	if ok {
    77  		if rv.Kind() == reflect.Ptr && rv.IsNil() {
    78  			rv.Set(reflectx.New(rv.Type()))
    79  		}
    80  		v = rv.Interface()
    81  	}
    82  	if m, ok := v.(Model); ok {
    83  		ctx = WithTableName(m.TableName())(ctx)
    84  	}
    85  
    86  	fields := StructFieldsFor(ctx, typesx.FromReflectType(reflect.TypeOf(v)))
    87  	rv = reflectx.Indirect(reflect.ValueOf(v))
    88  
    89  	for i := range fields {
    90  		f := fields[i]
    91  		tag := f.Tags["db"]
    92  
    93  		if tag.HasFlag("deprecated") {
    94  			continue
    95  		}
    96  
    97  		if alias, ok := f.Tags["alias"]; ok {
    98  			ctx = WithTabelAlias(alias.Name())(ctx)
    99  		} else {
   100  			if len(f.ModelLoc) > 0 {
   101  				if fmv := f.FieldModelValue(rv); fmv.IsValid() {
   102  					if m, ok := fmv.Interface().(Model); ok {
   103  						ctx = WithTableName(m.TableName())(ctx)
   104  					}
   105  				}
   106  			}
   107  		}
   108  		sf := &FieldValue{
   109  			Field: *f,
   110  			Value: f.FieldValue(rv),
   111  		}
   112  		sf.TableName, _ = TableNameFromContext(ctx)
   113  		if alias, ok := TableAliasFromContext(ctx); ok && alias != "" {
   114  			sf.TableName = alias
   115  		}
   116  		fn(sf)
   117  	}
   118  }
   119  
   120  func GetColumnName(name, tag string) string {
   121  	i := strings.Index(tag, ",")
   122  	if tag != "" {
   123  		if i == -1 {
   124  			return strings.ToLower(tag)
   125  		}
   126  		if i > 0 {
   127  			return strings.ToLower(tag[0:i])
   128  		}
   129  	}
   130  	return "f_" + strings.ToLower(name)
   131  }
   132  
   133  func ToMap(lst []string) map[string]bool {
   134  	m := make(map[string]bool)
   135  	for _, name := range lst {
   136  		m[name] = true
   137  	}
   138  	return m
   139  }
   140  
   141  func FieldValueFromStruct(v interface{}, names []string) FieldValues {
   142  	fvs := FieldValues{}
   143  	rv := reflect.Indirect(reflect.ValueOf(v))
   144  	m := ToMap(names)
   145  	ForEachFieldValue(context.Background(), rv, func(fv *FieldValue) {
   146  		if name := fv.Field.FieldName; m != nil && m[name] {
   147  			fvs[name] = fv.Value.Interface()
   148  		}
   149  	})
   150  	return fvs
   151  }
   152  
   153  func FieldValueFromStructByNoneZero(v interface{}, excludes ...string) FieldValues {
   154  	fvs := FieldValues{}
   155  	rv := reflect.Indirect(reflect.ValueOf(v))
   156  	m := ToMap(excludes)
   157  	ForEachFieldValue(context.Background(), rv, func(fv *FieldValue) {
   158  		name := fv.Field.FieldName
   159  		if !reflectx.IsEmptyValue(fv.Value) || m != nil && m[name] {
   160  			fvs[name] = fv.Value.Interface()
   161  		}
   162  	})
   163  	return fvs
   164  }
   165  
   166  func TableFromModel(m Model) *Table {
   167  	t := reflect.TypeOf(m)
   168  	if t.Kind() != reflect.Ptr {
   169  		panic("model must be a ptr")
   170  	}
   171  	t = t.Elem()
   172  	if t.Kind() != reflect.Struct {
   173  		panic("model must be a struct")
   174  	}
   175  	tbl := T(m.TableName())
   176  	tbl.Model = m
   177  	ScanDefToTable(tbl, m)
   178  	return tbl
   179  }
   180  
   181  func ScanDefToTable(tbl *Table, i interface{}) {
   182  	t := typesx.DeRef(typesx.FromReflectType(reflect.TypeOf(i)))
   183  	EachField(context.Background(), t,
   184  		func(f *StructField) bool {
   185  			tbl.AddCol(&Column{
   186  				FieldName:  f.FieldName,
   187  				Name:       f.Name,
   188  				ColumnType: &f.ColumnType,
   189  			})
   190  			return true
   191  		},
   192  	)
   193  
   194  	if with, ok := i.(WithTableDesc); ok {
   195  		tbl.Desc = with.TableDesc()
   196  	}
   197  	if with, ok := i.(WithComments); ok {
   198  		for name, comment := range with.Comments() {
   199  			if col := tbl.ColByFieldName(name); col != nil {
   200  				col.Comment = comment
   201  			}
   202  		}
   203  	}
   204  	if with, ok := i.(WithColDesc); ok {
   205  		for name, desc := range with.ColDesc() {
   206  			if col := tbl.ColByFieldName(name); col != nil {
   207  				col.Desc = desc
   208  			}
   209  		}
   210  	}
   211  	if with, ok := i.(WithColRel); ok {
   212  		for name, rel := range with.ColRel() {
   213  			if col := tbl.ColByFieldName(name); col != nil {
   214  				col.Rel = rel
   215  			}
   216  		}
   217  	}
   218  	if with, ok := i.(WithPrimaryKey); ok {
   219  		tbl.AddKey(&Key{
   220  			Name:     "primary",
   221  			IsUnique: true,
   222  			Def:      *ParseIndexDef(with.PrimaryKey()...),
   223  		})
   224  	}
   225  	if with, ok := i.(WithUniqueIndexes); ok {
   226  		for _index, names := range with.UniqueIndexes() {
   227  			name, method := SplitIndexNameAndMethod(_index)
   228  			tbl.AddKey(&Key{
   229  				Name:     name,
   230  				Method:   method,
   231  				IsUnique: true,
   232  				Def:      *ParseIndexDef(names...),
   233  			})
   234  		}
   235  	}
   236  	if with, ok := i.(WithIndexes); ok {
   237  		for _index, names := range with.Indexes() {
   238  			name, method := SplitIndexNameAndMethod(_index)
   239  			tbl.AddKey(&Key{
   240  				Name:   name,
   241  				Method: method,
   242  				Def:    *ParseIndexDef(names...),
   243  			})
   244  		}
   245  	}
   246  }
   247  
   248  // SplitIndexNameAndMethod @def index name/method
   249  func SplitIndexNameAndMethod(v string) (string, string) {
   250  	parts := strings.Split(v, "/")
   251  	name := strings.ToLower(parts[0])
   252  	method := ""
   253  	if len(parts) > 1 {
   254  		method = parts[1]
   255  	}
   256  	return name, method
   257  }
   258  
   259  type IndexDefine struct {
   260  	Kind   string
   261  	Name   string
   262  	Method string
   263  	IndexDef
   264  }
   265  
   266  func (i IndexDefine) ID() string {
   267  	if i.Method != "" {
   268  		return i.Name + "/" + i.Method
   269  	}
   270  	return i.Name
   271  }
   272  
   273  func ParseIndexDefine(def string) *IndexDefine {
   274  	d := &IndexDefine{}
   275  	for i := strings.Index(def, " "); i != -1; i = strings.Index(def, " ") {
   276  		part := def[0:i]
   277  		if part != "" {
   278  			if d.Kind == "" {
   279  				d.Kind = part
   280  			} else if d.Name == "" && d.Kind != "primary" {
   281  				d.Name, d.Method = SplitIndexNameAndMethod(part)
   282  			} else {
   283  				break
   284  			}
   285  		}
   286  		def = def[i+1:]
   287  	}
   288  	d.IndexDef = *ParseIndexDef(strings.TrimSpace(def))
   289  	return d
   290  }
   291  
   292  type FieldValue struct {
   293  	Field     StructField
   294  	TableName string
   295  	Value     reflect.Value
   296  }
   297  
   298  type FieldsFactory struct {
   299  	cache sync.Map
   300  }
   301  
   302  // gFields default global field cache
   303  var gFields = &FieldsFactory{}
   304  
   305  func StructFieldsFor(ctx context.Context, t typesx.Type) []*StructField {
   306  	return gFields.TableFieldsFor(ctx, t)
   307  }
   308  
   309  func (ft *FieldsFactory) TableFieldsFor(ctx context.Context, t typesx.Type) []*StructField {
   310  	t = typesx.DeRef(t)
   311  	i := t.Unwrap() // underlying
   312  
   313  	if v, ok := ft.cache.Load(i); ok {
   314  		return v.([]*StructField)
   315  	}
   316  
   317  	sf := make([]*StructField, 0)
   318  	EachField(ctx, t, func(f *StructField) bool {
   319  		name := f.Tags["db"]
   320  		if name != "" && name != "-" {
   321  			sf = append(sf, f)
   322  		}
   323  		return true
   324  	})
   325  	ft.cache.Store(i, sf)
   326  	return sf
   327  }
   328  
   329  func EachField(ctx context.Context, t typesx.Type, each func(*StructField) bool) {
   330  	must.BeTrue(t.Kind() == reflect.Struct)
   331  
   332  	var walk func(t typesx.Type, modelLoc []int, parents ...int)
   333  
   334  	walk = func(t typesx.Type, modelLoc []int, parents ...int) {
   335  		if t.Implements(typesx.FromReflectType(RtModel)) {
   336  			modelLoc = parents
   337  		}
   338  		for i := 0; i < t.NumField(); i++ {
   339  			fi := t.Field(i)
   340  			if !ast.IsExported(fi.Name()) {
   341  				continue
   342  			}
   343  
   344  			loc := append(parents, i)
   345  			tags := reflectx.ParseStructTag(string(fi.Tag()))
   346  			name := fi.Name()
   347  			tag, has := tags["db"]
   348  			if has {
   349  				if tagName := tag.Name(); tagName == "-" {
   350  					continue
   351  				} else {
   352  					if tagName != "" {
   353  						name = tagName
   354  					}
   355  				}
   356  			}
   357  
   358  			if !has && (fi.Anonymous() || fi.Type().Name() == fi.Name()) {
   359  				ft := fi.Type()
   360  				if !ft.Implements(typesx.FromReflectType(RtDriverValuer)) {
   361  					for ft.Kind() == reflect.Ptr {
   362  						ft = ft.Elem()
   363  					}
   364  					if ft.Kind() == reflect.Struct {
   365  						walk(ft, modelLoc, loc...)
   366  						continue
   367  					}
   368  				}
   369  			}
   370  
   371  			p := &StructField{
   372  				Name:       strings.ToLower(name),
   373  				FieldName:  fi.Name(),
   374  				Type:       fi.Type(),
   375  				Field:      fi,
   376  				Tags:       tags,
   377  				Loc:        clone.Ints(loc),
   378  				ModelLoc:   clone.Ints(modelLoc),
   379  				ColumnType: *AnalyzeColumnType(fi.Type(), string(tag)),
   380  			}
   381  			if !each(p) {
   382  				break
   383  			}
   384  		}
   385  	}
   386  
   387  	walk(t, []int{})
   388  }
   389  
   390  type StructField struct {
   391  	Name       string
   392  	FieldName  string
   393  	Type       typesx.Type
   394  	Field      typesx.StructField
   395  	Tags       map[string]reflectx.StructTag
   396  	Loc        []int
   397  	ModelLoc   []int
   398  	ColumnType ColumnType
   399  }
   400  
   401  func fieldValue(v reflect.Value, locs []int) reflect.Value {
   402  	n := len(locs)
   403  	if n == 0 {
   404  		return v
   405  	}
   406  	if n < 0 {
   407  		return reflect.Value{}
   408  	}
   409  	v = reflectx.Indirect(v)
   410  	fv := v
   411  	for i := 0; i < n; i++ {
   412  		loc := locs[i]
   413  		fv = fv.Field(loc)
   414  		if i < n-1 {
   415  			for fv.Kind() == reflect.Ptr {
   416  				if fv.IsNil() {
   417  					fv.Set(reflectx.New(fv.Type()))
   418  				}
   419  				fv = fv.Elem()
   420  			}
   421  		}
   422  	}
   423  	return fv
   424  }
   425  
   426  func (sf *StructField) FieldValue(v reflect.Value) reflect.Value {
   427  	return fieldValue(v, sf.Loc)
   428  }
   429  
   430  func (sf *StructField) FieldModelValue(v reflect.Value) reflect.Value {
   431  	return fieldValue(v, sf.ModelLoc)
   432  }