get.pme.sh/pnats@v0.0.0-20240304004023-26bb5a137ed0/server/closed_conns_test.go (about)

     1  // Copyright 2018-2020 The NATS Authors
     2  // Licensed under the Apache License, Version 2.0 (the "License");
     3  // you may not use this file except in compliance with the License.
     4  // You may obtain a copy of the License at
     5  //
     6  // http://www.apache.org/licenses/LICENSE-2.0
     7  //
     8  // Unless required by applicable law or agreed to in writing, software
     9  // distributed under the License is distributed on an "AS IS" BASIS,
    10  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package server
    15  
    16  import (
    17  	"fmt"
    18  	"net"
    19  	"strings"
    20  	"testing"
    21  	"time"
    22  
    23  	"github.com/nats-io/nats.go"
    24  )
    25  
    26  func checkClosedConns(t *testing.T, s *Server, num int, wait time.Duration) {
    27  	t.Helper()
    28  	checkFor(t, wait, 5*time.Millisecond, func() error {
    29  		if nc := s.numClosedConns(); nc != num {
    30  			return fmt.Errorf("Closed conns expected to be %v, got %v", num, nc)
    31  		}
    32  		return nil
    33  	})
    34  }
    35  
    36  func checkTotalClosedConns(t *testing.T, s *Server, num uint64, wait time.Duration) {
    37  	t.Helper()
    38  	checkFor(t, wait, 5*time.Millisecond, func() error {
    39  		if nc := s.totalClosedConns(); nc != num {
    40  			return fmt.Errorf("Total closed conns expected to be %v, got %v", num, nc)
    41  		}
    42  		return nil
    43  	})
    44  }
    45  
    46  func TestClosedConnsAccounting(t *testing.T) {
    47  	opts := DefaultOptions()
    48  	opts.MaxClosedClients = 10
    49  	opts.NoSystemAccount = true
    50  
    51  	s := RunServer(opts)
    52  	defer s.Shutdown()
    53  
    54  	wait := time.Second
    55  
    56  	nc, err := nats.Connect(fmt.Sprintf("nats://%s:%d", opts.Host, opts.Port))
    57  	if err != nil {
    58  		t.Fatalf("Error on connect: %v", err)
    59  	}
    60  	id, _ := nc.GetClientID()
    61  	nc.Close()
    62  
    63  	checkClosedConns(t, s, 1, wait)
    64  
    65  	conns := s.closedClients()
    66  	if lc := len(conns); lc != 1 {
    67  		t.Fatalf("len(conns) expected to be %d, got %d\n", 1, lc)
    68  	}
    69  	if conns[0].Cid != id {
    70  		t.Fatalf("Expected CID to be %d, got %d\n", id, conns[0].Cid)
    71  	}
    72  
    73  	// Now create 21 more
    74  	for i := 0; i < 21; i++ {
    75  		nc, err = nats.Connect(fmt.Sprintf("nats://%s:%d", opts.Host, opts.Port))
    76  		if err != nil {
    77  			t.Fatalf("Error on connect: %v", err)
    78  		}
    79  		nc.Close()
    80  		checkTotalClosedConns(t, s, uint64(i+2), wait)
    81  	}
    82  
    83  	checkClosedConns(t, s, opts.MaxClosedClients, wait)
    84  	checkTotalClosedConns(t, s, 22, wait)
    85  
    86  	conns = s.closedClients()
    87  	if lc := len(conns); lc != opts.MaxClosedClients {
    88  		t.Fatalf("len(conns) expected to be %d, got %d\n",
    89  			opts.MaxClosedClients, lc)
    90  	}
    91  
    92  	// Set it to the start after overflow.
    93  	cid := uint64(22 - opts.MaxClosedClients)
    94  	for _, ci := range conns {
    95  		cid++
    96  		if ci.Cid != cid {
    97  			t.Fatalf("Expected cid of %d, got %d\n", cid, ci.Cid)
    98  		}
    99  	}
   100  }
   101  
   102  func TestClosedConnsSubsAccounting(t *testing.T) {
   103  	opts := DefaultOptions()
   104  	s := RunServer(opts)
   105  	defer s.Shutdown()
   106  
   107  	url := fmt.Sprintf("nats://%s:%d", opts.Host, opts.Port)
   108  
   109  	nc, err := nats.Connect(url)
   110  	if err != nil {
   111  		t.Fatalf("Error on subscribe: %v", err)
   112  	}
   113  	defer nc.Close()
   114  
   115  	// Now create some subscriptions
   116  	numSubs := 10
   117  	for i := 0; i < numSubs; i++ {
   118  		subj := fmt.Sprintf("foo.%d", i)
   119  		nc.Subscribe(subj, func(m *nats.Msg) {})
   120  	}
   121  	nc.Flush()
   122  	nc.Close()
   123  
   124  	checkClosedConns(t, s, 1, 20*time.Millisecond)
   125  	conns := s.closedClients()
   126  	if lc := len(conns); lc != 1 {
   127  		t.Fatalf("len(conns) expected to be 1, got %d\n", lc)
   128  	}
   129  	ci := conns[0]
   130  
   131  	if len(ci.subs) != numSubs {
   132  		t.Fatalf("Expected number of Subs to be %d, got %d\n", numSubs, len(ci.subs))
   133  	}
   134  }
   135  
   136  func checkReason(t *testing.T, reason string, expected ClosedState) {
   137  	if !strings.Contains(reason, expected.String()) {
   138  		t.Fatalf("Expected closed connection with `%s` state, got `%s`\n",
   139  			expected, reason)
   140  	}
   141  }
   142  
   143  func TestClosedAuthorizationTimeout(t *testing.T) {
   144  	serverOptions := DefaultOptions()
   145  	serverOptions.Authorization = "my_token"
   146  	serverOptions.AuthTimeout = 0.4
   147  	s := RunServer(serverOptions)
   148  	defer s.Shutdown()
   149  
   150  	conn, err := net.Dial("tcp", fmt.Sprintf("%s:%d", serverOptions.Host, serverOptions.Port))
   151  	if err != nil {
   152  		t.Fatalf("Error dialing server: %v\n", err)
   153  	}
   154  	defer conn.Close()
   155  
   156  	checkClosedConns(t, s, 1, 2*time.Second)
   157  	conns := s.closedClients()
   158  	if lc := len(conns); lc != 1 {
   159  		t.Fatalf("len(conns) expected to be %d, got %d\n", 1, lc)
   160  	}
   161  	checkReason(t, conns[0].Reason, AuthenticationTimeout)
   162  }
   163  
   164  func TestClosedAuthorizationViolation(t *testing.T) {
   165  	serverOptions := DefaultOptions()
   166  	serverOptions.Authorization = "my_token"
   167  	s := RunServer(serverOptions)
   168  	defer s.Shutdown()
   169  
   170  	opts := s.getOpts()
   171  	url := fmt.Sprintf("nats://%s:%d", opts.Host, opts.Port)
   172  
   173  	nc, err := nats.Connect(url)
   174  	if err == nil {
   175  		nc.Close()
   176  		t.Fatal("Expected failure for connection")
   177  	}
   178  
   179  	checkClosedConns(t, s, 1, 2*time.Second)
   180  	conns := s.closedClients()
   181  	if lc := len(conns); lc != 1 {
   182  		t.Fatalf("len(conns) expected to be %d, got %d\n", 1, lc)
   183  	}
   184  	checkReason(t, conns[0].Reason, AuthenticationViolation)
   185  }
   186  
   187  func TestClosedUPAuthorizationViolation(t *testing.T) {
   188  	serverOptions := DefaultOptions()
   189  	serverOptions.Username = "my_user"
   190  	serverOptions.Password = "my_secret"
   191  	s := RunServer(serverOptions)
   192  	defer s.Shutdown()
   193  
   194  	opts := s.getOpts()
   195  	url := fmt.Sprintf("nats://%s:%d", opts.Host, opts.Port)
   196  
   197  	nc, err := nats.Connect(url)
   198  	if err == nil {
   199  		nc.Close()
   200  		t.Fatal("Expected failure for connection")
   201  	}
   202  
   203  	url2 := fmt.Sprintf("nats://my_user:wrong_pass@%s:%d", opts.Host, opts.Port)
   204  	nc, err = nats.Connect(url2)
   205  	if err == nil {
   206  		nc.Close()
   207  		t.Fatal("Expected failure for connection")
   208  	}
   209  
   210  	checkClosedConns(t, s, 2, 2*time.Second)
   211  	conns := s.closedClients()
   212  	if lc := len(conns); lc != 2 {
   213  		t.Fatalf("len(conns) expected to be %d, got %d\n", 2, lc)
   214  	}
   215  	checkReason(t, conns[0].Reason, AuthenticationViolation)
   216  	checkReason(t, conns[1].Reason, AuthenticationViolation)
   217  }
   218  
   219  func TestClosedMaxPayload(t *testing.T) {
   220  	serverOptions := DefaultOptions()
   221  	serverOptions.MaxPayload = 100
   222  
   223  	s := RunServer(serverOptions)
   224  	defer s.Shutdown()
   225  
   226  	opts := s.getOpts()
   227  	endpoint := fmt.Sprintf("%s:%d", opts.Host, opts.Port)
   228  
   229  	conn, err := net.DialTimeout("tcp", endpoint, time.Second)
   230  	if err != nil {
   231  		t.Fatalf("Could not make a raw connection to the server: %v", err)
   232  	}
   233  	defer conn.Close()
   234  
   235  	// This should trigger it.
   236  	pub := "PUB foo.bar 1024\r\n"
   237  	conn.Write([]byte(pub))
   238  
   239  	checkClosedConns(t, s, 1, 2*time.Second)
   240  	conns := s.closedClients()
   241  	if lc := len(conns); lc != 1 {
   242  		t.Fatalf("len(conns) expected to be %d, got %d\n", 1, lc)
   243  	}
   244  	checkReason(t, conns[0].Reason, MaxPayloadExceeded)
   245  }
   246  
   247  func TestClosedTLSHandshake(t *testing.T) {
   248  	opts, err := ProcessConfigFile("./configs/tls.conf")
   249  	if err != nil {
   250  		t.Fatalf("Error processing config file: %v", err)
   251  	}
   252  	opts.TLSVerify = true
   253  	opts.NoLog = true
   254  	opts.NoSigs = true
   255  	s := RunServer(opts)
   256  	defer s.Shutdown()
   257  
   258  	nc, err := nats.Connect(fmt.Sprintf("tls://%s:%d", opts.Host, opts.Port))
   259  	if err == nil {
   260  		nc.Close()
   261  		t.Fatal("Expected failure for connection")
   262  	}
   263  
   264  	checkClosedConns(t, s, 1, 2*time.Second)
   265  	conns := s.closedClients()
   266  	if lc := len(conns); lc != 1 {
   267  		t.Fatalf("len(conns) expected to be %d, got %d\n", 1, lc)
   268  	}
   269  	checkReason(t, conns[0].Reason, TLSHandshakeError)
   270  }