github.com/vescale/zgraph@v0.0.0-20230410094002-959c02d50f95/driver.go (about)

     1  // Copyright 2023 zGraph Authors. All rights reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package zgraph
    16  
    17  import (
    18  	"context"
    19  	"database/sql"
    20  	"database/sql/driver"
    21  	"errors"
    22  	"fmt"
    23  	"io"
    24  
    25  	"github.com/vescale/zgraph/datum"
    26  	"github.com/vescale/zgraph/session"
    27  	"github.com/vescale/zgraph/types"
    28  )
    29  
    30  const driverName = "zgraph"
    31  
    32  func init() {
    33  	sql.Register(driverName, &Driver{})
    34  }
    35  
    36  var (
    37  	_ driver.Driver           = &Driver{}
    38  	_ driver.DriverContext    = &Driver{}
    39  	_ driver.Connector        = &connector{}
    40  	_ io.Closer               = &connector{}
    41  	_ driver.Conn             = &conn{}
    42  	_ driver.Stmt             = &stmt{}
    43  	_ driver.StmtExecContext  = &stmt{}
    44  	_ driver.StmtQueryContext = &stmt{}
    45  	_ driver.Rows             = &rows{}
    46  )
    47  
    48  type Driver struct{}
    49  
    50  func (d *Driver) Open(_ string) (driver.Conn, error) {
    51  	return nil, errors.New("Driver.Open should not be called as Driver.OpenConnector is implemented")
    52  }
    53  
    54  func (d *Driver) OpenConnector(dsn string) (driver.Connector, error) {
    55  	db, err := Open(dsn, nil)
    56  	if err != nil {
    57  		return nil, err
    58  	}
    59  	return &connector{db: db}, nil
    60  }
    61  
    62  type connector struct {
    63  	db *DB
    64  }
    65  
    66  func (c *connector) Connect(_ context.Context) (driver.Conn, error) {
    67  	return &conn{session: c.db.NewSession()}, nil
    68  }
    69  
    70  func (c *connector) Driver() driver.Driver {
    71  	return &Driver{}
    72  }
    73  
    74  func (c *connector) Close() error {
    75  	return c.db.Close()
    76  }
    77  
    78  type conn struct {
    79  	session *session.Session
    80  }
    81  
    82  func (c *conn) Ping(_ context.Context) error {
    83  	return nil
    84  }
    85  
    86  func (c *conn) Prepare(query string) (driver.Stmt, error) {
    87  	return c.PrepareContext(context.Background(), query)
    88  }
    89  
    90  func (c *conn) PrepareContext(_ context.Context, query string) (driver.Stmt, error) {
    91  	return &stmt{session: c.session, query: query}, nil
    92  }
    93  
    94  func (c *conn) Close() error {
    95  	c.session.Close()
    96  	return nil
    97  }
    98  
    99  func (c *conn) Begin() (driver.Tx, error) {
   100  	return c.BeginTx(context.Background(), driver.TxOptions{})
   101  }
   102  
   103  func (c *conn) BeginTx(_ context.Context, _ driver.TxOptions) (driver.Tx, error) {
   104  	return nil, errors.New("transactions are not supported")
   105  }
   106  
   107  type stmt struct {
   108  	session *session.Session
   109  	query   string
   110  }
   111  
   112  func (s *stmt) Close() error {
   113  	return nil
   114  }
   115  
   116  func (s *stmt) NumInput() int {
   117  	return -1
   118  }
   119  
   120  func (s *stmt) Exec(args []driver.Value) (driver.Result, error) {
   121  	if len(args) > 0 {
   122  		return nil, fmt.Errorf("placeholder arguments not supported")
   123  	}
   124  	return s.ExecContext(context.Background(), nil)
   125  }
   126  
   127  func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
   128  	if len(args) > 0 {
   129  		return nil, fmt.Errorf("placeholder arguments not supported")
   130  	}
   131  	rs, err := s.session.Execute(ctx, s.query)
   132  	if err != nil {
   133  		return nil, err
   134  	}
   135  	if err := rs.Next(ctx); err != nil {
   136  		return nil, err
   137  	}
   138  	return driver.ResultNoRows, nil
   139  }
   140  
   141  func (s *stmt) Query(args []driver.Value) (driver.Rows, error) {
   142  	if len(args) > 0 {
   143  		return nil, fmt.Errorf("placeholder arguments not supported")
   144  	}
   145  	return s.QueryContext(context.Background(), nil)
   146  }
   147  
   148  func (s *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
   149  	if len(args) > 0 {
   150  		return nil, fmt.Errorf("placeholder arguments not supported")
   151  	}
   152  	rs, err := s.session.Execute(ctx, s.query)
   153  	if err != nil {
   154  		return nil, err
   155  	}
   156  	return &rows{ctx: ctx, rs: rs}, nil
   157  }
   158  
   159  type rows struct {
   160  	ctx context.Context
   161  	rs  session.ResultSet
   162  }
   163  
   164  func (r *rows) Columns() []string {
   165  	return r.rs.Columns()
   166  }
   167  
   168  func (r *rows) Close() error {
   169  	return r.rs.Close()
   170  }
   171  
   172  func (r *rows) Next(dest []driver.Value) error {
   173  	if err := r.rs.Next(r.ctx); err != nil {
   174  		return err
   175  	}
   176  	if !r.rs.Valid() {
   177  		return io.EOF
   178  	}
   179  	for i, d := range r.rs.Row() {
   180  		if d == datum.Null {
   181  			dest[i] = nil
   182  			continue
   183  		}
   184  		switch d.Type() {
   185  		case types.Bool:
   186  			dest[i] = datum.AsBool(d)
   187  		case types.Int:
   188  			dest[i] = datum.AsInt(d)
   189  		case types.Float:
   190  			dest[i] = datum.AsFloat(d)
   191  		default:
   192  			dest[i] = datum.AsString(d)
   193  		}
   194  	}
   195  	return nil
   196  }