vitess.io/vitess@v0.16.2/go/netutil/conn_test.go (about)

     1  /*
     2  Copyright 2019 The Vitess Authors.
     3  Licensed under the Apache License, Version 2.0 (the "License");
     4  you may not use this file except in compliance with the License.
     5  You may obtain a copy of the License at
     6      http://www.apache.org/licenses/LICENSE-2.0
     7  Unless required by applicable law or agreed to in writing, software
     8  distributed under the License is distributed on an "AS IS" BASIS,
     9  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    10  See the License for the specific language governing permissions and
    11  limitations under the License.
    12  */
    13  
    14  package netutil
    15  
    16  import (
    17  	"net"
    18  	"strings"
    19  	"sync"
    20  	"testing"
    21  	"time"
    22  )
    23  
    24  func createSocketPair(t *testing.T) (net.Listener, net.Conn, net.Conn) {
    25  	// Create a listener.
    26  	listener, err := net.Listen("tcp", "127.0.0.1:0")
    27  	if err != nil {
    28  		t.Fatalf("Listen failed: %v", err)
    29  	}
    30  	addr := listener.Addr().String()
    31  
    32  	// Dial a client, Accept a server.
    33  	wg := sync.WaitGroup{}
    34  
    35  	var clientConn net.Conn
    36  	wg.Add(1)
    37  	go func() {
    38  		defer wg.Done()
    39  		var err error
    40  		clientConn, err = net.Dial("tcp", addr)
    41  		if err != nil {
    42  			t.Errorf("Dial failed: %v", err)
    43  		}
    44  	}()
    45  
    46  	var serverConn net.Conn
    47  	wg.Add(1)
    48  	go func() {
    49  		defer wg.Done()
    50  		var err error
    51  		serverConn, err = listener.Accept()
    52  		if err != nil {
    53  			t.Errorf("Accept failed: %v", err)
    54  		}
    55  	}()
    56  
    57  	wg.Wait()
    58  
    59  	return listener, serverConn, clientConn
    60  }
    61  
    62  func TestReadTimeout(t *testing.T) {
    63  	listener, sConn, cConn := createSocketPair(t)
    64  	defer func() {
    65  		listener.Close()
    66  		sConn.Close()
    67  		cConn.Close()
    68  	}()
    69  
    70  	cConnWithTimeout := NewConnWithTimeouts(cConn, 1*time.Millisecond, 1*time.Millisecond)
    71  
    72  	c := make(chan error, 1)
    73  	go func() {
    74  		_, err := cConnWithTimeout.Read(make([]byte, 10))
    75  		c <- err
    76  	}()
    77  
    78  	select {
    79  	case err := <-c:
    80  		if err == nil {
    81  			t.Fatalf("Expected error, got nil")
    82  		}
    83  
    84  		if !strings.HasSuffix(err.Error(), "i/o timeout") {
    85  			t.Errorf("Expected error timeout, got %s", err)
    86  		}
    87  	case <-time.After(10 * time.Second):
    88  		t.Errorf("Timeout did not happen")
    89  	}
    90  }
    91  
    92  func TestWriteTimeout(t *testing.T) {
    93  	listener, sConn, cConn := createSocketPair(t)
    94  	defer func() {
    95  		listener.Close()
    96  		sConn.Close()
    97  		cConn.Close()
    98  	}()
    99  
   100  	sConnWithTimeout := NewConnWithTimeouts(sConn, 1*time.Millisecond, 1*time.Millisecond)
   101  
   102  	c := make(chan error, 1)
   103  	go func() {
   104  		// The timeout will trigger when the buffer is full, so to test this we need to write multiple times.
   105  		for {
   106  			_, err := sConnWithTimeout.Write([]byte("payload"))
   107  			if err != nil {
   108  				c <- err
   109  				return
   110  			}
   111  		}
   112  	}()
   113  
   114  	select {
   115  	case err := <-c:
   116  		if err == nil {
   117  			t.Fatalf("Expected error, got nil")
   118  		}
   119  
   120  		if !strings.HasSuffix(err.Error(), "i/o timeout") {
   121  			t.Errorf("Expected error timeout, got %s", err)
   122  		}
   123  	case <-time.After(10 * time.Second):
   124  		t.Errorf("Timeout did not happen")
   125  	}
   126  }
   127  
   128  func TestNoTimeouts(t *testing.T) {
   129  	listener, sConn, cConn := createSocketPair(t)
   130  	defer func() {
   131  		listener.Close()
   132  		sConn.Close()
   133  		cConn.Close()
   134  	}()
   135  
   136  	cConnWithTimeout := NewConnWithTimeouts(cConn, 0, 24*time.Hour)
   137  
   138  	c := make(chan error, 1)
   139  	go func() {
   140  		_, err := cConnWithTimeout.Read(make([]byte, 10))
   141  		c <- err
   142  	}()
   143  
   144  	select {
   145  	case <-c:
   146  		t.Fatalf("Connection timeout, without a timeout")
   147  	case <-time.After(100 * time.Millisecond):
   148  		// NOOP
   149  	}
   150  
   151  	c2 := make(chan error, 1)
   152  	sConnWithTimeout := NewConnWithTimeouts(sConn, 24*time.Hour, 0)
   153  	go func() {
   154  		// This should not fail as there is not timeout on write.
   155  		for {
   156  			_, err := sConnWithTimeout.Write([]byte("payload"))
   157  			if err != nil {
   158  				c2 <- err
   159  				return
   160  			}
   161  		}
   162  	}()
   163  	select {
   164  	case <-c2:
   165  		t.Fatalf("Connection timeout, without a timeout")
   166  	case <-time.After(100 * time.Millisecond):
   167  		// NOOP
   168  	}
   169  }