github.com/xmplusdev/xmcore@v1.8.11-0.20240412132628-5518b55526af/proxy/socks/server.go (about)

     1  package socks
     2  
     3  import (
     4  	"context"
     5  	"io"
     6  	"time"
     7  
     8  	"github.com/xmplusdev/xmcore/common"
     9  	"github.com/xmplusdev/xmcore/common/buf"
    10  	"github.com/xmplusdev/xmcore/common/log"
    11  	"github.com/xmplusdev/xmcore/common/net"
    12  	"github.com/xmplusdev/xmcore/common/protocol"
    13  	udp_proto "github.com/xmplusdev/xmcore/common/protocol/udp"
    14  	"github.com/xmplusdev/xmcore/common/session"
    15  	"github.com/xmplusdev/xmcore/common/signal"
    16  	"github.com/xmplusdev/xmcore/common/task"
    17  	"github.com/xmplusdev/xmcore/core"
    18  	"github.com/xmplusdev/xmcore/features"
    19  	"github.com/xmplusdev/xmcore/features/policy"
    20  	"github.com/xmplusdev/xmcore/features/routing"
    21  	"github.com/xmplusdev/xmcore/transport/internet/stat"
    22  	"github.com/xmplusdev/xmcore/transport/internet/udp"
    23  )
    24  
    25  // Server is a SOCKS 5 proxy server
    26  type Server struct {
    27  	config        *ServerConfig
    28  	policyManager policy.Manager
    29  	cone          bool
    30  }
    31  
    32  // NewServer creates a new Server object.
    33  func NewServer(ctx context.Context, config *ServerConfig) (*Server, error) {
    34  	v := core.MustFromContext(ctx)
    35  	s := &Server{
    36  		config:        config,
    37  		policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager),
    38  		cone:          ctx.Value("cone").(bool),
    39  	}
    40  	return s, nil
    41  }
    42  
    43  func (s *Server) policy() policy.Session {
    44  	config := s.config
    45  	p := s.policyManager.ForLevel(config.UserLevel)
    46  	if config.Timeout > 0 {
    47  		features.PrintDeprecatedFeatureWarning("Socks timeout")
    48  	}
    49  	if config.Timeout > 0 && config.UserLevel == 0 {
    50  		p.Timeouts.ConnectionIdle = time.Duration(config.Timeout) * time.Second
    51  	}
    52  	return p
    53  }
    54  
    55  // Network implements proxy.Inbound.
    56  func (s *Server) Network() []net.Network {
    57  	list := []net.Network{net.Network_TCP}
    58  	if s.config.UdpEnabled {
    59  		list = append(list, net.Network_UDP)
    60  	}
    61  	return list
    62  }
    63  
    64  // Process implements proxy.Inbound.
    65  func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Connection, dispatcher routing.Dispatcher) error {
    66  	inbound := session.InboundFromContext(ctx)
    67  	inbound.Name = "socks"
    68  	inbound.SetCanSpliceCopy(2)
    69  	inbound.User = &protocol.MemoryUser{
    70  		Level: s.config.UserLevel,
    71  	}
    72  
    73  	switch network {
    74  	case net.Network_TCP:
    75  		return s.processTCP(ctx, conn, dispatcher)
    76  	case net.Network_UDP:
    77  		return s.handleUDPPayload(ctx, conn, dispatcher)
    78  	default:
    79  		return newError("unknown network: ", network)
    80  	}
    81  }
    82  
    83  func (s *Server) processTCP(ctx context.Context, conn stat.Connection, dispatcher routing.Dispatcher) error {
    84  	plcy := s.policy()
    85  	if err := conn.SetReadDeadline(time.Now().Add(plcy.Timeouts.Handshake)); err != nil {
    86  		newError("failed to set deadline").Base(err).WriteToLog(session.ExportIDToError(ctx))
    87  	}
    88  
    89  	inbound := session.InboundFromContext(ctx)
    90  	if inbound == nil || !inbound.Gateway.IsValid() {
    91  		return newError("inbound gateway not specified")
    92  	}
    93  
    94  	svrSession := &ServerSession{
    95  		config:       s.config,
    96  		address:      inbound.Gateway.Address,
    97  		port:         inbound.Gateway.Port,
    98  		localAddress: net.IPAddress(conn.LocalAddr().(*net.TCPAddr).IP),
    99  	}
   100  
   101  	reader := &buf.BufferedReader{Reader: buf.NewReader(conn)}
   102  	request, err := svrSession.Handshake(reader, conn)
   103  	if err != nil {
   104  		if inbound.Source.IsValid() {
   105  			log.Record(&log.AccessMessage{
   106  				From:   inbound.Source,
   107  				To:     "",
   108  				Status: log.AccessRejected,
   109  				Reason: err,
   110  			})
   111  		}
   112  		return newError("failed to read request").Base(err)
   113  	}
   114  	if request.User != nil {
   115  		inbound.User.Email = request.User.Email
   116  	}
   117  
   118  	if err := conn.SetReadDeadline(time.Time{}); err != nil {
   119  		newError("failed to clear deadline").Base(err).WriteToLog(session.ExportIDToError(ctx))
   120  	}
   121  
   122  	if request.Command == protocol.RequestCommandTCP {
   123  		dest := request.Destination()
   124  		newError("TCP Connect request to ", dest).WriteToLog(session.ExportIDToError(ctx))
   125  		if inbound.Source.IsValid() {
   126  			ctx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{
   127  				From:   inbound.Source,
   128  				To:     dest,
   129  				Status: log.AccessAccepted,
   130  				Reason: "",
   131  			})
   132  		}
   133  
   134  		return s.transport(ctx, reader, conn, dest, dispatcher, inbound)
   135  	}
   136  
   137  	if request.Command == protocol.RequestCommandUDP {
   138  		return s.handleUDP(conn)
   139  	}
   140  
   141  	return nil
   142  }
   143  
   144  func (*Server) handleUDP(c io.Reader) error {
   145  	// The TCP connection closes after this method returns. We need to wait until
   146  	// the client closes it.
   147  	return common.Error2(io.Copy(buf.DiscardBytes, c))
   148  }
   149  
   150  func (s *Server) transport(ctx context.Context, reader io.Reader, writer io.Writer, dest net.Destination, dispatcher routing.Dispatcher, inbound *session.Inbound) error {
   151  	ctx, cancel := context.WithCancel(ctx)
   152  	timer := signal.CancelAfterInactivity(ctx, cancel, s.policy().Timeouts.ConnectionIdle)
   153  
   154  	if inbound != nil {
   155  		inbound.Timer = timer
   156  	}
   157  
   158  	plcy := s.policy()
   159  	ctx = policy.ContextWithBufferPolicy(ctx, plcy.Buffer)
   160  	link, err := dispatcher.Dispatch(ctx, dest)
   161  	if err != nil {
   162  		return err
   163  	}
   164  
   165  	requestDone := func() error {
   166  		defer timer.SetTimeout(plcy.Timeouts.DownlinkOnly)
   167  		if err := buf.Copy(buf.NewReader(reader), link.Writer, buf.UpdateActivity(timer)); err != nil {
   168  			return newError("failed to transport all TCP request").Base(err)
   169  		}
   170  
   171  		return nil
   172  	}
   173  
   174  	responseDone := func() error {
   175  		defer timer.SetTimeout(plcy.Timeouts.UplinkOnly)
   176  
   177  		v2writer := buf.NewWriter(writer)
   178  		if err := buf.Copy(link.Reader, v2writer, buf.UpdateActivity(timer)); err != nil {
   179  			return newError("failed to transport all TCP response").Base(err)
   180  		}
   181  
   182  		return nil
   183  	}
   184  
   185  	requestDonePost := task.OnSuccess(requestDone, task.Close(link.Writer))
   186  	if err := task.Run(ctx, requestDonePost, responseDone); err != nil {
   187  		common.Interrupt(link.Reader)
   188  		common.Interrupt(link.Writer)
   189  		return newError("connection ends").Base(err)
   190  	}
   191  
   192  	return nil
   193  }
   194  
   195  func (s *Server) handleUDPPayload(ctx context.Context, conn stat.Connection, dispatcher routing.Dispatcher) error {
   196  	udpServer := udp.NewDispatcher(dispatcher, func(ctx context.Context, packet *udp_proto.Packet) {
   197  		payload := packet.Payload
   198  		newError("writing back UDP response with ", payload.Len(), " bytes").AtDebug().WriteToLog(session.ExportIDToError(ctx))
   199  
   200  		request := protocol.RequestHeaderFromContext(ctx)
   201  		if request == nil {
   202  			return
   203  		}
   204  
   205  		if payload.UDP != nil {
   206  			request = &protocol.RequestHeader{
   207  				User:    request.User,
   208  				Address: payload.UDP.Address,
   209  				Port:    payload.UDP.Port,
   210  			}
   211  		}
   212  
   213  		udpMessage, err := EncodeUDPPacket(request, payload.Bytes())
   214  		payload.Release()
   215  
   216  		defer udpMessage.Release()
   217  		if err != nil {
   218  			newError("failed to write UDP response").AtWarning().Base(err).WriteToLog(session.ExportIDToError(ctx))
   219  		}
   220  
   221  		conn.Write(udpMessage.Bytes())
   222  	})
   223  
   224  	inbound := session.InboundFromContext(ctx)
   225  	if inbound != nil && inbound.Source.IsValid() {
   226  		newError("client UDP connection from ", inbound.Source).WriteToLog(session.ExportIDToError(ctx))
   227  	}
   228  
   229  	var dest *net.Destination
   230  
   231  	reader := buf.NewPacketReader(conn)
   232  	for {
   233  		mpayload, err := reader.ReadMultiBuffer()
   234  		if err != nil {
   235  			return err
   236  		}
   237  
   238  		for _, payload := range mpayload {
   239  			request, err := DecodeUDPPacket(payload)
   240  			if err != nil {
   241  				newError("failed to parse UDP request").Base(err).WriteToLog(session.ExportIDToError(ctx))
   242  				payload.Release()
   243  				continue
   244  			}
   245  
   246  			if payload.IsEmpty() {
   247  				payload.Release()
   248  				continue
   249  			}
   250  
   251  			destination := request.Destination()
   252  
   253  			currentPacketCtx := ctx
   254  			newError("send packet to ", destination, " with ", payload.Len(), " bytes").AtDebug().WriteToLog(session.ExportIDToError(ctx))
   255  			if inbound != nil && inbound.Source.IsValid() {
   256  				currentPacketCtx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{
   257  					From:   inbound.Source,
   258  					To:     destination,
   259  					Status: log.AccessAccepted,
   260  					Reason: "",
   261  				})
   262  			}
   263  
   264  			payload.UDP = &destination
   265  
   266  			if !s.cone || dest == nil {
   267  				dest = &destination
   268  			}
   269  
   270  			currentPacketCtx = protocol.ContextWithRequestHeader(currentPacketCtx, request)
   271  			udpServer.Dispatch(currentPacketCtx, *dest, payload)
   272  		}
   273  	}
   274  }
   275  
   276  func init() {
   277  	common.Must(common.RegisterConfig((*ServerConfig)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
   278  		return NewServer(ctx, config.(*ServerConfig))
   279  	}))
   280  }