github.com/xraypb/xray-core@v1.6.6/proxy/shadowsocks/server.go (about) 1 package shadowsocks 2 3 import ( 4 "context" 5 "time" 6 7 "github.com/xraypb/xray-core/common" 8 "github.com/xraypb/xray-core/common/buf" 9 "github.com/xraypb/xray-core/common/log" 10 "github.com/xraypb/xray-core/common/net" 11 "github.com/xraypb/xray-core/common/protocol" 12 udp_proto "github.com/xraypb/xray-core/common/protocol/udp" 13 "github.com/xraypb/xray-core/common/session" 14 "github.com/xraypb/xray-core/common/signal" 15 "github.com/xraypb/xray-core/common/task" 16 "github.com/xraypb/xray-core/core" 17 "github.com/xraypb/xray-core/features/policy" 18 "github.com/xraypb/xray-core/features/routing" 19 "github.com/xraypb/xray-core/transport/internet/stat" 20 "github.com/xraypb/xray-core/transport/internet/udp" 21 ) 22 23 type Server struct { 24 config *ServerConfig 25 validator *Validator 26 policyManager policy.Manager 27 cone bool 28 } 29 30 // NewServer create a new Shadowsocks server. 31 func NewServer(ctx context.Context, config *ServerConfig) (*Server, error) { 32 validator := new(Validator) 33 for _, user := range config.Users { 34 u, err := user.ToMemoryUser() 35 if err != nil { 36 return nil, newError("failed to get shadowsocks user").Base(err).AtError() 37 } 38 39 if err := validator.Add(u); err != nil { 40 return nil, newError("failed to add user").Base(err).AtError() 41 } 42 } 43 44 v := core.MustFromContext(ctx) 45 s := &Server{ 46 config: config, 47 validator: validator, 48 policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager), 49 cone: ctx.Value("cone").(bool), 50 } 51 52 return s, nil 53 } 54 55 // AddUser implements proxy.UserManager.AddUser(). 56 func (s *Server) AddUser(ctx context.Context, u *protocol.MemoryUser) error { 57 return s.validator.Add(u) 58 } 59 60 // RemoveUser implements proxy.UserManager.RemoveUser(). 61 func (s *Server) RemoveUser(ctx context.Context, e string) error { 62 return s.validator.Del(e) 63 } 64 65 func (s *Server) Network() []net.Network { 66 list := s.config.Network 67 if len(list) == 0 { 68 list = append(list, net.Network_TCP) 69 } 70 return list 71 } 72 73 func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Connection, dispatcher routing.Dispatcher) error { 74 switch network { 75 case net.Network_TCP: 76 return s.handleConnection(ctx, conn, dispatcher) 77 case net.Network_UDP: 78 return s.handleUDPPayload(ctx, conn, dispatcher) 79 default: 80 return newError("unknown network: ", network) 81 } 82 } 83 84 func (s *Server) handleUDPPayload(ctx context.Context, conn stat.Connection, dispatcher routing.Dispatcher) error { 85 udpServer := udp.NewDispatcher(dispatcher, func(ctx context.Context, packet *udp_proto.Packet) { 86 request := protocol.RequestHeaderFromContext(ctx) 87 if request == nil { 88 return 89 } 90 91 payload := packet.Payload 92 93 if payload.UDP != nil { 94 request = &protocol.RequestHeader{ 95 User: request.User, 96 Address: payload.UDP.Address, 97 Port: payload.UDP.Port, 98 } 99 } 100 101 data, err := EncodeUDPPacket(request, payload.Bytes()) 102 payload.Release() 103 if err != nil { 104 newError("failed to encode UDP packet").Base(err).AtWarning().WriteToLog(session.ExportIDToError(ctx)) 105 return 106 } 107 defer data.Release() 108 109 conn.Write(data.Bytes()) 110 }) 111 112 inbound := session.InboundFromContext(ctx) 113 if inbound == nil { 114 panic("no inbound metadata") 115 } 116 117 var dest *net.Destination 118 119 reader := buf.NewPacketReader(conn) 120 for { 121 mpayload, err := reader.ReadMultiBuffer() 122 if err != nil { 123 break 124 } 125 126 for _, payload := range mpayload { 127 var request *protocol.RequestHeader 128 var data *buf.Buffer 129 var err error 130 131 if inbound.User != nil { 132 validator := new(Validator) 133 validator.Add(inbound.User) 134 request, data, err = DecodeUDPPacket(validator, payload) 135 } else { 136 request, data, err = DecodeUDPPacket(s.validator, payload) 137 if err == nil { 138 inbound.User = request.User 139 } 140 } 141 142 if err != nil { 143 if inbound.Source.IsValid() { 144 newError("dropping invalid UDP packet from: ", inbound.Source).Base(err).WriteToLog(session.ExportIDToError(ctx)) 145 log.Record(&log.AccessMessage{ 146 From: inbound.Source, 147 To: "", 148 Status: log.AccessRejected, 149 Reason: err, 150 }) 151 } 152 payload.Release() 153 continue 154 } 155 156 destination := request.Destination() 157 158 currentPacketCtx := ctx 159 if inbound.Source.IsValid() { 160 currentPacketCtx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{ 161 From: inbound.Source, 162 To: destination, 163 Status: log.AccessAccepted, 164 Reason: "", 165 Email: request.User.Email, 166 }) 167 } 168 newError("tunnelling request to ", destination).WriteToLog(session.ExportIDToError(currentPacketCtx)) 169 170 data.UDP = &destination 171 172 if !s.cone || dest == nil { 173 dest = &destination 174 } 175 176 currentPacketCtx = protocol.ContextWithRequestHeader(currentPacketCtx, request) 177 udpServer.Dispatch(currentPacketCtx, *dest, data) 178 } 179 } 180 181 return nil 182 } 183 184 func (s *Server) handleConnection(ctx context.Context, conn stat.Connection, dispatcher routing.Dispatcher) error { 185 sessionPolicy := s.policyManager.ForLevel(0) 186 if err := conn.SetReadDeadline(time.Now().Add(sessionPolicy.Timeouts.Handshake)); err != nil { 187 return newError("unable to set read deadline").Base(err).AtWarning() 188 } 189 190 bufferedReader := buf.BufferedReader{Reader: buf.NewReader(conn)} 191 request, bodyReader, err := ReadTCPSession(s.validator, &bufferedReader) 192 if err != nil { 193 log.Record(&log.AccessMessage{ 194 From: conn.RemoteAddr(), 195 To: "", 196 Status: log.AccessRejected, 197 Reason: err, 198 }) 199 return newError("failed to create request from: ", conn.RemoteAddr()).Base(err) 200 } 201 conn.SetReadDeadline(time.Time{}) 202 203 inbound := session.InboundFromContext(ctx) 204 if inbound == nil { 205 panic("no inbound metadata") 206 } 207 inbound.User = request.User 208 209 dest := request.Destination() 210 ctx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{ 211 From: conn.RemoteAddr(), 212 To: dest, 213 Status: log.AccessAccepted, 214 Reason: "", 215 Email: request.User.Email, 216 }) 217 newError("tunnelling request to ", dest).WriteToLog(session.ExportIDToError(ctx)) 218 219 sessionPolicy = s.policyManager.ForLevel(request.User.Level) 220 ctx, cancel := context.WithCancel(ctx) 221 timer := signal.CancelAfterInactivity(ctx, cancel, sessionPolicy.Timeouts.ConnectionIdle) 222 223 ctx = policy.ContextWithBufferPolicy(ctx, sessionPolicy.Buffer) 224 link, err := dispatcher.Dispatch(ctx, dest) 225 if err != nil { 226 return err 227 } 228 229 responseDone := func() error { 230 defer timer.SetTimeout(sessionPolicy.Timeouts.UplinkOnly) 231 232 bufferedWriter := buf.NewBufferedWriter(buf.NewWriter(conn)) 233 responseWriter, err := WriteTCPResponse(request, bufferedWriter) 234 if err != nil { 235 return newError("failed to write response").Base(err) 236 } 237 238 { 239 payload, err := link.Reader.ReadMultiBuffer() 240 if err != nil { 241 return err 242 } 243 if err := responseWriter.WriteMultiBuffer(payload); err != nil { 244 return err 245 } 246 } 247 248 if err := bufferedWriter.SetBuffered(false); err != nil { 249 return err 250 } 251 252 if err := buf.Copy(link.Reader, responseWriter, buf.UpdateActivity(timer)); err != nil { 253 return newError("failed to transport all TCP response").Base(err) 254 } 255 256 return nil 257 } 258 259 requestDone := func() error { 260 defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly) 261 262 if err := buf.Copy(bodyReader, link.Writer, buf.UpdateActivity(timer)); err != nil { 263 return newError("failed to transport all TCP request").Base(err) 264 } 265 266 return nil 267 } 268 269 requestDoneAndCloseWriter := task.OnSuccess(requestDone, task.Close(link.Writer)) 270 if err := task.Run(ctx, requestDoneAndCloseWriter, responseDone); err != nil { 271 common.Interrupt(link.Reader) 272 common.Interrupt(link.Writer) 273 return newError("connection ends").Base(err) 274 } 275 276 return nil 277 } 278 279 func init() { 280 common.Must(common.RegisterConfig((*ServerConfig)(nil), func(ctx context.Context, config interface{}) (interface{}, error) { 281 return NewServer(ctx, config.(*ServerConfig)) 282 })) 283 }