github.com/datastax/go-cassandra-native-protocol@v0.0.0-20220706104457-5e8aad05cf90/client/main_test.go (about) 1 // Copyright 2020 DataStax 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package client_test 16 17 import ( 18 "flag" 19 "github.com/datastax/go-cassandra-native-protocol/client" 20 "github.com/datastax/go-cassandra-native-protocol/primitive" 21 "github.com/rs/zerolog" 22 "github.com/rs/zerolog/log" 23 "math" 24 "os" 25 "sync/atomic" 26 "testing" 27 ) 28 29 var remoteAvailable bool 30 var logLevel int 31 32 func TestMain(m *testing.M) { 33 parseFlags() 34 setLogLevel() 35 createStreamIdGenerators() 36 os.Exit(m.Run()) 37 } 38 39 func parseFlags() { 40 flag.IntVar( 41 &logLevel, 42 "logLevel", 43 int(zerolog.ErrorLevel), 44 "the log level to use (default: info)", 45 ) 46 flag.BoolVar( 47 &remoteAvailable, 48 "remote", 49 false, 50 "whether a remote cluster is available on localhost:9042", 51 ) 52 flag.Parse() 53 } 54 55 func setLogLevel() { 56 zerolog.SetGlobalLevel(zerolog.Level(logLevel)) 57 log.Logger = log.Output(zerolog.ConsoleWriter{ 58 Out: os.Stderr, 59 TimeFormat: zerolog.TimeFormatUnix, 60 }) 61 } 62 63 var compressions = []primitive.Compression{primitive.CompressionNone, primitive.CompressionLz4, primitive.CompressionSnappy} 64 65 var streamIdGenerators map[string]func(int, primitive.ProtocolVersion) int16 66 67 func createStreamIdGenerators() { 68 var managed = func(clientId int, version primitive.ProtocolVersion) int16 { 69 return client.ManagedStreamId 70 } 71 var fixed = func(clientId int, version primitive.ProtocolVersion) int16 { 72 if int16(clientId) == client.ManagedStreamId { 73 panic("stream id 0") 74 } 75 return int16(clientId) 76 } 77 counter := uint32(1) 78 var incremental = func(clientId int, version primitive.ProtocolVersion) int16 { 79 var max uint32 80 if version <= primitive.ProtocolVersion2 { 81 max = math.MaxInt8 82 } else { 83 max = math.MaxInt16 84 } 85 for { 86 current := counter 87 next := current + 1 88 if next > max { 89 next = 1 90 } 91 if atomic.CompareAndSwapUint32(&counter, current, next) { 92 return int16(next) 93 } 94 } 95 } 96 streamIdGenerators = map[string]func(int, primitive.ProtocolVersion) int16{ 97 "managed": managed, 98 "fixed": fixed, 99 "incremental": incremental, 100 } 101 }