github.com/movsb/taorm@v0.0.0-20201209183410-91bafb0b22a6/taorm/registry.go (about) 1 package taorm 2 3 import ( 4 "fmt" 5 "reflect" 6 "strings" 7 "sync" 8 "unsafe" 9 ) 10 11 // _FieldInfo stores info about a field in a struct. 12 type _FieldInfo struct { 13 offset uintptr // the memory offset of the field 14 _type reflect.Type // the reflection type of the field 15 } 16 17 // StructInfo stores info about a struct. 18 type _StructInfo struct { 19 tableName string // The database model name for this struct 20 fields map[string]_FieldInfo // struct member info 21 fieldstr string // fields for inserting 22 insertstr string // for insert 23 updatestr string // for update 24 insertFields []_FieldInfo // offsets of member to insert 25 pkeyField _FieldInfo 26 } 27 28 func newStructInfo() *_StructInfo { 29 return &_StructInfo{ 30 fields: make(map[string]_FieldInfo), 31 } 32 } 33 34 func (s *_StructInfo) baseOf(out interface{}) uintptr { 35 return uintptr((*_EmptyEface)(unsafe.Pointer(&out)).ptr) 36 } 37 38 func (s *_StructInfo) valueOf(out interface{}, field _FieldInfo) reflect.Value { 39 addr := unsafe.Pointer(s.baseOf(out) + field.offset) 40 return reflect.NewAt(field._type, addr).Elem() 41 } 42 43 func (s *_StructInfo) addrOf(out interface{}, field _FieldInfo) interface{} { 44 addr := unsafe.Pointer(s.baseOf(out) + field.offset) 45 return reflect.NewAt(field._type, addr).Interface() 46 } 47 48 func (s *_StructInfo) ptrsOf(out interface{}, fields []string) ([]interface{}, error) { 49 ptrs := make([]interface{}, 0, len(fields)) 50 for _, field := range fields { 51 fi, ok := s.fields[field] 52 if !ok { 53 return nil, &NoPlaceToSaveFieldError{field} 54 } 55 addr := s.addrOf(out, fi) 56 ptrs = append(ptrs, addr) 57 } 58 return ptrs, nil 59 } 60 61 func (s *_StructInfo) ifacesOf(out interface{}) []interface{} { 62 values := make([]interface{}, len(s.insertFields)) 63 base := s.baseOf(out) 64 for i, f := range s.insertFields { 65 addr := unsafe.Pointer(base + f.offset) 66 values[i] = reflect.NewAt(f._type, addr).Elem().Interface() 67 } 68 return values 69 } 70 71 func (s *_StructInfo) setPrimaryKey(out interface{}, id int64) { 72 pkey := s.valueOf(out, s.pkeyField) 73 switch s.pkeyField._type.Kind() { 74 case reflect.Uint, reflect.Uint64: 75 pkey.SetUint(uint64(id)) 76 case reflect.Int, reflect.Int64: 77 pkey.SetInt(id) 78 default: 79 panic("cannot set primary key") 80 } 81 } 82 83 func (s *_StructInfo) getPrimaryKey(out interface{}) (interface{}, bool) { 84 zero := reflect.Zero(s.pkeyField._type).Interface() 85 pkv := s.valueOf(out, s.pkeyField).Interface() 86 return pkv, pkv != zero 87 } 88 89 // structs maps struct type name to its info. 90 var structs = make(map[string]*_StructInfo) 91 var rwLock = &sync.RWMutex{} 92 93 // register ... 94 func register(ty reflect.Type) (*_StructInfo, error) { 95 rwLock.Lock() 96 defer rwLock.Unlock() 97 98 typeName := structName(ty) 99 tableName, err := getTableNameFromType(ty) 100 if err != nil { 101 return nil, err 102 } 103 104 // TODO check name validity 105 // TODO name can be empty because of auto-generated table info. 106 107 if si, ok := structs[typeName]; ok { 108 return si, nil 109 } 110 111 structInfo := newStructInfo() 112 structInfo.tableName = tableName 113 fieldNames := []string{} 114 115 addStructFields(structInfo, ty, &fieldNames) 116 117 structInfo.fieldstr = strings.Join(fieldNames, ",") 118 { 119 query := fmt.Sprintf(`INSERT INTO %s `, tableName) 120 query += fmt.Sprintf(`(%s) VALUES (%s)`, 121 structInfo.fieldstr, 122 createSQLInMarks(len(fieldNames)), 123 ) 124 structInfo.insertstr = query 125 } 126 { 127 query := fmt.Sprintf(`UPDATE %s SET `, tableName) 128 pairs := []string{} 129 for _, name := range fieldNames { 130 pairs = append(pairs, name+"=?") 131 } 132 query += strings.Join(pairs, ",") 133 structInfo.updatestr = query 134 } 135 structs[typeName] = structInfo 136 //fmt.Printf("taorm: registered: %s\n", typeName) 137 return structInfo, nil 138 } 139 140 func addStructFields(info *_StructInfo, ty reflect.Type, fieldNames *[]string) { 141 for i := 0; i < ty.NumField(); i++ { 142 f := ty.Field(i) 143 if isColumnField(f) { 144 columnName := getColumnName(f) 145 if columnName == "" { 146 continue 147 } 148 if columnName != "id" { 149 *fieldNames = append(*fieldNames, columnName) 150 } 151 fieldInfo := _FieldInfo{ 152 offset: f.Offset, 153 _type: f.Type, 154 } 155 info.fields[columnName] = fieldInfo 156 if columnName != "id" { 157 info.insertFields = append(info.insertFields, fieldInfo) 158 } else { 159 info.pkeyField = fieldInfo 160 } 161 } else if f.Anonymous { 162 addStructFields(info, f.Type, fieldNames) 163 } 164 } 165 } 166 167 // _struct can be any struct-related types. 168 // e.g.: struct{}, *struct{}, **struct{}, []struct{}, []*struct, []*struct{}, *[]strcut{}, *[]*struct{} ... 169 func structType(_struct interface{}) (reflect.Type, error) { 170 ty := reflect.TypeOf(_struct) 171 if ty == nil { 172 return nil, &NotStructError{} 173 } 174 for ty.Kind() == reflect.Ptr || ty.Kind() == reflect.Slice { 175 ty = ty.Elem() 176 } 177 if ty.Kind() != reflect.Struct { 178 return nil, &NotStructError{ty.Kind()} 179 } 180 return ty, nil 181 } 182 183 var tableNamerType = reflect.TypeOf((*TableNamer)(nil)).Elem() 184 185 // getTableName gets the table name for a specific type. 186 // The type must implement TableNamer. 187 func getTableNameFromType(ty reflect.Type) (string, error) { 188 if ty == nil { 189 return ``, &NotStructError{} 190 } 191 192 for ty.Kind() == reflect.Ptr || ty.Kind() == reflect.Slice { 193 ty = ty.Elem() 194 } 195 196 return getTableNameFromValue(reflect.New(ty).Interface()) 197 } 198 199 func getTableNameFromValue(value interface{}) (string, error) { 200 if i, ok := value.(TableNamer); ok { 201 return i.TableName(), nil 202 } 203 return ``, nil 204 } 205 206 func getRegistered(_struct interface{}) (*_StructInfo, error) { 207 ty, err := structType(_struct) 208 if err != nil { 209 return nil, err 210 } 211 name := structName(ty) 212 213 rwLock.RLock() 214 if si, ok := structs[name]; ok { 215 rwLock.RUnlock() 216 return si, nil 217 } 218 rwLock.RUnlock() 219 return register(ty) 220 }