github.com/jbronn/packer@v0.1.6-0.20140120165540-8a1364dbd817/packer/rpc/muxconn_test.go (about)

     1  package rpc
     2  
     3  import (
     4  	"io"
     5  	"net"
     6  	"sync"
     7  	"testing"
     8  )
     9  
    10  func readStream(t *testing.T, s io.Reader) string {
    11  	var data [1024]byte
    12  	n, err := s.Read(data[:])
    13  	if err != nil {
    14  		t.Fatalf("err: %s", err)
    15  	}
    16  
    17  	return string(data[0:n])
    18  }
    19  
    20  func testMux(t *testing.T) (client *MuxConn, server *MuxConn) {
    21  	l, err := net.Listen("tcp", ":0")
    22  	if err != nil {
    23  		t.Fatalf("err: %s", err)
    24  	}
    25  
    26  	// Server side
    27  	doneCh := make(chan struct{})
    28  	go func() {
    29  		defer close(doneCh)
    30  		conn, err := l.Accept()
    31  		l.Close()
    32  		if err != nil {
    33  			t.Fatalf("err: %s", err)
    34  		}
    35  
    36  		server = NewMuxConn(conn)
    37  	}()
    38  
    39  	// Client side
    40  	conn, err := net.Dial("tcp", l.Addr().String())
    41  	if err != nil {
    42  		t.Fatalf("err: %s", err)
    43  	}
    44  	client = NewMuxConn(conn)
    45  
    46  	// Wait for the server
    47  	<-doneCh
    48  
    49  	return
    50  }
    51  
    52  func TestMuxConn(t *testing.T) {
    53  	client, server := testMux(t)
    54  	defer client.Close()
    55  	defer server.Close()
    56  
    57  	// When the server is done
    58  	doneCh := make(chan struct{})
    59  
    60  	// The server side
    61  	go func() {
    62  		defer close(doneCh)
    63  
    64  		s0, err := server.Accept(0)
    65  		if err != nil {
    66  			t.Fatalf("err: %s", err)
    67  		}
    68  
    69  		s1, err := server.Dial(1)
    70  		if err != nil {
    71  			t.Fatalf("err: %s", err)
    72  		}
    73  
    74  		var wg sync.WaitGroup
    75  		wg.Add(2)
    76  
    77  		go func() {
    78  			defer wg.Done()
    79  			data := readStream(t, s1)
    80  			if data != "another" {
    81  				t.Fatalf("bad: %#v", data)
    82  			}
    83  		}()
    84  
    85  		go func() {
    86  			defer wg.Done()
    87  			data := readStream(t, s0)
    88  			if data != "hello" {
    89  				t.Fatalf("bad: %#v", data)
    90  			}
    91  		}()
    92  
    93  		wg.Wait()
    94  	}()
    95  
    96  	s0, err := client.Dial(0)
    97  	if err != nil {
    98  		t.Fatalf("err: %s", err)
    99  	}
   100  
   101  	s1, err := client.Accept(1)
   102  	if err != nil {
   103  		t.Fatalf("err: %s", err)
   104  	}
   105  
   106  	if _, err := s0.Write([]byte("hello")); err != nil {
   107  		t.Fatalf("err: %s", err)
   108  	}
   109  	if _, err := s1.Write([]byte("another")); err != nil {
   110  		t.Fatalf("err: %s", err)
   111  	}
   112  
   113  	// Wait for the server to be done
   114  	<-doneCh
   115  }
   116  
   117  // This tests that even when the client end is closed, data can be
   118  // read from the server.
   119  func TestMuxConn_clientCloseRead(t *testing.T) {
   120  	client, server := testMux(t)
   121  	defer client.Close()
   122  	defer server.Close()
   123  
   124  	// This channel will be closed when we close
   125  	waitCh := make(chan struct{})
   126  
   127  	go func() {
   128  		conn, err := server.Accept(0)
   129  		if err != nil {
   130  			t.Fatalf("err: %s", err)
   131  		}
   132  
   133  		<-waitCh
   134  
   135  		_, err = conn.Write([]byte("foo"))
   136  		if err != nil {
   137  			t.Fatalf("err: %s", err)
   138  		}
   139  
   140  		conn.Close()
   141  	}()
   142  
   143  	s0, err := client.Dial(0)
   144  	if err != nil {
   145  		t.Fatalf("err: %s", err)
   146  	}
   147  
   148  	if err := s0.Close(); err != nil {
   149  		t.Fatalf("bad: %s", err)
   150  	}
   151  
   152  	// Close this to continue on on the server-side
   153  	close(waitCh)
   154  
   155  	var data [1024]byte
   156  	n, err := s0.Read(data[:])
   157  	if string(data[:n]) != "foo" {
   158  		t.Fatalf("bad: %#v", string(data[:n]))
   159  	}
   160  }
   161  
   162  func TestMuxConn_socketClose(t *testing.T) {
   163  	client, server := testMux(t)
   164  	defer client.Close()
   165  	defer server.Close()
   166  
   167  	go func() {
   168  		_, err := server.Accept(0)
   169  		if err != nil {
   170  			t.Fatalf("err: %s", err)
   171  		}
   172  
   173  		server.rwc.Close()
   174  	}()
   175  
   176  	s0, err := client.Dial(0)
   177  	if err != nil {
   178  		t.Fatalf("err: %s", err)
   179  	}
   180  
   181  	var data [1024]byte
   182  	_, err = s0.Read(data[:])
   183  	if err != io.EOF {
   184  		t.Fatalf("err: %s", err)
   185  	}
   186  }
   187  
   188  func TestMuxConn_clientClosesStreams(t *testing.T) {
   189  	client, server := testMux(t)
   190  	defer client.Close()
   191  	defer server.Close()
   192  
   193  	go func() {
   194  		conn, err := server.Accept(0)
   195  		if err != nil {
   196  			t.Fatalf("err: %s", err)
   197  		}
   198  		conn.Close()
   199  	}()
   200  
   201  	s0, err := client.Dial(0)
   202  	if err != nil {
   203  		t.Fatalf("err: %s", err)
   204  	}
   205  
   206  	var data [1024]byte
   207  	_, err = s0.Read(data[:])
   208  	if err != io.EOF {
   209  		t.Fatalf("err: %s", err)
   210  	}
   211  }
   212  
   213  func TestMuxConn_serverClosesStreams(t *testing.T) {
   214  	client, server := testMux(t)
   215  	defer client.Close()
   216  	defer server.Close()
   217  	go server.Accept(0)
   218  
   219  	s0, err := client.Dial(0)
   220  	if err != nil {
   221  		t.Fatalf("err: %s", err)
   222  	}
   223  
   224  	if err := server.Close(); err != nil {
   225  		t.Fatalf("err: %s", err)
   226  	}
   227  
   228  	// This should block forever since we never write onto this stream.
   229  	var data [1024]byte
   230  	_, err = s0.Read(data[:])
   231  	if err != io.EOF {
   232  		t.Fatalf("err: %s", err)
   233  	}
   234  }
   235  
   236  func TestMuxConnNextId(t *testing.T) {
   237  	client, server := testMux(t)
   238  	defer client.Close()
   239  	defer server.Close()
   240  
   241  	a := client.NextId()
   242  	b := client.NextId()
   243  
   244  	if a != 1 || b != 2 {
   245  		t.Fatalf("IDs should increment")
   246  	}
   247  
   248  	a = server.NextId()
   249  	b = server.NextId()
   250  
   251  	if a != 1 || b != 2 {
   252  		t.Fatalf("IDs should increment: %d %d", a, b)
   253  	}
   254  }