github.com/EagleQL/Xray-core@v1.4.3/proxy/shadowsocks/server.go (about)

     1  package shadowsocks
     2  
     3  import (
     4  	"context"
     5  	"time"
     6  
     7  	"github.com/xtls/xray-core/common"
     8  	"github.com/xtls/xray-core/common/buf"
     9  	"github.com/xtls/xray-core/common/log"
    10  	"github.com/xtls/xray-core/common/net"
    11  	"github.com/xtls/xray-core/common/protocol"
    12  	udp_proto "github.com/xtls/xray-core/common/protocol/udp"
    13  	"github.com/xtls/xray-core/common/session"
    14  	"github.com/xtls/xray-core/common/signal"
    15  	"github.com/xtls/xray-core/common/task"
    16  	"github.com/xtls/xray-core/core"
    17  	"github.com/xtls/xray-core/features/policy"
    18  	"github.com/xtls/xray-core/features/routing"
    19  	"github.com/xtls/xray-core/transport/internet"
    20  	"github.com/xtls/xray-core/transport/internet/udp"
    21  )
    22  
    23  type Server struct {
    24  	config        *ServerConfig
    25  	validator     *Validator
    26  	policyManager policy.Manager
    27  	cone          bool
    28  }
    29  
    30  // NewServer create a new Shadowsocks server.
    31  func NewServer(ctx context.Context, config *ServerConfig) (*Server, error) {
    32  	validator := new(Validator)
    33  	for _, user := range config.Users {
    34  		u, err := user.ToMemoryUser()
    35  		if err != nil {
    36  			return nil, newError("failed to get shadowsocks user").Base(err).AtError()
    37  		}
    38  
    39  		if err := validator.Add(u); err != nil {
    40  			return nil, newError("failed to add user").Base(err).AtError()
    41  		}
    42  	}
    43  
    44  	v := core.MustFromContext(ctx)
    45  	s := &Server{
    46  		config:        config,
    47  		validator:     validator,
    48  		policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager),
    49  		cone:          ctx.Value("cone").(bool),
    50  	}
    51  
    52  	return s, nil
    53  }
    54  
    55  // AddUser implements proxy.UserManager.AddUser().
    56  func (s *Server) AddUser(ctx context.Context, u *protocol.MemoryUser) error {
    57  	return s.validator.Add(u)
    58  }
    59  
    60  // RemoveUser implements proxy.UserManager.RemoveUser().
    61  func (s *Server) RemoveUser(ctx context.Context, e string) error {
    62  	return s.validator.Del(e)
    63  }
    64  
    65  func (s *Server) Network() []net.Network {
    66  	list := s.config.Network
    67  	if len(list) == 0 {
    68  		list = append(list, net.Network_TCP)
    69  	}
    70  	return list
    71  }
    72  
    73  func (s *Server) Process(ctx context.Context, network net.Network, conn internet.Connection, dispatcher routing.Dispatcher) error {
    74  	switch network {
    75  	case net.Network_TCP:
    76  		return s.handleConnection(ctx, conn, dispatcher)
    77  	case net.Network_UDP:
    78  		return s.handleUDPPayload(ctx, conn, dispatcher)
    79  	default:
    80  		return newError("unknown network: ", network)
    81  	}
    82  }
    83  
    84  func (s *Server) handleUDPPayload(ctx context.Context, conn internet.Connection, dispatcher routing.Dispatcher) error {
    85  	udpServer := udp.NewDispatcher(dispatcher, func(ctx context.Context, packet *udp_proto.Packet) {
    86  		request := protocol.RequestHeaderFromContext(ctx)
    87  		if request == nil {
    88  			return
    89  		}
    90  
    91  		payload := packet.Payload
    92  
    93  		if payload.UDP != nil {
    94  			request = &protocol.RequestHeader{
    95  				User:    request.User,
    96  				Address: payload.UDP.Address,
    97  				Port:    payload.UDP.Port,
    98  			}
    99  		}
   100  
   101  		data, err := EncodeUDPPacket(request, payload.Bytes())
   102  		payload.Release()
   103  		if err != nil {
   104  			newError("failed to encode UDP packet").Base(err).AtWarning().WriteToLog(session.ExportIDToError(ctx))
   105  			return
   106  		}
   107  		defer data.Release()
   108  
   109  		conn.Write(data.Bytes())
   110  	})
   111  
   112  	inbound := session.InboundFromContext(ctx)
   113  	if inbound == nil {
   114  		panic("no inbound metadata")
   115  	}
   116  
   117  	if s.validator.Count() == 1 {
   118  		inbound.User, _ = s.validator.GetOnlyUser()
   119  	}
   120  
   121  	var dest *net.Destination
   122  
   123  	reader := buf.NewPacketReader(conn)
   124  	for {
   125  		mpayload, err := reader.ReadMultiBuffer()
   126  		if err != nil {
   127  			break
   128  		}
   129  
   130  		for _, payload := range mpayload {
   131  			var request *protocol.RequestHeader
   132  			var data *buf.Buffer
   133  			var err error
   134  
   135  			if inbound.User != nil {
   136  				validator := new(Validator)
   137  				validator.Add(inbound.User)
   138  				request, data, err = DecodeUDPPacket(validator, payload)
   139  			} else {
   140  				request, data, err = DecodeUDPPacket(s.validator, payload)
   141  				if err == nil {
   142  					inbound.User = request.User
   143  				}
   144  			}
   145  
   146  			if err != nil {
   147  				if inbound.Source.IsValid() {
   148  					newError("dropping invalid UDP packet from: ", inbound.Source).Base(err).WriteToLog(session.ExportIDToError(ctx))
   149  					log.Record(&log.AccessMessage{
   150  						From:   inbound.Source,
   151  						To:     "",
   152  						Status: log.AccessRejected,
   153  						Reason: err,
   154  					})
   155  				}
   156  				payload.Release()
   157  				continue
   158  			}
   159  
   160  			destination := request.Destination()
   161  
   162  			currentPacketCtx := ctx
   163  			if inbound.Source.IsValid() {
   164  				currentPacketCtx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{
   165  					From:   inbound.Source,
   166  					To:     destination,
   167  					Status: log.AccessAccepted,
   168  					Reason: "",
   169  					Email:  request.User.Email,
   170  				})
   171  			}
   172  			newError("tunnelling request to ", destination).WriteToLog(session.ExportIDToError(currentPacketCtx))
   173  
   174  			data.UDP = &destination
   175  
   176  			if !s.cone || dest == nil {
   177  				dest = &destination
   178  			}
   179  
   180  			currentPacketCtx = protocol.ContextWithRequestHeader(currentPacketCtx, request)
   181  			udpServer.Dispatch(currentPacketCtx, *dest, data)
   182  		}
   183  	}
   184  
   185  	return nil
   186  }
   187  
   188  func (s *Server) handleConnection(ctx context.Context, conn internet.Connection, dispatcher routing.Dispatcher) error {
   189  	sessionPolicy := s.policyManager.ForLevel(0)
   190  	if err := conn.SetReadDeadline(time.Now().Add(sessionPolicy.Timeouts.Handshake)); err != nil {
   191  		return newError("unable to set read deadline").Base(err).AtWarning()
   192  	}
   193  
   194  	bufferedReader := buf.BufferedReader{Reader: buf.NewReader(conn)}
   195  	request, bodyReader, err := ReadTCPSession(s.validator, &bufferedReader)
   196  	if err != nil {
   197  		log.Record(&log.AccessMessage{
   198  			From:   conn.RemoteAddr(),
   199  			To:     "",
   200  			Status: log.AccessRejected,
   201  			Reason: err,
   202  		})
   203  		return newError("failed to create request from: ", conn.RemoteAddr()).Base(err)
   204  	}
   205  	conn.SetReadDeadline(time.Time{})
   206  
   207  	inbound := session.InboundFromContext(ctx)
   208  	if inbound == nil {
   209  		panic("no inbound metadata")
   210  	}
   211  	inbound.User = request.User
   212  
   213  	dest := request.Destination()
   214  	ctx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{
   215  		From:   conn.RemoteAddr(),
   216  		To:     dest,
   217  		Status: log.AccessAccepted,
   218  		Reason: "",
   219  		Email:  request.User.Email,
   220  	})
   221  	newError("tunnelling request to ", dest).WriteToLog(session.ExportIDToError(ctx))
   222  
   223  	sessionPolicy = s.policyManager.ForLevel(request.User.Level)
   224  	ctx, cancel := context.WithCancel(ctx)
   225  	timer := signal.CancelAfterInactivity(ctx, cancel, sessionPolicy.Timeouts.ConnectionIdle)
   226  
   227  	ctx = policy.ContextWithBufferPolicy(ctx, sessionPolicy.Buffer)
   228  	link, err := dispatcher.Dispatch(ctx, dest)
   229  	if err != nil {
   230  		return err
   231  	}
   232  
   233  	responseDone := func() error {
   234  		defer timer.SetTimeout(sessionPolicy.Timeouts.UplinkOnly)
   235  
   236  		bufferedWriter := buf.NewBufferedWriter(buf.NewWriter(conn))
   237  		responseWriter, err := WriteTCPResponse(request, bufferedWriter)
   238  		if err != nil {
   239  			return newError("failed to write response").Base(err)
   240  		}
   241  
   242  		{
   243  			payload, err := link.Reader.ReadMultiBuffer()
   244  			if err != nil {
   245  				return err
   246  			}
   247  			if err := responseWriter.WriteMultiBuffer(payload); err != nil {
   248  				return err
   249  			}
   250  		}
   251  
   252  		if err := bufferedWriter.SetBuffered(false); err != nil {
   253  			return err
   254  		}
   255  
   256  		if err := buf.Copy(link.Reader, responseWriter, buf.UpdateActivity(timer)); err != nil {
   257  			return newError("failed to transport all TCP response").Base(err)
   258  		}
   259  
   260  		return nil
   261  	}
   262  
   263  	requestDone := func() error {
   264  		defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly)
   265  
   266  		if err := buf.Copy(bodyReader, link.Writer, buf.UpdateActivity(timer)); err != nil {
   267  			return newError("failed to transport all TCP request").Base(err)
   268  		}
   269  
   270  		return nil
   271  	}
   272  
   273  	var requestDoneAndCloseWriter = task.OnSuccess(requestDone, task.Close(link.Writer))
   274  	if err := task.Run(ctx, requestDoneAndCloseWriter, responseDone); err != nil {
   275  		common.Interrupt(link.Reader)
   276  		common.Interrupt(link.Writer)
   277  		return newError("connection ends").Base(err)
   278  	}
   279  
   280  	return nil
   281  }
   282  
   283  func init() {
   284  	common.Must(common.RegisterConfig((*ServerConfig)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
   285  		return NewServer(ctx, config.(*ServerConfig))
   286  	}))
   287  }