github.com/1aal/kubeblocks@v0.0.0-20231107070852-e1c03e598921/pkg/testutil/k8s/tunnel_util.go (about)

     1  /*
     2  Copyright (C) 2022-2023 ApeCloud Co., Ltd
     3  
     4  This file is part of KubeBlocks project
     5  
     6  This program is free software: you can redistribute it and/or modify
     7  it under the terms of the GNU Affero General Public License as published by
     8  the Free Software Foundation, either version 3 of the License, or
     9  (at your option) any later version.
    10  
    11  This program is distributed in the hope that it will be useful
    12  but WITHOUT ANY WARRANTY; without even the implied warranty of
    13  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    14  GNU Affero General Public License for more details.
    15  
    16  You should have received a copy of the GNU Affero General Public License
    17  along with this program.  If not, see <http://www.gnu.org/licenses/>.
    18  */
    19  
    20  package testutil
    21  
    22  import (
    23  	"context"
    24  	"database/sql"
    25  	"database/sql/driver"
    26  	"fmt"
    27  	"net"
    28  	"os/exec"
    29  	"reflect"
    30  	"strconv"
    31  	"strings"
    32  	"time"
    33  
    34  	"github.com/go-sql-driver/mysql"
    35  	"github.com/pkg/errors"
    36  	corev1 "k8s.io/api/core/v1"
    37  )
    38  
    39  const (
    40  	// configurations to connect to Mysql, either a data source name represent by URL.
    41  	connectionURLKey = "url"
    42  
    43  	// other general settings for DB connections.
    44  	maxIdleConnsKey    = "maxIdleConns"
    45  	maxOpenConnsKey    = "maxOpenConns"
    46  	connMaxLifetimeKey = "connMaxLifetime"
    47  	connMaxIdleTimeKey = "connMaxIdleTime"
    48  )
    49  
    50  // Mysql represents MySQL output bindings.
    51  type Mysql struct {
    52  	db *sql.DB
    53  }
    54  
    55  // Init initializes the MySQL binding.
    56  func (m *Mysql) Init(metadata map[string]string) error {
    57  	p := metadata
    58  	url, ok := p[connectionURLKey]
    59  	if !ok || url == "" {
    60  		return fmt.Errorf("missing MySql connection string")
    61  	}
    62  
    63  	db, err := initDB(url)
    64  	if err != nil {
    65  		return err
    66  	}
    67  
    68  	err = propertyToInt(p, maxIdleConnsKey, db.SetMaxIdleConns)
    69  	if err != nil {
    70  		return err
    71  	}
    72  
    73  	err = propertyToInt(p, maxOpenConnsKey, db.SetMaxOpenConns)
    74  	if err != nil {
    75  		return err
    76  	}
    77  
    78  	err = propertyToDuration(p, connMaxIdleTimeKey, db.SetConnMaxIdleTime)
    79  	if err != nil {
    80  		return err
    81  	}
    82  
    83  	err = propertyToDuration(p, connMaxLifetimeKey, db.SetConnMaxLifetime)
    84  	if err != nil {
    85  		return err
    86  	}
    87  
    88  	err = db.Ping()
    89  	if err != nil {
    90  		return errors.Wrap(err, "unable to ping the DB")
    91  	}
    92  
    93  	m.db = db
    94  
    95  	return nil
    96  }
    97  
    98  // Close closes the DB.
    99  func (m *Mysql) Close() error {
   100  	if m.db != nil {
   101  		return m.db.Close()
   102  	}
   103  
   104  	return nil
   105  }
   106  
   107  func (m *Mysql) query(ctx context.Context, sql string) ([]interface{}, error) {
   108  	rows, err := m.db.QueryContext(ctx, sql)
   109  	if err != nil {
   110  		return nil, errors.Wrapf(err, "error executing %s", sql)
   111  	}
   112  
   113  	defer func() {
   114  		_ = rows.Close()
   115  		_ = rows.Err()
   116  	}()
   117  
   118  	result, err := m.jsonify(rows)
   119  	if err != nil {
   120  		return nil, errors.Wrapf(err, "error marshalling query result for %s", sql)
   121  	}
   122  
   123  	return result, nil
   124  }
   125  
   126  func propertyToInt(props map[string]string, key string, setter func(int)) error {
   127  	if v, ok := props[key]; ok {
   128  		if i, err := strconv.Atoi(v); err == nil {
   129  			setter(i)
   130  		} else {
   131  			return errors.Wrapf(err, "error converitng %s:%s to int", key, v)
   132  		}
   133  	}
   134  
   135  	return nil
   136  }
   137  
   138  func propertyToDuration(props map[string]string, key string, setter func(time.Duration)) error {
   139  	if v, ok := props[key]; ok {
   140  		if d, err := time.ParseDuration(v); err == nil {
   141  			setter(d)
   142  		} else {
   143  			return errors.Wrapf(err, "error converitng %s:%s to time duration", key, v)
   144  		}
   145  	}
   146  
   147  	return nil
   148  }
   149  
   150  func initDB(url string) (*sql.DB, error) {
   151  	if _, err := mysql.ParseDSN(url); err != nil {
   152  		return nil, errors.Wrapf(err, "illegal Data Source Name (DNS) specified by %s", connectionURLKey)
   153  	}
   154  
   155  	db, err := sql.Open("mysql", url)
   156  	if err != nil {
   157  		return nil, errors.Wrap(err, "error opening DB connection")
   158  	}
   159  
   160  	return db, nil
   161  }
   162  
   163  func (m *Mysql) jsonify(rows *sql.Rows) ([]interface{}, error) {
   164  	columnTypes, err := rows.ColumnTypes()
   165  	if err != nil {
   166  		return nil, err
   167  	}
   168  
   169  	var ret []interface{}
   170  	for rows.Next() {
   171  		values := prepareValues(columnTypes)
   172  		err := rows.Scan(values...)
   173  		if err != nil {
   174  			return nil, err
   175  		}
   176  
   177  		r := m.convert(columnTypes, values)
   178  		ret = append(ret, r)
   179  	}
   180  
   181  	return ret, nil
   182  }
   183  
   184  func prepareValues(columnTypes []*sql.ColumnType) []interface{} {
   185  	types := make([]reflect.Type, len(columnTypes))
   186  	for i, tp := range columnTypes {
   187  		types[i] = tp.ScanType()
   188  	}
   189  
   190  	values := make([]interface{}, len(columnTypes))
   191  	for i := range values {
   192  		values[i] = reflect.New(types[i]).Interface()
   193  	}
   194  
   195  	return values
   196  }
   197  
   198  func (m *Mysql) convert(columnTypes []*sql.ColumnType, values []interface{}) map[string]interface{} {
   199  	r := map[string]interface{}{}
   200  
   201  	for i, ct := range columnTypes {
   202  		value := values[i]
   203  
   204  		switch v := values[i].(type) {
   205  		case driver.Valuer:
   206  			if vv, err := v.Value(); err == nil {
   207  				value = interface{}(vv)
   208  			}
   209  		case *sql.RawBytes:
   210  			// special case for sql.RawBytes, see https://github.com/go-sql-driver/mysql/blob/master/fields.go#L178
   211  			switch ct.DatabaseTypeName() {
   212  			case "VARCHAR", "CHAR", "TEXT", "LONGTEXT":
   213  				value = string(*v)
   214  			}
   215  		}
   216  
   217  		if value != nil {
   218  			r[ct.Name()] = value
   219  		}
   220  	}
   221  
   222  	return r
   223  }
   224  
   225  func (m *Mysql) GetRole(ctx context.Context) (string, error) {
   226  	const query = "select role from information_schema.wesql_cluster_local"
   227  	result, err := m.query(ctx, query)
   228  	if err != nil {
   229  		return "", err
   230  	}
   231  	if len(result) != 1 {
   232  		return "", errors.New("only one role should be observed")
   233  	}
   234  	row, ok := result[0].(map[string]interface{})
   235  	if !ok {
   236  		return "", errors.New("query result wrong type")
   237  	}
   238  	role, ok := row["role"].(string)
   239  	if !ok {
   240  		return "", errors.New("role parsing error")
   241  	}
   242  	if len(role) == 0 {
   243  		return "", errors.New("got empty role")
   244  	}
   245  	role = strings.ToLower(role)
   246  	return role, nil
   247  }
   248  
   249  type Tunnel struct {
   250  	kind string
   251  	name string
   252  	ip   string
   253  	port int32
   254  }
   255  
   256  func OpenTunnel(svc *corev1.Service) (*Tunnel, error) {
   257  	ip := getLocalIP()
   258  	t := &Tunnel{
   259  		kind: "svc",
   260  		name: svc.Name,
   261  		ip:   ip,
   262  		port: svc.Spec.Ports[0].Port,
   263  	}
   264  	err := t.startPortForward()
   265  	return t, err
   266  }
   267  
   268  func (t *Tunnel) Close() error {
   269  	return t.stopPortForward()
   270  }
   271  
   272  func (t *Tunnel) startPortForward() error {
   273  	portStr := strconv.Itoa(int(t.port))
   274  	cmd := exec.Command("bash", "-c", "kubectl port-forward "+t.kind+"/"+t.name+" --address 0.0.0.0 "+portStr+":"+portStr+" &")
   275  	return cmd.Start()
   276  }
   277  
   278  func (t *Tunnel) stopPortForward() error {
   279  	cmd := exec.Command("bash", "-c", "ps aux | grep port-forward | grep -v grep | grep "+t.name+" | awk '{print $2}' | xargs kill -9")
   280  	return cmd.Run()
   281  }
   282  
   283  func (t *Tunnel) GetMySQLConn() (*Mysql, error) {
   284  	db := &Mysql{}
   285  	url := "root@tcp(" + t.ip + ":" + strconv.Itoa(int(t.port)) + ")/information_schema?allowNativePasswords=true"
   286  	params := map[string]string{connectionURLKey: url}
   287  	if err := db.Init(params); err != nil {
   288  		return db, err
   289  	}
   290  	return db, nil
   291  }
   292  
   293  func getLocalIP() string {
   294  	addrs, err := net.InterfaceAddrs()
   295  	if err != nil {
   296  		return ""
   297  	}
   298  	for _, address := range addrs {
   299  		// check the address type and if it is not a loopback the display it
   300  		if ipnet, ok := address.(*net.IPNet); ok && !ipnet.IP.IsLoopback() {
   301  			if ipnet.IP.To4() != nil {
   302  				return ipnet.IP.String()
   303  			}
   304  		}
   305  	}
   306  	return ""
   307  }