github.com/cloudwego/hertz@v0.9.3/pkg/common/test/mock/network_test.go (about)

     1  /*
     2   * Copyright 2022 CloudWeGo Authors
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package mock
    18  
    19  import (
    20  	"context"
    21  	"io"
    22  	"testing"
    23  	"time"
    24  
    25  	errs "github.com/cloudwego/hertz/pkg/common/errors"
    26  	"github.com/cloudwego/hertz/pkg/common/test/assert"
    27  	"github.com/cloudwego/netpoll"
    28  )
    29  
    30  func TestConn(t *testing.T) {
    31  	t.Run("TestReader", func(t *testing.T) {
    32  		s1 := "abcdef4343"
    33  		conn1 := NewConn(s1)
    34  		assert.Nil(t, conn1.SetWriteTimeout(1))
    35  		err := conn1.SetReadDeadline(time.Now().Add(time.Millisecond * 100))
    36  		assert.DeepEqual(t, nil, err)
    37  		err = conn1.SetReadTimeout(time.Millisecond * 100)
    38  		assert.DeepEqual(t, nil, err)
    39  		assert.DeepEqual(t, time.Millisecond*100, conn1.GetReadTimeout())
    40  
    41  		// Peek Skip Read
    42  		b, _ := conn1.Peek(1)
    43  		assert.DeepEqual(t, []byte{'a'}, b)
    44  		conn1.Skip(1)
    45  		readByte, _ := conn1.ReadByte()
    46  		assert.DeepEqual(t, byte('b'), readByte)
    47  
    48  		p := make([]byte, 100)
    49  		n, err := conn1.Read(p)
    50  		assert.DeepEqual(t, nil, err)
    51  		assert.DeepEqual(t, s1[2:], string(p[:n]))
    52  
    53  		_, err = conn1.Peek(1)
    54  		assert.DeepEqual(t, errs.ErrTimeout, err)
    55  
    56  		conn2 := NewConn(s1)
    57  		p, _ = conn2.ReadBinary(len(s1))
    58  		assert.DeepEqual(t, s1, string(p))
    59  		assert.DeepEqual(t, 0, conn2.Len())
    60  		// Reader
    61  		assert.DeepEqual(t, conn2.zr, conn2.Reader())
    62  	})
    63  
    64  	t.Run("TestReadWriter", func(t *testing.T) {
    65  		s1 := "abcdef4343"
    66  		conn := NewConn(s1)
    67  		p, err := conn.ReadBinary(len(s1))
    68  		assert.DeepEqual(t, nil, err)
    69  		assert.DeepEqual(t, s1, string(p))
    70  
    71  		wr := conn.WriterRecorder()
    72  		s2 := "efghljk"
    73  		// WriteBinary
    74  		n, err := conn.WriteBinary([]byte(s2))
    75  		assert.DeepEqual(t, nil, err)
    76  		assert.DeepEqual(t, len(s2), n)
    77  		assert.DeepEqual(t, len(s2), wr.WroteLen())
    78  
    79  		// Flush
    80  		p, _ = wr.ReadBinary(len(s2))
    81  		assert.DeepEqual(t, len(p), 0)
    82  
    83  		conn.Flush()
    84  		p, _ = wr.ReadBinary(len(s2))
    85  		assert.DeepEqual(t, s2, string(p))
    86  
    87  		// Write
    88  		s3 := "foobarbaz"
    89  		n, err = conn.Write([]byte(s3))
    90  		assert.DeepEqual(t, nil, err)
    91  		assert.DeepEqual(t, len(s3), n)
    92  		p, _ = wr.ReadBinary(len(s3))
    93  		assert.DeepEqual(t, s3, string(p))
    94  
    95  		// Malloc
    96  		buf, _ := conn.Malloc(10)
    97  		assert.DeepEqual(t, 10, len(buf))
    98  		// Writer
    99  		assert.DeepEqual(t, conn.zw, conn.Writer())
   100  
   101  		_, err = DialerFun("")
   102  		assert.DeepEqual(t, nil, err)
   103  	})
   104  
   105  	t.Run("TestNotImplement", func(t *testing.T) {
   106  		conn := NewConn("")
   107  		t1 := time.Now().Add(time.Millisecond)
   108  		du1 := time.Second
   109  		assert.DeepEqual(t, nil, conn.Release())
   110  		assert.DeepEqual(t, nil, conn.Close())
   111  		assert.DeepEqual(t, nil, conn.LocalAddr())
   112  		assert.DeepEqual(t, nil, conn.RemoteAddr())
   113  		assert.DeepEqual(t, nil, conn.SetIdleTimeout(du1))
   114  		assert.Panic(t, func() {
   115  			conn.SetDeadline(t1)
   116  		})
   117  		assert.Panic(t, func() {
   118  			conn.SetWriteDeadline(t1)
   119  		})
   120  		assert.Panic(t, func() {
   121  			conn.IsActive()
   122  		})
   123  		assert.Panic(t, func() {
   124  			conn.SetOnRequest(func(ctx context.Context, connection netpoll.Connection) error {
   125  				return nil
   126  			})
   127  		})
   128  		assert.Panic(t, func() {
   129  			conn.AddCloseCallback(func(connection netpoll.Connection) error {
   130  				return nil
   131  			})
   132  		})
   133  	})
   134  }
   135  
   136  func TestSlowConn(t *testing.T) {
   137  	t.Run("TestSlowReadConn", func(t *testing.T) {
   138  		s1 := "abcdefg"
   139  		conn := NewSlowReadConn(s1)
   140  		assert.Nil(t, conn.SetWriteTimeout(1))
   141  		assert.Nil(t, conn.SetReadTimeout(1))
   142  		assert.DeepEqual(t, time.Duration(1), conn.readTimeout)
   143  
   144  		b, err := conn.Peek(4)
   145  		assert.DeepEqual(t, nil, err)
   146  		assert.DeepEqual(t, s1[:4], string(b))
   147  		conn.Skip(len(s1))
   148  		_, err = conn.Peek(1)
   149  		assert.DeepEqual(t, ErrReadTimeout, err)
   150  		_, err = SlowReadDialer("")
   151  		assert.DeepEqual(t, nil, err)
   152  	})
   153  
   154  	t.Run("TestSlowWriteConn", func(t *testing.T) {
   155  		conn, err := SlowWriteDialer("")
   156  		assert.DeepEqual(t, nil, err)
   157  		conn.SetWriteTimeout(time.Millisecond * 100)
   158  		err = conn.Flush()
   159  		assert.DeepEqual(t, ErrWriteTimeout, err)
   160  	})
   161  }
   162  
   163  func TestStreamConn(t *testing.T) {
   164  	t.Run("TestStreamConn", func(t *testing.T) {
   165  		conn := NewStreamConn()
   166  		_, err := conn.Peek(10)
   167  		assert.DeepEqual(t, nil, err)
   168  		conn.Skip(conn.Len())
   169  		assert.DeepEqual(t, 0, conn.Len())
   170  		_, err = conn.Peek(10)
   171  		assert.DeepEqual(t, "not enough data", err.Error())
   172  		_, err = conn.Peek(1)
   173  		assert.DeepEqual(t, nil, err)
   174  		assert.DeepEqual(t, cap(conn.Data), conn.Len())
   175  		err = conn.Skip(conn.Len() + 1)
   176  		assert.DeepEqual(t, "not enough data", err.Error())
   177  	})
   178  
   179  	t.Run("TestNotImplement", func(t *testing.T) {
   180  		conn := NewStreamConn()
   181  		assert.Panic(t, func() {
   182  			conn.Release()
   183  		})
   184  		assert.Panic(t, func() {
   185  			conn.ReadByte()
   186  		})
   187  		assert.Panic(t, func() {
   188  			conn.ReadBinary(10)
   189  		})
   190  	})
   191  }
   192  
   193  func TestBrokenConn_Flush(t *testing.T) {
   194  	conn := NewBrokenConn("")
   195  	n, err := conn.Writer().WriteBinary([]byte("Foo"))
   196  	assert.DeepEqual(t, 3, n)
   197  	assert.Nil(t, err)
   198  	assert.DeepEqual(t, errs.ErrConnectionClosed, conn.Flush())
   199  }
   200  
   201  func TestBrokenConn_Peek(t *testing.T) {
   202  	conn := NewBrokenConn("Foo")
   203  	buf, err := conn.Peek(3)
   204  	assert.Nil(t, buf)
   205  	assert.DeepEqual(t, io.ErrUnexpectedEOF, err)
   206  }
   207  
   208  func TestOneTimeConn_Flush(t *testing.T) {
   209  	conn := NewOneTimeConn("")
   210  	n, err := conn.Writer().WriteBinary([]byte("Foo"))
   211  	assert.DeepEqual(t, 3, n)
   212  	assert.Nil(t, err)
   213  	assert.Nil(t, conn.Flush())
   214  	n, err = conn.Writer().WriteBinary([]byte("Bar"))
   215  	assert.DeepEqual(t, 3, n)
   216  	assert.Nil(t, err)
   217  	assert.DeepEqual(t, errs.ErrConnectionClosed, conn.Flush())
   218  }
   219  
   220  func TestOneTimeConn_Skip(t *testing.T) {
   221  	conn := NewOneTimeConn("FooBar")
   222  	buf, err := conn.Peek(3)
   223  	assert.DeepEqual(t, "Foo", string(buf))
   224  	assert.Nil(t, err)
   225  	assert.Nil(t, conn.Skip(3))
   226  	assert.DeepEqual(t, 3, conn.contentLength)
   227  
   228  	buf, err = conn.Peek(3)
   229  	assert.DeepEqual(t, "Bar", string(buf))
   230  	assert.Nil(t, err)
   231  	assert.Nil(t, conn.Skip(3))
   232  	assert.DeepEqual(t, 0, conn.contentLength)
   233  
   234  	buf, err = conn.Peek(3)
   235  	assert.DeepEqual(t, 0, len(buf))
   236  	assert.DeepEqual(t, io.EOF, err)
   237  	assert.DeepEqual(t, io.EOF, conn.Skip(3))
   238  	assert.DeepEqual(t, 0, conn.contentLength)
   239  }