github.com/xmplusdev/xray-core@v1.8.10/proxy/shadowsocks/server.go (about) 1 package shadowsocks 2 3 import ( 4 "context" 5 "time" 6 7 "github.com/xmplusdev/xray-core/common" 8 "github.com/xmplusdev/xray-core/common/buf" 9 "github.com/xmplusdev/xray-core/common/log" 10 "github.com/xmplusdev/xray-core/common/net" 11 "github.com/xmplusdev/xray-core/common/protocol" 12 udp_proto "github.com/xmplusdev/xray-core/common/protocol/udp" 13 "github.com/xmplusdev/xray-core/common/session" 14 "github.com/xmplusdev/xray-core/common/signal" 15 "github.com/xmplusdev/xray-core/common/task" 16 "github.com/xmplusdev/xray-core/core" 17 "github.com/xmplusdev/xray-core/features/policy" 18 "github.com/xmplusdev/xray-core/features/routing" 19 "github.com/xmplusdev/xray-core/transport/internet/stat" 20 "github.com/xmplusdev/xray-core/transport/internet/udp" 21 ) 22 23 type Server struct { 24 config *ServerConfig 25 validator *Validator 26 policyManager policy.Manager 27 cone bool 28 } 29 30 // NewServer create a new Shadowsocks server. 31 func NewServer(ctx context.Context, config *ServerConfig) (*Server, error) { 32 validator := new(Validator) 33 for _, user := range config.Users { 34 u, err := user.ToMemoryUser() 35 if err != nil { 36 return nil, newError("failed to get shadowsocks user").Base(err).AtError() 37 } 38 39 if err := validator.Add(u); err != nil { 40 return nil, newError("failed to add user").Base(err).AtError() 41 } 42 } 43 44 v := core.MustFromContext(ctx) 45 s := &Server{ 46 config: config, 47 validator: validator, 48 policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager), 49 cone: ctx.Value("cone").(bool), 50 } 51 52 return s, nil 53 } 54 55 // AddUser implements proxy.UserManager.AddUser(). 56 func (s *Server) AddUser(ctx context.Context, u *protocol.MemoryUser) error { 57 return s.validator.Add(u) 58 } 59 60 // RemoveUser implements proxy.UserManager.RemoveUser(). 61 func (s *Server) RemoveUser(ctx context.Context, e string) error { 62 return s.validator.Del(e) 63 } 64 65 func (s *Server) Network() []net.Network { 66 list := s.config.Network 67 if len(list) == 0 { 68 list = append(list, net.Network_TCP) 69 } 70 return list 71 } 72 73 func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Connection, dispatcher routing.Dispatcher) error { 74 inbound := session.InboundFromContext(ctx) 75 inbound.Name = "shadowsocks" 76 inbound.SetCanSpliceCopy(3) 77 78 switch network { 79 case net.Network_TCP: 80 return s.handleConnection(ctx, conn, dispatcher) 81 case net.Network_UDP: 82 return s.handleUDPPayload(ctx, conn, dispatcher) 83 default: 84 return newError("unknown network: ", network) 85 } 86 } 87 88 func (s *Server) handleUDPPayload(ctx context.Context, conn stat.Connection, dispatcher routing.Dispatcher) error { 89 udpServer := udp.NewDispatcher(dispatcher, func(ctx context.Context, packet *udp_proto.Packet) { 90 request := protocol.RequestHeaderFromContext(ctx) 91 if request == nil { 92 return 93 } 94 95 payload := packet.Payload 96 97 if payload.UDP != nil { 98 request = &protocol.RequestHeader{ 99 User: request.User, 100 Address: payload.UDP.Address, 101 Port: payload.UDP.Port, 102 } 103 } 104 105 data, err := EncodeUDPPacket(request, payload.Bytes()) 106 payload.Release() 107 if err != nil { 108 newError("failed to encode UDP packet").Base(err).AtWarning().WriteToLog(session.ExportIDToError(ctx)) 109 return 110 } 111 defer data.Release() 112 113 conn.Write(data.Bytes()) 114 }) 115 116 inbound := session.InboundFromContext(ctx) 117 var dest *net.Destination 118 reader := buf.NewPacketReader(conn) 119 for { 120 mpayload, err := reader.ReadMultiBuffer() 121 if err != nil { 122 break 123 } 124 125 for _, payload := range mpayload { 126 var request *protocol.RequestHeader 127 var data *buf.Buffer 128 var err error 129 130 if inbound.User != nil { 131 validator := new(Validator) 132 validator.Add(inbound.User) 133 request, data, err = DecodeUDPPacket(validator, payload) 134 } else { 135 request, data, err = DecodeUDPPacket(s.validator, payload) 136 if err == nil { 137 inbound.User = request.User 138 } 139 } 140 141 if err != nil { 142 if inbound.Source.IsValid() { 143 newError("dropping invalid UDP packet from: ", inbound.Source).Base(err).WriteToLog(session.ExportIDToError(ctx)) 144 log.Record(&log.AccessMessage{ 145 From: inbound.Source, 146 To: "", 147 Status: log.AccessRejected, 148 Reason: err, 149 }) 150 } 151 payload.Release() 152 continue 153 } 154 155 destination := request.Destination() 156 157 currentPacketCtx := ctx 158 if inbound.Source.IsValid() { 159 currentPacketCtx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{ 160 From: inbound.Source, 161 To: destination, 162 Status: log.AccessAccepted, 163 Reason: "", 164 Email: request.User.Email, 165 }) 166 } 167 newError("tunnelling request to ", destination).WriteToLog(session.ExportIDToError(currentPacketCtx)) 168 169 data.UDP = &destination 170 171 if !s.cone || dest == nil { 172 dest = &destination 173 } 174 175 currentPacketCtx = protocol.ContextWithRequestHeader(currentPacketCtx, request) 176 udpServer.Dispatch(currentPacketCtx, *dest, data) 177 } 178 } 179 180 return nil 181 } 182 183 func (s *Server) handleConnection(ctx context.Context, conn stat.Connection, dispatcher routing.Dispatcher) error { 184 sessionPolicy := s.policyManager.ForLevel(0) 185 if err := conn.SetReadDeadline(time.Now().Add(sessionPolicy.Timeouts.Handshake)); err != nil { 186 return newError("unable to set read deadline").Base(err).AtWarning() 187 } 188 189 bufferedReader := buf.BufferedReader{Reader: buf.NewReader(conn)} 190 request, bodyReader, err := ReadTCPSession(s.validator, &bufferedReader) 191 if err != nil { 192 log.Record(&log.AccessMessage{ 193 From: conn.RemoteAddr(), 194 To: "", 195 Status: log.AccessRejected, 196 Reason: err, 197 }) 198 return newError("failed to create request from: ", conn.RemoteAddr()).Base(err) 199 } 200 conn.SetReadDeadline(time.Time{}) 201 202 inbound := session.InboundFromContext(ctx) 203 if inbound == nil { 204 panic("no inbound metadata") 205 } 206 inbound.User = request.User 207 208 dest := request.Destination() 209 ctx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{ 210 From: conn.RemoteAddr(), 211 To: dest, 212 Status: log.AccessAccepted, 213 Reason: "", 214 Email: request.User.Email, 215 }) 216 newError("tunnelling request to ", dest).WriteToLog(session.ExportIDToError(ctx)) 217 218 sessionPolicy = s.policyManager.ForLevel(request.User.Level) 219 ctx, cancel := context.WithCancel(ctx) 220 timer := signal.CancelAfterInactivity(ctx, cancel, sessionPolicy.Timeouts.ConnectionIdle) 221 222 ctx = policy.ContextWithBufferPolicy(ctx, sessionPolicy.Buffer) 223 link, err := dispatcher.Dispatch(ctx, dest) 224 if err != nil { 225 return err 226 } 227 228 responseDone := func() error { 229 defer timer.SetTimeout(sessionPolicy.Timeouts.UplinkOnly) 230 231 bufferedWriter := buf.NewBufferedWriter(buf.NewWriter(conn)) 232 responseWriter, err := WriteTCPResponse(request, bufferedWriter) 233 if err != nil { 234 return newError("failed to write response").Base(err) 235 } 236 237 { 238 payload, err := link.Reader.ReadMultiBuffer() 239 if err != nil { 240 return err 241 } 242 if err := responseWriter.WriteMultiBuffer(payload); err != nil { 243 return err 244 } 245 } 246 247 if err := bufferedWriter.SetBuffered(false); err != nil { 248 return err 249 } 250 251 if err := buf.Copy(link.Reader, responseWriter, buf.UpdateActivity(timer)); err != nil { 252 return newError("failed to transport all TCP response").Base(err) 253 } 254 255 return nil 256 } 257 258 requestDone := func() error { 259 defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly) 260 261 if err := buf.Copy(bodyReader, link.Writer, buf.UpdateActivity(timer)); err != nil { 262 return newError("failed to transport all TCP request").Base(err) 263 } 264 265 return nil 266 } 267 268 requestDoneAndCloseWriter := task.OnSuccess(requestDone, task.Close(link.Writer)) 269 if err := task.Run(ctx, requestDoneAndCloseWriter, responseDone); err != nil { 270 common.Interrupt(link.Reader) 271 common.Interrupt(link.Writer) 272 return newError("connection ends").Base(err) 273 } 274 275 return nil 276 } 277 278 func init() { 279 common.Must(common.RegisterConfig((*ServerConfig)(nil), func(ctx context.Context, config interface{}) (interface{}, error) { 280 return NewServer(ctx, config.(*ServerConfig)) 281 })) 282 }