github.com/cloudwego/hertz@v0.9.3/pkg/common/test/mock/network.go (about)

     1  /*
     2   * Copyright 2022 CloudWeGo Authors
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package mock
    18  
    19  import (
    20  	"bytes"
    21  	"io"
    22  	"net"
    23  	"strings"
    24  	"time"
    25  
    26  	errs "github.com/cloudwego/hertz/pkg/common/errors"
    27  	"github.com/cloudwego/hertz/pkg/network"
    28  	"github.com/cloudwego/netpoll"
    29  )
    30  
    31  var (
    32  	ErrReadTimeout  = errs.New(errs.ErrTimeout, errs.ErrorTypePublic, "read timeout")
    33  	ErrWriteTimeout = errs.New(errs.ErrTimeout, errs.ErrorTypePublic, "write timeout")
    34  )
    35  
    36  type Conn struct {
    37  	readTimeout time.Duration
    38  	zr          network.Reader
    39  	zw          network.ReadWriter
    40  	wroteLen    int
    41  }
    42  
    43  type Recorder interface {
    44  	network.Reader
    45  	WroteLen() int
    46  }
    47  
    48  func (m *Conn) SetWriteTimeout(t time.Duration) error {
    49  	// TODO implement me
    50  	return nil
    51  }
    52  
    53  type SlowReadConn struct {
    54  	*Conn
    55  }
    56  
    57  func (m *SlowReadConn) SetWriteTimeout(t time.Duration) error {
    58  	return nil
    59  }
    60  
    61  func (m *SlowReadConn) SetReadTimeout(t time.Duration) error {
    62  	m.Conn.readTimeout = t
    63  	return nil
    64  }
    65  
    66  func SlowReadDialer(addr string) (network.Conn, error) {
    67  	return NewSlowReadConn(""), nil
    68  }
    69  
    70  func SlowWriteDialer(addr string) (network.Conn, error) {
    71  	return NewSlowWriteConn(""), nil
    72  }
    73  
    74  func (m *Conn) ReadBinary(n int) (p []byte, err error) {
    75  	return m.zr.(netpoll.Reader).ReadBinary(n)
    76  }
    77  
    78  func (m *Conn) Read(b []byte) (n int, err error) {
    79  	return netpoll.NewIOReader(m.zr.(netpoll.Reader)).Read(b)
    80  }
    81  
    82  func (m *Conn) Write(b []byte) (n int, err error) {
    83  	return netpoll.NewIOWriter(m.zw.(netpoll.ReadWriter)).Write(b)
    84  }
    85  
    86  func (m *Conn) Release() error {
    87  	return nil
    88  }
    89  
    90  func (m *Conn) Peek(i int) ([]byte, error) {
    91  	b, err := m.zr.Peek(i)
    92  	if err != nil || len(b) != i {
    93  		if m.readTimeout <= 0 {
    94  			// simulate timeout forever
    95  			select {}
    96  		}
    97  		time.Sleep(m.readTimeout)
    98  		return nil, errs.ErrTimeout
    99  	}
   100  	return b, err
   101  }
   102  
   103  func (m *Conn) Skip(n int) error {
   104  	return m.zr.Skip(n)
   105  }
   106  
   107  func (m *Conn) ReadByte() (byte, error) {
   108  	return m.zr.ReadByte()
   109  }
   110  
   111  func (m *Conn) Len() int {
   112  	return m.zr.Len()
   113  }
   114  
   115  func (m *Conn) Malloc(n int) (buf []byte, err error) {
   116  	m.wroteLen += n
   117  	return m.zw.Malloc(n)
   118  }
   119  
   120  func (m *Conn) WriteBinary(b []byte) (n int, err error) {
   121  	n, err = m.zw.WriteBinary(b)
   122  	m.wroteLen += n
   123  	return n, err
   124  }
   125  
   126  func (m *Conn) Flush() error {
   127  	return m.zw.Flush()
   128  }
   129  
   130  func (m *Conn) WriterRecorder() Recorder {
   131  	return &recorder{c: m, Reader: m.zw}
   132  }
   133  
   134  func (m *Conn) GetReadTimeout() time.Duration {
   135  	return m.readTimeout
   136  }
   137  
   138  type recorder struct {
   139  	c *Conn
   140  	network.Reader
   141  }
   142  
   143  func (r *recorder) WroteLen() int {
   144  	return r.c.wroteLen
   145  }
   146  
   147  func (m *SlowReadConn) Peek(i int) ([]byte, error) {
   148  	b, err := m.zr.Peek(i)
   149  	if m.readTimeout > 0 {
   150  		time.Sleep(m.readTimeout)
   151  	} else {
   152  		time.Sleep(100 * time.Millisecond)
   153  	}
   154  	if err != nil || len(b) != i {
   155  		return nil, ErrReadTimeout
   156  	}
   157  	return b, err
   158  }
   159  
   160  func NewConn(source string) *Conn {
   161  	zr := netpoll.NewReader(strings.NewReader(source))
   162  	zw := netpoll.NewReadWriter(&bytes.Buffer{})
   163  
   164  	return &Conn{
   165  		zr: zr,
   166  		zw: zw,
   167  	}
   168  }
   169  
   170  type BrokenConn struct {
   171  	*Conn
   172  }
   173  
   174  func (o *BrokenConn) Peek(i int) ([]byte, error) {
   175  	return nil, io.ErrUnexpectedEOF
   176  }
   177  
   178  func (o *BrokenConn) Read(b []byte) (n int, err error) {
   179  	return 0, io.ErrUnexpectedEOF
   180  }
   181  
   182  func (o *BrokenConn) Flush() error {
   183  	return errs.ErrConnectionClosed
   184  }
   185  
   186  func NewBrokenConn(source string) *BrokenConn {
   187  	return &BrokenConn{Conn: NewConn(source)}
   188  }
   189  
   190  type OneTimeConn struct {
   191  	isRead        bool
   192  	isFlushed     bool
   193  	contentLength int
   194  	*Conn
   195  }
   196  
   197  func (o *OneTimeConn) Peek(n int) ([]byte, error) {
   198  	if o.isRead {
   199  		return nil, io.EOF
   200  	}
   201  	return o.Conn.Peek(n)
   202  }
   203  
   204  func (o *OneTimeConn) Skip(n int) error {
   205  	if o.isRead {
   206  		return io.EOF
   207  	}
   208  	o.contentLength -= n
   209  
   210  	if o.contentLength == 0 {
   211  		o.isRead = true
   212  	}
   213  
   214  	return o.Conn.Skip(n)
   215  }
   216  
   217  func (o *OneTimeConn) Flush() error {
   218  	if o.isFlushed {
   219  		return errs.ErrConnectionClosed
   220  	}
   221  	o.isFlushed = true
   222  	return o.Conn.Flush()
   223  }
   224  
   225  func NewOneTimeConn(source string) *OneTimeConn {
   226  	return &OneTimeConn{isRead: false, isFlushed: false, Conn: NewConn(source), contentLength: len(source)}
   227  }
   228  
   229  func NewSlowReadConn(source string) *SlowReadConn {
   230  	return &SlowReadConn{Conn: NewConn(source)}
   231  }
   232  
   233  type ErrorReadConn struct {
   234  	*Conn
   235  	errorToReturn error
   236  }
   237  
   238  func NewErrorReadConn(err error) *ErrorReadConn {
   239  	return &ErrorReadConn{
   240  		Conn:          NewConn(""),
   241  		errorToReturn: err,
   242  	}
   243  }
   244  
   245  func (er *ErrorReadConn) Peek(n int) ([]byte, error) {
   246  	return nil, er.errorToReturn
   247  }
   248  
   249  type SlowWriteConn struct {
   250  	*Conn
   251  	writeTimeout time.Duration
   252  }
   253  
   254  func (m *SlowWriteConn) SetWriteTimeout(t time.Duration) error {
   255  	m.writeTimeout = t
   256  	return nil
   257  }
   258  
   259  func NewSlowWriteConn(source string) *SlowWriteConn {
   260  	return &SlowWriteConn{NewConn(source), 0}
   261  }
   262  
   263  func (m *SlowWriteConn) Flush() error {
   264  	err := m.zw.Flush()
   265  	time.Sleep(100 * time.Millisecond)
   266  	if err == nil {
   267  		time.Sleep(m.writeTimeout)
   268  		return ErrWriteTimeout
   269  	}
   270  	return err
   271  }
   272  
   273  func (m *Conn) Close() error {
   274  	return nil
   275  }
   276  
   277  func (m *Conn) LocalAddr() net.Addr {
   278  	return nil
   279  }
   280  
   281  func (m *Conn) RemoteAddr() net.Addr {
   282  	return nil
   283  }
   284  
   285  func (m *Conn) SetDeadline(t time.Time) error {
   286  	panic("implement me")
   287  }
   288  
   289  func (m *Conn) SetReadDeadline(t time.Time) error {
   290  	m.readTimeout = -time.Since(t)
   291  	return nil
   292  }
   293  
   294  func (m *Conn) SetWriteDeadline(t time.Time) error {
   295  	panic("implement me")
   296  }
   297  
   298  func (m *Conn) Reader() network.Reader {
   299  	return m.zr
   300  }
   301  
   302  func (m *Conn) Writer() network.Writer {
   303  	return m.zw
   304  }
   305  
   306  func (m *Conn) IsActive() bool {
   307  	panic("implement me")
   308  }
   309  
   310  func (m *Conn) SetIdleTimeout(timeout time.Duration) error {
   311  	return nil
   312  }
   313  
   314  func (m *Conn) SetReadTimeout(t time.Duration) error {
   315  	m.readTimeout = t
   316  	return nil
   317  }
   318  
   319  func (m *Conn) SetOnRequest(on netpoll.OnRequest) error {
   320  	panic("implement me")
   321  }
   322  
   323  func (m *Conn) AddCloseCallback(callback netpoll.CloseCallback) error {
   324  	panic("implement me")
   325  }
   326  
   327  type StreamConn struct {
   328  	Data []byte
   329  }
   330  
   331  func NewStreamConn() *StreamConn {
   332  	return &StreamConn{
   333  		Data: make([]byte, 1<<15, 1<<16),
   334  	}
   335  }
   336  
   337  func (m *StreamConn) Peek(n int) ([]byte, error) {
   338  	if len(m.Data) >= n {
   339  		return m.Data[:n], nil
   340  	}
   341  	if n == 1 {
   342  		m.Data = m.Data[:cap(m.Data)]
   343  		return m.Data[:1], nil
   344  	}
   345  	return nil, errs.NewPublic("not enough data")
   346  }
   347  
   348  func (m *StreamConn) Skip(n int) error {
   349  	if len(m.Data) >= n {
   350  		m.Data = m.Data[n:]
   351  		return nil
   352  	}
   353  	return errs.NewPublic("not enough data")
   354  }
   355  
   356  func (m *StreamConn) Release() error {
   357  	panic("implement me")
   358  }
   359  
   360  func (m *StreamConn) Len() int {
   361  	return len(m.Data)
   362  }
   363  
   364  func (m *StreamConn) ReadByte() (byte, error) {
   365  	panic("implement me")
   366  }
   367  
   368  func (m *StreamConn) ReadBinary(n int) (p []byte, err error) {
   369  	panic("implement me")
   370  }
   371  
   372  func DialerFun(addr string) (network.Conn, error) {
   373  	return NewConn(""), nil
   374  }