github.com/sagernet/sing-box@v1.2.7/common/mux/service.go (about) 1 package mux 2 3 import ( 4 "context" 5 "encoding/binary" 6 "net" 7 8 "github.com/sagernet/sing-box/adapter" 9 "github.com/sagernet/sing-box/log" 10 "github.com/sagernet/sing/common" 11 "github.com/sagernet/sing/common/buf" 12 "github.com/sagernet/sing/common/bufio" 13 E "github.com/sagernet/sing/common/exceptions" 14 M "github.com/sagernet/sing/common/metadata" 15 N "github.com/sagernet/sing/common/network" 16 "github.com/sagernet/sing/common/rw" 17 "github.com/sagernet/sing/common/task" 18 ) 19 20 func NewConnection(ctx context.Context, router adapter.Router, errorHandler E.Handler, logger log.ContextLogger, conn net.Conn, metadata adapter.InboundContext) error { 21 request, err := ReadRequest(conn) 22 if err != nil { 23 return err 24 } 25 session, err := request.Protocol.newServer(conn) 26 if err != nil { 27 return err 28 } 29 var group task.Group 30 group.Append0(func(ctx context.Context) error { 31 var stream net.Conn 32 for { 33 stream, err = session.Accept() 34 if err != nil { 35 return err 36 } 37 go newConnection(ctx, router, errorHandler, logger, stream, metadata) 38 } 39 }) 40 group.Cleanup(func() { 41 session.Close() 42 }) 43 return group.Run(ctx) 44 } 45 46 func newConnection(ctx context.Context, router adapter.Router, errorHandler E.Handler, logger log.ContextLogger, stream net.Conn, metadata adapter.InboundContext) { 47 stream = &wrapStream{stream} 48 request, err := ReadStreamRequest(stream) 49 if err != nil { 50 logger.ErrorContext(ctx, err) 51 return 52 } 53 metadata.Destination = request.Destination 54 if request.Network == N.NetworkTCP { 55 logger.InfoContext(ctx, "inbound multiplex connection to ", metadata.Destination) 56 hErr := router.RouteConnection(ctx, &ServerConn{ExtendedConn: bufio.NewExtendedConn(stream)}, metadata) 57 stream.Close() 58 if hErr != nil { 59 errorHandler.NewError(ctx, hErr) 60 } 61 } else { 62 var packetConn N.PacketConn 63 if !request.PacketAddr { 64 logger.InfoContext(ctx, "inbound multiplex packet connection to ", metadata.Destination) 65 packetConn = &ServerPacketConn{ExtendedConn: bufio.NewExtendedConn(stream), destination: request.Destination} 66 } else { 67 logger.InfoContext(ctx, "inbound multiplex packet connection") 68 packetConn = &ServerPacketAddrConn{ExtendedConn: bufio.NewExtendedConn(stream)} 69 } 70 hErr := router.RoutePacketConnection(ctx, packetConn, metadata) 71 stream.Close() 72 if hErr != nil { 73 errorHandler.NewError(ctx, hErr) 74 } 75 } 76 } 77 78 var _ N.HandshakeConn = (*ServerConn)(nil) 79 80 type ServerConn struct { 81 N.ExtendedConn 82 responseWrite bool 83 } 84 85 func (c *ServerConn) HandshakeFailure(err error) error { 86 errMessage := err.Error() 87 _buffer := buf.StackNewSize(1 + rw.UVariantLen(uint64(len(errMessage))) + len(errMessage)) 88 defer common.KeepAlive(_buffer) 89 buffer := common.Dup(_buffer) 90 defer buffer.Release() 91 common.Must( 92 buffer.WriteByte(statusError), 93 rw.WriteVString(_buffer, errMessage), 94 ) 95 return c.ExtendedConn.WriteBuffer(buffer) 96 } 97 98 func (c *ServerConn) Write(b []byte) (n int, err error) { 99 if c.responseWrite { 100 return c.ExtendedConn.Write(b) 101 } 102 _buffer := buf.StackNewSize(1 + len(b)) 103 defer common.KeepAlive(_buffer) 104 buffer := common.Dup(_buffer) 105 defer buffer.Release() 106 common.Must( 107 buffer.WriteByte(statusSuccess), 108 common.Error(buffer.Write(b)), 109 ) 110 _, err = c.ExtendedConn.Write(buffer.Bytes()) 111 if err != nil { 112 return 113 } 114 c.responseWrite = true 115 return len(b), nil 116 } 117 118 func (c *ServerConn) WriteBuffer(buffer *buf.Buffer) error { 119 if c.responseWrite { 120 return c.ExtendedConn.WriteBuffer(buffer) 121 } 122 buffer.ExtendHeader(1)[0] = statusSuccess 123 c.responseWrite = true 124 return c.ExtendedConn.WriteBuffer(buffer) 125 } 126 127 func (c *ServerConn) FrontHeadroom() int { 128 if !c.responseWrite { 129 return 1 130 } 131 return 0 132 } 133 134 func (c *ServerConn) NeedAdditionalReadDeadline() bool { 135 return true 136 } 137 138 func (c *ServerConn) Upstream() any { 139 return c.ExtendedConn 140 } 141 142 var ( 143 _ N.HandshakeConn = (*ServerPacketConn)(nil) 144 _ N.PacketConn = (*ServerPacketConn)(nil) 145 ) 146 147 type ServerPacketConn struct { 148 N.ExtendedConn 149 destination M.Socksaddr 150 responseWrite bool 151 } 152 153 func (c *ServerPacketConn) HandshakeFailure(err error) error { 154 errMessage := err.Error() 155 _buffer := buf.StackNewSize(1 + rw.UVariantLen(uint64(len(errMessage))) + len(errMessage)) 156 defer common.KeepAlive(_buffer) 157 buffer := common.Dup(_buffer) 158 defer buffer.Release() 159 common.Must( 160 buffer.WriteByte(statusError), 161 rw.WriteVString(_buffer, errMessage), 162 ) 163 return c.ExtendedConn.WriteBuffer(buffer) 164 } 165 166 func (c *ServerPacketConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { 167 var length uint16 168 err = binary.Read(c.ExtendedConn, binary.BigEndian, &length) 169 if err != nil { 170 return 171 } 172 _, err = buffer.ReadFullFrom(c.ExtendedConn, int(length)) 173 if err != nil { 174 return 175 } 176 destination = c.destination 177 return 178 } 179 180 func (c *ServerPacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { 181 pLen := buffer.Len() 182 common.Must(binary.Write(buf.With(buffer.ExtendHeader(2)), binary.BigEndian, uint16(pLen))) 183 if !c.responseWrite { 184 buffer.ExtendHeader(1)[0] = statusSuccess 185 c.responseWrite = true 186 } 187 return c.ExtendedConn.WriteBuffer(buffer) 188 } 189 190 func (c *ServerPacketConn) NeedAdditionalReadDeadline() bool { 191 return true 192 } 193 194 func (c *ServerPacketConn) Upstream() any { 195 return c.ExtendedConn 196 } 197 198 func (c *ServerPacketConn) FrontHeadroom() int { 199 if !c.responseWrite { 200 return 3 201 } 202 return 2 203 } 204 205 var ( 206 _ N.HandshakeConn = (*ServerPacketAddrConn)(nil) 207 _ N.PacketConn = (*ServerPacketAddrConn)(nil) 208 ) 209 210 type ServerPacketAddrConn struct { 211 N.ExtendedConn 212 responseWrite bool 213 } 214 215 func (c *ServerPacketAddrConn) HandshakeFailure(err error) error { 216 errMessage := err.Error() 217 _buffer := buf.StackNewSize(1 + rw.UVariantLen(uint64(len(errMessage))) + len(errMessage)) 218 defer common.KeepAlive(_buffer) 219 buffer := common.Dup(_buffer) 220 defer buffer.Release() 221 common.Must( 222 buffer.WriteByte(statusError), 223 rw.WriteVString(_buffer, errMessage), 224 ) 225 return c.ExtendedConn.WriteBuffer(buffer) 226 } 227 228 func (c *ServerPacketAddrConn) ReadPacket(buffer *buf.Buffer) (destination M.Socksaddr, err error) { 229 destination, err = M.SocksaddrSerializer.ReadAddrPort(c.ExtendedConn) 230 if err != nil { 231 return 232 } 233 var length uint16 234 err = binary.Read(c.ExtendedConn, binary.BigEndian, &length) 235 if err != nil { 236 return 237 } 238 _, err = buffer.ReadFullFrom(c.ExtendedConn, int(length)) 239 if err != nil { 240 return 241 } 242 return 243 } 244 245 func (c *ServerPacketAddrConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { 246 pLen := buffer.Len() 247 common.Must(binary.Write(buf.With(buffer.ExtendHeader(2)), binary.BigEndian, uint16(pLen))) 248 common.Must(M.SocksaddrSerializer.WriteAddrPort(buf.With(buffer.ExtendHeader(M.SocksaddrSerializer.AddrPortLen(destination))), destination)) 249 if !c.responseWrite { 250 buffer.ExtendHeader(1)[0] = statusSuccess 251 c.responseWrite = true 252 } 253 return c.ExtendedConn.WriteBuffer(buffer) 254 } 255 256 func (c *ServerPacketAddrConn) NeedAdditionalReadDeadline() bool { 257 return true 258 } 259 260 func (c *ServerPacketAddrConn) Upstream() any { 261 return c.ExtendedConn 262 } 263 264 func (c *ServerPacketAddrConn) FrontHeadroom() int { 265 if !c.responseWrite { 266 return 3 + M.MaxSocksaddrLength 267 } 268 return 2 + M.MaxSocksaddrLength 269 }