github.com/dshekhar95/sub_dgraph@v0.0.0-20230424164411-6be28e40bbf1/dgraph/cmd/migrate/utils.go (about) 1 /* 2 * Copyright 2022 Dgraph Labs, Inc. and Contributors 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package migrate 18 19 import ( 20 "bufio" 21 "database/sql" 22 "fmt" 23 "os" 24 "reflect" 25 "strings" 26 27 "github.com/go-sql-driver/mysql" 28 "github.com/pkg/errors" 29 30 "github.com/dgraph-io/dgraph/x" 31 ) 32 33 func getPool(host, port, user, password, db string) (*sql.DB, 34 error) { 35 return sql.Open("mysql", 36 fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?parseTime=true", user, password, host, port, db)) 37 } 38 39 // showTables will return a slice of table names using one of the following logic 40 // 1) if the parameter tables is not empty, this function will return a slice of table names 41 // by splitting the parameter with the separate comma 42 // 2) if the parameter is empty, this function will read all the tables under the given 43 // database and then return the result 44 func showTables(pool *sql.DB, tableNames string) ([]string, error) { 45 if len(tableNames) > 0 { 46 return strings.Split(tableNames, ","), nil 47 } 48 query := "show tables" 49 rows, err := pool.Query(query) 50 if err != nil { 51 return nil, err 52 } 53 defer rows.Close() 54 55 tables := make([]string, 0) 56 for rows.Next() { 57 var table string 58 if err := rows.Scan(&table); err != nil { 59 return nil, errors.Wrapf(err, "while scanning table name") 60 } 61 tables = append(tables, table) 62 } 63 64 return tables, nil 65 } 66 67 type criteriaFunc func(info *sqlTable, column string) bool 68 69 // getColumnIndices first sort the columns in the table alphabetically, and then 70 // returns the indices of the columns satisfying the criteria function 71 func getColumnIndices(info *sqlTable, criteria criteriaFunc) []*columnIdx { 72 indices := make([]*columnIdx, 0) 73 for i, column := range info.columnNames { 74 if criteria(info, column) { 75 indices = append(indices, &columnIdx{ 76 name: column, 77 index: i, 78 }) 79 } 80 } 81 return indices 82 } 83 84 type columnIdx struct { 85 name string // the column name 86 index int // the column index 87 } 88 89 func getFileWriter(filename string) (*bufio.Writer, func(), error) { 90 output, err := os.OpenFile(filename, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) 91 if err != nil { 92 return nil, nil, err 93 } 94 95 return bufio.NewWriter(output), func() { _ = output.Close() }, nil 96 } 97 98 func getColumnValues(columns []string, dataTypes []dataType, 99 rows *sql.Rows) ([]interface{}, error) { 100 // ptrToValues takes a slice of pointers, deference them, and return the values referenced 101 // by these pointers 102 ptrToValues := func(ptrs []interface{}) []interface{} { 103 values := make([]interface{}, 0, len(ptrs)) 104 for _, ptr := range ptrs { 105 // dereference the pointer to get the actual value 106 v := reflect.ValueOf(ptr).Elem().Interface() 107 values = append(values, v) 108 } 109 return values 110 } 111 112 valuePtrs := make([]interface{}, 0, len(columns)) 113 for i := 0; i < len(columns); i++ { 114 switch dataTypes[i] { 115 case stringType: 116 valuePtrs = append(valuePtrs, new([]byte)) // the value can be nil 117 case intType: 118 valuePtrs = append(valuePtrs, new(sql.NullInt64)) 119 case floatType: 120 valuePtrs = append(valuePtrs, new(sql.NullFloat64)) 121 case datetimeType: 122 valuePtrs = append(valuePtrs, new(mysql.NullTime)) 123 default: 124 x.Panic(errors.Errorf("detected unsupported type %s on column %s", 125 dataTypes[i], columns[i])) 126 } 127 } 128 if err := rows.Scan(valuePtrs...); err != nil { 129 return nil, errors.Wrapf(err, "while scanning column values") 130 } 131 colValues := ptrToValues(valuePtrs) 132 return colValues, nil 133 }