github.com/moqsien/xraycore@v1.8.5/proxy/socks/server.go (about) 1 package socks 2 3 import ( 4 "context" 5 "io" 6 "time" 7 8 "github.com/moqsien/xraycore/common" 9 "github.com/moqsien/xraycore/common/buf" 10 "github.com/moqsien/xraycore/common/log" 11 "github.com/moqsien/xraycore/common/net" 12 "github.com/moqsien/xraycore/common/protocol" 13 udp_proto "github.com/moqsien/xraycore/common/protocol/udp" 14 "github.com/moqsien/xraycore/common/session" 15 "github.com/moqsien/xraycore/common/signal" 16 "github.com/moqsien/xraycore/common/task" 17 "github.com/moqsien/xraycore/core" 18 "github.com/moqsien/xraycore/features" 19 "github.com/moqsien/xraycore/features/policy" 20 "github.com/moqsien/xraycore/features/routing" 21 "github.com/moqsien/xraycore/transport/internet/stat" 22 "github.com/moqsien/xraycore/transport/internet/udp" 23 ) 24 25 // Server is a SOCKS 5 proxy server 26 type Server struct { 27 config *ServerConfig 28 policyManager policy.Manager 29 cone bool 30 } 31 32 // NewServer creates a new Server object. 33 func NewServer(ctx context.Context, config *ServerConfig) (*Server, error) { 34 v := core.MustFromContext(ctx) 35 s := &Server{ 36 config: config, 37 policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager), 38 cone: ctx.Value("cone").(bool), 39 } 40 return s, nil 41 } 42 43 func (s *Server) policy() policy.Session { 44 config := s.config 45 p := s.policyManager.ForLevel(config.UserLevel) 46 if config.Timeout > 0 { 47 features.PrintDeprecatedFeatureWarning("Socks timeout") 48 } 49 if config.Timeout > 0 && config.UserLevel == 0 { 50 p.Timeouts.ConnectionIdle = time.Duration(config.Timeout) * time.Second 51 } 52 return p 53 } 54 55 // Network implements proxy.Inbound. 56 func (s *Server) Network() []net.Network { 57 list := []net.Network{net.Network_TCP} 58 if s.config.UdpEnabled { 59 list = append(list, net.Network_UDP) 60 } 61 return list 62 } 63 64 // Process implements proxy.Inbound. 65 func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Connection, dispatcher routing.Dispatcher) error { 66 if inbound := session.InboundFromContext(ctx); inbound != nil { 67 inbound.Name = "socks" 68 inbound.User = &protocol.MemoryUser{ 69 Level: s.config.UserLevel, 70 } 71 } 72 73 switch network { 74 case net.Network_TCP: 75 return s.processTCP(ctx, conn, dispatcher) 76 case net.Network_UDP: 77 return s.handleUDPPayload(ctx, conn, dispatcher) 78 default: 79 return newError("unknown network: ", network) 80 } 81 } 82 83 func (s *Server) processTCP(ctx context.Context, conn stat.Connection, dispatcher routing.Dispatcher) error { 84 plcy := s.policy() 85 if err := conn.SetReadDeadline(time.Now().Add(plcy.Timeouts.Handshake)); err != nil { 86 newError("failed to set deadline").Base(err).WriteToLog(session.ExportIDToError(ctx)) 87 } 88 89 inbound := session.InboundFromContext(ctx) 90 if inbound == nil || !inbound.Gateway.IsValid() { 91 return newError("inbound gateway not specified") 92 } 93 94 svrSession := &ServerSession{ 95 config: s.config, 96 address: inbound.Gateway.Address, 97 port: inbound.Gateway.Port, 98 localAddress: net.IPAddress(conn.LocalAddr().(*net.TCPAddr).IP), 99 } 100 101 reader := &buf.BufferedReader{Reader: buf.NewReader(conn)} 102 request, err := svrSession.Handshake(reader, conn) 103 if err != nil { 104 if inbound != nil && inbound.Source.IsValid() { 105 log.Record(&log.AccessMessage{ 106 From: inbound.Source, 107 To: "", 108 Status: log.AccessRejected, 109 Reason: err, 110 }) 111 } 112 return newError("failed to read request").Base(err) 113 } 114 if request.User != nil { 115 inbound.User.Email = request.User.Email 116 } 117 118 if err := conn.SetReadDeadline(time.Time{}); err != nil { 119 newError("failed to clear deadline").Base(err).WriteToLog(session.ExportIDToError(ctx)) 120 } 121 122 if request.Command == protocol.RequestCommandTCP { 123 dest := request.Destination() 124 newError("TCP Connect request to ", dest).WriteToLog(session.ExportIDToError(ctx)) 125 if inbound != nil && inbound.Source.IsValid() { 126 ctx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{ 127 From: inbound.Source, 128 To: dest, 129 Status: log.AccessAccepted, 130 Reason: "", 131 }) 132 } 133 134 return s.transport(ctx, reader, conn, dest, dispatcher, inbound) 135 } 136 137 if request.Command == protocol.RequestCommandUDP { 138 return s.handleUDP(conn) 139 } 140 141 return nil 142 } 143 144 func (*Server) handleUDP(c io.Reader) error { 145 // The TCP connection closes after this method returns. We need to wait until 146 // the client closes it. 147 return common.Error2(io.Copy(buf.DiscardBytes, c)) 148 } 149 150 func (s *Server) transport(ctx context.Context, reader io.Reader, writer io.Writer, dest net.Destination, dispatcher routing.Dispatcher, inbound *session.Inbound) error { 151 ctx, cancel := context.WithCancel(ctx) 152 timer := signal.CancelAfterInactivity(ctx, cancel, s.policy().Timeouts.ConnectionIdle) 153 154 if inbound != nil { 155 inbound.Timer = timer 156 } 157 158 plcy := s.policy() 159 ctx = policy.ContextWithBufferPolicy(ctx, plcy.Buffer) 160 link, err := dispatcher.Dispatch(ctx, dest) 161 if err != nil { 162 return err 163 } 164 165 requestDone := func() error { 166 defer timer.SetTimeout(plcy.Timeouts.DownlinkOnly) 167 if err := buf.Copy(buf.NewReader(reader), link.Writer, buf.UpdateActivity(timer)); err != nil { 168 return newError("failed to transport all TCP request").Base(err) 169 } 170 171 return nil 172 } 173 174 responseDone := func() error { 175 defer timer.SetTimeout(plcy.Timeouts.UplinkOnly) 176 177 v2writer := buf.NewWriter(writer) 178 if err := buf.Copy(link.Reader, v2writer, buf.UpdateActivity(timer)); err != nil { 179 return newError("failed to transport all TCP response").Base(err) 180 } 181 182 return nil 183 } 184 185 requestDonePost := task.OnSuccess(requestDone, task.Close(link.Writer)) 186 if err := task.Run(ctx, requestDonePost, responseDone); err != nil { 187 common.Interrupt(link.Reader) 188 common.Interrupt(link.Writer) 189 return newError("connection ends").Base(err) 190 } 191 192 return nil 193 } 194 195 func (s *Server) handleUDPPayload(ctx context.Context, conn stat.Connection, dispatcher routing.Dispatcher) error { 196 udpServer := udp.NewDispatcher(dispatcher, func(ctx context.Context, packet *udp_proto.Packet) { 197 payload := packet.Payload 198 newError("writing back UDP response with ", payload.Len(), " bytes").AtDebug().WriteToLog(session.ExportIDToError(ctx)) 199 200 request := protocol.RequestHeaderFromContext(ctx) 201 if request == nil { 202 return 203 } 204 205 if payload.UDP != nil { 206 request = &protocol.RequestHeader{ 207 User: request.User, 208 Address: payload.UDP.Address, 209 Port: payload.UDP.Port, 210 } 211 } 212 213 udpMessage, err := EncodeUDPPacket(request, payload.Bytes()) 214 payload.Release() 215 216 defer udpMessage.Release() 217 if err != nil { 218 newError("failed to write UDP response").AtWarning().Base(err).WriteToLog(session.ExportIDToError(ctx)) 219 } 220 221 conn.Write(udpMessage.Bytes()) 222 }) 223 224 inbound := session.InboundFromContext(ctx) 225 if inbound != nil && inbound.Source.IsValid() { 226 newError("client UDP connection from ", inbound.Source).WriteToLog(session.ExportIDToError(ctx)) 227 } 228 229 var dest *net.Destination 230 231 reader := buf.NewPacketReader(conn) 232 for { 233 mpayload, err := reader.ReadMultiBuffer() 234 if err != nil { 235 return err 236 } 237 238 for _, payload := range mpayload { 239 request, err := DecodeUDPPacket(payload) 240 if err != nil { 241 newError("failed to parse UDP request").Base(err).WriteToLog(session.ExportIDToError(ctx)) 242 payload.Release() 243 continue 244 } 245 246 if payload.IsEmpty() { 247 payload.Release() 248 continue 249 } 250 251 destination := request.Destination() 252 253 currentPacketCtx := ctx 254 newError("send packet to ", destination, " with ", payload.Len(), " bytes").AtDebug().WriteToLog(session.ExportIDToError(ctx)) 255 if inbound != nil && inbound.Source.IsValid() { 256 currentPacketCtx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{ 257 From: inbound.Source, 258 To: destination, 259 Status: log.AccessAccepted, 260 Reason: "", 261 }) 262 } 263 264 payload.UDP = &destination 265 266 if !s.cone || dest == nil { 267 dest = &destination 268 } 269 270 currentPacketCtx = protocol.ContextWithRequestHeader(currentPacketCtx, request) 271 udpServer.Dispatch(currentPacketCtx, *dest, payload) 272 } 273 } 274 } 275 276 func init() { 277 common.Must(common.RegisterConfig((*ServerConfig)(nil), func(ctx context.Context, config interface{}) (interface{}, error) { 278 return NewServer(ctx, config.(*ServerConfig)) 279 })) 280 }