github.com/goplus/yap@v0.8.1/ydb/table.go (about) 1 /* 2 * Copyright (c) 2024 The GoPlus Authors (goplus.org). All rights reserved. 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package ydb 18 19 import ( 20 "context" 21 "database/sql" 22 "database/sql/driver" 23 "log" 24 "reflect" 25 "strings" 26 "time" 27 "unsafe" 28 29 "github.com/goplus/yap/reflectutil" 30 "github.com/qiniu/x/stringutil" 31 ) 32 33 // ----------------------------------------------------------------------------- 34 35 type nullTime time.Time 36 37 func (n *nullTime) Scan(value any) (err error) { 38 var ret sql.NullTime 39 err = ret.Scan(value) 40 *(*time.Time)(n) = ret.Time 41 return 42 } 43 44 func (n nullTime) Value() (driver.Value, error) { 45 if (*time.Time)(&n).IsZero() { 46 return nil, nil 47 } 48 return *(*time.Time)(&n), nil 49 } 50 51 // ----------------------------------------------------------------------------- 52 53 type dbType = reflect.Type 54 type ioType = reflect.Type 55 56 var ( 57 tyString = reflect.TypeOf("") 58 tyInt = reflect.TypeOf(0) 59 tyBool = reflect.TypeOf(false) 60 tyBlob = reflect.TypeOf([]byte(nil)) 61 tyTime = reflect.TypeOf(time.Time{}) 62 tyNullTime = reflect.TypeOf(nullTime{}) 63 tyFloat64 = reflect.TypeOf(float64(0)) 64 tyFloat32 = reflect.TypeOf(float32(0)) 65 ) 66 67 func columnType(fldType dbType) string { 68 switch fldType { 69 case tyString: 70 return "TEXT" 71 case tyInt: 72 return "INT" 73 case tyBool: 74 return "BOOL" 75 case tyBlob: 76 return "BLOB" 77 case tyTime: 78 return "DATETIME" 79 case tyFloat64: 80 return "DOUBLE" 81 case tyFloat32: 82 return "FLOAT" 83 } 84 panic("unknown column type: " + fldType.String()) 85 } 86 87 func colIOType(fldType dbType) ioType { 88 if fldType == tyTime { 89 return tyNullTime 90 } 91 return fldType 92 } 93 94 // ----------------------------------------------------------------------------- 95 96 type dbIndex struct { 97 index []*column 98 col *column 99 params string 100 } 101 102 func (p *dbIndex) get(tbl *Table) []*column { 103 if p.index == nil { 104 p.index = tbl.makeIndex(p.col, p.params) 105 } 106 return p.index 107 } 108 109 type Table struct { 110 name string 111 ver string 112 schema dbType 113 cols []*column 114 uniqs []*dbIndex 115 idxs []*dbIndex 116 } 117 118 type column struct { 119 typ string // type in DB 120 name string // column name 121 fld field 122 } 123 124 type field struct { 125 typ ioType // field io type 126 offset uintptr // offset within struct, in bytes 127 } 128 129 func newTable(name, ver string, schema dbType) *Table { 130 n := schema.NumField() 131 cols := make([]*column, 0, n) 132 p := &Table{name: name, ver: ver, schema: schema, cols: cols} 133 p.defineCols(n, schema, 0) 134 return p 135 } 136 137 func getVals(vals []any, v reflect.Value, cols []field, elem bool) []any { 138 this := reflectutil.UnsafeAddr(v) 139 for _, col := range cols { 140 v := reflect.NewAt(col.typ, unsafe.Pointer(this+col.offset)) 141 if elem { 142 v = v.Elem() 143 } 144 val := v.Interface() 145 vals = append(vals, val) 146 } 147 return vals 148 } 149 150 func getCols(names []string, cols []field, n int, t dbType, base uintptr) ([]string, []field) { 151 for i := 0; i < n; i++ { 152 fld := t.Field(i) 153 if fld.Anonymous { 154 fldType := fld.Type 155 names, cols = getCols(names, cols, fldType.NumField(), fldType, base+fld.Offset) 156 continue 157 } 158 if fld.IsExported() { 159 name := "" 160 if tag := string(fld.Tag); tag != "" { 161 if c := tag[0]; c >= 'a' && c <= 'z' { // suppose a column name is lower case 162 if pos := strings.IndexByte(tag, ' '); pos > 0 { 163 tag = tag[:pos] 164 } 165 name = tag 166 } 167 } 168 if name == "" { 169 name = dbName(fld.Name) 170 } 171 names = append(names, name) 172 cols = append(cols, field{colIOType(fld.Type), base + fld.Offset}) 173 } 174 } 175 return names, cols 176 } 177 178 func (p *Table) defineCols(n int, t dbType, base uintptr) { 179 for i := 0; i < n; i++ { 180 fld := t.Field(i) 181 if fld.Anonymous { 182 fldType := fld.Type 183 p.defineCols(fldType.NumField(), fldType, base+fld.Offset) 184 continue 185 } 186 if fld.IsExported() { 187 col := &column{fld: field{colIOType(fld.Type), base + fld.Offset}} 188 if tag := string(fld.Tag); tag != "" { 189 if parts := strings.Fields(tag); len(parts) > 0 { 190 if c := parts[0][0]; c >= 'a' && c <= 'z' { // suppose a column name is lower case 191 col.name = parts[0] 192 parts = parts[1:] 193 } else { 194 col.name = dbName(fld.Name) 195 } 196 for _, part := range parts { 197 cmd, params := part, "" // cmd(params) 198 if pos := strings.IndexByte(part, '('); pos > 0 && part[len(part)-1] == ')' { 199 cmd, params = part[:pos], part[pos+1:len(part)-1] 200 } 201 switch cmd { 202 case `UNIQUE`: 203 p.uniqs = append(p.uniqs, &dbIndex{nil, col, params}) 204 case `INDEX`: 205 p.idxs = append(p.idxs, &dbIndex{nil, col, params}) 206 default: 207 if col.typ != "" { 208 log.Panicf("invalid tag `%s`: multiple column types?\n", tag) 209 } 210 col.typ = part 211 } 212 } 213 } 214 } 215 if col.name == "" { 216 col.name = dbName(fld.Name) 217 } 218 if col.typ == "" { 219 col.typ = columnType(fld.Type) 220 } 221 p.cols = append(p.cols, col) 222 } 223 } 224 } 225 226 func (p *Table) makeIndex(col *column, params string) []*column { 227 if params == "" { 228 return []*column{col} 229 } 230 pos := strings.IndexByte(params, ',') 231 if pos < 0 { 232 return []*column{col, p.getCol(params)} 233 } 234 ret := make([]*column, 1, 4) 235 ret[0] = col 236 for { 237 ret = append(ret, p.getCol(params[:pos])) 238 params = params[pos+1:] 239 pos = strings.IndexByte(params, ',') 240 if pos < 0 { 241 break 242 } 243 } 244 return append(ret, p.getCol(params)) 245 } 246 247 func (p *Table) getCol(name string) *column { 248 for _, col := range p.cols { 249 if col.name == name { 250 return col 251 } 252 } 253 log.Panicf("table `%s` doesn't have column `%s`\n", p.name, name) 254 return nil 255 } 256 257 // ----------------------------------------------------------------------------- 258 259 func (p *Table) create(ctx context.Context, sql *Sql) { 260 n := len(p.cols) 261 if n == 0 { 262 log.Panicln("empty table:", p.name, p.ver) 263 } 264 265 db := sql.db 266 query := make([]byte, 0, 64) 267 if sql.autodrop { 268 query = append(query, "DROP TABLE "...) 269 query = append(query, p.name...) 270 db.ExecContext(ctx, string(query)) 271 query = query[:0] 272 } 273 274 query = append(query, "CREATE TABLE "...) 275 query = append(query, p.name...) 276 query = append(query, ' ', '(') 277 for _, c := range p.cols { 278 query = append(query, c.name...) 279 query = append(query, ' ') 280 query = append(query, c.typ...) 281 query = append(query, ',') 282 } 283 query[len(query)-1] = ')' 284 285 q := string(query) 286 _, err := db.ExecContext(ctx, q) 287 if err != nil { 288 log.Panicf("%s\ncreate table (%s): %v\n", q, p.name, err) 289 } 290 291 for _, uniq := range p.uniqs { 292 cols := uniq.get(p) 293 name := indexName(cols, "uniq_", p.name) 294 createIndex(sql, db, ctx, "CREATE UNIQUE INDEX ", name, p.name, cols) 295 } 296 for _, idx := range p.idxs { 297 cols := idx.get(p) 298 name := indexName(cols, "idx_", p.name) 299 createIndex(sql, db, ctx, "CREATE INDEX ", name, p.name, cols) 300 } 301 } 302 303 // prefix_tbl_name1_name2_... 304 func indexName(cols []*column, prefix, tbl string) string { 305 n := len(prefix) + len(tbl) 306 for _, col := range cols { 307 n += 1 + len(col.name) 308 } 309 b := make([]byte, 0, n) 310 b = append(b, prefix...) 311 b = append(b, tbl...) 312 for _, col := range cols { 313 b = append(b, '_') 314 b = append(b, col.name...) 315 } 316 return stringutil.String(b) 317 } 318 319 func createIndex(sql *Sql, db *sql.DB, ctx context.Context, cmd string, name, tbl string, cols []*column) { 320 parts := make([]string, 0, 5+2*len(cols)) 321 parts = append(parts, cmd, name, " ON ", tbl, "(") 322 for _, col := range cols { 323 parts = append(parts, col.name, ",") 324 } 325 parts[len(parts)-1] = ")" 326 query := stringutil.Concat(parts...) 327 if _, err := db.ExecContext(ctx, query); err != nil { 328 log.Panicf("%s\ncreate index `%s`: %v\n", query, name, err) 329 } 330 } 331 332 // -----------------------------------------------------------------------------