go-ml.dev/pkg/base@v0.0.0-20200610162856-60c38abac71b/tables/rdb/sql.go (about)

     1  package rdb
     2  
     3  import (
     4  	"database/sql"
     5  	"fmt"
     6  	"go-ml.dev/pkg/base/fu"
     7  	"go-ml.dev/pkg/base/fu/lazy"
     8  	"go-ml.dev/pkg/base/tables"
     9  	"go-ml.dev/pkg/iokit"
    10  	"go-ml.dev/pkg/zorros"
    11  	"io"
    12  	"reflect"
    13  	"strings"
    14  	//	_ "github.com/go-sql-driver/mysql"
    15  	//	_ "github.com/lib/pq"
    16  	//	_ "github.com/mattn/go-sqlite3"
    17  )
    18  
    19  func Read(source interface{}, opts ...interface{}) (*tables.Table, error) {
    20  	return Source(source, opts...).Collect()
    21  }
    22  
    23  func Write(source interface{}, t *tables.Table, opts ...interface{}) error {
    24  	return t.Lazy().Drain(Sink(source, opts...))
    25  }
    26  
    27  type dontclose bool
    28  
    29  func connectDB(source interface{}, opts []interface{}) (db *sql.DB, o []interface{}, err error) {
    30  	o = opts
    31  	if url, ok := source.(string); ok {
    32  		drv, conn := splitDriver(url)
    33  		o = append(o, Driver(drv))
    34  		db, err = sql.Open(drv, conn)
    35  	} else if db, ok = source.(*sql.DB); !ok {
    36  		err = zorros.Errorf("unknown database source %v", source)
    37  	} else {
    38  		o = append(o, dontclose(true))
    39  	}
    40  	return
    41  }
    42  
    43  func Source(source interface{}, opts ...interface{}) tables.Lazy {
    44  	return func() lazy.Stream {
    45  		db, opts, err := connectDB(source, opts)
    46  		cls := io.Closer(iokit.CloserChain{})
    47  		if !fu.BoolOption(dontclose(false), opts) {
    48  			cls = db
    49  		}
    50  		if err != nil {
    51  			tables.SourceError(zorros.Wrapf(err, "database connection error: %s", err.Error()))
    52  		}
    53  		drv := fu.StrOption(Driver(""), opts)
    54  		schema := fu.StrOption(Schema(""), opts)
    55  		if schema != "" {
    56  			switch drv {
    57  			case "mysql":
    58  				_, err = db.Exec("use " + schema)
    59  			case "postgres":
    60  				_, err = db.Exec("set search_path to " + schema)
    61  			}
    62  		}
    63  		if err != nil {
    64  			cls.Close()
    65  			return lazy.Error(zorros.Wrapf(err, "query error: %s", err.Error()))
    66  		}
    67  		query := fu.StrOption(Query(""), opts)
    68  		if query == "" {
    69  			table := fu.StrOption(Table(""), opts)
    70  			if table != "" {
    71  				query = "select * from " + table
    72  			} else {
    73  				panic("there is no query or table")
    74  			}
    75  		}
    76  		rows, err := db.Query(query)
    77  		if err != nil {
    78  			cls.Close()
    79  			return lazy.Error(zorros.Wrapf(err, "query error: %s", err.Error()))
    80  		}
    81  		cls = iokit.CloserChain{rows, cls}
    82  		tps, err := rows.ColumnTypes()
    83  		if err != nil {
    84  			cls.Close()
    85  			return lazy.Error(zorros.Wrapf(err, "get types error: %s", err.Error()))
    86  		}
    87  		ns, err := rows.Columns()
    88  		if err != nil {
    89  			cls.Close()
    90  			return lazy.Error(zorros.Wrapf(err, "get names error: %s", err.Error()))
    91  		}
    92  		x := make([]interface{}, len(ns))
    93  		describe, err := Describe(ns, opts)
    94  		if err != nil {
    95  			cls.Close()
    96  			return lazy.Error(err)
    97  		}
    98  		names := make([]string, len(ns))
    99  		for i, n := range ns {
   100  			var s SqlScan
   101  			colType, colName, _ := describe(n)
   102  			if colType != "" {
   103  				s = scanner(colType)
   104  			} else {
   105  				s = scanner(tps[i].DatabaseTypeName())
   106  			}
   107  			x[i] = s
   108  			names[i] = colName
   109  		}
   110  
   111  		wc := fu.WaitCounter{Value: 0}
   112  		f := fu.AtomicFlag{Value: 0}
   113  
   114  		return func(index uint64) (reflect.Value, error) {
   115  			if index == lazy.STOP {
   116  				wc.Stop()
   117  				return reflect.ValueOf(false), nil
   118  			}
   119  			if wc.Wait(index) {
   120  				end := !rows.Next()
   121  				if !end {
   122  					rows.Scan(x...)
   123  					lr := fu.Struct{Names: names, Columns: make([]reflect.Value, len(ns))}
   124  					for i := range x {
   125  						y := x[i].(SqlScan)
   126  						v, ok := y.Value()
   127  						if !ok {
   128  							lr.Na.Set(i, true)
   129  						}
   130  						lr.Columns[i] = v
   131  					}
   132  					wc.Inc()
   133  					return reflect.ValueOf(lr), nil
   134  				}
   135  				wc.Stop()
   136  			}
   137  			if f.Set() {
   138  				cls.Close()
   139  			}
   140  			return reflect.ValueOf(false), nil
   141  		}
   142  	}
   143  }
   144  
   145  func splitDriver(url string) (string, string) {
   146  	q := strings.SplitN(url, ":", 2)
   147  	return q[0], q[1]
   148  }
   149  
   150  func scanner(q string) SqlScan {
   151  	switch q {
   152  	case "VARCHAR", "TEXT", "CHAR", "STRING":
   153  		return &SqlString{}
   154  	case "INT8", "SMALLINT", "INT2":
   155  		return &SqlSmall{}
   156  	case "INTEGER", "INT", "INT4":
   157  		return &SqlInteger{}
   158  	case "BIGINT":
   159  		return &SqlBigint{}
   160  	case "BOOLEAN":
   161  		return &SqlBool{}
   162  	case "DECIMAL", "NUMERIC", "REAL", "DOUBLE", "FLOAT8":
   163  		return &SqlDouble{}
   164  	case "FLOAT", "FLOAT4":
   165  		return &SqlFloat{}
   166  	case "DATE", "DATETIME", "TIMESTAMP":
   167  		return &SqlTimestamp{}
   168  	default:
   169  		if strings.Index(q, "VARCHAR(") == 0 ||
   170  			strings.Index(q, "CHAR(") == 0 {
   171  			return &SqlString{}
   172  		}
   173  		if strings.Index(q, "DECIMAL(") == 0 ||
   174  			strings.Index(q, "NUMERIC(") == 0 {
   175  			return &SqlDouble{}
   176  		}
   177  	}
   178  	panic("unknown column type " + q)
   179  }
   180  
   181  func batchInsertStmt(tx *sql.Tx, names []string, pk []bool, lines int, table string, opts []interface{}) (stmt *sql.Stmt, err error) {
   182  	drv := fu.StrOption(Driver(""), opts)
   183  	ifExists := fu.Option(ErrorIfExists, opts).Interface().(IfExists_)
   184  	L := len(names)
   185  	q1 := " values "
   186  	for j := 0; j < lines; j++ {
   187  		q1 += "("
   188  		if drv == "postgres" {
   189  			for k := range names {
   190  				q1 += fmt.Sprintf("$%d,", j*L+k+1)
   191  			}
   192  		} else {
   193  			q1 += strings.Repeat("?,", L)
   194  		}
   195  		q1 = q1[:len(q1)-1] + "),"
   196  	}
   197  	q := "insert into " + table + "(" + strings.Join(names, ",") + ")" + q1[:len(q1)-1]
   198  
   199  	if ifExists == InsertUpdateIfExists {
   200  		if len(pk) > 0 {
   201  			q += " on duplicate key update "
   202  			for i, n := range names {
   203  				if !pk[i] {
   204  					q += " " + n + " = values(" + n + "),"
   205  				}
   206  			}
   207  			q = q[:len(q)-1]
   208  		}
   209  	}
   210  	stmt, err = tx.Prepare(q)
   211  	return
   212  }
   213  
   214  func Sink(source interface{}, opts ...interface{}) tables.Sink {
   215  	db, opts, err := connectDB(source, opts)
   216  	cls := io.Closer(iokit.CloserChain{})
   217  	if !fu.BoolOption(dontclose(false), opts) {
   218  		cls = db
   219  	}
   220  	if err != nil {
   221  		return tables.SinkError(zorros.Errorf("database connection error: %w", err))
   222  	}
   223  	drv := fu.StrOption(Driver(""), opts)
   224  
   225  	schema := fu.StrOption(Schema(""), opts)
   226  	if schema != "" {
   227  		switch drv {
   228  		case "mysql":
   229  			_, err = db.Exec("use " + schema)
   230  		case "postgres":
   231  			_, err = db.Exec("set search_path to " + schema)
   232  		}
   233  	}
   234  	if err != nil {
   235  		cls.Close()
   236  		return tables.SinkError(zorros.Wrapf(err, "query error: %s", err.Error()))
   237  	}
   238  
   239  	tx, err := db.Begin()
   240  	if err != nil {
   241  		cls.Close()
   242  		return tables.SinkError(zorros.Wrapf(err, "database begin transaction error: %s", err.Error()))
   243  	}
   244  
   245  	table := fu.StrOption(Table(""), opts)
   246  	if table == "" {
   247  		panic("there is no table")
   248  	}
   249  	if fu.Option(ErrorIfExists, opts).Interface().(IfExists_) == DropIfExists {
   250  		_, err := tx.Exec(sqlDropQuery(table, opts...))
   251  		if err != nil {
   252  			cls.Close()
   253  			return tables.SinkError(zorros.Wrapf(err, "drop table error: %s", err.Error()))
   254  		}
   255  	}
   256  
   257  	batchLen := fu.IntOption(Batch(1), opts)
   258  	var stmt *sql.Stmt
   259  	created := false
   260  	batch := []interface{}{}
   261  	names := []string{}
   262  	pk := []bool{}
   263  	return func(val reflect.Value) (err error) {
   264  		var describe func(int) (string, string, bool)
   265  		if val.Kind() == reflect.Bool {
   266  			if val.Bool() {
   267  				if len(batch) > 0 {
   268  					if stmt, err = batchInsertStmt(tx, names, pk, len(batch)/len(names), table, opts); err == nil {
   269  						if _, err = stmt.Exec(batch...); err == nil {
   270  							cls = iokit.CloserChain{stmt, cls}
   271  						}
   272  					}
   273  				}
   274  				if err == nil {
   275  					err = tx.Commit()
   276  				}
   277  			}
   278  			cls.Close()
   279  			return
   280  		}
   281  		lr := val.Interface().(fu.Struct)
   282  		names = make([]string, len(lr.Names))
   283  		pk = make([]bool, len(lr.Names))
   284  		drv := fu.StrOption(Driver(""), opts)
   285  		dsx, err := Describe(lr.Names, opts)
   286  		if err != nil {
   287  			cls.Close()
   288  			return
   289  		}
   290  		describe = func(i int) (colType, colName string, isPk bool) {
   291  			v := lr.Names[i]
   292  			colType, colName, isPk = dsx(v)
   293  			if colType == "" {
   294  				colType = sqlTypeOf(lr.Columns[i].Type(), drv)
   295  			}
   296  			return
   297  		}
   298  		for i := range names {
   299  			_, names[i], pk[i] = describe(i)
   300  		}
   301  		if !created {
   302  			_, err = tx.Exec(sqlCreateQuery(lr, table, describe, opts))
   303  			if err != nil {
   304  				cls.Close()
   305  				return zorros.Wrapf(err, "create table error: %s", err.Error())
   306  			}
   307  			created = true
   308  		}
   309  		if len(batch)/len(names) >= batchLen {
   310  			if stmt == nil {
   311  				stmt, err = batchInsertStmt(tx, names, pk, len(batch)/len(names), table, opts)
   312  				if err != nil {
   313  					return err
   314  				}
   315  				cls = iokit.CloserChain{stmt, cls}
   316  			}
   317  			_, err = stmt.Exec(batch...)
   318  			if err != nil {
   319  				return err
   320  			}
   321  			batch = batch[:0]
   322  		}
   323  		for i := range lr.Names {
   324  			if lr.Na.Bit(i) {
   325  				batch = append(batch, nil)
   326  			} else {
   327  				batch = append(batch, lr.Columns[i].Interface())
   328  			}
   329  		}
   330  		return
   331  	}
   332  }
   333  
   334  func sqlCreateQuery(lr fu.Struct, table string, describe func(int) (string, string, bool), opts []interface{}) string {
   335  	pk := []string{}
   336  	query := "create table "
   337  
   338  	ifExists := fu.Option(ErrorIfExists, opts).Interface().(IfExists_)
   339  	if ifExists != ErrorIfExists && ifExists != DropIfExists {
   340  		query += "if not exists "
   341  	}
   342  
   343  	query = query + table + "( "
   344  	for i := range lr.Names {
   345  		if i != 0 {
   346  			query += ", "
   347  		}
   348  		colType, colName, isPK := describe(i)
   349  		query = query + colName + " " + colType
   350  		if isPK {
   351  			pk = append(pk, colName)
   352  		}
   353  	}
   354  
   355  	if len(pk) > 0 {
   356  		query = query + ", primary key (" + strings.Join(pk, ",") + ")"
   357  	}
   358  
   359  	query += " )"
   360  	return query
   361  }
   362  
   363  func sqlDropQuery(table string, opts ...interface{}) string {
   364  	schema := fu.StrOption(Schema(""), opts)
   365  	if schema != "" {
   366  		schema = schema + "."
   367  	}
   368  	return "drop table if exists " + schema + table
   369  }
   370  
   371  func sqlTypeOf(tp reflect.Type, driver string) string {
   372  	switch tp.Kind() {
   373  	case reflect.String:
   374  		if driver == "postgres" {
   375  			return "VARCHAR(65535)" /* redshift TEXT == VARCHAR(256) */
   376  		}
   377  		return "TEXT"
   378  	case reflect.Int8, reflect.Uint8, reflect.Int16:
   379  		return "SMALLINT"
   380  	case reflect.Uint16, reflect.Int32, reflect.Int:
   381  		return "INTEGER"
   382  	case reflect.Uint, reflect.Uint32, reflect.Int64, reflect.Uint64:
   383  		return "BIGINT"
   384  	case reflect.Float32:
   385  		if driver == "postgres" {
   386  			return "REAL" /* redshift does not FLOAT */
   387  		}
   388  		return "FLOAT"
   389  	case reflect.Float64:
   390  		if driver == "postgres" {
   391  			return "DOUBLE PRECISION" /* redshift does not have DOUBLE */
   392  		}
   393  		return "DOUBLE"
   394  	case reflect.Bool:
   395  		return "BOOLEAN"
   396  	default:
   397  		if tp == fu.Ts {
   398  			return "DATETIME"
   399  		}
   400  	}
   401  	panic("unsupported data type " + fmt.Sprintf("%v %v", tp.String(), tp.Kind()))
   402  }