github.com/Uhtred009/v2ray-core-1@v4.31.2+incompatible/proxy/socks/server.go (about)

     1  // +build !confonly
     2  
     3  package socks
     4  
     5  import (
     6  	"context"
     7  	"io"
     8  	"time"
     9  
    10  	"v2ray.com/core"
    11  	"v2ray.com/core/common"
    12  	"v2ray.com/core/common/buf"
    13  	"v2ray.com/core/common/log"
    14  	"v2ray.com/core/common/net"
    15  	"v2ray.com/core/common/protocol"
    16  	udp_proto "v2ray.com/core/common/protocol/udp"
    17  	"v2ray.com/core/common/session"
    18  	"v2ray.com/core/common/signal"
    19  	"v2ray.com/core/common/task"
    20  	"v2ray.com/core/features"
    21  	"v2ray.com/core/features/policy"
    22  	"v2ray.com/core/features/routing"
    23  	"v2ray.com/core/transport/internet"
    24  	"v2ray.com/core/transport/internet/udp"
    25  )
    26  
    27  // Server is a SOCKS 5 proxy server
    28  type Server struct {
    29  	config        *ServerConfig
    30  	policyManager policy.Manager
    31  }
    32  
    33  // NewServer creates a new Server object.
    34  func NewServer(ctx context.Context, config *ServerConfig) (*Server, error) {
    35  	v := core.MustFromContext(ctx)
    36  	s := &Server{
    37  		config:        config,
    38  		policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager),
    39  	}
    40  	return s, nil
    41  }
    42  
    43  func (s *Server) policy() policy.Session {
    44  	config := s.config
    45  	p := s.policyManager.ForLevel(config.UserLevel)
    46  	if config.Timeout > 0 {
    47  		features.PrintDeprecatedFeatureWarning("Socks timeout")
    48  	}
    49  	if config.Timeout > 0 && config.UserLevel == 0 {
    50  		p.Timeouts.ConnectionIdle = time.Duration(config.Timeout) * time.Second
    51  	}
    52  	return p
    53  }
    54  
    55  // Network implements proxy.Inbound.
    56  func (s *Server) Network() []net.Network {
    57  	list := []net.Network{net.Network_TCP}
    58  	if s.config.UdpEnabled {
    59  		list = append(list, net.Network_UDP)
    60  	}
    61  	return list
    62  }
    63  
    64  // Process implements proxy.Inbound.
    65  func (s *Server) Process(ctx context.Context, network net.Network, conn internet.Connection, dispatcher routing.Dispatcher) error {
    66  	if inbound := session.InboundFromContext(ctx); inbound != nil {
    67  		inbound.User = &protocol.MemoryUser{
    68  			Level: s.config.UserLevel,
    69  		}
    70  	}
    71  
    72  	switch network {
    73  	case net.Network_TCP:
    74  		return s.processTCP(ctx, conn, dispatcher)
    75  	case net.Network_UDP:
    76  		return s.handleUDPPayload(ctx, conn, dispatcher)
    77  	default:
    78  		return newError("unknown network: ", network)
    79  	}
    80  }
    81  
    82  func (s *Server) processTCP(ctx context.Context, conn internet.Connection, dispatcher routing.Dispatcher) error {
    83  	plcy := s.policy()
    84  	if err := conn.SetReadDeadline(time.Now().Add(plcy.Timeouts.Handshake)); err != nil {
    85  		newError("failed to set deadline").Base(err).WriteToLog(session.ExportIDToError(ctx))
    86  	}
    87  
    88  	inbound := session.InboundFromContext(ctx)
    89  	if inbound == nil || !inbound.Gateway.IsValid() {
    90  		return newError("inbound gateway not specified")
    91  	}
    92  
    93  	svrSession := &ServerSession{
    94  		config: s.config,
    95  		port:   inbound.Gateway.Port,
    96  	}
    97  
    98  	reader := &buf.BufferedReader{Reader: buf.NewReader(conn)}
    99  	request, err := svrSession.Handshake(reader, conn)
   100  	if err != nil {
   101  		if inbound != nil && inbound.Source.IsValid() {
   102  			log.Record(&log.AccessMessage{
   103  				From:   inbound.Source,
   104  				To:     "",
   105  				Status: log.AccessRejected,
   106  				Reason: err,
   107  			})
   108  		}
   109  		return newError("failed to read request").Base(err)
   110  	}
   111  	if request.User != nil {
   112  		inbound.User.Email = request.User.Email
   113  	}
   114  
   115  	if err := conn.SetReadDeadline(time.Time{}); err != nil {
   116  		newError("failed to clear deadline").Base(err).WriteToLog(session.ExportIDToError(ctx))
   117  	}
   118  
   119  	if request.Command == protocol.RequestCommandTCP {
   120  		dest := request.Destination()
   121  		newError("TCP Connect request to ", dest).WriteToLog(session.ExportIDToError(ctx))
   122  		if inbound != nil && inbound.Source.IsValid() {
   123  			ctx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{
   124  				From:   inbound.Source,
   125  				To:     dest,
   126  				Status: log.AccessAccepted,
   127  				Reason: "",
   128  			})
   129  		}
   130  
   131  		return s.transport(ctx, reader, conn, dest, dispatcher)
   132  	}
   133  
   134  	if request.Command == protocol.RequestCommandUDP {
   135  		return s.handleUDP(conn)
   136  	}
   137  
   138  	return nil
   139  }
   140  
   141  func (*Server) handleUDP(c io.Reader) error {
   142  	// The TCP connection closes after this method returns. We need to wait until
   143  	// the client closes it.
   144  	return common.Error2(io.Copy(buf.DiscardBytes, c))
   145  }
   146  
   147  func (s *Server) transport(ctx context.Context, reader io.Reader, writer io.Writer, dest net.Destination, dispatcher routing.Dispatcher) error {
   148  	ctx, cancel := context.WithCancel(ctx)
   149  	timer := signal.CancelAfterInactivity(ctx, cancel, s.policy().Timeouts.ConnectionIdle)
   150  
   151  	plcy := s.policy()
   152  	ctx = policy.ContextWithBufferPolicy(ctx, plcy.Buffer)
   153  	link, err := dispatcher.Dispatch(ctx, dest)
   154  	if err != nil {
   155  		return err
   156  	}
   157  
   158  	requestDone := func() error {
   159  		defer timer.SetTimeout(plcy.Timeouts.DownlinkOnly)
   160  		if err := buf.Copy(buf.NewReader(reader), link.Writer, buf.UpdateActivity(timer)); err != nil {
   161  			return newError("failed to transport all TCP request").Base(err)
   162  		}
   163  
   164  		return nil
   165  	}
   166  
   167  	responseDone := func() error {
   168  		defer timer.SetTimeout(plcy.Timeouts.UplinkOnly)
   169  
   170  		v2writer := buf.NewWriter(writer)
   171  		if err := buf.Copy(link.Reader, v2writer, buf.UpdateActivity(timer)); err != nil {
   172  			return newError("failed to transport all TCP response").Base(err)
   173  		}
   174  
   175  		return nil
   176  	}
   177  
   178  	var requestDonePost = task.OnSuccess(requestDone, task.Close(link.Writer))
   179  	if err := task.Run(ctx, requestDonePost, responseDone); err != nil {
   180  		common.Interrupt(link.Reader)
   181  		common.Interrupt(link.Writer)
   182  		return newError("connection ends").Base(err)
   183  	}
   184  
   185  	return nil
   186  }
   187  
   188  func (s *Server) handleUDPPayload(ctx context.Context, conn internet.Connection, dispatcher routing.Dispatcher) error {
   189  	udpServer := udp.NewDispatcher(dispatcher, func(ctx context.Context, packet *udp_proto.Packet) {
   190  		payload := packet.Payload
   191  		newError("writing back UDP response with ", payload.Len(), " bytes").AtDebug().WriteToLog(session.ExportIDToError(ctx))
   192  
   193  		request := protocol.RequestHeaderFromContext(ctx)
   194  		if request == nil {
   195  			return
   196  		}
   197  		udpMessage, err := EncodeUDPPacket(request, payload.Bytes())
   198  		payload.Release()
   199  
   200  		defer udpMessage.Release()
   201  		if err != nil {
   202  			newError("failed to write UDP response").AtWarning().Base(err).WriteToLog(session.ExportIDToError(ctx))
   203  		}
   204  
   205  		conn.Write(udpMessage.Bytes()) // nolint: errcheck
   206  	})
   207  
   208  	if inbound := session.InboundFromContext(ctx); inbound != nil && inbound.Source.IsValid() {
   209  		newError("client UDP connection from ", inbound.Source).WriteToLog(session.ExportIDToError(ctx))
   210  	}
   211  
   212  	reader := buf.NewPacketReader(conn)
   213  	for {
   214  		mpayload, err := reader.ReadMultiBuffer()
   215  		if err != nil {
   216  			return err
   217  		}
   218  
   219  		for _, payload := range mpayload {
   220  			request, err := DecodeUDPPacket(payload)
   221  
   222  			if err != nil {
   223  				newError("failed to parse UDP request").Base(err).WriteToLog(session.ExportIDToError(ctx))
   224  				payload.Release()
   225  				continue
   226  			}
   227  
   228  			if payload.IsEmpty() {
   229  				payload.Release()
   230  				continue
   231  			}
   232  			currentPacketCtx := ctx
   233  			newError("send packet to ", request.Destination(), " with ", payload.Len(), " bytes").AtDebug().WriteToLog(session.ExportIDToError(ctx))
   234  			if inbound := session.InboundFromContext(ctx); inbound != nil && inbound.Source.IsValid() {
   235  				currentPacketCtx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{
   236  					From:   inbound.Source,
   237  					To:     request.Destination(),
   238  					Status: log.AccessAccepted,
   239  					Reason: "",
   240  				})
   241  			}
   242  
   243  			currentPacketCtx = protocol.ContextWithRequestHeader(currentPacketCtx, request)
   244  			udpServer.Dispatch(currentPacketCtx, request.Destination(), payload)
   245  		}
   246  	}
   247  }
   248  
   249  func init() {
   250  	common.Must(common.RegisterConfig((*ServerConfig)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
   251  		return NewServer(ctx, config.(*ServerConfig))
   252  	}))
   253  }