github.com/xraypb/xray-core@v1.6.6/proxy/shadowsocks/server.go (about)

     1  package shadowsocks
     2  
     3  import (
     4  	"context"
     5  	"time"
     6  
     7  	"github.com/xraypb/xray-core/common"
     8  	"github.com/xraypb/xray-core/common/buf"
     9  	"github.com/xraypb/xray-core/common/log"
    10  	"github.com/xraypb/xray-core/common/net"
    11  	"github.com/xraypb/xray-core/common/protocol"
    12  	udp_proto "github.com/xraypb/xray-core/common/protocol/udp"
    13  	"github.com/xraypb/xray-core/common/session"
    14  	"github.com/xraypb/xray-core/common/signal"
    15  	"github.com/xraypb/xray-core/common/task"
    16  	"github.com/xraypb/xray-core/core"
    17  	"github.com/xraypb/xray-core/features/policy"
    18  	"github.com/xraypb/xray-core/features/routing"
    19  	"github.com/xraypb/xray-core/transport/internet/stat"
    20  	"github.com/xraypb/xray-core/transport/internet/udp"
    21  )
    22  
    23  type Server struct {
    24  	config        *ServerConfig
    25  	validator     *Validator
    26  	policyManager policy.Manager
    27  	cone          bool
    28  }
    29  
    30  // NewServer create a new Shadowsocks server.
    31  func NewServer(ctx context.Context, config *ServerConfig) (*Server, error) {
    32  	validator := new(Validator)
    33  	for _, user := range config.Users {
    34  		u, err := user.ToMemoryUser()
    35  		if err != nil {
    36  			return nil, newError("failed to get shadowsocks user").Base(err).AtError()
    37  		}
    38  
    39  		if err := validator.Add(u); err != nil {
    40  			return nil, newError("failed to add user").Base(err).AtError()
    41  		}
    42  	}
    43  
    44  	v := core.MustFromContext(ctx)
    45  	s := &Server{
    46  		config:        config,
    47  		validator:     validator,
    48  		policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager),
    49  		cone:          ctx.Value("cone").(bool),
    50  	}
    51  
    52  	return s, nil
    53  }
    54  
    55  // AddUser implements proxy.UserManager.AddUser().
    56  func (s *Server) AddUser(ctx context.Context, u *protocol.MemoryUser) error {
    57  	return s.validator.Add(u)
    58  }
    59  
    60  // RemoveUser implements proxy.UserManager.RemoveUser().
    61  func (s *Server) RemoveUser(ctx context.Context, e string) error {
    62  	return s.validator.Del(e)
    63  }
    64  
    65  func (s *Server) Network() []net.Network {
    66  	list := s.config.Network
    67  	if len(list) == 0 {
    68  		list = append(list, net.Network_TCP)
    69  	}
    70  	return list
    71  }
    72  
    73  func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Connection, dispatcher routing.Dispatcher) error {
    74  	switch network {
    75  	case net.Network_TCP:
    76  		return s.handleConnection(ctx, conn, dispatcher)
    77  	case net.Network_UDP:
    78  		return s.handleUDPPayload(ctx, conn, dispatcher)
    79  	default:
    80  		return newError("unknown network: ", network)
    81  	}
    82  }
    83  
    84  func (s *Server) handleUDPPayload(ctx context.Context, conn stat.Connection, dispatcher routing.Dispatcher) error {
    85  	udpServer := udp.NewDispatcher(dispatcher, func(ctx context.Context, packet *udp_proto.Packet) {
    86  		request := protocol.RequestHeaderFromContext(ctx)
    87  		if request == nil {
    88  			return
    89  		}
    90  
    91  		payload := packet.Payload
    92  
    93  		if payload.UDP != nil {
    94  			request = &protocol.RequestHeader{
    95  				User:    request.User,
    96  				Address: payload.UDP.Address,
    97  				Port:    payload.UDP.Port,
    98  			}
    99  		}
   100  
   101  		data, err := EncodeUDPPacket(request, payload.Bytes())
   102  		payload.Release()
   103  		if err != nil {
   104  			newError("failed to encode UDP packet").Base(err).AtWarning().WriteToLog(session.ExportIDToError(ctx))
   105  			return
   106  		}
   107  		defer data.Release()
   108  
   109  		conn.Write(data.Bytes())
   110  	})
   111  
   112  	inbound := session.InboundFromContext(ctx)
   113  	if inbound == nil {
   114  		panic("no inbound metadata")
   115  	}
   116  
   117  	var dest *net.Destination
   118  
   119  	reader := buf.NewPacketReader(conn)
   120  	for {
   121  		mpayload, err := reader.ReadMultiBuffer()
   122  		if err != nil {
   123  			break
   124  		}
   125  
   126  		for _, payload := range mpayload {
   127  			var request *protocol.RequestHeader
   128  			var data *buf.Buffer
   129  			var err error
   130  
   131  			if inbound.User != nil {
   132  				validator := new(Validator)
   133  				validator.Add(inbound.User)
   134  				request, data, err = DecodeUDPPacket(validator, payload)
   135  			} else {
   136  				request, data, err = DecodeUDPPacket(s.validator, payload)
   137  				if err == nil {
   138  					inbound.User = request.User
   139  				}
   140  			}
   141  
   142  			if err != nil {
   143  				if inbound.Source.IsValid() {
   144  					newError("dropping invalid UDP packet from: ", inbound.Source).Base(err).WriteToLog(session.ExportIDToError(ctx))
   145  					log.Record(&log.AccessMessage{
   146  						From:   inbound.Source,
   147  						To:     "",
   148  						Status: log.AccessRejected,
   149  						Reason: err,
   150  					})
   151  				}
   152  				payload.Release()
   153  				continue
   154  			}
   155  
   156  			destination := request.Destination()
   157  
   158  			currentPacketCtx := ctx
   159  			if inbound.Source.IsValid() {
   160  				currentPacketCtx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{
   161  					From:   inbound.Source,
   162  					To:     destination,
   163  					Status: log.AccessAccepted,
   164  					Reason: "",
   165  					Email:  request.User.Email,
   166  				})
   167  			}
   168  			newError("tunnelling request to ", destination).WriteToLog(session.ExportIDToError(currentPacketCtx))
   169  
   170  			data.UDP = &destination
   171  
   172  			if !s.cone || dest == nil {
   173  				dest = &destination
   174  			}
   175  
   176  			currentPacketCtx = protocol.ContextWithRequestHeader(currentPacketCtx, request)
   177  			udpServer.Dispatch(currentPacketCtx, *dest, data)
   178  		}
   179  	}
   180  
   181  	return nil
   182  }
   183  
   184  func (s *Server) handleConnection(ctx context.Context, conn stat.Connection, dispatcher routing.Dispatcher) error {
   185  	sessionPolicy := s.policyManager.ForLevel(0)
   186  	if err := conn.SetReadDeadline(time.Now().Add(sessionPolicy.Timeouts.Handshake)); err != nil {
   187  		return newError("unable to set read deadline").Base(err).AtWarning()
   188  	}
   189  
   190  	bufferedReader := buf.BufferedReader{Reader: buf.NewReader(conn)}
   191  	request, bodyReader, err := ReadTCPSession(s.validator, &bufferedReader)
   192  	if err != nil {
   193  		log.Record(&log.AccessMessage{
   194  			From:   conn.RemoteAddr(),
   195  			To:     "",
   196  			Status: log.AccessRejected,
   197  			Reason: err,
   198  		})
   199  		return newError("failed to create request from: ", conn.RemoteAddr()).Base(err)
   200  	}
   201  	conn.SetReadDeadline(time.Time{})
   202  
   203  	inbound := session.InboundFromContext(ctx)
   204  	if inbound == nil {
   205  		panic("no inbound metadata")
   206  	}
   207  	inbound.User = request.User
   208  
   209  	dest := request.Destination()
   210  	ctx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{
   211  		From:   conn.RemoteAddr(),
   212  		To:     dest,
   213  		Status: log.AccessAccepted,
   214  		Reason: "",
   215  		Email:  request.User.Email,
   216  	})
   217  	newError("tunnelling request to ", dest).WriteToLog(session.ExportIDToError(ctx))
   218  
   219  	sessionPolicy = s.policyManager.ForLevel(request.User.Level)
   220  	ctx, cancel := context.WithCancel(ctx)
   221  	timer := signal.CancelAfterInactivity(ctx, cancel, sessionPolicy.Timeouts.ConnectionIdle)
   222  
   223  	ctx = policy.ContextWithBufferPolicy(ctx, sessionPolicy.Buffer)
   224  	link, err := dispatcher.Dispatch(ctx, dest)
   225  	if err != nil {
   226  		return err
   227  	}
   228  
   229  	responseDone := func() error {
   230  		defer timer.SetTimeout(sessionPolicy.Timeouts.UplinkOnly)
   231  
   232  		bufferedWriter := buf.NewBufferedWriter(buf.NewWriter(conn))
   233  		responseWriter, err := WriteTCPResponse(request, bufferedWriter)
   234  		if err != nil {
   235  			return newError("failed to write response").Base(err)
   236  		}
   237  
   238  		{
   239  			payload, err := link.Reader.ReadMultiBuffer()
   240  			if err != nil {
   241  				return err
   242  			}
   243  			if err := responseWriter.WriteMultiBuffer(payload); err != nil {
   244  				return err
   245  			}
   246  		}
   247  
   248  		if err := bufferedWriter.SetBuffered(false); err != nil {
   249  			return err
   250  		}
   251  
   252  		if err := buf.Copy(link.Reader, responseWriter, buf.UpdateActivity(timer)); err != nil {
   253  			return newError("failed to transport all TCP response").Base(err)
   254  		}
   255  
   256  		return nil
   257  	}
   258  
   259  	requestDone := func() error {
   260  		defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly)
   261  
   262  		if err := buf.Copy(bodyReader, link.Writer, buf.UpdateActivity(timer)); err != nil {
   263  			return newError("failed to transport all TCP request").Base(err)
   264  		}
   265  
   266  		return nil
   267  	}
   268  
   269  	requestDoneAndCloseWriter := task.OnSuccess(requestDone, task.Close(link.Writer))
   270  	if err := task.Run(ctx, requestDoneAndCloseWriter, responseDone); err != nil {
   271  		common.Interrupt(link.Reader)
   272  		common.Interrupt(link.Writer)
   273  		return newError("connection ends").Base(err)
   274  	}
   275  
   276  	return nil
   277  }
   278  
   279  func init() {
   280  	common.Must(common.RegisterConfig((*ServerConfig)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
   281  		return NewServer(ctx, config.(*ServerConfig))
   282  	}))
   283  }