go-ml.dev/pkg/base@v0.0.0-20200610162856-60c38abac71b/tables/rdb/sql.go (about) 1 package rdb 2 3 import ( 4 "database/sql" 5 "fmt" 6 "go-ml.dev/pkg/base/fu" 7 "go-ml.dev/pkg/base/fu/lazy" 8 "go-ml.dev/pkg/base/tables" 9 "go-ml.dev/pkg/iokit" 10 "go-ml.dev/pkg/zorros" 11 "io" 12 "reflect" 13 "strings" 14 // _ "github.com/go-sql-driver/mysql" 15 // _ "github.com/lib/pq" 16 // _ "github.com/mattn/go-sqlite3" 17 ) 18 19 func Read(source interface{}, opts ...interface{}) (*tables.Table, error) { 20 return Source(source, opts...).Collect() 21 } 22 23 func Write(source interface{}, t *tables.Table, opts ...interface{}) error { 24 return t.Lazy().Drain(Sink(source, opts...)) 25 } 26 27 type dontclose bool 28 29 func connectDB(source interface{}, opts []interface{}) (db *sql.DB, o []interface{}, err error) { 30 o = opts 31 if url, ok := source.(string); ok { 32 drv, conn := splitDriver(url) 33 o = append(o, Driver(drv)) 34 db, err = sql.Open(drv, conn) 35 } else if db, ok = source.(*sql.DB); !ok { 36 err = zorros.Errorf("unknown database source %v", source) 37 } else { 38 o = append(o, dontclose(true)) 39 } 40 return 41 } 42 43 func Source(source interface{}, opts ...interface{}) tables.Lazy { 44 return func() lazy.Stream { 45 db, opts, err := connectDB(source, opts) 46 cls := io.Closer(iokit.CloserChain{}) 47 if !fu.BoolOption(dontclose(false), opts) { 48 cls = db 49 } 50 if err != nil { 51 tables.SourceError(zorros.Wrapf(err, "database connection error: %s", err.Error())) 52 } 53 drv := fu.StrOption(Driver(""), opts) 54 schema := fu.StrOption(Schema(""), opts) 55 if schema != "" { 56 switch drv { 57 case "mysql": 58 _, err = db.Exec("use " + schema) 59 case "postgres": 60 _, err = db.Exec("set search_path to " + schema) 61 } 62 } 63 if err != nil { 64 cls.Close() 65 return lazy.Error(zorros.Wrapf(err, "query error: %s", err.Error())) 66 } 67 query := fu.StrOption(Query(""), opts) 68 if query == "" { 69 table := fu.StrOption(Table(""), opts) 70 if table != "" { 71 query = "select * from " + table 72 } else { 73 panic("there is no query or table") 74 } 75 } 76 rows, err := db.Query(query) 77 if err != nil { 78 cls.Close() 79 return lazy.Error(zorros.Wrapf(err, "query error: %s", err.Error())) 80 } 81 cls = iokit.CloserChain{rows, cls} 82 tps, err := rows.ColumnTypes() 83 if err != nil { 84 cls.Close() 85 return lazy.Error(zorros.Wrapf(err, "get types error: %s", err.Error())) 86 } 87 ns, err := rows.Columns() 88 if err != nil { 89 cls.Close() 90 return lazy.Error(zorros.Wrapf(err, "get names error: %s", err.Error())) 91 } 92 x := make([]interface{}, len(ns)) 93 describe, err := Describe(ns, opts) 94 if err != nil { 95 cls.Close() 96 return lazy.Error(err) 97 } 98 names := make([]string, len(ns)) 99 for i, n := range ns { 100 var s SqlScan 101 colType, colName, _ := describe(n) 102 if colType != "" { 103 s = scanner(colType) 104 } else { 105 s = scanner(tps[i].DatabaseTypeName()) 106 } 107 x[i] = s 108 names[i] = colName 109 } 110 111 wc := fu.WaitCounter{Value: 0} 112 f := fu.AtomicFlag{Value: 0} 113 114 return func(index uint64) (reflect.Value, error) { 115 if index == lazy.STOP { 116 wc.Stop() 117 return reflect.ValueOf(false), nil 118 } 119 if wc.Wait(index) { 120 end := !rows.Next() 121 if !end { 122 rows.Scan(x...) 123 lr := fu.Struct{Names: names, Columns: make([]reflect.Value, len(ns))} 124 for i := range x { 125 y := x[i].(SqlScan) 126 v, ok := y.Value() 127 if !ok { 128 lr.Na.Set(i, true) 129 } 130 lr.Columns[i] = v 131 } 132 wc.Inc() 133 return reflect.ValueOf(lr), nil 134 } 135 wc.Stop() 136 } 137 if f.Set() { 138 cls.Close() 139 } 140 return reflect.ValueOf(false), nil 141 } 142 } 143 } 144 145 func splitDriver(url string) (string, string) { 146 q := strings.SplitN(url, ":", 2) 147 return q[0], q[1] 148 } 149 150 func scanner(q string) SqlScan { 151 switch q { 152 case "VARCHAR", "TEXT", "CHAR", "STRING": 153 return &SqlString{} 154 case "INT8", "SMALLINT", "INT2": 155 return &SqlSmall{} 156 case "INTEGER", "INT", "INT4": 157 return &SqlInteger{} 158 case "BIGINT": 159 return &SqlBigint{} 160 case "BOOLEAN": 161 return &SqlBool{} 162 case "DECIMAL", "NUMERIC", "REAL", "DOUBLE", "FLOAT8": 163 return &SqlDouble{} 164 case "FLOAT", "FLOAT4": 165 return &SqlFloat{} 166 case "DATE", "DATETIME", "TIMESTAMP": 167 return &SqlTimestamp{} 168 default: 169 if strings.Index(q, "VARCHAR(") == 0 || 170 strings.Index(q, "CHAR(") == 0 { 171 return &SqlString{} 172 } 173 if strings.Index(q, "DECIMAL(") == 0 || 174 strings.Index(q, "NUMERIC(") == 0 { 175 return &SqlDouble{} 176 } 177 } 178 panic("unknown column type " + q) 179 } 180 181 func batchInsertStmt(tx *sql.Tx, names []string, pk []bool, lines int, table string, opts []interface{}) (stmt *sql.Stmt, err error) { 182 drv := fu.StrOption(Driver(""), opts) 183 ifExists := fu.Option(ErrorIfExists, opts).Interface().(IfExists_) 184 L := len(names) 185 q1 := " values " 186 for j := 0; j < lines; j++ { 187 q1 += "(" 188 if drv == "postgres" { 189 for k := range names { 190 q1 += fmt.Sprintf("$%d,", j*L+k+1) 191 } 192 } else { 193 q1 += strings.Repeat("?,", L) 194 } 195 q1 = q1[:len(q1)-1] + ")," 196 } 197 q := "insert into " + table + "(" + strings.Join(names, ",") + ")" + q1[:len(q1)-1] 198 199 if ifExists == InsertUpdateIfExists { 200 if len(pk) > 0 { 201 q += " on duplicate key update " 202 for i, n := range names { 203 if !pk[i] { 204 q += " " + n + " = values(" + n + ")," 205 } 206 } 207 q = q[:len(q)-1] 208 } 209 } 210 stmt, err = tx.Prepare(q) 211 return 212 } 213 214 func Sink(source interface{}, opts ...interface{}) tables.Sink { 215 db, opts, err := connectDB(source, opts) 216 cls := io.Closer(iokit.CloserChain{}) 217 if !fu.BoolOption(dontclose(false), opts) { 218 cls = db 219 } 220 if err != nil { 221 return tables.SinkError(zorros.Errorf("database connection error: %w", err)) 222 } 223 drv := fu.StrOption(Driver(""), opts) 224 225 schema := fu.StrOption(Schema(""), opts) 226 if schema != "" { 227 switch drv { 228 case "mysql": 229 _, err = db.Exec("use " + schema) 230 case "postgres": 231 _, err = db.Exec("set search_path to " + schema) 232 } 233 } 234 if err != nil { 235 cls.Close() 236 return tables.SinkError(zorros.Wrapf(err, "query error: %s", err.Error())) 237 } 238 239 tx, err := db.Begin() 240 if err != nil { 241 cls.Close() 242 return tables.SinkError(zorros.Wrapf(err, "database begin transaction error: %s", err.Error())) 243 } 244 245 table := fu.StrOption(Table(""), opts) 246 if table == "" { 247 panic("there is no table") 248 } 249 if fu.Option(ErrorIfExists, opts).Interface().(IfExists_) == DropIfExists { 250 _, err := tx.Exec(sqlDropQuery(table, opts...)) 251 if err != nil { 252 cls.Close() 253 return tables.SinkError(zorros.Wrapf(err, "drop table error: %s", err.Error())) 254 } 255 } 256 257 batchLen := fu.IntOption(Batch(1), opts) 258 var stmt *sql.Stmt 259 created := false 260 batch := []interface{}{} 261 names := []string{} 262 pk := []bool{} 263 return func(val reflect.Value) (err error) { 264 var describe func(int) (string, string, bool) 265 if val.Kind() == reflect.Bool { 266 if val.Bool() { 267 if len(batch) > 0 { 268 if stmt, err = batchInsertStmt(tx, names, pk, len(batch)/len(names), table, opts); err == nil { 269 if _, err = stmt.Exec(batch...); err == nil { 270 cls = iokit.CloserChain{stmt, cls} 271 } 272 } 273 } 274 if err == nil { 275 err = tx.Commit() 276 } 277 } 278 cls.Close() 279 return 280 } 281 lr := val.Interface().(fu.Struct) 282 names = make([]string, len(lr.Names)) 283 pk = make([]bool, len(lr.Names)) 284 drv := fu.StrOption(Driver(""), opts) 285 dsx, err := Describe(lr.Names, opts) 286 if err != nil { 287 cls.Close() 288 return 289 } 290 describe = func(i int) (colType, colName string, isPk bool) { 291 v := lr.Names[i] 292 colType, colName, isPk = dsx(v) 293 if colType == "" { 294 colType = sqlTypeOf(lr.Columns[i].Type(), drv) 295 } 296 return 297 } 298 for i := range names { 299 _, names[i], pk[i] = describe(i) 300 } 301 if !created { 302 _, err = tx.Exec(sqlCreateQuery(lr, table, describe, opts)) 303 if err != nil { 304 cls.Close() 305 return zorros.Wrapf(err, "create table error: %s", err.Error()) 306 } 307 created = true 308 } 309 if len(batch)/len(names) >= batchLen { 310 if stmt == nil { 311 stmt, err = batchInsertStmt(tx, names, pk, len(batch)/len(names), table, opts) 312 if err != nil { 313 return err 314 } 315 cls = iokit.CloserChain{stmt, cls} 316 } 317 _, err = stmt.Exec(batch...) 318 if err != nil { 319 return err 320 } 321 batch = batch[:0] 322 } 323 for i := range lr.Names { 324 if lr.Na.Bit(i) { 325 batch = append(batch, nil) 326 } else { 327 batch = append(batch, lr.Columns[i].Interface()) 328 } 329 } 330 return 331 } 332 } 333 334 func sqlCreateQuery(lr fu.Struct, table string, describe func(int) (string, string, bool), opts []interface{}) string { 335 pk := []string{} 336 query := "create table " 337 338 ifExists := fu.Option(ErrorIfExists, opts).Interface().(IfExists_) 339 if ifExists != ErrorIfExists && ifExists != DropIfExists { 340 query += "if not exists " 341 } 342 343 query = query + table + "( " 344 for i := range lr.Names { 345 if i != 0 { 346 query += ", " 347 } 348 colType, colName, isPK := describe(i) 349 query = query + colName + " " + colType 350 if isPK { 351 pk = append(pk, colName) 352 } 353 } 354 355 if len(pk) > 0 { 356 query = query + ", primary key (" + strings.Join(pk, ",") + ")" 357 } 358 359 query += " )" 360 return query 361 } 362 363 func sqlDropQuery(table string, opts ...interface{}) string { 364 schema := fu.StrOption(Schema(""), opts) 365 if schema != "" { 366 schema = schema + "." 367 } 368 return "drop table if exists " + schema + table 369 } 370 371 func sqlTypeOf(tp reflect.Type, driver string) string { 372 switch tp.Kind() { 373 case reflect.String: 374 if driver == "postgres" { 375 return "VARCHAR(65535)" /* redshift TEXT == VARCHAR(256) */ 376 } 377 return "TEXT" 378 case reflect.Int8, reflect.Uint8, reflect.Int16: 379 return "SMALLINT" 380 case reflect.Uint16, reflect.Int32, reflect.Int: 381 return "INTEGER" 382 case reflect.Uint, reflect.Uint32, reflect.Int64, reflect.Uint64: 383 return "BIGINT" 384 case reflect.Float32: 385 if driver == "postgres" { 386 return "REAL" /* redshift does not FLOAT */ 387 } 388 return "FLOAT" 389 case reflect.Float64: 390 if driver == "postgres" { 391 return "DOUBLE PRECISION" /* redshift does not have DOUBLE */ 392 } 393 return "DOUBLE" 394 case reflect.Bool: 395 return "BOOLEAN" 396 default: 397 if tp == fu.Ts { 398 return "DATETIME" 399 } 400 } 401 panic("unsupported data type " + fmt.Sprintf("%v %v", tp.String(), tp.Kind())) 402 }