github.com/johnnyeven/libtools@v0.0.0-20191126065708-61829c1adf46/sqlx/transforms.go (about)

     1  package sqlx
     2  
     3  import (
     4  	"fmt"
     5  	"go/ast"
     6  	"reflect"
     7  
     8  	"github.com/johnnyeven/libtools/reflectx"
     9  	"github.com/johnnyeven/libtools/sqlx/builder"
    10  )
    11  
    12  func ForEachStructField(structType reflect.Type, fn func(structField reflect.StructField, columnName string)) {
    13  	for i := 0; i < structType.NumField(); i++ {
    14  		field := structType.Field(i)
    15  
    16  		if ast.IsExported(field.Name) {
    17  			fieldName, exists := field.Tag.Lookup("db")
    18  			if exists {
    19  				if fieldName != "-" {
    20  					fn(field, fieldName)
    21  				}
    22  			} else if field.Anonymous {
    23  				ForEachStructField(field.Type, fn)
    24  				continue
    25  			}
    26  		}
    27  	}
    28  }
    29  
    30  func ForEachStructFieldValue(rv reflect.Value, fn func(structFieldValue reflect.Value, structField reflect.StructField, columnName string)) {
    31  	rv = reflect.Indirect(rv)
    32  	structType := rv.Type()
    33  	for i := 0; i < structType.NumField(); i++ {
    34  		field := structType.Field(i)
    35  		if ast.IsExported(field.Name) {
    36  			fieldValue := rv.Field(i)
    37  
    38  			columnName, exists := field.Tag.Lookup("db")
    39  			if exists {
    40  				if columnName != "-" {
    41  					fn(fieldValue, field, columnName)
    42  				}
    43  			} else if field.Anonymous {
    44  				ForEachStructFieldValue(fieldValue, fn)
    45  				continue
    46  			}
    47  		}
    48  	}
    49  }
    50  
    51  func FieldValuesFromStructBy(structValue interface{}, fieldNames []string) (fieldValues builder.FieldValues) {
    52  	fieldValues = builder.FieldValues{}
    53  	rv := reflect.Indirect(reflect.ValueOf(structValue))
    54  	fieldMap := FieldNames(fieldNames).Map()
    55  	ForEachStructFieldValue(rv, func(structFieldValue reflect.Value, structField reflect.StructField, columnName string) {
    56  		if fieldMap != nil && fieldMap[structField.Name] {
    57  			fieldValues[structField.Name] = structFieldValue.Interface()
    58  		}
    59  	})
    60  	return fieldValues
    61  }
    62  
    63  func FieldValuesFromStructByNonZero(structValue interface{}, excludes ...string) (fieldValues builder.FieldValues) {
    64  	fieldValues = builder.FieldValues{}
    65  	rv := reflect.Indirect(reflect.ValueOf(structValue))
    66  	fieldMap := FieldNames(excludes).Map()
    67  	ForEachStructFieldValue(rv, func(structFieldValue reflect.Value, structField reflect.StructField, columnName string) {
    68  		if !reflectx.IsEmptyValue(structFieldValue) || (fieldMap != nil && fieldMap[structField.Name]) {
    69  			fieldValues[structField.Name] = structFieldValue.Interface()
    70  		}
    71  	})
    72  	return
    73  }
    74  
    75  func ScanDefToTable(rv reflect.Value, table *builder.Table) {
    76  	rv = reflect.Indirect(rv)
    77  	structType := rv.Type()
    78  
    79  	ForEachStructField(structType, func(structField reflect.StructField, columnName string) {
    80  		sqlType, exists := structField.Tag.Lookup("sql")
    81  		if !exists {
    82  			panic(fmt.Errorf("%s.%s sql tag must defined for sql type", table.Name, structField.Name))
    83  		}
    84  
    85  		col := builder.Col(table, columnName).Type(sqlType).Field(structField.Name)
    86  
    87  		if structField.Type.AssignableTo(reflect.TypeOf((*EnumTypeDescriber)(nil)).Elem()) {
    88  			enumTypeDescriber := reflect.New(structField.Type).Interface().(EnumTypeDescriber)
    89  			col = col.Enum(enumTypeDescriber.EnumType(), enumTypeDescriber.Enums())
    90  		}
    91  
    92  		finalSql := col.ColumnType.String()
    93  		if sqlType != finalSql {
    94  			panic(fmt.Errorf("%s.%s sql tag may be `%s`, current `%s`", table.Name, structField.Name, finalSql, sqlType))
    95  		}
    96  		table.Columns.Add(col)
    97  	})
    98  
    99  	if rv.CanAddr() {
   100  		addr := rv.Addr()
   101  		if addr.CanInterface() {
   102  			i := addr.Interface()
   103  
   104  			if primaryKeyHook, ok := i.(WithPrimaryKey); ok {
   105  				primaryKey := builder.PrimaryKey()
   106  				for _, fieldName := range primaryKeyHook.PrimaryKey() {
   107  					if col := table.F(fieldName); col != nil {
   108  						primaryKey = primaryKey.WithCols(col)
   109  					} else {
   110  						panic(fmt.Errorf("field %s for PrimaryKey is not defined in table model %s", fieldName, table.Name))
   111  					}
   112  				}
   113  				table.Keys.Add(primaryKey)
   114  			}
   115  
   116  			if withComments, ok := i.(WithComments); ok {
   117  				for fieldName, comment := range withComments.Comments() {
   118  					field := table.F(fieldName)
   119  					if field != nil {
   120  						field.Comment = comment
   121  					}
   122  				}
   123  			}
   124  
   125  			if indexesHook, ok := i.(WithIndexes); ok {
   126  				for name, indexes := range indexesHook.Indexes() {
   127  					idx := builder.Index(name)
   128  					for _, fieldName := range indexes {
   129  						if col := table.F(fieldName); col != nil {
   130  							idx = idx.WithCols(col)
   131  						} else {
   132  							panic(fmt.Errorf("field %s for key %s is not defined in table model %s", fieldName, name, table.Name))
   133  						}
   134  					}
   135  					table.Keys.Add(idx)
   136  				}
   137  			}
   138  
   139  			if uniqueIndexesHook, ok := i.(WithUniqueIndexes); ok {
   140  				for name, indexes := range uniqueIndexesHook.UniqueIndexes() {
   141  					idx := builder.UniqueIndex(name)
   142  					for _, indexName := range indexes {
   143  						if col := table.F(indexName); col != nil {
   144  							idx = idx.WithCols(col)
   145  						} else {
   146  							panic(fmt.Errorf("field %s for unique indexes %s is not defined in table model %s", indexName, name, table.Name))
   147  						}
   148  					}
   149  					table.Keys.Add(idx)
   150  				}
   151  			}
   152  		}
   153  	}
   154  }