github.com/xraypb/xray-core@v1.6.6/proxy/socks/server.go (about)

     1  package socks
     2  
     3  import (
     4  	"context"
     5  	"io"
     6  	"time"
     7  
     8  	"github.com/xraypb/xray-core/common"
     9  	"github.com/xraypb/xray-core/common/buf"
    10  	"github.com/xraypb/xray-core/common/log"
    11  	"github.com/xraypb/xray-core/common/net"
    12  	"github.com/xraypb/xray-core/common/protocol"
    13  	udp_proto "github.com/xraypb/xray-core/common/protocol/udp"
    14  	"github.com/xraypb/xray-core/common/session"
    15  	"github.com/xraypb/xray-core/common/signal"
    16  	"github.com/xraypb/xray-core/common/task"
    17  	"github.com/xraypb/xray-core/core"
    18  	"github.com/xraypb/xray-core/features"
    19  	"github.com/xraypb/xray-core/features/policy"
    20  	"github.com/xraypb/xray-core/features/routing"
    21  	"github.com/xraypb/xray-core/transport/internet/stat"
    22  	"github.com/xraypb/xray-core/transport/internet/udp"
    23  )
    24  
    25  // Server is a SOCKS 5 proxy server
    26  type Server struct {
    27  	config        *ServerConfig
    28  	policyManager policy.Manager
    29  	cone          bool
    30  }
    31  
    32  // NewServer creates a new Server object.
    33  func NewServer(ctx context.Context, config *ServerConfig) (*Server, error) {
    34  	v := core.MustFromContext(ctx)
    35  	s := &Server{
    36  		config:        config,
    37  		policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager),
    38  		cone:          ctx.Value("cone").(bool),
    39  	}
    40  	return s, nil
    41  }
    42  
    43  func (s *Server) policy() policy.Session {
    44  	config := s.config
    45  	p := s.policyManager.ForLevel(config.UserLevel)
    46  	if config.Timeout > 0 {
    47  		features.PrintDeprecatedFeatureWarning("Socks timeout")
    48  	}
    49  	if config.Timeout > 0 && config.UserLevel == 0 {
    50  		p.Timeouts.ConnectionIdle = time.Duration(config.Timeout) * time.Second
    51  	}
    52  	return p
    53  }
    54  
    55  // Network implements proxy.Inbound.
    56  func (s *Server) Network() []net.Network {
    57  	list := []net.Network{net.Network_TCP}
    58  	if s.config.UdpEnabled {
    59  		list = append(list, net.Network_UDP)
    60  	}
    61  	return list
    62  }
    63  
    64  // Process implements proxy.Inbound.
    65  func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Connection, dispatcher routing.Dispatcher) error {
    66  	if inbound := session.InboundFromContext(ctx); inbound != nil {
    67  		inbound.User = &protocol.MemoryUser{
    68  			Level: s.config.UserLevel,
    69  		}
    70  	}
    71  
    72  	switch network {
    73  	case net.Network_TCP:
    74  		return s.processTCP(ctx, conn, dispatcher)
    75  	case net.Network_UDP:
    76  		return s.handleUDPPayload(ctx, conn, dispatcher)
    77  	default:
    78  		return newError("unknown network: ", network)
    79  	}
    80  }
    81  
    82  func (s *Server) processTCP(ctx context.Context, conn stat.Connection, dispatcher routing.Dispatcher) error {
    83  	plcy := s.policy()
    84  	if err := conn.SetReadDeadline(time.Now().Add(plcy.Timeouts.Handshake)); err != nil {
    85  		newError("failed to set deadline").Base(err).WriteToLog(session.ExportIDToError(ctx))
    86  	}
    87  
    88  	inbound := session.InboundFromContext(ctx)
    89  	if inbound == nil || !inbound.Gateway.IsValid() {
    90  		return newError("inbound gateway not specified")
    91  	}
    92  
    93  	svrSession := &ServerSession{
    94  		config:       s.config,
    95  		address:      inbound.Gateway.Address,
    96  		port:         inbound.Gateway.Port,
    97  		localAddress: net.IPAddress(conn.LocalAddr().(*net.TCPAddr).IP),
    98  	}
    99  
   100  	reader := &buf.BufferedReader{Reader: buf.NewReader(conn)}
   101  	request, err := svrSession.Handshake(reader, conn)
   102  	if err != nil {
   103  		if inbound != nil && inbound.Source.IsValid() {
   104  			log.Record(&log.AccessMessage{
   105  				From:   inbound.Source,
   106  				To:     "",
   107  				Status: log.AccessRejected,
   108  				Reason: err,
   109  			})
   110  		}
   111  		return newError("failed to read request").Base(err)
   112  	}
   113  	if request.User != nil {
   114  		inbound.User.Email = request.User.Email
   115  	}
   116  
   117  	if err := conn.SetReadDeadline(time.Time{}); err != nil {
   118  		newError("failed to clear deadline").Base(err).WriteToLog(session.ExportIDToError(ctx))
   119  	}
   120  
   121  	if request.Command == protocol.RequestCommandTCP {
   122  		dest := request.Destination()
   123  		newError("TCP Connect request to ", dest).WriteToLog(session.ExportIDToError(ctx))
   124  		if inbound != nil && inbound.Source.IsValid() {
   125  			ctx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{
   126  				From:   inbound.Source,
   127  				To:     dest,
   128  				Status: log.AccessAccepted,
   129  				Reason: "",
   130  			})
   131  		}
   132  
   133  		return s.transport(ctx, reader, conn, dest, dispatcher, inbound)
   134  	}
   135  
   136  	if request.Command == protocol.RequestCommandUDP {
   137  		return s.handleUDP(conn)
   138  	}
   139  
   140  	return nil
   141  }
   142  
   143  func (*Server) handleUDP(c io.Reader) error {
   144  	// The TCP connection closes after this method returns. We need to wait until
   145  	// the client closes it.
   146  	return common.Error2(io.Copy(buf.DiscardBytes, c))
   147  }
   148  
   149  func (s *Server) transport(ctx context.Context, reader io.Reader, writer io.Writer, dest net.Destination, dispatcher routing.Dispatcher, inbound *session.Inbound) error {
   150  	ctx, cancel := context.WithCancel(ctx)
   151  	timer := signal.CancelAfterInactivity(ctx, cancel, s.policy().Timeouts.ConnectionIdle)
   152  
   153  	if inbound != nil {
   154  		inbound.Timer = timer
   155  	}
   156  
   157  	plcy := s.policy()
   158  	ctx = policy.ContextWithBufferPolicy(ctx, plcy.Buffer)
   159  	link, err := dispatcher.Dispatch(ctx, dest)
   160  	if err != nil {
   161  		return err
   162  	}
   163  
   164  	requestDone := func() error {
   165  		defer timer.SetTimeout(plcy.Timeouts.DownlinkOnly)
   166  		if err := buf.Copy(buf.NewReader(reader), link.Writer, buf.UpdateActivity(timer)); err != nil {
   167  			return newError("failed to transport all TCP request").Base(err)
   168  		}
   169  
   170  		return nil
   171  	}
   172  
   173  	responseDone := func() error {
   174  		defer timer.SetTimeout(plcy.Timeouts.UplinkOnly)
   175  
   176  		v2writer := buf.NewWriter(writer)
   177  		if err := buf.Copy(link.Reader, v2writer, buf.UpdateActivity(timer)); err != nil {
   178  			return newError("failed to transport all TCP response").Base(err)
   179  		}
   180  
   181  		return nil
   182  	}
   183  
   184  	requestDonePost := task.OnSuccess(requestDone, task.Close(link.Writer))
   185  	if err := task.Run(ctx, requestDonePost, responseDone); err != nil {
   186  		common.Interrupt(link.Reader)
   187  		common.Interrupt(link.Writer)
   188  		return newError("connection ends").Base(err)
   189  	}
   190  
   191  	return nil
   192  }
   193  
   194  func (s *Server) handleUDPPayload(ctx context.Context, conn stat.Connection, dispatcher routing.Dispatcher) error {
   195  	udpServer := udp.NewDispatcher(dispatcher, func(ctx context.Context, packet *udp_proto.Packet) {
   196  		payload := packet.Payload
   197  		newError("writing back UDP response with ", payload.Len(), " bytes").AtDebug().WriteToLog(session.ExportIDToError(ctx))
   198  
   199  		request := protocol.RequestHeaderFromContext(ctx)
   200  		if request == nil {
   201  			return
   202  		}
   203  
   204  		if payload.UDP != nil {
   205  			request = &protocol.RequestHeader{
   206  				User:    request.User,
   207  				Address: payload.UDP.Address,
   208  				Port:    payload.UDP.Port,
   209  			}
   210  		}
   211  
   212  		udpMessage, err := EncodeUDPPacket(request, payload.Bytes())
   213  		payload.Release()
   214  
   215  		defer udpMessage.Release()
   216  		if err != nil {
   217  			newError("failed to write UDP response").AtWarning().Base(err).WriteToLog(session.ExportIDToError(ctx))
   218  		}
   219  
   220  		conn.Write(udpMessage.Bytes())
   221  	})
   222  
   223  	inbound := session.InboundFromContext(ctx)
   224  	if inbound != nil && inbound.Source.IsValid() {
   225  		newError("client UDP connection from ", inbound.Source).WriteToLog(session.ExportIDToError(ctx))
   226  	}
   227  
   228  	var dest *net.Destination
   229  
   230  	reader := buf.NewPacketReader(conn)
   231  	for {
   232  		mpayload, err := reader.ReadMultiBuffer()
   233  		if err != nil {
   234  			return err
   235  		}
   236  
   237  		for _, payload := range mpayload {
   238  			request, err := DecodeUDPPacket(payload)
   239  			if err != nil {
   240  				newError("failed to parse UDP request").Base(err).WriteToLog(session.ExportIDToError(ctx))
   241  				payload.Release()
   242  				continue
   243  			}
   244  
   245  			if payload.IsEmpty() {
   246  				payload.Release()
   247  				continue
   248  			}
   249  
   250  			destination := request.Destination()
   251  
   252  			currentPacketCtx := ctx
   253  			newError("send packet to ", destination, " with ", payload.Len(), " bytes").AtDebug().WriteToLog(session.ExportIDToError(ctx))
   254  			if inbound != nil && inbound.Source.IsValid() {
   255  				currentPacketCtx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{
   256  					From:   inbound.Source,
   257  					To:     destination,
   258  					Status: log.AccessAccepted,
   259  					Reason: "",
   260  				})
   261  			}
   262  
   263  			payload.UDP = &destination
   264  
   265  			if !s.cone || dest == nil {
   266  				dest = &destination
   267  			}
   268  
   269  			currentPacketCtx = protocol.ContextWithRequestHeader(currentPacketCtx, request)
   270  			udpServer.Dispatch(currentPacketCtx, *dest, payload)
   271  		}
   272  	}
   273  }
   274  
   275  func init() {
   276  	common.Must(common.RegisterConfig((*ServerConfig)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
   277  		return NewServer(ctx, config.(*ServerConfig))
   278  	}))
   279  }