github.com/v2fly/v2ray-core/v5@v5.16.2-0.20240507031116-8191faa6e095/proxy/socks/server.go (about)

     1  package socks
     2  
     3  import (
     4  	"context"
     5  	"io"
     6  	"time"
     7  
     8  	core "github.com/v2fly/v2ray-core/v5"
     9  	"github.com/v2fly/v2ray-core/v5/common"
    10  	"github.com/v2fly/v2ray-core/v5/common/buf"
    11  	"github.com/v2fly/v2ray-core/v5/common/log"
    12  	"github.com/v2fly/v2ray-core/v5/common/net"
    13  	"github.com/v2fly/v2ray-core/v5/common/net/packetaddr"
    14  	"github.com/v2fly/v2ray-core/v5/common/protocol"
    15  	udp_proto "github.com/v2fly/v2ray-core/v5/common/protocol/udp"
    16  	"github.com/v2fly/v2ray-core/v5/common/session"
    17  	"github.com/v2fly/v2ray-core/v5/common/signal"
    18  	"github.com/v2fly/v2ray-core/v5/common/task"
    19  	"github.com/v2fly/v2ray-core/v5/features"
    20  	"github.com/v2fly/v2ray-core/v5/features/policy"
    21  	"github.com/v2fly/v2ray-core/v5/features/routing"
    22  	"github.com/v2fly/v2ray-core/v5/transport/internet"
    23  	"github.com/v2fly/v2ray-core/v5/transport/internet/udp"
    24  )
    25  
    26  // Server is a SOCKS 5 proxy server
    27  type Server struct {
    28  	config        *ServerConfig
    29  	policyManager policy.Manager
    30  }
    31  
    32  // NewServer creates a new Server object.
    33  func NewServer(ctx context.Context, config *ServerConfig) (*Server, error) {
    34  	v := core.MustFromContext(ctx)
    35  	s := &Server{
    36  		config:        config,
    37  		policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager),
    38  	}
    39  	return s, nil
    40  }
    41  
    42  func (s *Server) policy() policy.Session {
    43  	config := s.config
    44  	p := s.policyManager.ForLevel(config.UserLevel)
    45  	if config.Timeout > 0 {
    46  		features.PrintDeprecatedFeatureWarning("Socks timeout")
    47  	}
    48  	if config.Timeout > 0 && config.UserLevel == 0 {
    49  		p.Timeouts.ConnectionIdle = time.Duration(config.Timeout) * time.Second
    50  	}
    51  	return p
    52  }
    53  
    54  // Network implements proxy.Inbound.
    55  func (s *Server) Network() []net.Network {
    56  	list := []net.Network{net.Network_TCP}
    57  	if s.config.UdpEnabled {
    58  		list = append(list, net.Network_UDP)
    59  	}
    60  	return list
    61  }
    62  
    63  // Process implements proxy.Inbound.
    64  func (s *Server) Process(ctx context.Context, network net.Network, conn internet.Connection, dispatcher routing.Dispatcher) error {
    65  	if inbound := session.InboundFromContext(ctx); inbound != nil {
    66  		inbound.User = &protocol.MemoryUser{
    67  			Level: s.config.UserLevel,
    68  		}
    69  	}
    70  
    71  	switch network {
    72  	case net.Network_TCP:
    73  		return s.processTCP(ctx, conn, dispatcher)
    74  	case net.Network_UDP:
    75  		return s.handleUDPPayload(ctx, conn, dispatcher)
    76  	default:
    77  		return newError("unknown network: ", network)
    78  	}
    79  }
    80  
    81  func (s *Server) processTCP(ctx context.Context, conn internet.Connection, dispatcher routing.Dispatcher) error {
    82  	plcy := s.policy()
    83  	if err := conn.SetReadDeadline(time.Now().Add(plcy.Timeouts.Handshake)); err != nil {
    84  		newError("failed to set deadline").Base(err).WriteToLog(session.ExportIDToError(ctx))
    85  	}
    86  
    87  	inbound := session.InboundFromContext(ctx)
    88  	if inbound == nil || !inbound.Gateway.IsValid() {
    89  		return newError("inbound gateway not specified")
    90  	}
    91  
    92  	svrSession := &ServerSession{
    93  		config:        s.config,
    94  		address:       inbound.Gateway.Address,
    95  		port:          inbound.Gateway.Port,
    96  		clientAddress: inbound.Source.Address,
    97  	}
    98  
    99  	reader := &buf.BufferedReader{Reader: buf.NewReader(conn)}
   100  	request, err := svrSession.Handshake(reader, conn)
   101  	if err != nil {
   102  		if inbound != nil && inbound.Source.IsValid() {
   103  			log.Record(&log.AccessMessage{
   104  				From:   inbound.Source,
   105  				To:     "",
   106  				Status: log.AccessRejected,
   107  				Reason: err,
   108  			})
   109  		}
   110  		return newError("failed to read request").Base(err)
   111  	}
   112  	if request.User != nil {
   113  		inbound.User.Email = request.User.Email
   114  	}
   115  
   116  	if err := conn.SetReadDeadline(time.Time{}); err != nil {
   117  		newError("failed to clear deadline").Base(err).WriteToLog(session.ExportIDToError(ctx))
   118  	}
   119  
   120  	if request.Command == protocol.RequestCommandTCP {
   121  		dest := request.Destination()
   122  		newError("TCP Connect request to ", dest).WriteToLog(session.ExportIDToError(ctx))
   123  		if inbound != nil && inbound.Source.IsValid() {
   124  			ctx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{
   125  				From:   inbound.Source,
   126  				To:     dest,
   127  				Status: log.AccessAccepted,
   128  				Reason: "",
   129  			})
   130  		}
   131  
   132  		return s.transport(ctx, reader, conn, dest, dispatcher)
   133  	}
   134  
   135  	if request.Command == protocol.RequestCommandUDP {
   136  		return s.handleUDP(conn)
   137  	}
   138  
   139  	return nil
   140  }
   141  
   142  func (*Server) handleUDP(c io.Reader) error {
   143  	// The TCP connection closes after this method returns. We need to wait until
   144  	// the client closes it.
   145  	return common.Error2(io.Copy(buf.DiscardBytes, c))
   146  }
   147  
   148  func (s *Server) transport(ctx context.Context, reader io.Reader, writer io.Writer, dest net.Destination, dispatcher routing.Dispatcher) error {
   149  	ctx, cancel := context.WithCancel(ctx)
   150  	timer := signal.CancelAfterInactivity(ctx, cancel, s.policy().Timeouts.ConnectionIdle)
   151  
   152  	plcy := s.policy()
   153  	ctx = policy.ContextWithBufferPolicy(ctx, plcy.Buffer)
   154  	link, err := dispatcher.Dispatch(ctx, dest)
   155  	if err != nil {
   156  		return err
   157  	}
   158  
   159  	requestDone := func() error {
   160  		defer timer.SetTimeout(plcy.Timeouts.DownlinkOnly)
   161  		if err := buf.Copy(buf.NewReader(reader), link.Writer, buf.UpdateActivity(timer)); err != nil {
   162  			return newError("failed to transport all TCP request").Base(err)
   163  		}
   164  
   165  		return nil
   166  	}
   167  
   168  	responseDone := func() error {
   169  		defer timer.SetTimeout(plcy.Timeouts.UplinkOnly)
   170  
   171  		v2writer := buf.NewWriter(writer)
   172  		if err := buf.Copy(link.Reader, v2writer, buf.UpdateActivity(timer)); err != nil {
   173  			return newError("failed to transport all TCP response").Base(err)
   174  		}
   175  
   176  		return nil
   177  	}
   178  
   179  	requestDonePost := task.OnSuccess(requestDone, task.Close(link.Writer))
   180  	if err := task.Run(ctx, requestDonePost, responseDone); err != nil {
   181  		common.Interrupt(link.Reader)
   182  		common.Interrupt(link.Writer)
   183  		return newError("connection ends").Base(err)
   184  	}
   185  
   186  	return nil
   187  }
   188  
   189  func (s *Server) handleUDPPayload(ctx context.Context, conn internet.Connection, dispatcher routing.Dispatcher) error {
   190  	udpDispatcherConstructor := udp.NewSplitDispatcher
   191  	switch s.config.PacketEncoding {
   192  	case packetaddr.PacketAddrType_None:
   193  		break
   194  	case packetaddr.PacketAddrType_Packet:
   195  		packetAddrDispatcherFactory := udp.NewPacketAddrDispatcherCreator(ctx)
   196  		udpDispatcherConstructor = packetAddrDispatcherFactory.NewPacketAddrDispatcher
   197  	}
   198  	udpServer := udpDispatcherConstructor(dispatcher, func(ctx context.Context, packet *udp_proto.Packet) {
   199  		payload := packet.Payload
   200  		newError("writing back UDP response with ", payload.Len(), " bytes").AtDebug().WriteToLog(session.ExportIDToError(ctx))
   201  
   202  		request := protocol.RequestHeaderFromContext(ctx)
   203  		var packetSource net.Destination
   204  		if request == nil {
   205  			packetSource = packet.Source
   206  		} else {
   207  			packetSource = net.UDPDestination(request.Address, request.Port)
   208  		}
   209  		udpMessage, err := EncodeUDPPacketFromAddress(packetSource, payload.Bytes())
   210  		payload.Release()
   211  
   212  		defer udpMessage.Release()
   213  		if err != nil {
   214  			newError("failed to write UDP response").AtWarning().Base(err).WriteToLog(session.ExportIDToError(ctx))
   215  		}
   216  
   217  		conn.Write(udpMessage.Bytes())
   218  	})
   219  
   220  	if inbound := session.InboundFromContext(ctx); inbound != nil && inbound.Source.IsValid() {
   221  		newError("client UDP connection from ", inbound.Source).WriteToLog(session.ExportIDToError(ctx))
   222  	}
   223  
   224  	reader := buf.NewPacketReader(conn)
   225  	for {
   226  		mpayload, err := reader.ReadMultiBuffer()
   227  		if err != nil {
   228  			return err
   229  		}
   230  
   231  		for _, payload := range mpayload {
   232  			request, err := DecodeUDPPacket(payload)
   233  			if err != nil {
   234  				newError("failed to parse UDP request").Base(err).WriteToLog(session.ExportIDToError(ctx))
   235  				payload.Release()
   236  				continue
   237  			}
   238  
   239  			if payload.IsEmpty() {
   240  				payload.Release()
   241  				continue
   242  			}
   243  			currentPacketCtx := ctx
   244  			newError("send packet to ", request.Destination(), " with ", payload.Len(), " bytes").AtDebug().WriteToLog(session.ExportIDToError(ctx))
   245  			if inbound := session.InboundFromContext(ctx); inbound != nil && inbound.Source.IsValid() {
   246  				currentPacketCtx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{
   247  					From:   inbound.Source,
   248  					To:     request.Destination(),
   249  					Status: log.AccessAccepted,
   250  					Reason: "",
   251  				})
   252  			}
   253  
   254  			currentPacketCtx = protocol.ContextWithRequestHeader(currentPacketCtx, request)
   255  			udpServer.Dispatch(currentPacketCtx, request.Destination(), payload)
   256  		}
   257  	}
   258  }
   259  
   260  func init() {
   261  	common.Must(common.RegisterConfig((*ServerConfig)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
   262  		return NewServer(ctx, config.(*ServerConfig))
   263  	}))
   264  }