github.com/goplus/yap@v0.8.1/ydb/table.go (about)

     1  /*
     2   * Copyright (c) 2024 The GoPlus Authors (goplus.org). All rights reserved.
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package ydb
    18  
    19  import (
    20  	"context"
    21  	"database/sql"
    22  	"database/sql/driver"
    23  	"log"
    24  	"reflect"
    25  	"strings"
    26  	"time"
    27  	"unsafe"
    28  
    29  	"github.com/goplus/yap/reflectutil"
    30  	"github.com/qiniu/x/stringutil"
    31  )
    32  
    33  // -----------------------------------------------------------------------------
    34  
    35  type nullTime time.Time
    36  
    37  func (n *nullTime) Scan(value any) (err error) {
    38  	var ret sql.NullTime
    39  	err = ret.Scan(value)
    40  	*(*time.Time)(n) = ret.Time
    41  	return
    42  }
    43  
    44  func (n nullTime) Value() (driver.Value, error) {
    45  	if (*time.Time)(&n).IsZero() {
    46  		return nil, nil
    47  	}
    48  	return *(*time.Time)(&n), nil
    49  }
    50  
    51  // -----------------------------------------------------------------------------
    52  
    53  type dbType = reflect.Type
    54  type ioType = reflect.Type
    55  
    56  var (
    57  	tyString   = reflect.TypeOf("")
    58  	tyInt      = reflect.TypeOf(0)
    59  	tyBool     = reflect.TypeOf(false)
    60  	tyBlob     = reflect.TypeOf([]byte(nil))
    61  	tyTime     = reflect.TypeOf(time.Time{})
    62  	tyNullTime = reflect.TypeOf(nullTime{})
    63  	tyFloat64  = reflect.TypeOf(float64(0))
    64  	tyFloat32  = reflect.TypeOf(float32(0))
    65  )
    66  
    67  func columnType(fldType dbType) string {
    68  	switch fldType {
    69  	case tyString:
    70  		return "TEXT"
    71  	case tyInt:
    72  		return "INT"
    73  	case tyBool:
    74  		return "BOOL"
    75  	case tyBlob:
    76  		return "BLOB"
    77  	case tyTime:
    78  		return "DATETIME"
    79  	case tyFloat64:
    80  		return "DOUBLE"
    81  	case tyFloat32:
    82  		return "FLOAT"
    83  	}
    84  	panic("unknown column type: " + fldType.String())
    85  }
    86  
    87  func colIOType(fldType dbType) ioType {
    88  	if fldType == tyTime {
    89  		return tyNullTime
    90  	}
    91  	return fldType
    92  }
    93  
    94  // -----------------------------------------------------------------------------
    95  
    96  type dbIndex struct {
    97  	index  []*column
    98  	col    *column
    99  	params string
   100  }
   101  
   102  func (p *dbIndex) get(tbl *Table) []*column {
   103  	if p.index == nil {
   104  		p.index = tbl.makeIndex(p.col, p.params)
   105  	}
   106  	return p.index
   107  }
   108  
   109  type Table struct {
   110  	name   string
   111  	ver    string
   112  	schema dbType
   113  	cols   []*column
   114  	uniqs  []*dbIndex
   115  	idxs   []*dbIndex
   116  }
   117  
   118  type column struct {
   119  	typ  string // type in DB
   120  	name string // column name
   121  	fld  field
   122  }
   123  
   124  type field struct {
   125  	typ    ioType  // field io type
   126  	offset uintptr // offset within struct, in bytes
   127  }
   128  
   129  func newTable(name, ver string, schema dbType) *Table {
   130  	n := schema.NumField()
   131  	cols := make([]*column, 0, n)
   132  	p := &Table{name: name, ver: ver, schema: schema, cols: cols}
   133  	p.defineCols(n, schema, 0)
   134  	return p
   135  }
   136  
   137  func getVals(vals []any, v reflect.Value, cols []field, elem bool) []any {
   138  	this := reflectutil.UnsafeAddr(v)
   139  	for _, col := range cols {
   140  		v := reflect.NewAt(col.typ, unsafe.Pointer(this+col.offset))
   141  		if elem {
   142  			v = v.Elem()
   143  		}
   144  		val := v.Interface()
   145  		vals = append(vals, val)
   146  	}
   147  	return vals
   148  }
   149  
   150  func getCols(names []string, cols []field, n int, t dbType, base uintptr) ([]string, []field) {
   151  	for i := 0; i < n; i++ {
   152  		fld := t.Field(i)
   153  		if fld.Anonymous {
   154  			fldType := fld.Type
   155  			names, cols = getCols(names, cols, fldType.NumField(), fldType, base+fld.Offset)
   156  			continue
   157  		}
   158  		if fld.IsExported() {
   159  			name := ""
   160  			if tag := string(fld.Tag); tag != "" {
   161  				if c := tag[0]; c >= 'a' && c <= 'z' { // suppose a column name is lower case
   162  					if pos := strings.IndexByte(tag, ' '); pos > 0 {
   163  						tag = tag[:pos]
   164  					}
   165  					name = tag
   166  				}
   167  			}
   168  			if name == "" {
   169  				name = dbName(fld.Name)
   170  			}
   171  			names = append(names, name)
   172  			cols = append(cols, field{colIOType(fld.Type), base + fld.Offset})
   173  		}
   174  	}
   175  	return names, cols
   176  }
   177  
   178  func (p *Table) defineCols(n int, t dbType, base uintptr) {
   179  	for i := 0; i < n; i++ {
   180  		fld := t.Field(i)
   181  		if fld.Anonymous {
   182  			fldType := fld.Type
   183  			p.defineCols(fldType.NumField(), fldType, base+fld.Offset)
   184  			continue
   185  		}
   186  		if fld.IsExported() {
   187  			col := &column{fld: field{colIOType(fld.Type), base + fld.Offset}}
   188  			if tag := string(fld.Tag); tag != "" {
   189  				if parts := strings.Fields(tag); len(parts) > 0 {
   190  					if c := parts[0][0]; c >= 'a' && c <= 'z' { // suppose a column name is lower case
   191  						col.name = parts[0]
   192  						parts = parts[1:]
   193  					} else {
   194  						col.name = dbName(fld.Name)
   195  					}
   196  					for _, part := range parts {
   197  						cmd, params := part, "" // cmd(params)
   198  						if pos := strings.IndexByte(part, '('); pos > 0 && part[len(part)-1] == ')' {
   199  							cmd, params = part[:pos], part[pos+1:len(part)-1]
   200  						}
   201  						switch cmd {
   202  						case `UNIQUE`:
   203  							p.uniqs = append(p.uniqs, &dbIndex{nil, col, params})
   204  						case `INDEX`:
   205  							p.idxs = append(p.idxs, &dbIndex{nil, col, params})
   206  						default:
   207  							if col.typ != "" {
   208  								log.Panicf("invalid tag `%s`: multiple column types?\n", tag)
   209  							}
   210  							col.typ = part
   211  						}
   212  					}
   213  				}
   214  			}
   215  			if col.name == "" {
   216  				col.name = dbName(fld.Name)
   217  			}
   218  			if col.typ == "" {
   219  				col.typ = columnType(fld.Type)
   220  			}
   221  			p.cols = append(p.cols, col)
   222  		}
   223  	}
   224  }
   225  
   226  func (p *Table) makeIndex(col *column, params string) []*column {
   227  	if params == "" {
   228  		return []*column{col}
   229  	}
   230  	pos := strings.IndexByte(params, ',')
   231  	if pos < 0 {
   232  		return []*column{col, p.getCol(params)}
   233  	}
   234  	ret := make([]*column, 1, 4)
   235  	ret[0] = col
   236  	for {
   237  		ret = append(ret, p.getCol(params[:pos]))
   238  		params = params[pos+1:]
   239  		pos = strings.IndexByte(params, ',')
   240  		if pos < 0 {
   241  			break
   242  		}
   243  	}
   244  	return append(ret, p.getCol(params))
   245  }
   246  
   247  func (p *Table) getCol(name string) *column {
   248  	for _, col := range p.cols {
   249  		if col.name == name {
   250  			return col
   251  		}
   252  	}
   253  	log.Panicf("table `%s` doesn't have column `%s`\n", p.name, name)
   254  	return nil
   255  }
   256  
   257  // -----------------------------------------------------------------------------
   258  
   259  func (p *Table) create(ctx context.Context, sql *Sql) {
   260  	n := len(p.cols)
   261  	if n == 0 {
   262  		log.Panicln("empty table:", p.name, p.ver)
   263  	}
   264  
   265  	db := sql.db
   266  	query := make([]byte, 0, 64)
   267  	if sql.autodrop {
   268  		query = append(query, "DROP TABLE "...)
   269  		query = append(query, p.name...)
   270  		db.ExecContext(ctx, string(query))
   271  		query = query[:0]
   272  	}
   273  
   274  	query = append(query, "CREATE TABLE "...)
   275  	query = append(query, p.name...)
   276  	query = append(query, ' ', '(')
   277  	for _, c := range p.cols {
   278  		query = append(query, c.name...)
   279  		query = append(query, ' ')
   280  		query = append(query, c.typ...)
   281  		query = append(query, ',')
   282  	}
   283  	query[len(query)-1] = ')'
   284  
   285  	q := string(query)
   286  	_, err := db.ExecContext(ctx, q)
   287  	if err != nil {
   288  		log.Panicf("%s\ncreate table (%s): %v\n", q, p.name, err)
   289  	}
   290  
   291  	for _, uniq := range p.uniqs {
   292  		cols := uniq.get(p)
   293  		name := indexName(cols, "uniq_", p.name)
   294  		createIndex(sql, db, ctx, "CREATE UNIQUE INDEX ", name, p.name, cols)
   295  	}
   296  	for _, idx := range p.idxs {
   297  		cols := idx.get(p)
   298  		name := indexName(cols, "idx_", p.name)
   299  		createIndex(sql, db, ctx, "CREATE INDEX ", name, p.name, cols)
   300  	}
   301  }
   302  
   303  // prefix_tbl_name1_name2_...
   304  func indexName(cols []*column, prefix, tbl string) string {
   305  	n := len(prefix) + len(tbl)
   306  	for _, col := range cols {
   307  		n += 1 + len(col.name)
   308  	}
   309  	b := make([]byte, 0, n)
   310  	b = append(b, prefix...)
   311  	b = append(b, tbl...)
   312  	for _, col := range cols {
   313  		b = append(b, '_')
   314  		b = append(b, col.name...)
   315  	}
   316  	return stringutil.String(b)
   317  }
   318  
   319  func createIndex(sql *Sql, db *sql.DB, ctx context.Context, cmd string, name, tbl string, cols []*column) {
   320  	parts := make([]string, 0, 5+2*len(cols))
   321  	parts = append(parts, cmd, name, " ON ", tbl, "(")
   322  	for _, col := range cols {
   323  		parts = append(parts, col.name, ",")
   324  	}
   325  	parts[len(parts)-1] = ")"
   326  	query := stringutil.Concat(parts...)
   327  	if _, err := db.ExecContext(ctx, query); err != nil {
   328  		log.Panicf("%s\ncreate index `%s`: %v\n", query, name, err)
   329  	}
   330  }
   331  
   332  // -----------------------------------------------------------------------------