github.com/cloudwego/hertz@v0.9.3/pkg/network/standard/connection_test.go (about)

     1  /*
     2   * Copyright 2022 CloudWeGo Authors
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package standard
    18  
    19  import (
    20  	"bytes"
    21  	"crypto/tls"
    22  	"errors"
    23  	"io"
    24  	"net"
    25  	"runtime"
    26  	"strings"
    27  	"sync/atomic"
    28  	"syscall"
    29  	"testing"
    30  	"time"
    31  
    32  	. "github.com/bytedance/mockey"
    33  	"github.com/cloudwego/hertz/pkg/common/test/assert"
    34  )
    35  
    36  func TestRead(t *testing.T) {
    37  	c := mockConn{}
    38  	conn := newConn(&c, 4096)
    39  	// test read small data
    40  	b := make([]byte, 1)
    41  	conn.Read(b)
    42  	if conn.Len() != 4095 {
    43  		t.Errorf("unexpected conn.Len: %v, expected 4095", conn.Len())
    44  	}
    45  
    46  	// test read small data again
    47  	conn.Read(b)
    48  	if conn.Len() != 4094 {
    49  		t.Errorf("unexpected conn.Len: %v, expected 4094", conn.Len())
    50  	}
    51  
    52  	// test read large data
    53  	b = make([]byte, 10000)
    54  	n, _ := conn.Read(b)
    55  	if n != 4094 {
    56  		t.Errorf("unexpected n: %v, expected 4094", n)
    57  	}
    58  
    59  	// test read large data again
    60  	n, _ = conn.Read(b)
    61  	if n != 8192 {
    62  		t.Errorf("unexpected n: %v, expected 4094", n)
    63  	}
    64  }
    65  
    66  func TestReadFromHasBufferAvailable(t *testing.T) {
    67  	preData := []byte("head data")
    68  	rawData := strings.Repeat("helloworld", 1)
    69  	tailData := []byte("tail data")
    70  	data := strings.NewReader(rawData)
    71  	c := &mockConn{}
    72  	conn := newConn(c, 4096)
    73  
    74  	// WriteBinary will malloc a buffer if no buffer available.
    75  	_, err0 := conn.WriteBinary(preData)
    76  	assert.Nil(t, err0)
    77  
    78  	reader, ok := conn.(io.ReaderFrom)
    79  	assert.True(t, ok)
    80  
    81  	l, err := reader.ReadFrom(data)
    82  	assert.Nil(t, err)
    83  	assert.DeepEqual(t, len(rawData), int(l))
    84  
    85  	_, err1 := conn.WriteBinary(tailData)
    86  	assert.Nil(t, err1)
    87  
    88  	err2 := conn.Flush()
    89  	assert.Nil(t, err2)
    90  	assert.DeepEqual(t, string(preData)+rawData+string(tailData), c.buffer.String())
    91  }
    92  
    93  func TestReadFromNoBufferAvailable(t *testing.T) {
    94  	rawData := strings.Repeat("helloworld", 1)
    95  	tailData := []byte("tail data")
    96  	data := strings.NewReader(rawData)
    97  	c := &mockConn{}
    98  	conn := newConn(c, 4096)
    99  	reader, ok := conn.(io.ReaderFrom)
   100  	assert.True(t, ok)
   101  
   102  	l, err := reader.ReadFrom(data)
   103  	assert.Nil(t, err)
   104  	assert.DeepEqual(t, len(rawData), int(l))
   105  
   106  	_, err1 := conn.WriteBinary(tailData)
   107  	assert.Nil(t, err1)
   108  
   109  	err2 := conn.Flush()
   110  	assert.Nil(t, err2)
   111  
   112  	assert.DeepEqual(t, rawData+string(tailData), c.buffer.String())
   113  }
   114  
   115  func TestPeekRelease(t *testing.T) {
   116  	c := mockConn{}
   117  	conn := newConn(&c, 4096)
   118  	b, _ := conn.Peek(1)
   119  	if len(b) != 1 {
   120  		t.Errorf("unexpected len(b): %v, expected 1", len(b))
   121  	}
   122  
   123  	b, _ = conn.Peek(10000)
   124  	if len(b) != 10000 {
   125  		t.Errorf("unexpected len(b): %v, expected 10000", len(b))
   126  	}
   127  
   128  	if conn.Len() != 12288 {
   129  		t.Errorf("unexpected conn.Len: %v, expected 12288", conn.Len())
   130  	}
   131  	err := conn.Skip(12289)
   132  	if err == nil {
   133  		t.Errorf("unexpected no error, expected link buffer skip[12289] not enough")
   134  	}
   135  	conn.Skip(12288)
   136  	if conn.Len() != 0 {
   137  		t.Errorf("unexpected conn.Len: %v, expected 2287", conn.Len())
   138  	}
   139  
   140  	// test reuse buffer
   141  	conn.Release()
   142  	b, _ = conn.Peek(1)
   143  	if len(b) != 1 {
   144  		t.Errorf("unexpected len(b): %v, expected 1", len(b))
   145  	}
   146  	if conn.Len() != 8192 {
   147  		t.Errorf("unexpected conn.Len: %v, expected 8192", conn.Len())
   148  	}
   149  
   150  	// test cross node
   151  	b, _ = conn.Peek(1000000)
   152  	if len(b) != 1000000 {
   153  		t.Errorf("unexpected len(b): %v, expected 1", len(b))
   154  	}
   155  	conn.Skip(1000000)
   156  	conn.Release()
   157  
   158  	// test maxSize
   159  	if conn.(*Conn).maxSize != 524288 {
   160  		t.Errorf("unexpected maxSize: %v, expected 524288", conn.(*Conn).maxSize)
   161  	}
   162  }
   163  
   164  func TestReadBytes(t *testing.T) {
   165  	c := mockConn{}
   166  	conn := newConn(&c, 4096)
   167  	b, _ := conn.Peek(1)
   168  	if len(b) != 1 {
   169  		t.Errorf("unexpected len(b): %v, expected 1", len(b))
   170  	}
   171  	b[0] = 'a'
   172  	peekByte, _ := conn.Peek(1)
   173  	if peekByte[0] != 'a' {
   174  		t.Errorf("unexpected bb[0]: %v, expected a", peekByte[0])
   175  	}
   176  	if conn.Len() != 4096 {
   177  		t.Errorf("unexpected conn.Len: %v, expected 4096", conn.Len())
   178  	}
   179  
   180  	readBinary, _ := conn.ReadBinary(1)
   181  	if readBinary[0] != 'a' {
   182  		t.Errorf("unexpected readBinary[0]: %v, expected a", readBinary[0])
   183  	}
   184  	b[0] = 'b'
   185  	if readBinary[0] != 'a' {
   186  		t.Errorf("unexpected readBinary[0]: %v, expected a", readBinary[0])
   187  	}
   188  	bbb, _ := conn.ReadByte()
   189  	if bbb != 0 {
   190  		t.Errorf("unexpected bbb: %v, expected nil", string(bbb))
   191  	}
   192  	if conn.Len() != 4094 {
   193  		t.Errorf("unexpected conn.Len: %v, expected 4094", conn.Len())
   194  	}
   195  }
   196  
   197  func TestWriteLogic(t *testing.T) {
   198  	c := mockConn{}
   199  	conn := newConn(&c, 4096)
   200  	conn.Malloc(8190)
   201  	connection := conn.(*Conn)
   202  	// test left buffer
   203  	if connection.outputBuffer.len != 2 {
   204  		t.Errorf("unexpected Len: %v, expected 2", connection.outputBuffer.len)
   205  	}
   206  	// test malloc next node and left buffer
   207  	conn.Malloc(8190)
   208  	if connection.outputBuffer.len != 2 {
   209  		t.Errorf("unexpected Len: %v, expected 2", connection.outputBuffer.len)
   210  	}
   211  	conn.Flush()
   212  	if connection.outputBuffer.head != connection.outputBuffer.write {
   213  		t.Errorf("outputBuffer head != outputBuffer read")
   214  	}
   215  	conn.Malloc(8190)
   216  	if connection.outputBuffer.len != 2 {
   217  		t.Errorf("unexpected Len: %v, expected 2", connection.outputBuffer.len)
   218  	}
   219  	if connection.outputBuffer.head != connection.outputBuffer.write {
   220  		t.Errorf("outputBuffer head != outputBuffer read")
   221  	}
   222  	// test readOnly
   223  	b := make([]byte, 4096)
   224  	conn.WriteBinary(b)
   225  	conn.Flush()
   226  	conn.Malloc(2)
   227  	if connection.outputBuffer.head == connection.outputBuffer.write {
   228  		t.Errorf("outputBuffer head == outputBuffer read")
   229  	}
   230  	// test reuse outputBuffer
   231  	b = make([]byte, 2)
   232  	conn.WriteBinary(b)
   233  	conn.Flush()
   234  	conn.Malloc(2)
   235  	if connection.outputBuffer.head != connection.outputBuffer.write {
   236  		t.Errorf("outputBuffer head != outputBuffer read")
   237  	}
   238  }
   239  
   240  func TestInitializeConn(t *testing.T) {
   241  	c := mockConn{
   242  		localAddr: &mockAddr{
   243  			network: "tcp",
   244  			address: "192.168.0.10:80",
   245  		},
   246  		remoteAddr: &mockAddr{
   247  			network: "tcp",
   248  			address: "192.168.0.20:80",
   249  		},
   250  	}
   251  	conn := newConn(&c, 8192)
   252  	// check the assignment
   253  	assert.DeepEqual(t, errors.New("conn: write deadline not supported"), conn.SetDeadline(time.Time{}))
   254  	assert.DeepEqual(t, errors.New("conn: read deadline not supported"), conn.SetReadDeadline(time.Time{}))
   255  	assert.DeepEqual(t, errors.New("conn: write deadline not supported"), conn.SetWriteDeadline(time.Time{}))
   256  	assert.DeepEqual(t, errors.New("conn: read deadline not supported"), conn.SetReadTimeout(time.Duration(1)*time.Second))
   257  	assert.DeepEqual(t, errors.New("conn: read deadline not supported"), conn.SetReadTimeout(time.Duration(-1)*time.Second))
   258  	assert.DeepEqual(t, errors.New("conn: method not supported"), conn.Close())
   259  	assert.DeepEqual(t, &mockAddr{network: "tcp", address: "192.168.0.10:80"}, conn.LocalAddr())
   260  	assert.DeepEqual(t, &mockAddr{network: "tcp", address: "192.168.0.20:80"}, conn.RemoteAddr())
   261  }
   262  
   263  func TestInitializeTLSConn(t *testing.T) {
   264  	c := mockConn{}
   265  	tlsConn := newTLSConn(&c, 8192).(*TLSConn)
   266  	assert.DeepEqual(t, errors.New("conn: method not supported"), tlsConn.Handshake())
   267  	assert.DeepEqual(t, tls.ConnectionState{}, tlsConn.ConnectionState())
   268  }
   269  
   270  func TestHandleSpecificError(t *testing.T) {
   271  	conn := &Conn{}
   272  	assert.DeepEqual(t, false, conn.HandleSpecificError(nil, ""))
   273  	assert.DeepEqual(t, true, conn.HandleSpecificError(syscall.EPIPE, ""))
   274  }
   275  
   276  type mockConn struct {
   277  	buffer        bytes.Buffer
   278  	localAddr     net.Addr
   279  	remoteAddr    net.Addr
   280  	readReturnErr bool
   281  }
   282  
   283  func (m *mockConn) Handshake() error {
   284  	return errors.New("conn: method not supported")
   285  }
   286  
   287  func (m *mockConn) ConnectionState() tls.ConnectionState {
   288  	return tls.ConnectionState{}
   289  }
   290  
   291  func (m mockConn) Read(b []byte) (n int, err error) {
   292  	length := len(b)
   293  	for i := 0; i < length; i++ {
   294  		b[i] = 0
   295  	}
   296  
   297  	if m.readReturnErr {
   298  		err = io.EOF
   299  	}
   300  	if length > 8192 {
   301  		return 8192, err
   302  	}
   303  	if len(b) < 1024 {
   304  		return 100, err
   305  	}
   306  	if len(b) < 5000 {
   307  		return 4096, err
   308  	}
   309  
   310  	return 4099, err
   311  }
   312  
   313  func (m *mockConn) Write(b []byte) (n int, err error) {
   314  	return m.buffer.Write(b)
   315  }
   316  
   317  func (m *mockConn) Close() error {
   318  	return errors.New("conn: method not supported")
   319  }
   320  
   321  func (m *mockConn) LocalAddr() net.Addr {
   322  	return m.localAddr
   323  }
   324  
   325  func (m *mockConn) RemoteAddr() net.Addr {
   326  	return m.remoteAddr
   327  }
   328  
   329  func (m *mockConn) SetDeadline(deadline time.Time) error {
   330  	if err := m.SetWriteDeadline(deadline); err != nil {
   331  		return err
   332  	}
   333  	return m.SetWriteDeadline(deadline)
   334  }
   335  
   336  func (m *mockConn) SetReadDeadline(deadline time.Time) error {
   337  	return errors.New("conn: read deadline not supported")
   338  }
   339  
   340  func (m *mockConn) SetWriteDeadline(deadline time.Time) error {
   341  	return errors.New("conn: write deadline not supported")
   342  }
   343  
   344  type mockAddr struct {
   345  	network string
   346  	address string
   347  }
   348  
   349  func (m *mockAddr) Network() string {
   350  	return m.network
   351  }
   352  
   353  func (m *mockAddr) String() string {
   354  	return m.address
   355  }
   356  
   357  var release_count uint32 = 0
   358  
   359  func mockLinkBufferNodeRelease(b *linkBufferNode) {
   360  	atomic.AddUint32(&release_count, 1)
   361  
   362  	if !b.readOnly {
   363  		free(b.buf)
   364  	}
   365  	b.readOnly = false
   366  	b.buf = nil
   367  	b.next = nil
   368  	b.malloc, b.off = 0, 0
   369  	bufferPool.Put(b)
   370  }
   371  
   372  func TestConnSetFinalizer(t *testing.T) {
   373  	runtime.GC()
   374  	time.Sleep(time.Millisecond * 100)
   375  
   376  	Mock((*linkBufferNode).Release).To(mockLinkBufferNodeRelease).Build()
   377  
   378  	atomic.StoreUint32(&release_count, 0)
   379  	_ = newConn(&mockConn{}, 4096)
   380  
   381  	runtime.GC()
   382  	time.Sleep(time.Millisecond * 100)
   383  
   384  	assert.DeepEqual(t, uint32(2), atomic.LoadUint32(&release_count))
   385  }
   386  
   387  func TestFillReturnErrAndN(t *testing.T) {
   388  	c := &mockConn{
   389  		readReturnErr: true,
   390  	}
   391  	conn := newConn(c, 4099)
   392  	b, err := conn.Peek(4099)
   393  	assert.Nil(t, err)
   394  	assert.DeepEqual(t, len(b), 4099)
   395  	conn.Skip(10)
   396  	b, err = conn.Peek(4099)
   397  	assert.DeepEqual(t, err, io.EOF)
   398  	assert.DeepEqual(t, len(b), 4089)
   399  }