vitess.io/vitess@v0.16.2/go/vt/vttablet/tmclienttest/tmclienttest.go (about)

     1  /*
     2  Copyright 2022 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package tmclienttest
    18  
    19  import (
    20  	"os"
    21  
    22  	"github.com/spf13/pflag"
    23  
    24  	"vitess.io/vitess/go/vt/log"
    25  	"vitess.io/vitess/go/vt/servenv"
    26  	"vitess.io/vitess/go/vt/vttablet/tmclient"
    27  )
    28  
    29  const tmclientProtocolFlagName = "tablet_manager_protocol"
    30  
    31  // SetProtocol is a helper function to set the tmclient --tablet_manager_protocol
    32  // flag value for tests. If successful, it returns a function that, when called,
    33  // returns the flag to its previous value.
    34  //
    35  // Note that because this variable is bound to a flag, the effects of this
    36  // function are global, not scoped to the calling test-case. Therefore it should
    37  // not be used in conjunction with t.Parallel.
    38  func SetProtocol(name string, protocol string) (reset func()) {
    39  	var tmp []string
    40  	tmp, os.Args = os.Args[:], []string{name}
    41  	defer func() { os.Args = tmp }()
    42  
    43  	servenv.OnParseFor(name, func(fs *pflag.FlagSet) {
    44  		if fs.Lookup(tmclientProtocolFlagName) != nil {
    45  			return
    46  		}
    47  
    48  		tmclient.RegisterFlags(fs)
    49  	})
    50  	servenv.ParseFlags(name)
    51  
    52  	switch oldVal, err := pflag.CommandLine.GetString(tmclientProtocolFlagName); err {
    53  	case nil:
    54  		reset = func() { SetProtocol(name, oldVal) }
    55  	default:
    56  		log.Errorf("failed to get string value for flag %q: %v", tmclientProtocolFlagName, err)
    57  		reset = func() {}
    58  	}
    59  
    60  	if err := pflag.Set(tmclientProtocolFlagName, protocol); err != nil {
    61  		msg := "failed to set flag %q to %q: %v"
    62  		log.Errorf(msg, tmclientProtocolFlagName, protocol, err)
    63  		reset = func() {}
    64  	}
    65  
    66  	return reset
    67  }