github.com/dshekhar95/sub_dgraph@v0.0.0-20230424164411-6be28e40bbf1/dgraph/cmd/migrate/utils.go (about)

     1  /*
     2  * Copyright 2022 Dgraph Labs, Inc. and Contributors
     3  *
     4  * Licensed under the Apache License, Version 2.0 (the "License");
     5  * you may not use this file except in compliance with the License.
     6  * You may obtain a copy of the License at
     7  *
     8  *     http://www.apache.org/licenses/LICENSE-2.0
     9  *
    10  * Unless required by applicable law or agreed to in writing, software
    11  * distributed under the License is distributed on an "AS IS" BASIS,
    12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  * See the License for the specific language governing permissions and
    14  * limitations under the License.
    15   */
    16  
    17  package migrate
    18  
    19  import (
    20  	"bufio"
    21  	"database/sql"
    22  	"fmt"
    23  	"os"
    24  	"reflect"
    25  	"strings"
    26  
    27  	"github.com/go-sql-driver/mysql"
    28  	"github.com/pkg/errors"
    29  
    30  	"github.com/dgraph-io/dgraph/x"
    31  )
    32  
    33  func getPool(host, port, user, password, db string) (*sql.DB,
    34  	error) {
    35  	return sql.Open("mysql",
    36  		fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?parseTime=true", user, password, host, port, db))
    37  }
    38  
    39  // showTables will return a slice of table names using one of the following logic
    40  // 1) if the parameter tables is not empty, this function will return a slice of table names
    41  // by splitting the parameter with the separate comma
    42  // 2) if the parameter is empty, this function will read all the tables under the given
    43  // database and then return the result
    44  func showTables(pool *sql.DB, tableNames string) ([]string, error) {
    45  	if len(tableNames) > 0 {
    46  		return strings.Split(tableNames, ","), nil
    47  	}
    48  	query := "show tables"
    49  	rows, err := pool.Query(query)
    50  	if err != nil {
    51  		return nil, err
    52  	}
    53  	defer rows.Close()
    54  
    55  	tables := make([]string, 0)
    56  	for rows.Next() {
    57  		var table string
    58  		if err := rows.Scan(&table); err != nil {
    59  			return nil, errors.Wrapf(err, "while scanning table name")
    60  		}
    61  		tables = append(tables, table)
    62  	}
    63  
    64  	return tables, nil
    65  }
    66  
    67  type criteriaFunc func(info *sqlTable, column string) bool
    68  
    69  // getColumnIndices first sort the columns in the table alphabetically, and then
    70  // returns the indices of the columns satisfying the criteria function
    71  func getColumnIndices(info *sqlTable, criteria criteriaFunc) []*columnIdx {
    72  	indices := make([]*columnIdx, 0)
    73  	for i, column := range info.columnNames {
    74  		if criteria(info, column) {
    75  			indices = append(indices, &columnIdx{
    76  				name:  column,
    77  				index: i,
    78  			})
    79  		}
    80  	}
    81  	return indices
    82  }
    83  
    84  type columnIdx struct {
    85  	name  string // the column name
    86  	index int    // the column index
    87  }
    88  
    89  func getFileWriter(filename string) (*bufio.Writer, func(), error) {
    90  	output, err := os.OpenFile(filename, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
    91  	if err != nil {
    92  		return nil, nil, err
    93  	}
    94  
    95  	return bufio.NewWriter(output), func() { _ = output.Close() }, nil
    96  }
    97  
    98  func getColumnValues(columns []string, dataTypes []dataType,
    99  	rows *sql.Rows) ([]interface{}, error) {
   100  	// ptrToValues takes a slice of pointers, deference them, and return the values referenced
   101  	// by these pointers
   102  	ptrToValues := func(ptrs []interface{}) []interface{} {
   103  		values := make([]interface{}, 0, len(ptrs))
   104  		for _, ptr := range ptrs {
   105  			// dereference the pointer to get the actual value
   106  			v := reflect.ValueOf(ptr).Elem().Interface()
   107  			values = append(values, v)
   108  		}
   109  		return values
   110  	}
   111  
   112  	valuePtrs := make([]interface{}, 0, len(columns))
   113  	for i := 0; i < len(columns); i++ {
   114  		switch dataTypes[i] {
   115  		case stringType:
   116  			valuePtrs = append(valuePtrs, new([]byte)) // the value can be nil
   117  		case intType:
   118  			valuePtrs = append(valuePtrs, new(sql.NullInt64))
   119  		case floatType:
   120  			valuePtrs = append(valuePtrs, new(sql.NullFloat64))
   121  		case datetimeType:
   122  			valuePtrs = append(valuePtrs, new(mysql.NullTime))
   123  		default:
   124  			x.Panic(errors.Errorf("detected unsupported type %s on column %s",
   125  				dataTypes[i], columns[i]))
   126  		}
   127  	}
   128  	if err := rows.Scan(valuePtrs...); err != nil {
   129  		return nil, errors.Wrapf(err, "while scanning column values")
   130  	}
   131  	colValues := ptrToValues(valuePtrs)
   132  	return colValues, nil
   133  }