github.com/pawelgaczynski/gain@v0.4.0-alpha.0.20230821120126-41f1e60a18da/conn.go (about) 1 // Copyright (c) 2023 Paweł Gaczyński 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package gain 16 17 import ( 18 "io" 19 "net" 20 "sync/atomic" 21 "syscall" 22 "time" 23 "unsafe" 24 25 "github.com/pawelgaczynski/gain/pkg/buffer/magicring" 26 "github.com/pawelgaczynski/gain/pkg/errors" 27 "github.com/pawelgaczynski/gain/pkg/pool/byteslice" 28 "github.com/pawelgaczynski/gain/pkg/pool/ringbuffer" 29 "github.com/pawelgaczynski/gain/pkg/socket" 30 ) 31 32 type connectionState int 33 34 const ( 35 connInvalid connectionState = iota 36 connAccept 37 connRead 38 connWrite 39 connClose 40 ) 41 42 func (s connectionState) String() string { 43 switch s { 44 case connAccept: 45 return "accept" 46 case connRead: 47 return "read" 48 case connWrite: 49 return "write" 50 case connClose: 51 return "close" 52 default: 53 return "invalid" 54 } 55 } 56 57 const ( 58 kernelSpace = iota 59 userSpace 60 ) 61 62 func connModeString(m uint32) string { 63 switch m { 64 case kernelSpace: 65 return "kernelSpace" 66 case userSpace: 67 return "userSpace" 68 default: 69 return "invalid" 70 } 71 } 72 73 const ( 74 msgControlBufferSize = 64 75 ) 76 77 const ( 78 noOp = iota 79 readOp 80 writeOp 81 closeOp 82 ) 83 84 const ( 85 tcp = iota 86 udp 87 ) 88 89 type connection struct { 90 fd int 91 key int 92 network uint32 93 94 inboundBuffer *magicring.RingBuffer 95 outboundBuffer *magicring.RingBuffer 96 state connectionState 97 mode atomic.Uint32 98 closed atomic.Bool 99 100 msgHdr *syscall.Msghdr 101 rawSockaddr *syscall.RawSockaddrAny 102 103 localAddr net.Addr 104 remoteAddr net.Addr 105 106 ctx interface{} 107 108 nextAsyncOp int 109 } 110 111 func (c *connection) outboundReadAddress() unsafe.Pointer { 112 return c.outboundBuffer.ReadAddress() 113 } 114 115 func (c *connection) inboundWriteAddress() unsafe.Pointer { 116 return c.inboundBuffer.WriteAddress() 117 } 118 119 func (c *connection) setKernelSpace() { 120 c.mode.Store(kernelSpace) 121 } 122 123 func (c *connection) setUserSpace() { 124 c.mode.Store(userSpace) 125 } 126 127 func (c *connection) Context() interface{} { 128 return c.ctx 129 } 130 131 func (c *connection) SetContext(ctx interface{}) { 132 c.ctx = ctx 133 } 134 135 func (c *connection) LocalAddr() net.Addr { 136 return c.localAddr 137 } 138 139 func (c *connection) RemoteAddr() net.Addr { 140 return c.remoteAddr 141 } 142 143 func (c *connection) Fd() int { 144 return c.fd 145 } 146 147 func (c *connection) userOpAllowed(name string) error { 148 if c.closed.Load() { 149 return errors.ErrConnectionClosed 150 } 151 152 if mode := c.mode.Load(); mode != userSpace { 153 return errors.ErrorOpNotAvailableInMode(name, connModeString(mode)) 154 } 155 156 return nil 157 } 158 159 func (c *connection) SetReadBuffer(bytes int) error { 160 err := c.userOpAllowed("setReadBuffer") 161 if err != nil { 162 return err 163 } 164 //nolint:wrapcheck 165 return socket.SetRecvBuffer(c.fd, bytes) 166 } 167 168 func (c *connection) SetWriteBuffer(bytes int) error { 169 err := c.userOpAllowed("setWriteBuffer") 170 if err != nil { 171 return err 172 } 173 //nolint:wrapcheck 174 return socket.SetSendBuffer(c.fd, bytes) 175 } 176 177 func (c *connection) SetLinger(sec int) error { 178 err := c.userOpAllowed("setLinger") 179 if err != nil { 180 return err 181 } 182 //nolint:wrapcheck 183 return socket.SetLinger(c.fd, sec) 184 } 185 186 func (c *connection) SetNoDelay(noDelay bool) error { 187 err := c.userOpAllowed("setNoDelay") 188 if err != nil { 189 return err 190 } 191 //nolint:wrapcheck 192 return socket.SetNoDelay(c.fd, boolToInt(noDelay)) 193 } 194 195 func (c *connection) SetKeepAlivePeriod(period time.Duration) error { 196 err := c.userOpAllowed("setKeepAlivePeriod") 197 if err != nil { 198 return err 199 } 200 //nolint:wrapcheck 201 return socket.SetKeepAlivePeriod(c.fd, int(period.Seconds())) 202 } 203 204 func (c *connection) onKernelRead(n int) { 205 c.inboundBuffer.AdvanceWrite(n) 206 } 207 208 func (c *connection) onKernelWrite(n int) { 209 c.outboundBuffer.AdvanceRead(n) 210 } 211 212 func (c *connection) isClosed() bool { 213 return c.closed.Load() 214 } 215 216 func (c *connection) Close() error { 217 if network := atomic.LoadUint32(&c.network); network == udp { 218 return nil 219 } 220 221 if c.closed.Load() { 222 return errors.ErrConnectionAlreadyClosed 223 } 224 225 c.closed.Store(true) 226 227 return nil 228 } 229 230 func (c *connection) Next(n int) ([]byte, error) { 231 err := c.userOpAllowed("next") 232 if err != nil { 233 return nil, err 234 } 235 236 //nolint:wrapcheck 237 return c.inboundBuffer.Next(n) 238 } 239 240 func (c *connection) Discard(n int) (int, error) { 241 err := c.userOpAllowed("discard") 242 if err != nil { 243 return 0, err 244 } 245 246 return c.inboundBuffer.Discard(n), nil 247 } 248 249 func (c *connection) Peek(n int) ([]byte, error) { 250 err := c.userOpAllowed("peek") 251 if err != nil { 252 return nil, err 253 } 254 255 return c.inboundBuffer.Peek(n), nil 256 } 257 258 func (c *connection) ReadFrom(reader io.Reader) (int64, error) { 259 err := c.userOpAllowed("readFrom") 260 if err != nil { 261 return 0, err 262 } 263 264 //nolint:wrapcheck 265 return c.outboundBuffer.ReadFrom(reader) 266 } 267 268 func (c *connection) WriteTo(writer io.Writer) (int64, error) { 269 err := c.userOpAllowed("writeTo") 270 if err != nil { 271 return 0, err 272 } 273 274 //nolint:wrapcheck 275 return c.inboundBuffer.WriteTo(writer) 276 } 277 278 func (c *connection) Read(buffer []byte) (int, error) { 279 err := c.userOpAllowed("read") 280 if err != nil { 281 return 0, err 282 } 283 284 //nolint:wrapcheck 285 return c.inboundBuffer.Read(buffer) 286 } 287 288 func (c *connection) Write(buffer []byte) (int, error) { 289 err := c.userOpAllowed("write") 290 if err != nil { 291 return 0, err 292 } 293 294 //nolint:wrapcheck 295 return c.outboundBuffer.Write(buffer) 296 } 297 298 func (c *connection) OutboundBuffered() int { 299 return c.outboundBuffer.Buffered() 300 } 301 302 func (c *connection) InboundBuffered() int { 303 return c.inboundBuffer.Buffered() 304 } 305 306 func (c *connection) setMsgHeaderWrite() { 307 c.msgHdr.Iov.Base = (*byte)(c.outboundReadAddress()) 308 c.msgHdr.Iov.SetLen(c.OutboundBuffered()) 309 } 310 311 func (c *connection) initMsgHeader() { 312 var iovec syscall.Iovec 313 iovec.Base = (*byte)(c.inboundWriteAddress()) 314 iovec.SetLen(c.inboundBuffer.Cap()) 315 316 var ( 317 msg syscall.Msghdr 318 rsa syscall.RawSockaddrAny 319 ) 320 321 msg.Name = (*byte)(unsafe.Pointer(&rsa)) 322 msg.Namelen = uint32(syscall.SizeofSockaddrAny) 323 msg.Iov = &iovec 324 msg.Iovlen = 1 325 326 controlBuffer := byteslice.Get(msgControlBufferSize) 327 msg.Control = (*byte)(unsafe.Pointer(&controlBuffer[0])) 328 msg.SetControllen(msgControlBufferSize) 329 330 c.msgHdr = &msg 331 c.rawSockaddr = &rsa 332 } 333 334 func (c *connection) fork(newConn *connection, key int, write bool) *connection { 335 newConn.inboundBuffer = c.inboundBuffer 336 newConn.outboundBuffer = c.outboundBuffer 337 newConn.msgHdr = c.msgHdr 338 newConn.rawSockaddr = c.rawSockaddr 339 newConn.state = c.state 340 newConn.fd = c.fd 341 newConn.key = key 342 newConn.network = udp 343 344 if sockAddr, err := anyToSockaddr(newConn.rawSockaddr); err == nil { 345 newConn.remoteAddr = socket.SockaddrToUDPAddr(sockAddr) 346 } 347 348 if write { 349 newConn.setMsgHeaderWrite() 350 } 351 352 c.inboundBuffer = ringbuffer.Get() 353 c.outboundBuffer = ringbuffer.Get() 354 c.initMsgHeader() 355 356 return newConn 357 } 358 359 func newConnection() *connection { 360 conn := &connection{ 361 inboundBuffer: ringbuffer.Get(), 362 outboundBuffer: ringbuffer.Get(), 363 } 364 365 return conn 366 }