github.com/1aal/kubeblocks@v0.0.0-20231107070852-e1c03e598921/pkg/lorry/engines/postgres/query.go (about)

     1  /*
     2  Copyright (C) 2022-2023 ApeCloud Co., Ltd
     3  
     4  This file is part of KubeBlocks project
     5  
     6  This program is free software: you can redistribute it and/or modify
     7  it under the terms of the GNU Affero General Public License as published by
     8  the Free Software Foundation, either version 3 of the License, or
     9  (at your option) any later version.
    10  
    11  This program is distributed in the hope that it will be useful
    12  but WITHOUT ANY WARRANTY; without even the implied warranty of
    13  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    14  GNU Affero General Public License for more details.
    15  
    16  You should have received a copy of the GNU Affero General Public License
    17  along with this program.  If not, see <http://www.gnu.org/licenses/>.
    18  */
    19  
    20  package postgres
    21  
    22  import (
    23  	"context"
    24  	"encoding/json"
    25  	"fmt"
    26  
    27  	"github.com/jackc/pgx/v5"
    28  	"github.com/jackc/pgx/v5/pgconn"
    29  	"github.com/pkg/errors"
    30  	"github.com/spf13/cast"
    31  
    32  	"github.com/1aal/kubeblocks/pkg/lorry/dcs"
    33  )
    34  
    35  // Query is equivalent to QueryWithHost(ctx, sql, ""), query itself.
    36  func (mgr *Manager) Query(ctx context.Context, sql string) (result []byte, err error) {
    37  	return mgr.QueryWithHost(ctx, sql, "")
    38  }
    39  
    40  func (mgr *Manager) QueryWithHost(ctx context.Context, sql string, host string) (result []byte, err error) {
    41  	var rows pgx.Rows
    42  	// when host is empty, use manager's connection pool
    43  	if host == "" {
    44  		rows, err = mgr.Pool.Query(ctx, sql)
    45  	} else {
    46  		rows, err = mgr.QueryOthers(ctx, sql, host)
    47  	}
    48  	if err != nil {
    49  		mgr.Logger.Error(err, fmt.Sprintf("query sql:%s failed", sql))
    50  		return nil, err
    51  	}
    52  	defer func() {
    53  		rows.Close()
    54  		_ = rows.Err()
    55  	}()
    56  
    57  	result, err = parseRows(rows)
    58  	if err != nil {
    59  		mgr.Logger.Error(err, fmt.Sprintf("parse query:%s failed", sql))
    60  		return nil, err
    61  	}
    62  
    63  	return result, nil
    64  }
    65  
    66  func (mgr *Manager) QueryOthers(ctx context.Context, sql string, host string) (rows pgx.Rows, err error) {
    67  	conn, err := pgx.Connect(ctx, config.GetConnectURLWithHost(host))
    68  	if err != nil {
    69  		mgr.Logger.Error(err, fmt.Sprintf("get host:%s connection failed", host))
    70  		return nil, err
    71  	}
    72  	defer func() {
    73  		_ = conn.Close(ctx)
    74  	}()
    75  
    76  	return conn.Query(ctx, sql)
    77  }
    78  
    79  func (mgr *Manager) QueryLeader(ctx context.Context, sql string, cluster *dcs.Cluster) (result []byte, err error) {
    80  	leaderMember := cluster.GetLeaderMember()
    81  	if leaderMember == nil {
    82  		return nil, ClusterHasNoLeader
    83  	}
    84  
    85  	var host string
    86  	if leaderMember.Name != mgr.CurrentMemberName {
    87  		host = cluster.GetMemberAddr(*leaderMember)
    88  	}
    89  	return mgr.QueryWithHost(ctx, sql, host)
    90  }
    91  
    92  // Exec is equivalent to ExecWithHost(ctx, sql, ""), exec itself.
    93  func (mgr *Manager) Exec(ctx context.Context, sql string) (result int64, err error) {
    94  	return mgr.ExecWithHost(ctx, sql, "")
    95  }
    96  
    97  func (mgr *Manager) ExecWithHost(ctx context.Context, sql string, host string) (result int64, err error) {
    98  	var res pgconn.CommandTag
    99  
   100  	// when host is empty, use manager's connection pool
   101  	if host == "" {
   102  		res, err = mgr.Pool.Exec(ctx, sql)
   103  	} else {
   104  		res, err = mgr.ExecOthers(ctx, sql, host)
   105  	}
   106  	if err != nil {
   107  		return 0, errors.Errorf("exec sql:%s failed: %v", sql, err)
   108  	}
   109  
   110  	result = res.RowsAffected()
   111  	return result, nil
   112  }
   113  
   114  func (mgr *Manager) ExecOthers(ctx context.Context, sql string, host string) (resp pgconn.CommandTag, err error) {
   115  	conn, err := pgx.Connect(ctx, config.GetConnectURLWithHost(host))
   116  	if err != nil {
   117  		return resp, err
   118  	}
   119  	defer func() {
   120  		_ = conn.Close(ctx)
   121  	}()
   122  
   123  	return conn.Exec(ctx, sql)
   124  }
   125  
   126  func (mgr *Manager) ExecLeader(ctx context.Context, sql string, cluster *dcs.Cluster) (result int64, err error) {
   127  	leaderMember := cluster.GetLeaderMember()
   128  	if leaderMember == nil {
   129  		return 0, ClusterHasNoLeader
   130  	}
   131  
   132  	var host string
   133  	if leaderMember.Name != mgr.CurrentMemberName {
   134  		host = cluster.GetMemberAddr(*leaderMember)
   135  	}
   136  	return mgr.ExecWithHost(ctx, sql, host)
   137  }
   138  
   139  func (mgr *Manager) GetPgCurrentSetting(ctx context.Context, setting string) (string, error) {
   140  	sql := fmt.Sprintf(`select pg_catalog.current_setting('%s');`, setting)
   141  
   142  	resp, err := mgr.Query(ctx, sql)
   143  	if err != nil {
   144  		return "", err
   145  	}
   146  
   147  	resMap, err := ParseQuery(string(resp))
   148  	if err != nil {
   149  		return "", err
   150  	}
   151  
   152  	return cast.ToString(resMap[0]["current_setting"]), nil
   153  }
   154  
   155  func parseRows(rows pgx.Rows) (result []byte, err error) {
   156  	rs := make([]interface{}, 0)
   157  	columnTypes := rows.FieldDescriptions()
   158  	for rows.Next() {
   159  		values := make([]interface{}, len(columnTypes))
   160  		for i := range values {
   161  			values[i] = new(interface{})
   162  		}
   163  
   164  		if err = rows.Scan(values...); err != nil {
   165  			return nil, errors.Errorf("scanning row failed, err:%v", err)
   166  		}
   167  
   168  		r := map[string]interface{}{}
   169  		for i, ct := range columnTypes {
   170  			r[ct.Name] = values[i]
   171  		}
   172  		rs = append(rs, r)
   173  	}
   174  
   175  	if result, err = json.Marshal(rs); err != nil {
   176  		err = errors.Errorf("json marshal failed, err: %v", err)
   177  	}
   178  	return result, err
   179  }
   180  
   181  func ParseQuery(str string) (result []map[string]interface{}, err error) {
   182  	// Notice: in golang, json unmarshal will map all numeric types to float64.
   183  	err = json.Unmarshal([]byte(str), &result)
   184  	if err != nil || len(result) == 0 {
   185  		return nil, errors.Errorf("json unmarshal failed, err:%v", err)
   186  	}
   187  
   188  	return result, nil
   189  }