go-hep.org/x/hep@v0.38.1/groot/rsql/rsqldrv/driver.go (about) 1 // Copyright ©2019 The go-hep Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 // Package rsqldrv registers a database/sql/driver.Driver implementation for ROOT files. 6 package rsqldrv // import "go-hep.org/x/hep/groot/rsql/rsqldrv" 7 8 import ( 9 "context" 10 "database/sql" 11 "database/sql/driver" 12 "errors" 13 "fmt" 14 "io" 15 "reflect" 16 "sync" 17 18 "github.com/xwb1989/sqlparser" 19 "go-hep.org/x/hep/groot/riofs" 20 "go-hep.org/x/hep/groot/rtree" 21 ) 22 23 const driverName = "root" 24 25 func init() { 26 sql.Register(driverName, &rootDriver{}) 27 } 28 29 // Open is a ROOT/SQL-driver helper function for sql.Open. 30 // 31 // It opens a database connection to the ROOT/SQL driver. 32 func Open(name string) (*sql.DB, error) { 33 return sql.Open(driverName, name) 34 } 35 36 // Create is a ROOT/SQL-driver helper function for sql.Open. 37 // 38 // It creates a new ROOT file, connected via the ROOT/SQL driver. 39 func Create(name string) (*sql.DB, error) { 40 panic("not implemented") // FIXME(sbinet) 41 } 42 43 // rootDriver implements the interface required by database/sql/driver. 44 type rootDriver struct { 45 mu sync.Mutex 46 dbs map[string]*driverConn 47 owns map[string]bool // whether the driver owns the ROOT files (and needs to close it) 48 } 49 50 func (drv *rootDriver) open(fname string) (driver.Conn, error) { 51 drv.mu.Lock() 52 defer drv.mu.Unlock() 53 if drv.dbs == nil { 54 drv.dbs = make(map[string]*driverConn) 55 } 56 if drv.owns == nil { 57 drv.owns = make(map[string]bool) 58 } 59 60 conn := drv.dbs[fname] 61 if conn == nil { 62 f, err := riofs.Open(fname) 63 if err != nil { 64 return nil, fmt.Errorf("rsqldriver: could not open file: %w", err) 65 } 66 67 conn = &driverConn{ 68 f: f, 69 // cfg: c, 70 drv: drv, 71 stop: make(map[*driverStmt]struct{}), 72 refs: 0, 73 } 74 75 drv.dbs[fname] = conn 76 drv.owns[fname] = true 77 } 78 conn.refs++ 79 80 return conn, nil 81 } 82 83 func (drv *rootDriver) connect(f *riofs.File) driver.Conn { 84 drv.mu.Lock() 85 defer drv.mu.Unlock() 86 if drv.dbs == nil { 87 drv.dbs = make(map[string]*driverConn) 88 } 89 if drv.owns == nil { 90 drv.owns = make(map[string]bool) 91 } 92 93 conn := drv.dbs[f.Name()] 94 if conn == nil { 95 conn = &driverConn{ 96 f: f, 97 //cfg: c, 98 drv: drv, 99 stop: make(map[*driverStmt]struct{}), 100 refs: 0, 101 } 102 drv.dbs[f.Name()] = conn 103 drv.owns[f.Name()] = false 104 } 105 conn.refs++ 106 107 return conn 108 } 109 110 // Open returns a new connection to the database. 111 // The name is a string in a driver-specific format. 112 // 113 // Open may return a cached connection (one previously 114 // closed), but doing so is unnecessary; the sql package 115 // maintains a pool of idle connections for efficient re-use. 116 // 117 // The returned connection is only used by one goroutine at a 118 // time. 119 func (drv *rootDriver) Open(name string) (driver.Conn, error) { 120 return drv.open(name) 121 } 122 123 type driverConn struct { 124 f *riofs.File 125 drv *rootDriver 126 stop map[*driverStmt]struct{} 127 refs int 128 } 129 130 // Prepare returns a prepared statement, bound to this connection. 131 func (conn *driverConn) Prepare(query string) (driver.Stmt, error) { 132 stmt, err := sqlparser.Parse(query) 133 if err != nil { 134 return nil, err 135 } 136 137 s := &driverStmt{conn: conn, stmt: stmt} 138 conn.stop[s] = struct{}{} 139 return s, nil 140 } 141 142 // Close invalidates and potentially stops any current 143 // prepared statements and transactions, marking this 144 // connection as no longer in use. 145 // 146 // Because the sql package maintains a free pool of 147 // connections and only calls Close when there's a surplus of 148 // idle connections, it shouldn't be necessary for drivers to 149 // do their own connection caching. 150 func (conn *driverConn) Close() error { 151 conn.drv.mu.Lock() 152 defer conn.drv.mu.Unlock() 153 154 if conn.refs > 1 { 155 conn.refs-- 156 return nil 157 } 158 159 for s := range conn.stop { 160 err := s.Close() 161 if err != nil { 162 return fmt.Errorf("rsqldrv: could not close statement %v: %w", s, err) 163 } 164 } 165 166 var err error 167 if conn.drv.owns[conn.f.Name()] { 168 err = conn.f.Close() 169 if err != nil { 170 return err 171 } 172 } 173 174 if conn.refs == 1 { 175 delete(conn.drv.dbs, conn.f.Name()) 176 } 177 conn.refs = 0 178 179 return err 180 } 181 182 // Begin starts and returns a new transaction. 183 func (conn *driverConn) Begin() (driver.Tx, error) { 184 panic("conn-begin: not implemented") 185 } 186 187 func (conn *driverConn) Commit() error { 188 panic("conn-commit: not implemented") 189 } 190 191 func (conn *driverConn) Rollback() error { 192 panic("conn-rollback: not implemented") 193 } 194 195 func (conn *driverConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { 196 stmt, err := sqlparser.Parse(query) 197 if err != nil { 198 return nil, err 199 } 200 201 return conn.exec(ctx, stmt, args) 202 } 203 204 func (conn *driverConn) exec(ctx context.Context, stmt sqlparser.Statement, args []driver.NamedValue) (driver.Result, error) { 205 panic("not implemented") 206 } 207 208 func (conn *driverConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { 209 stmt, err := sqlparser.Parse(query) 210 if err != nil { 211 return nil, err 212 } 213 return conn.query(ctx, stmt, args) 214 } 215 216 func (conn *driverConn) query(ctx context.Context, stmt sqlparser.Statement, args []driver.NamedValue) (driver.Rows, error) { 217 switch stmt := stmt.(type) { 218 case *sqlparser.Select: 219 rows, err := newDriverRows(ctx, conn, stmt, args) 220 return rows, err 221 } 222 panic("not implemented") 223 } 224 225 type driverResult struct { 226 id int64 // last inserted ID 227 rows int64 // rows affected 228 } 229 230 func (res *driverResult) LastInsertId() (int64, error) { return res.id, nil } // -golint 231 func (res *driverResult) RowsAffected() (int64, error) { return res.rows, nil } 232 233 // driverRows is an iterator over an executed query's results. 234 type driverRows struct { 235 conn *driverConn 236 args []driver.NamedValue 237 cols []string 238 types []colDescr // types of the columns 239 deps []string // names of the columns to be read 240 vars []any // values of the columns that were read 241 242 reader *rtree.Reader 243 row rowCtx 244 rows chan rowCtx 245 246 eval expression 247 filter expression 248 } 249 250 type colDescr struct { 251 Name string 252 Len int64 // -1 if no length. 253 Nullable bool 254 Type reflect.Type 255 } 256 257 func newDriverRows(ctx context.Context, conn *driverConn, stmt *sqlparser.Select, args []driver.NamedValue) (*driverRows, error) { 258 var ( 259 name = "" 260 f = conn.f 261 ) 262 263 switch len(stmt.From) { 264 case 1: 265 switch from := stmt.From[0].(type) { 266 case *sqlparser.AliasedTableExpr: 267 switch expr := from.Expr.(type) { 268 case sqlparser.TableName: 269 name = expr.Name.CompliantName() 270 default: 271 panic(fmt.Errorf("unknown FROM expression type: %#v", expr)) 272 } 273 274 default: 275 panic(fmt.Errorf("unknown table expression: %#v", from)) 276 } 277 278 default: 279 return nil, fmt.Errorf("rsqldrv: invalid number of tables (got=%d, want=1)", len(stmt.From)) 280 } 281 282 obj, err := riofs.Dir(f).Get(name) 283 if err != nil { 284 return nil, err 285 } 286 287 tree, ok := obj.(rtree.Tree) 288 if !ok { 289 return nil, fmt.Errorf("rsqldrv: object %q is not a Tree", name) 290 } 291 292 rows := &driverRows{conn: conn, args: args} 293 294 rows.cols, err = rows.extractColsFromSelect(tree, stmt, args) 295 if err != nil { 296 return nil, fmt.Errorf("could not extract columns: %w", err) 297 } 298 rows.types = make([]colDescr, len(rows.cols)) 299 for i, name := range rows.cols { 300 if name == "" { 301 rows.types[i].Type = reflect.TypeOf(new(any)).Elem() 302 continue 303 } 304 rows.types[i].Name = name 305 branch := tree.Branch(name) 306 if branch == nil { 307 rows.types[i].Type = reflect.TypeOf(new(any)).Elem() 308 continue 309 } 310 rows.types[i] = colDescrFromLeaf(branch.Leaves()[0]) // FIXME(sbinet): multi-leaves' branches 311 } 312 313 vars, err := rows.extractDepsFromSelect(tree, stmt, args) 314 if err != nil { 315 return nil, fmt.Errorf("could not extract read-vars: %w", err) 316 } 317 rows.vars = varsFrom(vars) 318 for _, v := range vars { 319 rows.deps = append(rows.deps, v.Name) 320 } 321 322 rows.reader, err = rtree.NewReader(tree, vars) 323 if err != nil { 324 return nil, err 325 } 326 327 switch expr := stmt.SelectExprs[0].(type) { // FIXME(sbinet): handle multiple select-expressions 328 case *sqlparser.AliasedExpr: 329 rows.eval, err = newExprFrom(expr.Expr, args) 330 if err != nil { 331 return nil, fmt.Errorf("could not generate row expression: %w", err) 332 } 333 case *sqlparser.StarExpr: 334 tuple := make(sqlparser.ValTuple, len(rows.cols)) 335 for i, name := range rows.cols { 336 tuple[i] = &sqlparser.ColName{Name: sqlparser.NewColIdent(name)} 337 } 338 rows.eval, err = newExprFrom(tuple, args) 339 if err != nil { 340 return nil, fmt.Errorf("could not generate row expression from 'select *': %w", err) 341 } 342 } 343 344 if stmt.Where != nil { 345 switch stmt.Where.Type { 346 case sqlparser.WhereStr: 347 rows.filter, err = newExprFrom(stmt.Where.Expr, args) 348 if err != nil { 349 return nil, err 350 } 351 default: 352 panic(fmt.Errorf("unknown 'where' type: %q", stmt.Where.Type)) 353 } 354 } 355 356 rows.start() 357 return rows, nil 358 } 359 360 func varsFrom(vars []rtree.ReadVar) []any { 361 vs := make([]any, len(vars)) 362 for i, v := range vars { 363 vs[i] = v.Value 364 } 365 return vs 366 } 367 368 // extractDepsFromSelect analyses the query and extracts the branches that need to be read 369 // for the query to be properly executed. 370 func (rows *driverRows) extractDepsFromSelect(tree rtree.Tree, stmt *sqlparser.Select, args []driver.NamedValue) ([]rtree.ReadVar, error) { 371 var ( 372 vars []rtree.ReadVar 373 374 set = make(map[string]struct{}) 375 cols []string 376 ) 377 378 markBranch := func(name string) { 379 if name != "" { 380 if _, dup := set[name]; !dup { 381 set[name] = struct{}{} 382 cols = append(cols, name) 383 } 384 } 385 } 386 387 collectCols := func(node sqlparser.SQLNode) (bool, error) { 388 switch node := node.(type) { 389 case *sqlparser.StarExpr: 390 other := node.TableName.Name.CompliantName() 391 switch other { 392 case "", tree.Name(): 393 for _, b := range tree.Branches() { 394 markBranch(b.Name()) 395 } 396 default: 397 panic(fmt.Errorf("rsqldrv: star-expression with other table name not supported")) 398 } 399 return false, nil 400 401 case sqlparser.ColIdent: 402 name := node.CompliantName() 403 markBranch(name) 404 return false, nil 405 406 default: 407 return true, nil 408 } 409 } 410 411 nodes := make([]sqlparser.SQLNode, len(stmt.SelectExprs)) 412 for i, expr := range stmt.SelectExprs { 413 nodes[i] = expr 414 } 415 416 if stmt.Where != nil { 417 nodes = append(nodes, stmt.Where.Expr) 418 } 419 420 err := sqlparser.Walk(collectCols, nodes...) 421 if err != nil { 422 return nil, err 423 } 424 425 for _, name := range cols { 426 branch := tree.Branch(name) 427 if branch == nil { 428 return nil, fmt.Errorf("rsqldrv: could not find branch/leaf %q in tree %q", name, tree.Name()) 429 } 430 leaf := branch.Leaves()[0] // FIXME(sbinet): handle sub-leaves 431 etyp := leaf.Type() 432 switch etyp.Kind() { 433 case reflect.Int8: 434 if leaf.IsUnsigned() { 435 etyp = reflect.TypeOf(uint8(0)) 436 } 437 case reflect.Int16: 438 if leaf.IsUnsigned() { 439 etyp = reflect.TypeOf(uint16(0)) 440 } 441 case reflect.Int32: 442 if leaf.IsUnsigned() { 443 etyp = reflect.TypeOf(uint32(0)) 444 } 445 case reflect.Int64: 446 if leaf.IsUnsigned() { 447 etyp = reflect.TypeOf(uint64(0)) 448 } 449 } 450 switch { 451 case leaf.LeafCount() != nil: 452 etyp = reflect.SliceOf(etyp) 453 case leaf.Len() > 1 && leaf.Kind() != reflect.String: 454 etyp = reflect.ArrayOf(leaf.Len(), etyp) 455 } 456 vars = append(vars, rtree.ReadVar{ 457 Name: branch.Name(), 458 Leaf: leaf.Name(), 459 Value: reflect.New(etyp).Interface(), 460 }) 461 } 462 463 return vars, nil 464 } 465 466 func (rows *driverRows) extractColsFromSelect(tree rtree.Tree, stmt *sqlparser.Select, args []driver.NamedValue) ([]string, error) { 467 var cols []string 468 469 collect := func(node sqlparser.SQLNode) (bool, error) { 470 switch node := node.(type) { 471 case *sqlparser.ColName: 472 return true, nil 473 case sqlparser.ColIdent: 474 cols = append(cols, node.CompliantName()) 475 return false, nil 476 case *sqlparser.ParenExpr: 477 return true, nil 478 case sqlparser.ValTuple: 479 return true, nil 480 case sqlparser.Exprs: 481 return true, nil 482 case *sqlparser.BinaryExpr: 483 // not a simple select query. 484 // add a dummy column name and stop recursion 485 cols = append(cols, "") 486 return false, nil 487 case *sqlparser.UnaryExpr: 488 return true, nil 489 case *sqlparser.SQLVal: 490 // not a simple select query. 491 // add a dummy column name and stop recursion 492 cols = append(cols, "") 493 return false, nil 494 } 495 return false, nil 496 } 497 498 switch expr := stmt.SelectExprs[0].(type) { // FIXME(sbinet): handle multiple select-expressions 499 case *sqlparser.AliasedExpr: 500 err := sqlparser.Walk(collect, expr.Expr) 501 return cols, err 502 503 case *sqlparser.StarExpr: 504 branches := make([]string, len(tree.Branches())) 505 for i, b := range tree.Branches() { 506 branches[i] = b.Name() 507 } 508 return branches, nil 509 510 default: 511 panic(fmt.Errorf("rsqldrv: invalid select-expr type %#v", expr)) 512 } 513 } 514 515 // Columns returns the names of the columns. The number of columns of the 516 // result is inferred from the length of the slice. If a particular column 517 // name isn't known, an empty string should be returned for that entry. 518 func (r *driverRows) Columns() []string { 519 cols := make([]string, len(r.cols)) 520 copy(cols, r.cols) 521 return cols 522 } 523 524 // ColumnTypeScanType returns the value type that can be used to scan types into. 525 // 526 // See database/sql/driver.RowsColumnTypeScanType. 527 func (r *driverRows) ColumnTypeScanType(i int) reflect.Type { 528 return r.types[i].Type 529 } 530 531 // ColumnTypeLength returns the column type length for variable length column types such 532 // as text and binary field types. If the type length is unbounded the value will 533 // be math.MaxInt64 (any database limits will still apply). 534 // If the column type is not variable length, such as an int, or if not supported 535 // by the driver ok is false. 536 func (r *driverRows) ColumnTypeLength(i int) (length int64, ok bool) { 537 col := r.types[i] 538 switch col.Len { 539 case -1: 540 return 0, false 541 } 542 return col.Len, true 543 } 544 545 // ColumnTypeNullable reports whether the column may be null. 546 func (r *driverRows) ColumnTypeNullable(i int) (nullable, ok bool) { 547 return r.types[i].Nullable, true 548 } 549 550 // Close closes the rows iterator. 551 func (r *driverRows) Close() error { 552 return r.reader.Close() 553 } 554 555 type rowCtx struct { 556 ctx rtree.RCtx 557 vs any 558 done chan int 559 err error 560 } 561 562 func (r *driverRows) start() { 563 r.rows = make(chan rowCtx) 564 r.row.ctx.Entry = -1 565 go func() { 566 defer close(r.rows) 567 err := r.reader.Read(func(ctx rtree.RCtx) error { 568 ectx := newExecCtx(r.conn, r.args) 569 vctx := make(map[any]any) 570 for i, v := range r.vars { 571 vctx[r.deps[i]] = reflect.Indirect(reflect.ValueOf(v)).Interface() 572 } 573 574 switch r.filter { 575 case nil: 576 // no filter 577 default: 578 ok, err := r.filter.eval(ectx, vctx) 579 if err != nil { 580 //log.Printf("filter.eval: ok=%#v err=%v", ok, err) 581 return err 582 } 583 if !ok.(bool) { 584 return nil 585 } 586 } 587 588 vs, err := r.eval.eval(ectx, vctx) 589 // log.Printf("row.eval: v=%#v, err=%v n=%d", vs, err, len(dest)) 590 if err != nil { 591 return fmt.Errorf("could not evaluate row values: %w", err) 592 } 593 594 evt := rowCtx{ 595 ctx: ctx, 596 vs: vs, 597 err: nil, 598 done: make(chan int), 599 } 600 601 r.rows <- evt 602 <-evt.done 603 return nil 604 }) 605 if err != nil { 606 r.rows <- rowCtx{err: err} 607 return 608 } 609 r.rows <- rowCtx{err: io.EOF} 610 }() 611 } 612 613 // Next is called to populate the next row of data into 614 // the provided slice. The provided slice will be the same 615 // size as the Columns() are wide. 616 // 617 // Next should return io.EOF when there are no more rows. 618 // 619 // The dest should not be written to outside of Next. Care 620 // should be taken when closing Rows not to modify 621 // a buffer held in dest. 622 func (r *driverRows) Next(dest []driver.Value) error { 623 if r.row.ctx.Entry >= 0 { 624 close(r.row.done) 625 } 626 627 row, ok := <-r.rows 628 r.row = row 629 if !ok { 630 return io.EOF 631 } 632 if row.err != nil { 633 switch { 634 case errors.Is(row.err, io.EOF): 635 return io.EOF 636 default: 637 return row.err 638 } 639 } 640 641 switch vs := row.vs.(type) { 642 case []any: 643 for i, v := range vs { 644 switch v := v.(type) { 645 case string: 646 dest[i] = []byte(v) 647 default: 648 dest[i] = v 649 } 650 } 651 case string: 652 dest[0] = []byte(vs) 653 default: 654 dest[0] = vs 655 } 656 657 return nil 658 } 659 660 type driverStmt struct { 661 conn *driverConn 662 stmt sqlparser.Statement 663 } 664 665 func (stmt *driverStmt) Close() error { 666 panic("not implemented") 667 } 668 669 func (stmt *driverStmt) NumInput() int { 670 panic("not implemented") 671 } 672 673 func (stmt *driverStmt) Exec(args []driver.Value) (driver.Result, error) { 674 panic("not implemented") 675 } 676 677 func (stmt *driverStmt) Query(args []driver.Value) (driver.Rows, error) { 678 panic("not implemented") 679 } 680 681 func newExprFrom(expr sqlparser.Expr, args []driver.NamedValue) (expression, error) { 682 switch expr := expr.(type) { 683 case *sqlparser.ComparisonExpr: 684 op := operatorFrom(expr.Operator) 685 if op == opInvalid { 686 return nil, fmt.Errorf("rsqldrv: invalid comparison operator %q", expr.Operator) 687 } 688 689 l, err := newExprFrom(expr.Left, args) 690 if err != nil { 691 return nil, err 692 } 693 r, err := newExprFrom(expr.Right, args) 694 if err != nil { 695 return nil, err 696 } 697 return newBinExpr(expr, op, l, r) 698 699 case *sqlparser.ParenExpr: 700 return newExprFrom(expr.Expr, args) 701 702 case *sqlparser.AndExpr: 703 l, err := newExprFrom(expr.Left, args) 704 if err != nil { 705 return nil, err 706 } 707 r, err := newExprFrom(expr.Right, args) 708 if err != nil { 709 return nil, err 710 } 711 return newBinExpr(expr, opAndAnd, l, r) 712 713 case *sqlparser.OrExpr: 714 l, err := newExprFrom(expr.Left, args) 715 if err != nil { 716 return nil, err 717 } 718 r, err := newExprFrom(expr.Right, args) 719 if err != nil { 720 return nil, err 721 } 722 return newBinExpr(expr, opOrOr, l, r) 723 724 case *sqlparser.ColName: 725 return &identExpr{ 726 expr: expr, 727 name: expr.Name.CompliantName(), 728 }, nil 729 730 case *sqlparser.SQLVal: 731 return newValueExpr(expr, args) 732 733 case sqlparser.BoolVal: 734 return &valueExpr{expr: expr, v: bool(expr)}, nil 735 736 case *sqlparser.BinaryExpr: 737 l, err := newExprFrom(expr.Left, args) 738 if err != nil { 739 return nil, err 740 } 741 r, err := newExprFrom(expr.Right, args) 742 if err != nil { 743 return nil, err 744 } 745 op := operatorFrom(expr.Operator) 746 if op == opInvalid { 747 return nil, fmt.Errorf("rsqldrv: invalid binary-expression operator %q", expr.Operator) 748 } 749 return newBinExpr(expr, op, l, r) 750 751 case sqlparser.ValTuple: 752 vs := make([]expression, len(expr)) 753 for i, e := range expr { 754 v, err := newExprFrom(e, args) 755 if err != nil { 756 return nil, err 757 } 758 vs[i] = v 759 } 760 return &tupleExpr{expr: expr, exprs: vs}, nil 761 } 762 return nil, fmt.Errorf("rsqldrv: invalid filter expression %#v %T", expr, expr) 763 } 764 765 var ( 766 _ driver.Driver = (*rootDriver)(nil) 767 _ driver.Conn = (*driverConn)(nil) 768 _ driver.ExecerContext = (*driverConn)(nil) 769 _ driver.QueryerContext = (*driverConn)(nil) 770 _ driver.Tx = (*driverConn)(nil) 771 772 _ driver.Result = (*driverResult)(nil) 773 _ driver.Rows = (*driverRows)(nil) 774 ) 775 776 var ( 777 _ driver.RowsColumnTypeLength = (*driverRows)(nil) 778 _ driver.RowsColumnTypeNullable = (*driverRows)(nil) 779 _ driver.RowsColumnTypeScanType = (*driverRows)(nil) 780 )