get.pme.sh/pnats@v0.0.0-20240304004023-26bb5a137ed0/server/closed_conns_test.go (about) 1 // Copyright 2018-2020 The NATS Authors 2 // Licensed under the Apache License, Version 2.0 (the "License"); 3 // you may not use this file except in compliance with the License. 4 // You may obtain a copy of the License at 5 // 6 // http://www.apache.org/licenses/LICENSE-2.0 7 // 8 // Unless required by applicable law or agreed to in writing, software 9 // distributed under the License is distributed on an "AS IS" BASIS, 10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 // See the License for the specific language governing permissions and 12 // limitations under the License. 13 14 package server 15 16 import ( 17 "fmt" 18 "net" 19 "strings" 20 "testing" 21 "time" 22 23 "github.com/nats-io/nats.go" 24 ) 25 26 func checkClosedConns(t *testing.T, s *Server, num int, wait time.Duration) { 27 t.Helper() 28 checkFor(t, wait, 5*time.Millisecond, func() error { 29 if nc := s.numClosedConns(); nc != num { 30 return fmt.Errorf("Closed conns expected to be %v, got %v", num, nc) 31 } 32 return nil 33 }) 34 } 35 36 func checkTotalClosedConns(t *testing.T, s *Server, num uint64, wait time.Duration) { 37 t.Helper() 38 checkFor(t, wait, 5*time.Millisecond, func() error { 39 if nc := s.totalClosedConns(); nc != num { 40 return fmt.Errorf("Total closed conns expected to be %v, got %v", num, nc) 41 } 42 return nil 43 }) 44 } 45 46 func TestClosedConnsAccounting(t *testing.T) { 47 opts := DefaultOptions() 48 opts.MaxClosedClients = 10 49 opts.NoSystemAccount = true 50 51 s := RunServer(opts) 52 defer s.Shutdown() 53 54 wait := time.Second 55 56 nc, err := nats.Connect(fmt.Sprintf("nats://%s:%d", opts.Host, opts.Port)) 57 if err != nil { 58 t.Fatalf("Error on connect: %v", err) 59 } 60 id, _ := nc.GetClientID() 61 nc.Close() 62 63 checkClosedConns(t, s, 1, wait) 64 65 conns := s.closedClients() 66 if lc := len(conns); lc != 1 { 67 t.Fatalf("len(conns) expected to be %d, got %d\n", 1, lc) 68 } 69 if conns[0].Cid != id { 70 t.Fatalf("Expected CID to be %d, got %d\n", id, conns[0].Cid) 71 } 72 73 // Now create 21 more 74 for i := 0; i < 21; i++ { 75 nc, err = nats.Connect(fmt.Sprintf("nats://%s:%d", opts.Host, opts.Port)) 76 if err != nil { 77 t.Fatalf("Error on connect: %v", err) 78 } 79 nc.Close() 80 checkTotalClosedConns(t, s, uint64(i+2), wait) 81 } 82 83 checkClosedConns(t, s, opts.MaxClosedClients, wait) 84 checkTotalClosedConns(t, s, 22, wait) 85 86 conns = s.closedClients() 87 if lc := len(conns); lc != opts.MaxClosedClients { 88 t.Fatalf("len(conns) expected to be %d, got %d\n", 89 opts.MaxClosedClients, lc) 90 } 91 92 // Set it to the start after overflow. 93 cid := uint64(22 - opts.MaxClosedClients) 94 for _, ci := range conns { 95 cid++ 96 if ci.Cid != cid { 97 t.Fatalf("Expected cid of %d, got %d\n", cid, ci.Cid) 98 } 99 } 100 } 101 102 func TestClosedConnsSubsAccounting(t *testing.T) { 103 opts := DefaultOptions() 104 s := RunServer(opts) 105 defer s.Shutdown() 106 107 url := fmt.Sprintf("nats://%s:%d", opts.Host, opts.Port) 108 109 nc, err := nats.Connect(url) 110 if err != nil { 111 t.Fatalf("Error on subscribe: %v", err) 112 } 113 defer nc.Close() 114 115 // Now create some subscriptions 116 numSubs := 10 117 for i := 0; i < numSubs; i++ { 118 subj := fmt.Sprintf("foo.%d", i) 119 nc.Subscribe(subj, func(m *nats.Msg) {}) 120 } 121 nc.Flush() 122 nc.Close() 123 124 checkClosedConns(t, s, 1, 20*time.Millisecond) 125 conns := s.closedClients() 126 if lc := len(conns); lc != 1 { 127 t.Fatalf("len(conns) expected to be 1, got %d\n", lc) 128 } 129 ci := conns[0] 130 131 if len(ci.subs) != numSubs { 132 t.Fatalf("Expected number of Subs to be %d, got %d\n", numSubs, len(ci.subs)) 133 } 134 } 135 136 func checkReason(t *testing.T, reason string, expected ClosedState) { 137 if !strings.Contains(reason, expected.String()) { 138 t.Fatalf("Expected closed connection with `%s` state, got `%s`\n", 139 expected, reason) 140 } 141 } 142 143 func TestClosedAuthorizationTimeout(t *testing.T) { 144 serverOptions := DefaultOptions() 145 serverOptions.Authorization = "my_token" 146 serverOptions.AuthTimeout = 0.4 147 s := RunServer(serverOptions) 148 defer s.Shutdown() 149 150 conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", serverOptions.Host, serverOptions.Port)) 151 if err != nil { 152 t.Fatalf("Error dialing server: %v\n", err) 153 } 154 defer conn.Close() 155 156 checkClosedConns(t, s, 1, 2*time.Second) 157 conns := s.closedClients() 158 if lc := len(conns); lc != 1 { 159 t.Fatalf("len(conns) expected to be %d, got %d\n", 1, lc) 160 } 161 checkReason(t, conns[0].Reason, AuthenticationTimeout) 162 } 163 164 func TestClosedAuthorizationViolation(t *testing.T) { 165 serverOptions := DefaultOptions() 166 serverOptions.Authorization = "my_token" 167 s := RunServer(serverOptions) 168 defer s.Shutdown() 169 170 opts := s.getOpts() 171 url := fmt.Sprintf("nats://%s:%d", opts.Host, opts.Port) 172 173 nc, err := nats.Connect(url) 174 if err == nil { 175 nc.Close() 176 t.Fatal("Expected failure for connection") 177 } 178 179 checkClosedConns(t, s, 1, 2*time.Second) 180 conns := s.closedClients() 181 if lc := len(conns); lc != 1 { 182 t.Fatalf("len(conns) expected to be %d, got %d\n", 1, lc) 183 } 184 checkReason(t, conns[0].Reason, AuthenticationViolation) 185 } 186 187 func TestClosedUPAuthorizationViolation(t *testing.T) { 188 serverOptions := DefaultOptions() 189 serverOptions.Username = "my_user" 190 serverOptions.Password = "my_secret" 191 s := RunServer(serverOptions) 192 defer s.Shutdown() 193 194 opts := s.getOpts() 195 url := fmt.Sprintf("nats://%s:%d", opts.Host, opts.Port) 196 197 nc, err := nats.Connect(url) 198 if err == nil { 199 nc.Close() 200 t.Fatal("Expected failure for connection") 201 } 202 203 url2 := fmt.Sprintf("nats://my_user:wrong_pass@%s:%d", opts.Host, opts.Port) 204 nc, err = nats.Connect(url2) 205 if err == nil { 206 nc.Close() 207 t.Fatal("Expected failure for connection") 208 } 209 210 checkClosedConns(t, s, 2, 2*time.Second) 211 conns := s.closedClients() 212 if lc := len(conns); lc != 2 { 213 t.Fatalf("len(conns) expected to be %d, got %d\n", 2, lc) 214 } 215 checkReason(t, conns[0].Reason, AuthenticationViolation) 216 checkReason(t, conns[1].Reason, AuthenticationViolation) 217 } 218 219 func TestClosedMaxPayload(t *testing.T) { 220 serverOptions := DefaultOptions() 221 serverOptions.MaxPayload = 100 222 223 s := RunServer(serverOptions) 224 defer s.Shutdown() 225 226 opts := s.getOpts() 227 endpoint := fmt.Sprintf("%s:%d", opts.Host, opts.Port) 228 229 conn, err := net.DialTimeout("tcp", endpoint, time.Second) 230 if err != nil { 231 t.Fatalf("Could not make a raw connection to the server: %v", err) 232 } 233 defer conn.Close() 234 235 // This should trigger it. 236 pub := "PUB foo.bar 1024\r\n" 237 conn.Write([]byte(pub)) 238 239 checkClosedConns(t, s, 1, 2*time.Second) 240 conns := s.closedClients() 241 if lc := len(conns); lc != 1 { 242 t.Fatalf("len(conns) expected to be %d, got %d\n", 1, lc) 243 } 244 checkReason(t, conns[0].Reason, MaxPayloadExceeded) 245 } 246 247 func TestClosedTLSHandshake(t *testing.T) { 248 opts, err := ProcessConfigFile("./configs/tls.conf") 249 if err != nil { 250 t.Fatalf("Error processing config file: %v", err) 251 } 252 opts.TLSVerify = true 253 opts.NoLog = true 254 opts.NoSigs = true 255 s := RunServer(opts) 256 defer s.Shutdown() 257 258 nc, err := nats.Connect(fmt.Sprintf("tls://%s:%d", opts.Host, opts.Port)) 259 if err == nil { 260 nc.Close() 261 t.Fatal("Expected failure for connection") 262 } 263 264 checkClosedConns(t, s, 1, 2*time.Second) 265 conns := s.closedClients() 266 if lc := len(conns); lc != 1 { 267 t.Fatalf("len(conns) expected to be %d, got %d\n", 1, lc) 268 } 269 checkReason(t, conns[0].Reason, TLSHandshakeError) 270 }