github.com/v2fly/v2ray-core/v4@v4.45.2/proxy/socks/server.go (about)

     1  //go:build !confonly
     2  // +build !confonly
     3  
     4  package socks
     5  
     6  import (
     7  	"context"
     8  	"io"
     9  	"time"
    10  
    11  	core "github.com/v2fly/v2ray-core/v4"
    12  	"github.com/v2fly/v2ray-core/v4/common"
    13  	"github.com/v2fly/v2ray-core/v4/common/buf"
    14  	"github.com/v2fly/v2ray-core/v4/common/log"
    15  	"github.com/v2fly/v2ray-core/v4/common/net"
    16  	"github.com/v2fly/v2ray-core/v4/common/protocol"
    17  	udp_proto "github.com/v2fly/v2ray-core/v4/common/protocol/udp"
    18  	"github.com/v2fly/v2ray-core/v4/common/session"
    19  	"github.com/v2fly/v2ray-core/v4/common/signal"
    20  	"github.com/v2fly/v2ray-core/v4/common/task"
    21  	"github.com/v2fly/v2ray-core/v4/features"
    22  	"github.com/v2fly/v2ray-core/v4/features/policy"
    23  	"github.com/v2fly/v2ray-core/v4/features/routing"
    24  	"github.com/v2fly/v2ray-core/v4/transport/internet"
    25  	"github.com/v2fly/v2ray-core/v4/transport/internet/udp"
    26  )
    27  
    28  // Server is a SOCKS 5 proxy server
    29  type Server struct {
    30  	config        *ServerConfig
    31  	policyManager policy.Manager
    32  }
    33  
    34  // NewServer creates a new Server object.
    35  func NewServer(ctx context.Context, config *ServerConfig) (*Server, error) {
    36  	v := core.MustFromContext(ctx)
    37  	s := &Server{
    38  		config:        config,
    39  		policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager),
    40  	}
    41  	return s, nil
    42  }
    43  
    44  func (s *Server) policy() policy.Session {
    45  	config := s.config
    46  	p := s.policyManager.ForLevel(config.UserLevel)
    47  	if config.Timeout > 0 {
    48  		features.PrintDeprecatedFeatureWarning("Socks timeout")
    49  	}
    50  	if config.Timeout > 0 && config.UserLevel == 0 {
    51  		p.Timeouts.ConnectionIdle = time.Duration(config.Timeout) * time.Second
    52  	}
    53  	return p
    54  }
    55  
    56  // Network implements proxy.Inbound.
    57  func (s *Server) Network() []net.Network {
    58  	list := []net.Network{net.Network_TCP}
    59  	if s.config.UdpEnabled {
    60  		list = append(list, net.Network_UDP)
    61  	}
    62  	return list
    63  }
    64  
    65  // Process implements proxy.Inbound.
    66  func (s *Server) Process(ctx context.Context, network net.Network, conn internet.Connection, dispatcher routing.Dispatcher) error {
    67  	if inbound := session.InboundFromContext(ctx); inbound != nil {
    68  		inbound.User = &protocol.MemoryUser{
    69  			Level: s.config.UserLevel,
    70  		}
    71  	}
    72  
    73  	switch network {
    74  	case net.Network_TCP:
    75  		return s.processTCP(ctx, conn, dispatcher)
    76  	case net.Network_UDP:
    77  		return s.handleUDPPayload(ctx, conn, dispatcher)
    78  	default:
    79  		return newError("unknown network: ", network)
    80  	}
    81  }
    82  
    83  func (s *Server) processTCP(ctx context.Context, conn internet.Connection, dispatcher routing.Dispatcher) error {
    84  	plcy := s.policy()
    85  	if err := conn.SetReadDeadline(time.Now().Add(plcy.Timeouts.Handshake)); err != nil {
    86  		newError("failed to set deadline").Base(err).WriteToLog(session.ExportIDToError(ctx))
    87  	}
    88  
    89  	inbound := session.InboundFromContext(ctx)
    90  	if inbound == nil || !inbound.Gateway.IsValid() {
    91  		return newError("inbound gateway not specified")
    92  	}
    93  
    94  	svrSession := &ServerSession{
    95  		config:        s.config,
    96  		address:       inbound.Gateway.Address,
    97  		port:          inbound.Gateway.Port,
    98  		clientAddress: inbound.Source.Address,
    99  	}
   100  
   101  	reader := &buf.BufferedReader{Reader: buf.NewReader(conn)}
   102  	request, err := svrSession.Handshake(reader, conn)
   103  	if err != nil {
   104  		if inbound != nil && inbound.Source.IsValid() {
   105  			log.Record(&log.AccessMessage{
   106  				From:   inbound.Source,
   107  				To:     "",
   108  				Status: log.AccessRejected,
   109  				Reason: err,
   110  			})
   111  		}
   112  		return newError("failed to read request").Base(err)
   113  	}
   114  	if request.User != nil {
   115  		inbound.User.Email = request.User.Email
   116  	}
   117  
   118  	if err := conn.SetReadDeadline(time.Time{}); err != nil {
   119  		newError("failed to clear deadline").Base(err).WriteToLog(session.ExportIDToError(ctx))
   120  	}
   121  
   122  	if request.Command == protocol.RequestCommandTCP {
   123  		dest := request.Destination()
   124  		newError("TCP Connect request to ", dest).WriteToLog(session.ExportIDToError(ctx))
   125  		if inbound != nil && inbound.Source.IsValid() {
   126  			ctx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{
   127  				From:   inbound.Source,
   128  				To:     dest,
   129  				Status: log.AccessAccepted,
   130  				Reason: "",
   131  			})
   132  		}
   133  
   134  		return s.transport(ctx, reader, conn, dest, dispatcher)
   135  	}
   136  
   137  	if request.Command == protocol.RequestCommandUDP {
   138  		return s.handleUDP(conn)
   139  	}
   140  
   141  	return nil
   142  }
   143  
   144  func (*Server) handleUDP(c io.Reader) error {
   145  	// The TCP connection closes after this method returns. We need to wait until
   146  	// the client closes it.
   147  	return common.Error2(io.Copy(buf.DiscardBytes, c))
   148  }
   149  
   150  func (s *Server) transport(ctx context.Context, reader io.Reader, writer io.Writer, dest net.Destination, dispatcher routing.Dispatcher) error {
   151  	ctx, cancel := context.WithCancel(ctx)
   152  	timer := signal.CancelAfterInactivity(ctx, cancel, s.policy().Timeouts.ConnectionIdle)
   153  
   154  	plcy := s.policy()
   155  	ctx = policy.ContextWithBufferPolicy(ctx, plcy.Buffer)
   156  	link, err := dispatcher.Dispatch(ctx, dest)
   157  	if err != nil {
   158  		return err
   159  	}
   160  
   161  	requestDone := func() error {
   162  		defer timer.SetTimeout(plcy.Timeouts.DownlinkOnly)
   163  		if err := buf.Copy(buf.NewReader(reader), link.Writer, buf.UpdateActivity(timer)); err != nil {
   164  			return newError("failed to transport all TCP request").Base(err)
   165  		}
   166  
   167  		return nil
   168  	}
   169  
   170  	responseDone := func() error {
   171  		defer timer.SetTimeout(plcy.Timeouts.UplinkOnly)
   172  
   173  		v2writer := buf.NewWriter(writer)
   174  		if err := buf.Copy(link.Reader, v2writer, buf.UpdateActivity(timer)); err != nil {
   175  			return newError("failed to transport all TCP response").Base(err)
   176  		}
   177  
   178  		return nil
   179  	}
   180  
   181  	requestDonePost := task.OnSuccess(requestDone, task.Close(link.Writer))
   182  	if err := task.Run(ctx, requestDonePost, responseDone); err != nil {
   183  		common.Interrupt(link.Reader)
   184  		common.Interrupt(link.Writer)
   185  		return newError("connection ends").Base(err)
   186  	}
   187  
   188  	return nil
   189  }
   190  
   191  func (s *Server) handleUDPPayload(ctx context.Context, conn internet.Connection, dispatcher routing.Dispatcher) error {
   192  	udpServer := udp.NewDispatcher(dispatcher, func(ctx context.Context, packet *udp_proto.Packet) {
   193  		payload := packet.Payload
   194  		newError("writing back UDP response with ", payload.Len(), " bytes").AtDebug().WriteToLog(session.ExportIDToError(ctx))
   195  
   196  		request := protocol.RequestHeaderFromContext(ctx)
   197  		if request == nil {
   198  			return
   199  		}
   200  		udpMessage, err := EncodeUDPPacket(request, payload.Bytes())
   201  		payload.Release()
   202  
   203  		defer udpMessage.Release()
   204  		if err != nil {
   205  			newError("failed to write UDP response").AtWarning().Base(err).WriteToLog(session.ExportIDToError(ctx))
   206  		}
   207  
   208  		conn.Write(udpMessage.Bytes())
   209  	})
   210  
   211  	if inbound := session.InboundFromContext(ctx); inbound != nil && inbound.Source.IsValid() {
   212  		newError("client UDP connection from ", inbound.Source).WriteToLog(session.ExportIDToError(ctx))
   213  	}
   214  
   215  	reader := buf.NewPacketReader(conn)
   216  	for {
   217  		mpayload, err := reader.ReadMultiBuffer()
   218  		if err != nil {
   219  			return err
   220  		}
   221  
   222  		for _, payload := range mpayload {
   223  			request, err := DecodeUDPPacket(payload)
   224  			if err != nil {
   225  				newError("failed to parse UDP request").Base(err).WriteToLog(session.ExportIDToError(ctx))
   226  				payload.Release()
   227  				continue
   228  			}
   229  
   230  			if payload.IsEmpty() {
   231  				payload.Release()
   232  				continue
   233  			}
   234  			currentPacketCtx := ctx
   235  			newError("send packet to ", request.Destination(), " with ", payload.Len(), " bytes").AtDebug().WriteToLog(session.ExportIDToError(ctx))
   236  			if inbound := session.InboundFromContext(ctx); inbound != nil && inbound.Source.IsValid() {
   237  				currentPacketCtx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{
   238  					From:   inbound.Source,
   239  					To:     request.Destination(),
   240  					Status: log.AccessAccepted,
   241  					Reason: "",
   242  				})
   243  			}
   244  
   245  			currentPacketCtx = protocol.ContextWithRequestHeader(currentPacketCtx, request)
   246  			udpServer.Dispatch(currentPacketCtx, request.Destination(), payload)
   247  		}
   248  	}
   249  }
   250  
   251  func init() {
   252  	common.Must(common.RegisterConfig((*ServerConfig)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
   253  		return NewServer(ctx, config.(*ServerConfig))
   254  	}))
   255  }