github.com/EagleQL/Xray-core@v1.4.3/proxy/shadowsocks/server.go (about) 1 package shadowsocks 2 3 import ( 4 "context" 5 "time" 6 7 "github.com/xtls/xray-core/common" 8 "github.com/xtls/xray-core/common/buf" 9 "github.com/xtls/xray-core/common/log" 10 "github.com/xtls/xray-core/common/net" 11 "github.com/xtls/xray-core/common/protocol" 12 udp_proto "github.com/xtls/xray-core/common/protocol/udp" 13 "github.com/xtls/xray-core/common/session" 14 "github.com/xtls/xray-core/common/signal" 15 "github.com/xtls/xray-core/common/task" 16 "github.com/xtls/xray-core/core" 17 "github.com/xtls/xray-core/features/policy" 18 "github.com/xtls/xray-core/features/routing" 19 "github.com/xtls/xray-core/transport/internet" 20 "github.com/xtls/xray-core/transport/internet/udp" 21 ) 22 23 type Server struct { 24 config *ServerConfig 25 validator *Validator 26 policyManager policy.Manager 27 cone bool 28 } 29 30 // NewServer create a new Shadowsocks server. 31 func NewServer(ctx context.Context, config *ServerConfig) (*Server, error) { 32 validator := new(Validator) 33 for _, user := range config.Users { 34 u, err := user.ToMemoryUser() 35 if err != nil { 36 return nil, newError("failed to get shadowsocks user").Base(err).AtError() 37 } 38 39 if err := validator.Add(u); err != nil { 40 return nil, newError("failed to add user").Base(err).AtError() 41 } 42 } 43 44 v := core.MustFromContext(ctx) 45 s := &Server{ 46 config: config, 47 validator: validator, 48 policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager), 49 cone: ctx.Value("cone").(bool), 50 } 51 52 return s, nil 53 } 54 55 // AddUser implements proxy.UserManager.AddUser(). 56 func (s *Server) AddUser(ctx context.Context, u *protocol.MemoryUser) error { 57 return s.validator.Add(u) 58 } 59 60 // RemoveUser implements proxy.UserManager.RemoveUser(). 61 func (s *Server) RemoveUser(ctx context.Context, e string) error { 62 return s.validator.Del(e) 63 } 64 65 func (s *Server) Network() []net.Network { 66 list := s.config.Network 67 if len(list) == 0 { 68 list = append(list, net.Network_TCP) 69 } 70 return list 71 } 72 73 func (s *Server) Process(ctx context.Context, network net.Network, conn internet.Connection, dispatcher routing.Dispatcher) error { 74 switch network { 75 case net.Network_TCP: 76 return s.handleConnection(ctx, conn, dispatcher) 77 case net.Network_UDP: 78 return s.handleUDPPayload(ctx, conn, dispatcher) 79 default: 80 return newError("unknown network: ", network) 81 } 82 } 83 84 func (s *Server) handleUDPPayload(ctx context.Context, conn internet.Connection, dispatcher routing.Dispatcher) error { 85 udpServer := udp.NewDispatcher(dispatcher, func(ctx context.Context, packet *udp_proto.Packet) { 86 request := protocol.RequestHeaderFromContext(ctx) 87 if request == nil { 88 return 89 } 90 91 payload := packet.Payload 92 93 if payload.UDP != nil { 94 request = &protocol.RequestHeader{ 95 User: request.User, 96 Address: payload.UDP.Address, 97 Port: payload.UDP.Port, 98 } 99 } 100 101 data, err := EncodeUDPPacket(request, payload.Bytes()) 102 payload.Release() 103 if err != nil { 104 newError("failed to encode UDP packet").Base(err).AtWarning().WriteToLog(session.ExportIDToError(ctx)) 105 return 106 } 107 defer data.Release() 108 109 conn.Write(data.Bytes()) 110 }) 111 112 inbound := session.InboundFromContext(ctx) 113 if inbound == nil { 114 panic("no inbound metadata") 115 } 116 117 if s.validator.Count() == 1 { 118 inbound.User, _ = s.validator.GetOnlyUser() 119 } 120 121 var dest *net.Destination 122 123 reader := buf.NewPacketReader(conn) 124 for { 125 mpayload, err := reader.ReadMultiBuffer() 126 if err != nil { 127 break 128 } 129 130 for _, payload := range mpayload { 131 var request *protocol.RequestHeader 132 var data *buf.Buffer 133 var err error 134 135 if inbound.User != nil { 136 validator := new(Validator) 137 validator.Add(inbound.User) 138 request, data, err = DecodeUDPPacket(validator, payload) 139 } else { 140 request, data, err = DecodeUDPPacket(s.validator, payload) 141 if err == nil { 142 inbound.User = request.User 143 } 144 } 145 146 if err != nil { 147 if inbound.Source.IsValid() { 148 newError("dropping invalid UDP packet from: ", inbound.Source).Base(err).WriteToLog(session.ExportIDToError(ctx)) 149 log.Record(&log.AccessMessage{ 150 From: inbound.Source, 151 To: "", 152 Status: log.AccessRejected, 153 Reason: err, 154 }) 155 } 156 payload.Release() 157 continue 158 } 159 160 destination := request.Destination() 161 162 currentPacketCtx := ctx 163 if inbound.Source.IsValid() { 164 currentPacketCtx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{ 165 From: inbound.Source, 166 To: destination, 167 Status: log.AccessAccepted, 168 Reason: "", 169 Email: request.User.Email, 170 }) 171 } 172 newError("tunnelling request to ", destination).WriteToLog(session.ExportIDToError(currentPacketCtx)) 173 174 data.UDP = &destination 175 176 if !s.cone || dest == nil { 177 dest = &destination 178 } 179 180 currentPacketCtx = protocol.ContextWithRequestHeader(currentPacketCtx, request) 181 udpServer.Dispatch(currentPacketCtx, *dest, data) 182 } 183 } 184 185 return nil 186 } 187 188 func (s *Server) handleConnection(ctx context.Context, conn internet.Connection, dispatcher routing.Dispatcher) error { 189 sessionPolicy := s.policyManager.ForLevel(0) 190 if err := conn.SetReadDeadline(time.Now().Add(sessionPolicy.Timeouts.Handshake)); err != nil { 191 return newError("unable to set read deadline").Base(err).AtWarning() 192 } 193 194 bufferedReader := buf.BufferedReader{Reader: buf.NewReader(conn)} 195 request, bodyReader, err := ReadTCPSession(s.validator, &bufferedReader) 196 if err != nil { 197 log.Record(&log.AccessMessage{ 198 From: conn.RemoteAddr(), 199 To: "", 200 Status: log.AccessRejected, 201 Reason: err, 202 }) 203 return newError("failed to create request from: ", conn.RemoteAddr()).Base(err) 204 } 205 conn.SetReadDeadline(time.Time{}) 206 207 inbound := session.InboundFromContext(ctx) 208 if inbound == nil { 209 panic("no inbound metadata") 210 } 211 inbound.User = request.User 212 213 dest := request.Destination() 214 ctx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{ 215 From: conn.RemoteAddr(), 216 To: dest, 217 Status: log.AccessAccepted, 218 Reason: "", 219 Email: request.User.Email, 220 }) 221 newError("tunnelling request to ", dest).WriteToLog(session.ExportIDToError(ctx)) 222 223 sessionPolicy = s.policyManager.ForLevel(request.User.Level) 224 ctx, cancel := context.WithCancel(ctx) 225 timer := signal.CancelAfterInactivity(ctx, cancel, sessionPolicy.Timeouts.ConnectionIdle) 226 227 ctx = policy.ContextWithBufferPolicy(ctx, sessionPolicy.Buffer) 228 link, err := dispatcher.Dispatch(ctx, dest) 229 if err != nil { 230 return err 231 } 232 233 responseDone := func() error { 234 defer timer.SetTimeout(sessionPolicy.Timeouts.UplinkOnly) 235 236 bufferedWriter := buf.NewBufferedWriter(buf.NewWriter(conn)) 237 responseWriter, err := WriteTCPResponse(request, bufferedWriter) 238 if err != nil { 239 return newError("failed to write response").Base(err) 240 } 241 242 { 243 payload, err := link.Reader.ReadMultiBuffer() 244 if err != nil { 245 return err 246 } 247 if err := responseWriter.WriteMultiBuffer(payload); err != nil { 248 return err 249 } 250 } 251 252 if err := bufferedWriter.SetBuffered(false); err != nil { 253 return err 254 } 255 256 if err := buf.Copy(link.Reader, responseWriter, buf.UpdateActivity(timer)); err != nil { 257 return newError("failed to transport all TCP response").Base(err) 258 } 259 260 return nil 261 } 262 263 requestDone := func() error { 264 defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly) 265 266 if err := buf.Copy(bodyReader, link.Writer, buf.UpdateActivity(timer)); err != nil { 267 return newError("failed to transport all TCP request").Base(err) 268 } 269 270 return nil 271 } 272 273 var requestDoneAndCloseWriter = task.OnSuccess(requestDone, task.Close(link.Writer)) 274 if err := task.Run(ctx, requestDoneAndCloseWriter, responseDone); err != nil { 275 common.Interrupt(link.Reader) 276 common.Interrupt(link.Writer) 277 return newError("connection ends").Base(err) 278 } 279 280 return nil 281 } 282 283 func init() { 284 common.Must(common.RegisterConfig((*ServerConfig)(nil), func(ctx context.Context, config interface{}) (interface{}, error) { 285 return NewServer(ctx, config.(*ServerConfig)) 286 })) 287 }