tractor.dev/toolkit-go@v0.0.0-20241010005851-214d91207d07/duplex/mux/session_test.go (about)

     1  package mux
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"errors"
     7  	"io/ioutil"
     8  	"net"
     9  	"testing"
    10  	"time"
    11  )
    12  
    13  func init() {
    14  	openTimeout = 100 * time.Millisecond
    15  }
    16  
    17  func fatal(err error, t *testing.T) {
    18  	t.Helper()
    19  	if err != nil {
    20  		t.Fatal(err)
    21  	}
    22  }
    23  
    24  func TestQmux(t *testing.T) {
    25  	l, err := net.Listen("tcp", "127.0.0.1:0")
    26  	fatal(err, t)
    27  	defer l.Close()
    28  
    29  	testComplete := make(chan struct{})
    30  	sessionClosed := make(chan struct{})
    31  
    32  	go func() {
    33  		conn, err := l.Accept()
    34  		fatal(err, t)
    35  		defer conn.Close()
    36  
    37  		sess := New(conn)
    38  
    39  		ch, err := sess.Open(context.Background())
    40  		fatal(err, t)
    41  		b, err := ioutil.ReadAll(ch)
    42  		fatal(err, t)
    43  		ch.Close() // should already be closed by other end
    44  
    45  		ch, err = sess.Accept()
    46  		fatal(err, t)
    47  		_, err = ch.Write(b)
    48  		fatal(err, t)
    49  		err = ch.CloseWrite()
    50  		fatal(err, t)
    51  
    52  		<-testComplete
    53  		err = sess.Close()
    54  		fatal(err, t)
    55  		close(sessionClosed)
    56  	}()
    57  
    58  	conn, err := net.Dial("tcp", l.Addr().String())
    59  	fatal(err, t)
    60  	defer conn.Close()
    61  
    62  	sess := New(conn)
    63  
    64  	var ch Channel
    65  	t.Run("session accept", func(t *testing.T) {
    66  		ch, err = sess.Accept()
    67  		fatal(err, t)
    68  	})
    69  
    70  	t.Run("channel write", func(t *testing.T) {
    71  		_, err = ch.Write([]byte("Hello world"))
    72  		fatal(err, t)
    73  		err = ch.Close()
    74  		fatal(err, t)
    75  	})
    76  
    77  	t.Run("session open", func(t *testing.T) {
    78  		ch, err = sess.Open(context.Background())
    79  		fatal(err, t)
    80  	})
    81  
    82  	var b []byte
    83  	t.Run("channel read", func(t *testing.T) {
    84  		b, err = ioutil.ReadAll(ch)
    85  		fatal(err, t)
    86  		ch.Close() // should already be closed by other end
    87  	})
    88  
    89  	if !bytes.Equal(b, []byte("Hello world")) {
    90  		t.Fatalf("unexpected bytes: %s", b)
    91  	}
    92  	close(testComplete)
    93  	<-sessionClosed
    94  }
    95  
    96  func TestSessionOpenClientTimeout(t *testing.T) {
    97  	l, err := net.Listen("tcp", "127.0.0.1:0")
    98  	fatal(err, t)
    99  	defer l.Close()
   100  
   101  	conn, err := net.Dial("tcp", l.Addr().String())
   102  	fatal(err, t)
   103  	defer conn.Close()
   104  
   105  	sess := New(conn)
   106  
   107  	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond)
   108  	defer cancel()
   109  
   110  	ch, err := sess.Open(ctx)
   111  	if err != context.DeadlineExceeded {
   112  		t.Fatalf("expected DeadlineExceeded, but got: %v", err)
   113  	}
   114  	if ch != nil {
   115  		ch.Close()
   116  	}
   117  }
   118  
   119  func TestSessionOpenServerTimeout(t *testing.T) {
   120  	l, err := net.Listen("tcp", "127.0.0.1:0")
   121  	fatal(err, t)
   122  	defer l.Close()
   123  
   124  	errCh := make(chan error)
   125  	go func() {
   126  		conn, err := net.Dial("tcp", l.Addr().String())
   127  		fatal(err, t)
   128  		defer conn.Close()
   129  
   130  		sess := New(conn)
   131  		defer sess.Close()
   132  
   133  		_, err = sess.Open(context.Background())
   134  		errCh <- err
   135  	}()
   136  
   137  	conn, err := l.Accept()
   138  	fatal(err, t)
   139  	defer conn.Close()
   140  
   141  	sess := New(conn)
   142  	defer sess.Close()
   143  
   144  	if <-errCh == nil {
   145  		t.Errorf("expected open to fail when listener doesn't call Accept")
   146  	}
   147  	fatal(sess.Close(), t)
   148  }
   149  
   150  func TestSessionWait(t *testing.T) {
   151  	l, err := net.Listen("tcp", "127.0.0.1:0")
   152  	fatal(err, t)
   153  	defer l.Close()
   154  
   155  	conn, err := net.Dial("tcp", l.Addr().String())
   156  	fatal(err, t)
   157  	defer conn.Close()
   158  
   159  	sess := New(conn)
   160  	fatal(sess.Close(), t)
   161  	// wait should return immediately since the connection was closed
   162  	err = sess.Wait()
   163  	var netErr net.Error
   164  	if !errors.As(err, &netErr) {
   165  		t.Fatalf("expected a network error, but got: %v", err)
   166  	}
   167  }