github.com/bingoohuang/gg@v0.0.0-20240325092523-45da7dee9335/pkg/sqx/mysqlds.go (about)

     1  package sqx
     2  
     3  import (
     4  	"fmt"
     5  	"net"
     6  	"strings"
     7  
     8  	"github.com/bingoohuang/gg/pkg/shellwords"
     9  	"github.com/bingoohuang/gg/pkg/ss"
    10  	"github.com/sirupsen/logrus"
    11  	"github.com/spf13/pflag"
    12  )
    13  
    14  // CompatibleMySQLDs make mysql datasource be compatible with raw, mysql or gossh host format.
    15  func CompatibleMySQLDs(s string) string {
    16  	// user:pass@tcp(localhost:3306)/sdb?charset=utf8mb4&parseTime=true&loc=Local
    17  	if strings.Contains(s, "@tcp") {
    18  		return s
    19  	}
    20  
    21  	// user:pass@localhost:3306/dbname
    22  	// https://github.com/xo/dburl
    23  	if strings.Contains(s, ":") && strings.Contains(s, "@") {
    24  		if v, ok := compatibleDBURL(s); ok {
    25  			return v
    26  		}
    27  	}
    28  
    29  	// MYSQL_PWD=8BE4 mysql -h 127.0.0.1 -P 9633 -u root
    30  	// -u, --user=name     User for login if not current user.
    31  	if strings.Contains(s, " -u") || strings.Contains(s, " --user") {
    32  		return compatibleMySQLClientCmd(s)
    33  	}
    34  
    35  	// 127.0.0.1:9633 root/8BE4 [sdb=sdb]
    36  	if strings.Contains(s, ":") || strings.Contains(s, "/") {
    37  		if v, ok := compatibleGoSSHHost(s); ok {
    38  			return v
    39  		}
    40  	}
    41  
    42  	return s
    43  }
    44  
    45  func compatibleDBURL(s string) (string, bool) {
    46  	// user:pass@localhost/dbname
    47  	// betaapiadmin:xx@123.206.185.162:3306/metrics_ump
    48  	atPos := strings.LastIndex(s, "@")
    49  	up := s[:atPos]
    50  	user, password := ss.Split2(up, ss.WithSeps(":"))
    51  
    52  	db := ""
    53  	right := s[atPos+1:]
    54  	slashPos := strings.Index(right, "/")
    55  
    56  	if slashPos > 0 {
    57  		db = right[slashPos+1:]
    58  		right = right[:slashPos]
    59  	}
    60  
    61  	askPos := strings.Index(db, "?")
    62  	if askPos > 0 {
    63  		db = db[:askPos]
    64  	}
    65  
    66  	host, port := parseHostPort(right, "3306")
    67  
    68  	if IsIPv6(host) {
    69  		host = "[" + host + "]"
    70  	}
    71  
    72  	return fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=true&loc=Local",
    73  		user, password, host, port, db), true
    74  }
    75  
    76  // IsIPv6 tests if the str is an IPv6 format.
    77  func IsIPv6(str string) bool {
    78  	ip := net.ParseIP(str)
    79  	return ip != nil && strings.Contains(str, ":")
    80  }
    81  
    82  func compatibleGoSSHHost(s string) (string, bool) {
    83  	// 127.0.0.1:9633 root/8BE4 [sdb=sdb]
    84  	fields := ss.FieldsX(s, "", "", 3)
    85  	if len(fields) < 2 { // nolint:gomnd
    86  		return "", false
    87  	}
    88  
    89  	host, port := parseHostPort(fields[0], "3306")
    90  	user, password := ss.Split2(fields[1], ss.WithSeps("/"))
    91  	props := parseProps(fields)
    92  	db := ""
    93  
    94  	if v, ok := props["db"]; ok {
    95  		db = v
    96  	}
    97  
    98  	if IsIPv6(host) {
    99  		host = "[" + host + "]"
   100  	}
   101  
   102  	return fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=true&loc=Local",
   103  		user, password, host, port, db), true
   104  }
   105  
   106  func parseProps(fields []string) map[string]string {
   107  	props := make(map[string]string)
   108  
   109  	for i := 2; i < len(fields); i++ {
   110  		k, v := ss.Split2(fields[i], ss.WithSeps("="))
   111  		props[k] = v
   112  	}
   113  
   114  	return props
   115  }
   116  
   117  func parseHostPort(addr, defaultPort string) (string, string) {
   118  	pos := strings.LastIndex(addr, ":")
   119  	if pos < 0 {
   120  		return addr, defaultPort
   121  	}
   122  
   123  	return addr[0:pos], addr[pos+1:]
   124  }
   125  
   126  func compatibleMySQLClientCmd(s string) string {
   127  	if pos := strings.Index(s, "MYSQL_PWD="); pos >= 0 {
   128  		s = s[0:pos] + "--" + s[pos:]
   129  	}
   130  
   131  	if pos := strings.Index(s, "mysql "); pos >= 0 {
   132  		s = s[0:pos] + "--" + s[pos:]
   133  	}
   134  
   135  	pf := pflag.NewFlagSet("ds", pflag.ExitOnError)
   136  
   137  	pf.BoolP("mysql", "", false, "mysql command")
   138  	pf.StringP("MYSQL_PWD", "", "", "MYSQL_PWD env password")
   139  	pf.StringP("database", "D", "", "Schema to use")
   140  	pf.StringP("host", "h", "", "Connect to host")
   141  	pf.IntP("port", "P", 3306, "Port number to use")
   142  	pf.StringP("user", "u", "", "User for login if not current user")
   143  	pf.StringP("password", "p", "", "Password to use when connecting to serve")
   144  
   145  	p := shellwords.NewParser()
   146  	p.ParseEnv = true
   147  	args, err := p.Parse(s)
   148  	if err != nil {
   149  		logrus.Fatalf("Fail to parse ds %s error %v", s, err)
   150  	}
   151  
   152  	if err := pf.Parse(args); err != nil {
   153  		return s
   154  	}
   155  
   156  	host, _ := pf.GetString("host")
   157  	port, _ := pf.GetInt("port")
   158  	user, _ := pf.GetString("user")
   159  	db, _ := pf.GetString("database")
   160  	password, _ := pf.GetString("password")
   161  
   162  	if password == "" {
   163  		password, _ = pf.GetString("MYSQL_PWD")
   164  	}
   165  
   166  	if IsIPv6(host) {
   167  		host = "[" + host + "]"
   168  	}
   169  
   170  	return fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=true&loc=Local",
   171  		user, password, host, port, db)
   172  }