github.com/1aal/kubeblocks@v0.0.0-20231107070852-e1c03e598921/pkg/testutil/k8s/tunnel_util.go (about) 1 /* 2 Copyright (C) 2022-2023 ApeCloud Co., Ltd 3 4 This file is part of KubeBlocks project 5 6 This program is free software: you can redistribute it and/or modify 7 it under the terms of the GNU Affero General Public License as published by 8 the Free Software Foundation, either version 3 of the License, or 9 (at your option) any later version. 10 11 This program is distributed in the hope that it will be useful 12 but WITHOUT ANY WARRANTY; without even the implied warranty of 13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 GNU Affero General Public License for more details. 15 16 You should have received a copy of the GNU Affero General Public License 17 along with this program. If not, see <http://www.gnu.org/licenses/>. 18 */ 19 20 package testutil 21 22 import ( 23 "context" 24 "database/sql" 25 "database/sql/driver" 26 "fmt" 27 "net" 28 "os/exec" 29 "reflect" 30 "strconv" 31 "strings" 32 "time" 33 34 "github.com/go-sql-driver/mysql" 35 "github.com/pkg/errors" 36 corev1 "k8s.io/api/core/v1" 37 ) 38 39 const ( 40 // configurations to connect to Mysql, either a data source name represent by URL. 41 connectionURLKey = "url" 42 43 // other general settings for DB connections. 44 maxIdleConnsKey = "maxIdleConns" 45 maxOpenConnsKey = "maxOpenConns" 46 connMaxLifetimeKey = "connMaxLifetime" 47 connMaxIdleTimeKey = "connMaxIdleTime" 48 ) 49 50 // Mysql represents MySQL output bindings. 51 type Mysql struct { 52 db *sql.DB 53 } 54 55 // Init initializes the MySQL binding. 56 func (m *Mysql) Init(metadata map[string]string) error { 57 p := metadata 58 url, ok := p[connectionURLKey] 59 if !ok || url == "" { 60 return fmt.Errorf("missing MySql connection string") 61 } 62 63 db, err := initDB(url) 64 if err != nil { 65 return err 66 } 67 68 err = propertyToInt(p, maxIdleConnsKey, db.SetMaxIdleConns) 69 if err != nil { 70 return err 71 } 72 73 err = propertyToInt(p, maxOpenConnsKey, db.SetMaxOpenConns) 74 if err != nil { 75 return err 76 } 77 78 err = propertyToDuration(p, connMaxIdleTimeKey, db.SetConnMaxIdleTime) 79 if err != nil { 80 return err 81 } 82 83 err = propertyToDuration(p, connMaxLifetimeKey, db.SetConnMaxLifetime) 84 if err != nil { 85 return err 86 } 87 88 err = db.Ping() 89 if err != nil { 90 return errors.Wrap(err, "unable to ping the DB") 91 } 92 93 m.db = db 94 95 return nil 96 } 97 98 // Close closes the DB. 99 func (m *Mysql) Close() error { 100 if m.db != nil { 101 return m.db.Close() 102 } 103 104 return nil 105 } 106 107 func (m *Mysql) query(ctx context.Context, sql string) ([]interface{}, error) { 108 rows, err := m.db.QueryContext(ctx, sql) 109 if err != nil { 110 return nil, errors.Wrapf(err, "error executing %s", sql) 111 } 112 113 defer func() { 114 _ = rows.Close() 115 _ = rows.Err() 116 }() 117 118 result, err := m.jsonify(rows) 119 if err != nil { 120 return nil, errors.Wrapf(err, "error marshalling query result for %s", sql) 121 } 122 123 return result, nil 124 } 125 126 func propertyToInt(props map[string]string, key string, setter func(int)) error { 127 if v, ok := props[key]; ok { 128 if i, err := strconv.Atoi(v); err == nil { 129 setter(i) 130 } else { 131 return errors.Wrapf(err, "error converitng %s:%s to int", key, v) 132 } 133 } 134 135 return nil 136 } 137 138 func propertyToDuration(props map[string]string, key string, setter func(time.Duration)) error { 139 if v, ok := props[key]; ok { 140 if d, err := time.ParseDuration(v); err == nil { 141 setter(d) 142 } else { 143 return errors.Wrapf(err, "error converitng %s:%s to time duration", key, v) 144 } 145 } 146 147 return nil 148 } 149 150 func initDB(url string) (*sql.DB, error) { 151 if _, err := mysql.ParseDSN(url); err != nil { 152 return nil, errors.Wrapf(err, "illegal Data Source Name (DNS) specified by %s", connectionURLKey) 153 } 154 155 db, err := sql.Open("mysql", url) 156 if err != nil { 157 return nil, errors.Wrap(err, "error opening DB connection") 158 } 159 160 return db, nil 161 } 162 163 func (m *Mysql) jsonify(rows *sql.Rows) ([]interface{}, error) { 164 columnTypes, err := rows.ColumnTypes() 165 if err != nil { 166 return nil, err 167 } 168 169 var ret []interface{} 170 for rows.Next() { 171 values := prepareValues(columnTypes) 172 err := rows.Scan(values...) 173 if err != nil { 174 return nil, err 175 } 176 177 r := m.convert(columnTypes, values) 178 ret = append(ret, r) 179 } 180 181 return ret, nil 182 } 183 184 func prepareValues(columnTypes []*sql.ColumnType) []interface{} { 185 types := make([]reflect.Type, len(columnTypes)) 186 for i, tp := range columnTypes { 187 types[i] = tp.ScanType() 188 } 189 190 values := make([]interface{}, len(columnTypes)) 191 for i := range values { 192 values[i] = reflect.New(types[i]).Interface() 193 } 194 195 return values 196 } 197 198 func (m *Mysql) convert(columnTypes []*sql.ColumnType, values []interface{}) map[string]interface{} { 199 r := map[string]interface{}{} 200 201 for i, ct := range columnTypes { 202 value := values[i] 203 204 switch v := values[i].(type) { 205 case driver.Valuer: 206 if vv, err := v.Value(); err == nil { 207 value = interface{}(vv) 208 } 209 case *sql.RawBytes: 210 // special case for sql.RawBytes, see https://github.com/go-sql-driver/mysql/blob/master/fields.go#L178 211 switch ct.DatabaseTypeName() { 212 case "VARCHAR", "CHAR", "TEXT", "LONGTEXT": 213 value = string(*v) 214 } 215 } 216 217 if value != nil { 218 r[ct.Name()] = value 219 } 220 } 221 222 return r 223 } 224 225 func (m *Mysql) GetRole(ctx context.Context) (string, error) { 226 const query = "select role from information_schema.wesql_cluster_local" 227 result, err := m.query(ctx, query) 228 if err != nil { 229 return "", err 230 } 231 if len(result) != 1 { 232 return "", errors.New("only one role should be observed") 233 } 234 row, ok := result[0].(map[string]interface{}) 235 if !ok { 236 return "", errors.New("query result wrong type") 237 } 238 role, ok := row["role"].(string) 239 if !ok { 240 return "", errors.New("role parsing error") 241 } 242 if len(role) == 0 { 243 return "", errors.New("got empty role") 244 } 245 role = strings.ToLower(role) 246 return role, nil 247 } 248 249 type Tunnel struct { 250 kind string 251 name string 252 ip string 253 port int32 254 } 255 256 func OpenTunnel(svc *corev1.Service) (*Tunnel, error) { 257 ip := getLocalIP() 258 t := &Tunnel{ 259 kind: "svc", 260 name: svc.Name, 261 ip: ip, 262 port: svc.Spec.Ports[0].Port, 263 } 264 err := t.startPortForward() 265 return t, err 266 } 267 268 func (t *Tunnel) Close() error { 269 return t.stopPortForward() 270 } 271 272 func (t *Tunnel) startPortForward() error { 273 portStr := strconv.Itoa(int(t.port)) 274 cmd := exec.Command("bash", "-c", "kubectl port-forward "+t.kind+"/"+t.name+" --address 0.0.0.0 "+portStr+":"+portStr+" &") 275 return cmd.Start() 276 } 277 278 func (t *Tunnel) stopPortForward() error { 279 cmd := exec.Command("bash", "-c", "ps aux | grep port-forward | grep -v grep | grep "+t.name+" | awk '{print $2}' | xargs kill -9") 280 return cmd.Run() 281 } 282 283 func (t *Tunnel) GetMySQLConn() (*Mysql, error) { 284 db := &Mysql{} 285 url := "root@tcp(" + t.ip + ":" + strconv.Itoa(int(t.port)) + ")/information_schema?allowNativePasswords=true" 286 params := map[string]string{connectionURLKey: url} 287 if err := db.Init(params); err != nil { 288 return db, err 289 } 290 return db, nil 291 } 292 293 func getLocalIP() string { 294 addrs, err := net.InterfaceAddrs() 295 if err != nil { 296 return "" 297 } 298 for _, address := range addrs { 299 // check the address type and if it is not a loopback the display it 300 if ipnet, ok := address.(*net.IPNet); ok && !ipnet.IP.IsLoopback() { 301 if ipnet.IP.To4() != nil { 302 return ipnet.IP.String() 303 } 304 } 305 } 306 return "" 307 }