github.com/pingcap/tiflow@v0.0.0-20240520035814-5bf52d54e205/dm/ctl/common/util.go (about)

     1  // Copyright 2019 PingCAP, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package common
    15  
    16  import (
    17  	"context"
    18  	"encoding/json"
    19  	"fmt"
    20  	"os"
    21  	"reflect"
    22  	"regexp"
    23  	"strconv"
    24  	"strings"
    25  	"sync"
    26  	"time"
    27  
    28  	"github.com/gogo/protobuf/jsonpb"
    29  	"github.com/golang/protobuf/proto"
    30  	"github.com/pingcap/errors"
    31  	"github.com/pingcap/failpoint"
    32  	toolutils "github.com/pingcap/tidb-tools/pkg/utils"
    33  	"github.com/pingcap/tidb/pkg/parser"
    34  	"github.com/pingcap/tiflow/dm/config"
    35  	"github.com/pingcap/tiflow/dm/config/security"
    36  	"github.com/pingcap/tiflow/dm/pb"
    37  	"github.com/pingcap/tiflow/dm/pkg/log"
    38  	parserpkg "github.com/pingcap/tiflow/dm/pkg/parser"
    39  	"github.com/pingcap/tiflow/dm/pkg/terror"
    40  	"github.com/pingcap/tiflow/dm/pkg/utils"
    41  	"github.com/spf13/cobra"
    42  	clientv3 "go.etcd.io/etcd/client/v3"
    43  	"go.uber.org/zap"
    44  	"google.golang.org/grpc"
    45  	"google.golang.org/grpc/codes"
    46  	"google.golang.org/grpc/status"
    47  )
    48  
    49  var (
    50  	globalConfig = &Config{}
    51  	// GlobalCtlClient is the globally used CtlClient in this package. Exposed to be used in test.
    52  	GlobalCtlClient = &CtlClient{}
    53  
    54  	re = regexp.MustCompile(`grpc: received message larger than max \((\d+) vs. (\d+)\)`)
    55  )
    56  
    57  // CtlClient used to get master client for dmctl.
    58  type CtlClient struct {
    59  	mu           sync.RWMutex
    60  	tls          *toolutils.TLS
    61  	conn         *grpc.ClientConn
    62  	MasterClient pb.MasterClient  // exposed to be used in test
    63  	EtcdClient   *clientv3.Client // exposed to be used in export config
    64  }
    65  
    66  func (c *CtlClient) updateMasterClient() error {
    67  	var (
    68  		err  error
    69  		conn *grpc.ClientConn
    70  	)
    71  
    72  	c.mu.Lock()
    73  	defer c.mu.Unlock()
    74  
    75  	if c.conn != nil {
    76  		c.conn.Close()
    77  	}
    78  
    79  	endpoints := c.EtcdClient.Endpoints()
    80  	for _, endpoint := range endpoints {
    81  		//nolint:staticcheck
    82  		conn, err = grpc.Dial(utils.UnwrapScheme(endpoint), c.tls.ToGRPCDialOption(), grpc.WithBackoffMaxDelay(3*time.Second), grpc.WithBlock(), grpc.WithTimeout(3*time.Second))
    83  		if err == nil {
    84  			c.conn = conn
    85  			c.MasterClient = pb.NewMasterClient(conn)
    86  			return nil
    87  		}
    88  	}
    89  	return terror.ErrCtlGRPCCreateConn.AnnotateDelegate(err, "can't connect to %s", strings.Join(endpoints, ","))
    90  }
    91  
    92  func (c *CtlClient) sendRequest(
    93  	ctx context.Context,
    94  	reqName string,
    95  	req interface{},
    96  	respPointer interface{},
    97  	opts ...interface{},
    98  ) error {
    99  	c.mu.RLock()
   100  	defer c.mu.RUnlock()
   101  
   102  	params := []reflect.Value{reflect.ValueOf(ctx), reflect.ValueOf(req)}
   103  	for _, o := range opts {
   104  		params = append(params, reflect.ValueOf(o))
   105  	}
   106  	results := reflect.ValueOf(c.MasterClient).MethodByName(reqName).Call(params)
   107  
   108  	reflect.ValueOf(respPointer).Elem().Set(results[0])
   109  	errInterface := results[1].Interface()
   110  	// nil can't pass type conversion, so we handle it separately
   111  	if errInterface == nil {
   112  		return nil
   113  	}
   114  	return errInterface.(error)
   115  }
   116  
   117  // SendRequest send request to master.
   118  func SendRequest(ctx context.Context, reqName string, req interface{}, respPointer interface{}) error {
   119  	err := GlobalCtlClient.sendRequest(ctx, reqName, req, respPointer)
   120  	if err == nil {
   121  		return nil
   122  	}
   123  	var opts []interface{}
   124  	switch status.Code(err) {
   125  	case codes.ResourceExhausted:
   126  		matches := re.FindStringSubmatch(err.Error())
   127  		if len(matches) == 3 {
   128  			msgSize, err2 := strconv.Atoi(matches[1])
   129  			if err2 == nil {
   130  				log.L().Info("increase gRPC maximum message size", zap.Int("size", msgSize))
   131  				opts = append(opts, grpc.MaxCallRecvMsgSize(msgSize))
   132  			}
   133  		}
   134  	case codes.Unavailable:
   135  	default:
   136  		return err
   137  	}
   138  
   139  	failpoint.Inject("SkipUpdateMasterClient", func() {
   140  		failpoint.Goto("bypass")
   141  	})
   142  	// update master client
   143  	err = GlobalCtlClient.updateMasterClient()
   144  	if err != nil {
   145  		return err
   146  	}
   147  	failpoint.Label("bypass")
   148  
   149  	// sendRequest again
   150  	return GlobalCtlClient.sendRequest(ctx, reqName, req, respPointer, opts...)
   151  }
   152  
   153  // InitUtils inits necessary dmctl utils.
   154  func InitUtils(cfg *Config) error {
   155  	globalConfig = cfg
   156  	return errors.Trace(InitClient(cfg.MasterAddr, cfg.Security))
   157  }
   158  
   159  // InitClient initializes dm-master client.
   160  func InitClient(addr string, securityCfg security.Security) error {
   161  	tls, err := toolutils.NewTLS(securityCfg.SSLCA, securityCfg.SSLCert, securityCfg.SSLKey, "", securityCfg.CertAllowedCN)
   162  	if err != nil {
   163  		return terror.ErrCtlInvalidTLSCfg.Delegate(err)
   164  	}
   165  
   166  	endpoints := strings.Split(addr, ",")
   167  	etcdClient, err := clientv3.New(clientv3.Config{
   168  		Endpoints:            endpoints,
   169  		DialTimeout:          dialTimeout,
   170  		DialKeepAliveTime:    keepaliveTime,
   171  		DialKeepAliveTimeout: keepaliveTimeout,
   172  		TLS:                  tls.TLSConfig(),
   173  	})
   174  	if err != nil {
   175  		return err
   176  	}
   177  
   178  	GlobalCtlClient = &CtlClient{
   179  		tls:        tls,
   180  		EtcdClient: etcdClient,
   181  	}
   182  
   183  	return GlobalCtlClient.updateMasterClient()
   184  }
   185  
   186  // GlobalConfig returns global dmctl config.
   187  func GlobalConfig() *Config {
   188  	return globalConfig
   189  }
   190  
   191  // PrintLinesf adds a wrap to support `\n` within `chzyer/readline`.
   192  func PrintLinesf(format string, a ...interface{}) {
   193  	fmt.Println(fmt.Sprintf(format, a...))
   194  }
   195  
   196  // PrettyPrintResponse prints a PRC response prettily.
   197  func PrettyPrintResponse(resp proto.Message) {
   198  	s, err := marshResponseToString(resp)
   199  	if err != nil {
   200  		PrintLinesf("%v", err)
   201  	} else {
   202  		fmt.Println(s)
   203  	}
   204  }
   205  
   206  // PrettyPrintInterface prints an interface through encoding/json prettily.
   207  func PrettyPrintInterface(resp interface{}) {
   208  	s, err := json.MarshalIndent(resp, "", "    ")
   209  	if err != nil {
   210  		PrintLinesf("%v", err)
   211  	} else {
   212  		fmt.Println(string(s))
   213  	}
   214  }
   215  
   216  func marshResponseToString(resp proto.Message) (string, error) {
   217  	// encoding/json does not support proto Enum well
   218  	mar := jsonpb.Marshaler{EmitDefaults: true, Indent: "    "}
   219  	s, err := mar.MarshalToString(resp)
   220  	return s, errors.Trace(err)
   221  }
   222  
   223  // PrettyPrintResponseWithCheckTask prints a RPC response may contain response Msg with check-task's response prettily.
   224  // check-task's response may contain json-string when checking fail in `detail` field.
   225  // ugly code, but it is a little hard to refine this because needing to convert type.
   226  func PrettyPrintResponseWithCheckTask(resp proto.Message, subStr string) bool {
   227  	var (
   228  		err          error
   229  		found        bool
   230  		replacedStr  string
   231  		marshaledStr string
   232  		placeholder  = "PLACEHOLDER"
   233  	)
   234  	switch chr := resp.(type) {
   235  	case *pb.StartTaskResponse:
   236  		if strings.Contains(chr.CheckResult, subStr) {
   237  			found = true
   238  			rawMsg := chr.CheckResult
   239  			chr.CheckResult = placeholder // replace Msg with placeholder
   240  			marshaledStr, err = marshResponseToString(chr)
   241  			if err == nil {
   242  				replacedStr = strings.Replace(marshaledStr, placeholder, rawMsg, 1)
   243  			}
   244  		}
   245  	case *pb.UpdateTaskResponse:
   246  		if strings.Contains(chr.CheckResult, subStr) {
   247  			found = true
   248  			rawMsg := chr.CheckResult
   249  			chr.CheckResult = placeholder // replace Msg with placeholder
   250  			marshaledStr, err = marshResponseToString(chr)
   251  			if err == nil {
   252  				replacedStr = strings.Replace(marshaledStr, placeholder, rawMsg, 1)
   253  			}
   254  		}
   255  	case *pb.CheckTaskResponse:
   256  		if strings.Contains(chr.Msg, subStr) {
   257  			found = true
   258  			rawMsg := chr.Msg
   259  			chr.Msg = placeholder // replace Msg with placeholder
   260  			marshaledStr, err = marshResponseToString(chr)
   261  			if err == nil {
   262  				replacedStr = strings.Replace(marshaledStr, placeholder, rawMsg, 1)
   263  			}
   264  		}
   265  
   266  	default:
   267  		return false
   268  	}
   269  
   270  	if !found {
   271  		return found
   272  	}
   273  
   274  	if err != nil {
   275  		PrintLinesf("%v", err)
   276  	} else {
   277  		// add indent to make it prettily.
   278  		replacedStr = strings.Replace(replacedStr, "detail: {", "   \tdetail: {", 1)
   279  		fmt.Println(replacedStr)
   280  	}
   281  	return found
   282  }
   283  
   284  // GetFileContent reads and returns file's content.
   285  func GetFileContent(fpath string) ([]byte, error) {
   286  	content, err := os.ReadFile(fpath)
   287  	if err != nil {
   288  		return nil, errors.Annotate(err, "error in get file content")
   289  	}
   290  	return content, nil
   291  }
   292  
   293  // GetSourceArgs extracts sources from cmd.
   294  func GetSourceArgs(cmd *cobra.Command) ([]string, error) {
   295  	ret, err := cmd.Flags().GetStringSlice("source")
   296  	if err != nil {
   297  		PrintLinesf("error in parse `-s` / `--source`")
   298  	}
   299  	return ret, err
   300  }
   301  
   302  // ExtractSQLsFromArgs extract multiple sql from args.
   303  func ExtractSQLsFromArgs(args []string) ([]string, error) {
   304  	if len(args) == 0 {
   305  		return nil, errors.New("args is empty")
   306  	}
   307  
   308  	concat := strings.TrimSpace(strings.Join(args, " "))
   309  	concat = utils.TrimQuoteMark(concat)
   310  
   311  	parser := parser.New()
   312  	nodes, err := parserpkg.Parse(parser, concat, "", "")
   313  	if err != nil {
   314  		return nil, errors.Annotatef(err, "invalid sql '%s'", concat)
   315  	}
   316  	realSQLs := make([]string, 0, len(nodes))
   317  	for _, node := range nodes {
   318  		realSQLs = append(realSQLs, node.Text())
   319  	}
   320  	if len(realSQLs) == 0 {
   321  		return nil, errors.New("no valid SQLs")
   322  	}
   323  
   324  	return realSQLs, nil
   325  }
   326  
   327  // GetTaskNameFromArgOrFile tries to retrieve name from the file if arg is yaml-filename-like, otherwise returns arg directly.
   328  func GetTaskNameFromArgOrFile(arg string) string {
   329  	if !(strings.HasSuffix(arg, ".yaml") || strings.HasSuffix(arg, ".yml")) {
   330  		return arg
   331  	}
   332  	var (
   333  		content []byte
   334  		err     error
   335  	)
   336  	if content, err = GetFileContent(arg); err != nil {
   337  		return arg
   338  	}
   339  	cfg := config.NewTaskConfig()
   340  	if err := cfg.FromYaml(string(content)); err != nil {
   341  		return arg
   342  	}
   343  	return cfg.Name
   344  }
   345  
   346  // PrintCmdUsage prints the usage of the command.
   347  func PrintCmdUsage(cmd *cobra.Command) {
   348  	if err := cmd.Usage(); err != nil {
   349  		fmt.Println("can't output command's usage:", err)
   350  	}
   351  }