github.com/movsb/taorm@v0.0.0-20201209183410-91bafb0b22a6/taorm/registry.go (about)

     1  package taorm
     2  
     3  import (
     4  	"fmt"
     5  	"reflect"
     6  	"strings"
     7  	"sync"
     8  	"unsafe"
     9  )
    10  
    11  // _FieldInfo stores info about a field in a struct.
    12  type _FieldInfo struct {
    13  	offset uintptr      // the memory offset of the field
    14  	_type  reflect.Type // the reflection type of the field
    15  }
    16  
    17  // StructInfo stores info about a struct.
    18  type _StructInfo struct {
    19  	tableName    string                // The database model name for this struct
    20  	fields       map[string]_FieldInfo // struct member info
    21  	fieldstr     string                // fields for inserting
    22  	insertstr    string                // for insert
    23  	updatestr    string                // for update
    24  	insertFields []_FieldInfo          // offsets of member to insert
    25  	pkeyField    _FieldInfo
    26  }
    27  
    28  func newStructInfo() *_StructInfo {
    29  	return &_StructInfo{
    30  		fields: make(map[string]_FieldInfo),
    31  	}
    32  }
    33  
    34  func (s *_StructInfo) baseOf(out interface{}) uintptr {
    35  	return uintptr((*_EmptyEface)(unsafe.Pointer(&out)).ptr)
    36  }
    37  
    38  func (s *_StructInfo) valueOf(out interface{}, field _FieldInfo) reflect.Value {
    39  	addr := unsafe.Pointer(s.baseOf(out) + field.offset)
    40  	return reflect.NewAt(field._type, addr).Elem()
    41  }
    42  
    43  func (s *_StructInfo) addrOf(out interface{}, field _FieldInfo) interface{} {
    44  	addr := unsafe.Pointer(s.baseOf(out) + field.offset)
    45  	return reflect.NewAt(field._type, addr).Interface()
    46  }
    47  
    48  func (s *_StructInfo) ptrsOf(out interface{}, fields []string) ([]interface{}, error) {
    49  	ptrs := make([]interface{}, 0, len(fields))
    50  	for _, field := range fields {
    51  		fi, ok := s.fields[field]
    52  		if !ok {
    53  			return nil, &NoPlaceToSaveFieldError{field}
    54  		}
    55  		addr := s.addrOf(out, fi)
    56  		ptrs = append(ptrs, addr)
    57  	}
    58  	return ptrs, nil
    59  }
    60  
    61  func (s *_StructInfo) ifacesOf(out interface{}) []interface{} {
    62  	values := make([]interface{}, len(s.insertFields))
    63  	base := s.baseOf(out)
    64  	for i, f := range s.insertFields {
    65  		addr := unsafe.Pointer(base + f.offset)
    66  		values[i] = reflect.NewAt(f._type, addr).Elem().Interface()
    67  	}
    68  	return values
    69  }
    70  
    71  func (s *_StructInfo) setPrimaryKey(out interface{}, id int64) {
    72  	pkey := s.valueOf(out, s.pkeyField)
    73  	switch s.pkeyField._type.Kind() {
    74  	case reflect.Uint, reflect.Uint64:
    75  		pkey.SetUint(uint64(id))
    76  	case reflect.Int, reflect.Int64:
    77  		pkey.SetInt(id)
    78  	default:
    79  		panic("cannot set primary key")
    80  	}
    81  }
    82  
    83  func (s *_StructInfo) getPrimaryKey(out interface{}) (interface{}, bool) {
    84  	zero := reflect.Zero(s.pkeyField._type).Interface()
    85  	pkv := s.valueOf(out, s.pkeyField).Interface()
    86  	return pkv, pkv != zero
    87  }
    88  
    89  // structs maps struct type name to its info.
    90  var structs = make(map[string]*_StructInfo)
    91  var rwLock = &sync.RWMutex{}
    92  
    93  // register ...
    94  func register(ty reflect.Type) (*_StructInfo, error) {
    95  	rwLock.Lock()
    96  	defer rwLock.Unlock()
    97  
    98  	typeName := structName(ty)
    99  	tableName, err := getTableNameFromType(ty)
   100  	if err != nil {
   101  		return nil, err
   102  	}
   103  
   104  	// TODO check name validity
   105  	// TODO name can be empty because of auto-generated table info.
   106  
   107  	if si, ok := structs[typeName]; ok {
   108  		return si, nil
   109  	}
   110  
   111  	structInfo := newStructInfo()
   112  	structInfo.tableName = tableName
   113  	fieldNames := []string{}
   114  
   115  	addStructFields(structInfo, ty, &fieldNames)
   116  
   117  	structInfo.fieldstr = strings.Join(fieldNames, ",")
   118  	{
   119  		query := fmt.Sprintf(`INSERT INTO %s `, tableName)
   120  		query += fmt.Sprintf(`(%s) VALUES (%s)`,
   121  			structInfo.fieldstr,
   122  			createSQLInMarks(len(fieldNames)),
   123  		)
   124  		structInfo.insertstr = query
   125  	}
   126  	{
   127  		query := fmt.Sprintf(`UPDATE %s SET `, tableName)
   128  		pairs := []string{}
   129  		for _, name := range fieldNames {
   130  			pairs = append(pairs, name+"=?")
   131  		}
   132  		query += strings.Join(pairs, ",")
   133  		structInfo.updatestr = query
   134  	}
   135  	structs[typeName] = structInfo
   136  	//fmt.Printf("taorm: registered: %s\n", typeName)
   137  	return structInfo, nil
   138  }
   139  
   140  func addStructFields(info *_StructInfo, ty reflect.Type, fieldNames *[]string) {
   141  	for i := 0; i < ty.NumField(); i++ {
   142  		f := ty.Field(i)
   143  		if isColumnField(f) {
   144  			columnName := getColumnName(f)
   145  			if columnName == "" {
   146  				continue
   147  			}
   148  			if columnName != "id" {
   149  				*fieldNames = append(*fieldNames, columnName)
   150  			}
   151  			fieldInfo := _FieldInfo{
   152  				offset: f.Offset,
   153  				_type:  f.Type,
   154  			}
   155  			info.fields[columnName] = fieldInfo
   156  			if columnName != "id" {
   157  				info.insertFields = append(info.insertFields, fieldInfo)
   158  			} else {
   159  				info.pkeyField = fieldInfo
   160  			}
   161  		} else if f.Anonymous {
   162  			addStructFields(info, f.Type, fieldNames)
   163  		}
   164  	}
   165  }
   166  
   167  // _struct can be any struct-related types.
   168  // e.g.: struct{}, *struct{}, **struct{}, []struct{}, []*struct, []*struct{}, *[]strcut{}, *[]*struct{} ...
   169  func structType(_struct interface{}) (reflect.Type, error) {
   170  	ty := reflect.TypeOf(_struct)
   171  	if ty == nil {
   172  		return nil, &NotStructError{}
   173  	}
   174  	for ty.Kind() == reflect.Ptr || ty.Kind() == reflect.Slice {
   175  		ty = ty.Elem()
   176  	}
   177  	if ty.Kind() != reflect.Struct {
   178  		return nil, &NotStructError{ty.Kind()}
   179  	}
   180  	return ty, nil
   181  }
   182  
   183  var tableNamerType = reflect.TypeOf((*TableNamer)(nil)).Elem()
   184  
   185  // getTableName gets the table name for a specific type.
   186  // The type must implement TableNamer.
   187  func getTableNameFromType(ty reflect.Type) (string, error) {
   188  	if ty == nil {
   189  		return ``, &NotStructError{}
   190  	}
   191  
   192  	for ty.Kind() == reflect.Ptr || ty.Kind() == reflect.Slice {
   193  		ty = ty.Elem()
   194  	}
   195  
   196  	return getTableNameFromValue(reflect.New(ty).Interface())
   197  }
   198  
   199  func getTableNameFromValue(value interface{}) (string, error) {
   200  	if i, ok := value.(TableNamer); ok {
   201  		return i.TableName(), nil
   202  	}
   203  	return ``, nil
   204  }
   205  
   206  func getRegistered(_struct interface{}) (*_StructInfo, error) {
   207  	ty, err := structType(_struct)
   208  	if err != nil {
   209  		return nil, err
   210  	}
   211  	name := structName(ty)
   212  
   213  	rwLock.RLock()
   214  	if si, ok := structs[name]; ok {
   215  		rwLock.RUnlock()
   216  		return si, nil
   217  	}
   218  	rwLock.RUnlock()
   219  	return register(ty)
   220  }