github.com/cloudwego/hertz@v0.9.3/pkg/common/test/mock/network.go (about) 1 /* 2 * Copyright 2022 CloudWeGo Authors 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package mock 18 19 import ( 20 "bytes" 21 "io" 22 "net" 23 "strings" 24 "time" 25 26 errs "github.com/cloudwego/hertz/pkg/common/errors" 27 "github.com/cloudwego/hertz/pkg/network" 28 "github.com/cloudwego/netpoll" 29 ) 30 31 var ( 32 ErrReadTimeout = errs.New(errs.ErrTimeout, errs.ErrorTypePublic, "read timeout") 33 ErrWriteTimeout = errs.New(errs.ErrTimeout, errs.ErrorTypePublic, "write timeout") 34 ) 35 36 type Conn struct { 37 readTimeout time.Duration 38 zr network.Reader 39 zw network.ReadWriter 40 wroteLen int 41 } 42 43 type Recorder interface { 44 network.Reader 45 WroteLen() int 46 } 47 48 func (m *Conn) SetWriteTimeout(t time.Duration) error { 49 // TODO implement me 50 return nil 51 } 52 53 type SlowReadConn struct { 54 *Conn 55 } 56 57 func (m *SlowReadConn) SetWriteTimeout(t time.Duration) error { 58 return nil 59 } 60 61 func (m *SlowReadConn) SetReadTimeout(t time.Duration) error { 62 m.Conn.readTimeout = t 63 return nil 64 } 65 66 func SlowReadDialer(addr string) (network.Conn, error) { 67 return NewSlowReadConn(""), nil 68 } 69 70 func SlowWriteDialer(addr string) (network.Conn, error) { 71 return NewSlowWriteConn(""), nil 72 } 73 74 func (m *Conn) ReadBinary(n int) (p []byte, err error) { 75 return m.zr.(netpoll.Reader).ReadBinary(n) 76 } 77 78 func (m *Conn) Read(b []byte) (n int, err error) { 79 return netpoll.NewIOReader(m.zr.(netpoll.Reader)).Read(b) 80 } 81 82 func (m *Conn) Write(b []byte) (n int, err error) { 83 return netpoll.NewIOWriter(m.zw.(netpoll.ReadWriter)).Write(b) 84 } 85 86 func (m *Conn) Release() error { 87 return nil 88 } 89 90 func (m *Conn) Peek(i int) ([]byte, error) { 91 b, err := m.zr.Peek(i) 92 if err != nil || len(b) != i { 93 if m.readTimeout <= 0 { 94 // simulate timeout forever 95 select {} 96 } 97 time.Sleep(m.readTimeout) 98 return nil, errs.ErrTimeout 99 } 100 return b, err 101 } 102 103 func (m *Conn) Skip(n int) error { 104 return m.zr.Skip(n) 105 } 106 107 func (m *Conn) ReadByte() (byte, error) { 108 return m.zr.ReadByte() 109 } 110 111 func (m *Conn) Len() int { 112 return m.zr.Len() 113 } 114 115 func (m *Conn) Malloc(n int) (buf []byte, err error) { 116 m.wroteLen += n 117 return m.zw.Malloc(n) 118 } 119 120 func (m *Conn) WriteBinary(b []byte) (n int, err error) { 121 n, err = m.zw.WriteBinary(b) 122 m.wroteLen += n 123 return n, err 124 } 125 126 func (m *Conn) Flush() error { 127 return m.zw.Flush() 128 } 129 130 func (m *Conn) WriterRecorder() Recorder { 131 return &recorder{c: m, Reader: m.zw} 132 } 133 134 func (m *Conn) GetReadTimeout() time.Duration { 135 return m.readTimeout 136 } 137 138 type recorder struct { 139 c *Conn 140 network.Reader 141 } 142 143 func (r *recorder) WroteLen() int { 144 return r.c.wroteLen 145 } 146 147 func (m *SlowReadConn) Peek(i int) ([]byte, error) { 148 b, err := m.zr.Peek(i) 149 if m.readTimeout > 0 { 150 time.Sleep(m.readTimeout) 151 } else { 152 time.Sleep(100 * time.Millisecond) 153 } 154 if err != nil || len(b) != i { 155 return nil, ErrReadTimeout 156 } 157 return b, err 158 } 159 160 func NewConn(source string) *Conn { 161 zr := netpoll.NewReader(strings.NewReader(source)) 162 zw := netpoll.NewReadWriter(&bytes.Buffer{}) 163 164 return &Conn{ 165 zr: zr, 166 zw: zw, 167 } 168 } 169 170 type BrokenConn struct { 171 *Conn 172 } 173 174 func (o *BrokenConn) Peek(i int) ([]byte, error) { 175 return nil, io.ErrUnexpectedEOF 176 } 177 178 func (o *BrokenConn) Read(b []byte) (n int, err error) { 179 return 0, io.ErrUnexpectedEOF 180 } 181 182 func (o *BrokenConn) Flush() error { 183 return errs.ErrConnectionClosed 184 } 185 186 func NewBrokenConn(source string) *BrokenConn { 187 return &BrokenConn{Conn: NewConn(source)} 188 } 189 190 type OneTimeConn struct { 191 isRead bool 192 isFlushed bool 193 contentLength int 194 *Conn 195 } 196 197 func (o *OneTimeConn) Peek(n int) ([]byte, error) { 198 if o.isRead { 199 return nil, io.EOF 200 } 201 return o.Conn.Peek(n) 202 } 203 204 func (o *OneTimeConn) Skip(n int) error { 205 if o.isRead { 206 return io.EOF 207 } 208 o.contentLength -= n 209 210 if o.contentLength == 0 { 211 o.isRead = true 212 } 213 214 return o.Conn.Skip(n) 215 } 216 217 func (o *OneTimeConn) Flush() error { 218 if o.isFlushed { 219 return errs.ErrConnectionClosed 220 } 221 o.isFlushed = true 222 return o.Conn.Flush() 223 } 224 225 func NewOneTimeConn(source string) *OneTimeConn { 226 return &OneTimeConn{isRead: false, isFlushed: false, Conn: NewConn(source), contentLength: len(source)} 227 } 228 229 func NewSlowReadConn(source string) *SlowReadConn { 230 return &SlowReadConn{Conn: NewConn(source)} 231 } 232 233 type ErrorReadConn struct { 234 *Conn 235 errorToReturn error 236 } 237 238 func NewErrorReadConn(err error) *ErrorReadConn { 239 return &ErrorReadConn{ 240 Conn: NewConn(""), 241 errorToReturn: err, 242 } 243 } 244 245 func (er *ErrorReadConn) Peek(n int) ([]byte, error) { 246 return nil, er.errorToReturn 247 } 248 249 type SlowWriteConn struct { 250 *Conn 251 writeTimeout time.Duration 252 } 253 254 func (m *SlowWriteConn) SetWriteTimeout(t time.Duration) error { 255 m.writeTimeout = t 256 return nil 257 } 258 259 func NewSlowWriteConn(source string) *SlowWriteConn { 260 return &SlowWriteConn{NewConn(source), 0} 261 } 262 263 func (m *SlowWriteConn) Flush() error { 264 err := m.zw.Flush() 265 time.Sleep(100 * time.Millisecond) 266 if err == nil { 267 time.Sleep(m.writeTimeout) 268 return ErrWriteTimeout 269 } 270 return err 271 } 272 273 func (m *Conn) Close() error { 274 return nil 275 } 276 277 func (m *Conn) LocalAddr() net.Addr { 278 return nil 279 } 280 281 func (m *Conn) RemoteAddr() net.Addr { 282 return nil 283 } 284 285 func (m *Conn) SetDeadline(t time.Time) error { 286 panic("implement me") 287 } 288 289 func (m *Conn) SetReadDeadline(t time.Time) error { 290 m.readTimeout = -time.Since(t) 291 return nil 292 } 293 294 func (m *Conn) SetWriteDeadline(t time.Time) error { 295 panic("implement me") 296 } 297 298 func (m *Conn) Reader() network.Reader { 299 return m.zr 300 } 301 302 func (m *Conn) Writer() network.Writer { 303 return m.zw 304 } 305 306 func (m *Conn) IsActive() bool { 307 panic("implement me") 308 } 309 310 func (m *Conn) SetIdleTimeout(timeout time.Duration) error { 311 return nil 312 } 313 314 func (m *Conn) SetReadTimeout(t time.Duration) error { 315 m.readTimeout = t 316 return nil 317 } 318 319 func (m *Conn) SetOnRequest(on netpoll.OnRequest) error { 320 panic("implement me") 321 } 322 323 func (m *Conn) AddCloseCallback(callback netpoll.CloseCallback) error { 324 panic("implement me") 325 } 326 327 type StreamConn struct { 328 Data []byte 329 } 330 331 func NewStreamConn() *StreamConn { 332 return &StreamConn{ 333 Data: make([]byte, 1<<15, 1<<16), 334 } 335 } 336 337 func (m *StreamConn) Peek(n int) ([]byte, error) { 338 if len(m.Data) >= n { 339 return m.Data[:n], nil 340 } 341 if n == 1 { 342 m.Data = m.Data[:cap(m.Data)] 343 return m.Data[:1], nil 344 } 345 return nil, errs.NewPublic("not enough data") 346 } 347 348 func (m *StreamConn) Skip(n int) error { 349 if len(m.Data) >= n { 350 m.Data = m.Data[n:] 351 return nil 352 } 353 return errs.NewPublic("not enough data") 354 } 355 356 func (m *StreamConn) Release() error { 357 panic("implement me") 358 } 359 360 func (m *StreamConn) Len() int { 361 return len(m.Data) 362 } 363 364 func (m *StreamConn) ReadByte() (byte, error) { 365 panic("implement me") 366 } 367 368 func (m *StreamConn) ReadBinary(n int) (p []byte, err error) { 369 panic("implement me") 370 } 371 372 func DialerFun(addr string) (network.Conn, error) { 373 return NewConn(""), nil 374 }