github.com/dolthub/go-mysql-server@v0.18.0/enginetest/mysqlshim/connection.go (about) 1 // Copyright 2021 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package mysqlshim 16 17 import ( 18 "fmt" 19 "sort" 20 "strings" 21 22 _ "github.com/go-sql-driver/mysql" 23 "github.com/gocraft/dbr/v2" 24 25 "github.com/dolthub/go-mysql-server/sql" 26 ) 27 28 // MySQLShim is a shim for a local MySQL server. Ensure that a MySQL instance is running prior to using this shim. Note: 29 // this may be destructive to pre-existing data, as databases and tables will be created and destroyed. 30 type MySQLShim struct { 31 conn *dbr.Connection 32 databases map[string]string 33 } 34 35 var _ sql.MutableDatabaseProvider = (*MySQLShim)(nil) 36 37 // NewMySQLShim returns a new MySQLShim. 38 func NewMySQLShim(user string, password string, host string, port int) (*MySQLShim, error) { 39 conn, err := dbr.Open("mysql", fmt.Sprintf("%s:%s@tcp(%s:%d)/", user, password, host, port), nil) 40 if err != nil { 41 return nil, err 42 } 43 err = conn.Ping() 44 45 if err != nil { 46 return nil, err 47 } 48 return &MySQLShim{conn, make(map[string]string)}, nil 49 } 50 51 // Database implements the interface sql.MutableDatabaseProvider. 52 func (m *MySQLShim) Database(ctx *sql.Context, name string) (sql.Database, error) { 53 if dbName, ok := m.databases[strings.ToLower(name)]; ok { 54 return Database{m, dbName}, nil 55 } 56 return nil, sql.ErrDatabaseNotFound.New(name) 57 } 58 59 // HasDatabase implements the interface sql.MutableDatabaseProvider. 60 func (m *MySQLShim) HasDatabase(ctx *sql.Context, name string) bool { 61 _, ok := m.databases[strings.ToLower(name)] 62 return ok 63 } 64 65 // AllDatabases implements the interface sql.MutableDatabaseProvider. 66 func (m *MySQLShim) AllDatabases(*sql.Context) []sql.Database { 67 var dbStrings []string 68 for _, dbName := range m.databases { 69 dbStrings = append(dbStrings, dbName) 70 } 71 sort.Strings(dbStrings) 72 dbs := make([]sql.Database, len(dbStrings)) 73 for i, dbString := range dbStrings { 74 dbs[i] = Database{m, dbString} 75 } 76 return dbs 77 } 78 79 // CreateDatabase implements the interface sql.MutableDatabaseProvider. 80 func (m *MySQLShim) CreateDatabase(ctx *sql.Context, name string) error { 81 _, err := m.conn.Exec(fmt.Sprintf("CREATE DATABASE `%s` DEFAULT COLLATE %s;", name, sql.Collation_Default.String())) 82 if err != nil { 83 return err 84 } 85 m.databases[strings.ToLower(name)] = name 86 return nil 87 } 88 89 // DropDatabase implements the interface sql.MutableDatabaseProvider. 90 func (m *MySQLShim) DropDatabase(ctx *sql.Context, name string) error { 91 _, err := m.conn.Exec(fmt.Sprintf("DROP DATABASE `%s`;", name)) 92 if err != nil { 93 return err 94 } 95 delete(m.databases, strings.ToLower(name)) 96 return nil 97 } 98 99 // Close closes the shim. This will drop all databases created and accessed since this shim was created. 100 func (m *MySQLShim) Close() { 101 for dbName := range m.databases { 102 _, _ = m.conn.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS `%s`;", dbName)) 103 } 104 _ = m.conn.Close() 105 } 106 107 // Query queries the connection and return a row iterator. 108 func (m *MySQLShim) Query(db string, query string) (sql.RowIter, error) { 109 if len(db) > 0 { 110 _, err := m.conn.Exec(fmt.Sprintf("USE `%s`;", db)) 111 if err != nil { 112 return nil, err 113 } 114 } 115 rows, err := m.conn.Query(query) 116 if err != nil { 117 return nil, err 118 } 119 return newMySQLIter(rows), nil 120 } 121 122 // QueryRows queries the connection and returns the rows returned. 123 func (m *MySQLShim) QueryRows(db string, query string) ([]sql.Row, error) { 124 ctx := sql.NewEmptyContext() 125 if len(db) > 0 { 126 _, err := m.conn.Exec(fmt.Sprintf("USE `%s`;", db)) 127 if err != nil { 128 return nil, err 129 } 130 } 131 rows, err := m.conn.Query(query) 132 if err != nil { 133 return nil, err 134 } 135 iter := newMySQLIter(rows) 136 defer iter.Close(ctx) 137 allRows, err := sql.RowIterToRows(ctx, iter) 138 if err != nil { 139 return nil, err 140 } 141 return allRows, nil 142 } 143 144 // Exec executes the query on the connection. 145 func (m *MySQLShim) Exec(db string, query string) error { 146 if len(db) > 0 { 147 _, err := m.conn.Exec(fmt.Sprintf("USE `%s`;", db)) 148 if err != nil { 149 return err 150 } 151 } 152 _, err := m.conn.Exec(query) 153 return err 154 }