vitess.io/vitess@v0.16.2/go/vt/vttablet/tmclienttest/tmclienttest.go (about) 1 /* 2 Copyright 2022 The Vitess Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package tmclienttest 18 19 import ( 20 "os" 21 22 "github.com/spf13/pflag" 23 24 "vitess.io/vitess/go/vt/log" 25 "vitess.io/vitess/go/vt/servenv" 26 "vitess.io/vitess/go/vt/vttablet/tmclient" 27 ) 28 29 const tmclientProtocolFlagName = "tablet_manager_protocol" 30 31 // SetProtocol is a helper function to set the tmclient --tablet_manager_protocol 32 // flag value for tests. If successful, it returns a function that, when called, 33 // returns the flag to its previous value. 34 // 35 // Note that because this variable is bound to a flag, the effects of this 36 // function are global, not scoped to the calling test-case. Therefore it should 37 // not be used in conjunction with t.Parallel. 38 func SetProtocol(name string, protocol string) (reset func()) { 39 var tmp []string 40 tmp, os.Args = os.Args[:], []string{name} 41 defer func() { os.Args = tmp }() 42 43 servenv.OnParseFor(name, func(fs *pflag.FlagSet) { 44 if fs.Lookup(tmclientProtocolFlagName) != nil { 45 return 46 } 47 48 tmclient.RegisterFlags(fs) 49 }) 50 servenv.ParseFlags(name) 51 52 switch oldVal, err := pflag.CommandLine.GetString(tmclientProtocolFlagName); err { 53 case nil: 54 reset = func() { SetProtocol(name, oldVal) } 55 default: 56 log.Errorf("failed to get string value for flag %q: %v", tmclientProtocolFlagName, err) 57 reset = func() {} 58 } 59 60 if err := pflag.Set(tmclientProtocolFlagName, protocol); err != nil { 61 msg := "failed to set flag %q to %q: %v" 62 log.Errorf(msg, tmclientProtocolFlagName, protocol, err) 63 reset = func() {} 64 } 65 66 return reset 67 }