github.com/xmplusdev/xmcore@v1.8.11-0.20240412132628-5518b55526af/proxy/shadowsocks/server.go (about)

     1  package shadowsocks
     2  
     3  import (
     4  	"context"
     5  	"time"
     6  
     7  	"github.com/xmplusdev/xmcore/common"
     8  	"github.com/xmplusdev/xmcore/common/buf"
     9  	"github.com/xmplusdev/xmcore/common/log"
    10  	"github.com/xmplusdev/xmcore/common/net"
    11  	"github.com/xmplusdev/xmcore/common/protocol"
    12  	udp_proto "github.com/xmplusdev/xmcore/common/protocol/udp"
    13  	"github.com/xmplusdev/xmcore/common/session"
    14  	"github.com/xmplusdev/xmcore/common/signal"
    15  	"github.com/xmplusdev/xmcore/common/task"
    16  	"github.com/xmplusdev/xmcore/core"
    17  	"github.com/xmplusdev/xmcore/features/policy"
    18  	"github.com/xmplusdev/xmcore/features/routing"
    19  	"github.com/xmplusdev/xmcore/transport/internet/stat"
    20  	"github.com/xmplusdev/xmcore/transport/internet/udp"
    21  )
    22  
    23  type Server struct {
    24  	config        *ServerConfig
    25  	validator     *Validator
    26  	policyManager policy.Manager
    27  	cone          bool
    28  }
    29  
    30  // NewServer create a new Shadowsocks server.
    31  func NewServer(ctx context.Context, config *ServerConfig) (*Server, error) {
    32  	validator := new(Validator)
    33  	for _, user := range config.Users {
    34  		u, err := user.ToMemoryUser()
    35  		if err != nil {
    36  			return nil, newError("failed to get shadowsocks user").Base(err).AtError()
    37  		}
    38  
    39  		if err := validator.Add(u); err != nil {
    40  			return nil, newError("failed to add user").Base(err).AtError()
    41  		}
    42  	}
    43  
    44  	v := core.MustFromContext(ctx)
    45  	s := &Server{
    46  		config:        config,
    47  		validator:     validator,
    48  		policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager),
    49  		cone:          ctx.Value("cone").(bool),
    50  	}
    51  
    52  	return s, nil
    53  }
    54  
    55  // AddUser implements proxy.UserManager.AddUser().
    56  func (s *Server) AddUser(ctx context.Context, u *protocol.MemoryUser) error {
    57  	return s.validator.Add(u)
    58  }
    59  
    60  // RemoveUser implements proxy.UserManager.RemoveUser().
    61  func (s *Server) RemoveUser(ctx context.Context, e string) error {
    62  	return s.validator.Del(e)
    63  }
    64  
    65  func (s *Server) Network() []net.Network {
    66  	list := s.config.Network
    67  	if len(list) == 0 {
    68  		list = append(list, net.Network_TCP)
    69  	}
    70  	return list
    71  }
    72  
    73  func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Connection, dispatcher routing.Dispatcher) error {
    74  	inbound := session.InboundFromContext(ctx)
    75  	inbound.Name = "shadowsocks"
    76  	inbound.SetCanSpliceCopy(3)
    77  	
    78  	switch network {
    79  	case net.Network_TCP:
    80  		return s.handleConnection(ctx, conn, dispatcher)
    81  	case net.Network_UDP:
    82  		return s.handleUDPPayload(ctx, conn, dispatcher)
    83  	default:
    84  		return newError("unknown network: ", network)
    85  	}
    86  }
    87  
    88  func (s *Server) handleUDPPayload(ctx context.Context, conn stat.Connection, dispatcher routing.Dispatcher) error {
    89  	udpServer := udp.NewDispatcher(dispatcher, func(ctx context.Context, packet *udp_proto.Packet) {
    90  		request := protocol.RequestHeaderFromContext(ctx)
    91  		if request == nil {
    92  			return
    93  		}
    94  
    95  		payload := packet.Payload
    96  
    97  		if payload.UDP != nil {
    98  			request = &protocol.RequestHeader{
    99  				User:    request.User,
   100  				Address: payload.UDP.Address,
   101  				Port:    payload.UDP.Port,
   102  			}
   103  		}
   104  
   105  		data, err := EncodeUDPPacket(request, payload.Bytes())
   106  		payload.Release()
   107  		if err != nil {
   108  			newError("failed to encode UDP packet").Base(err).AtWarning().WriteToLog(session.ExportIDToError(ctx))
   109  			return
   110  		}
   111  		defer data.Release()
   112  
   113  		conn.Write(data.Bytes())
   114  	})
   115  
   116  	inbound := session.InboundFromContext(ctx)
   117  	var dest *net.Destination
   118  	reader := buf.NewPacketReader(conn)
   119  	for {
   120  		mpayload, err := reader.ReadMultiBuffer()
   121  		if err != nil {
   122  			break
   123  		}
   124  
   125  		for _, payload := range mpayload {
   126  			var request *protocol.RequestHeader
   127  			var data *buf.Buffer
   128  			var err error
   129  
   130  			if inbound.User != nil {
   131  				validator := new(Validator)
   132  				validator.Add(inbound.User)
   133  				request, data, err = DecodeUDPPacket(validator, payload)
   134  			} else {
   135  				request, data, err = DecodeUDPPacket(s.validator, payload)
   136  				if err == nil {
   137  					inbound.User = request.User
   138  				}
   139  			}
   140  
   141  			if err != nil {
   142  				if inbound.Source.IsValid() {
   143  					newError("dropping invalid UDP packet from: ", inbound.Source).Base(err).WriteToLog(session.ExportIDToError(ctx))
   144  					log.Record(&log.AccessMessage{
   145  						From:   inbound.Source,
   146  						To:     "",
   147  						Status: log.AccessRejected,
   148  						Reason: err,
   149  					})
   150  				}
   151  				payload.Release()
   152  				continue
   153  			}
   154  
   155  			destination := request.Destination()
   156  
   157  			currentPacketCtx := ctx
   158  			if inbound.Source.IsValid() {
   159  				currentPacketCtx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{
   160  					From:   inbound.Source,
   161  					To:     destination,
   162  					Status: log.AccessAccepted,
   163  					Reason: "",
   164  					Email:  request.User.Email,
   165  				})
   166  			}
   167  			newError("tunnelling request to ", destination).WriteToLog(session.ExportIDToError(currentPacketCtx))
   168  
   169  			data.UDP = &destination
   170  
   171  			if !s.cone || dest == nil {
   172  				dest = &destination
   173  			}
   174  
   175  			currentPacketCtx = protocol.ContextWithRequestHeader(currentPacketCtx, request)
   176  			udpServer.Dispatch(currentPacketCtx, *dest, data)
   177  		}
   178  	}
   179  
   180  	return nil
   181  }
   182  
   183  func (s *Server) handleConnection(ctx context.Context, conn stat.Connection, dispatcher routing.Dispatcher) error {
   184  	sessionPolicy := s.policyManager.ForLevel(0)
   185  	if err := conn.SetReadDeadline(time.Now().Add(sessionPolicy.Timeouts.Handshake)); err != nil {
   186  		return newError("unable to set read deadline").Base(err).AtWarning()
   187  	}
   188  
   189  	bufferedReader := buf.BufferedReader{Reader: buf.NewReader(conn)}
   190  	request, bodyReader, err := ReadTCPSession(s.validator, &bufferedReader)
   191  	if err != nil {
   192  		log.Record(&log.AccessMessage{
   193  			From:   conn.RemoteAddr(),
   194  			To:     "",
   195  			Status: log.AccessRejected,
   196  			Reason: err,
   197  		})
   198  		return newError("failed to create request from: ", conn.RemoteAddr()).Base(err)
   199  	}
   200  	conn.SetReadDeadline(time.Time{})
   201  
   202  	inbound := session.InboundFromContext(ctx)
   203  	if inbound == nil {
   204  		panic("no inbound metadata")
   205  	}
   206  	inbound.User = request.User
   207  
   208  	dest := request.Destination()
   209  	ctx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{
   210  		From:   conn.RemoteAddr(),
   211  		To:     dest,
   212  		Status: log.AccessAccepted,
   213  		Reason: "",
   214  		Email:  request.User.Email,
   215  	})
   216  	newError("tunnelling request to ", dest).WriteToLog(session.ExportIDToError(ctx))
   217  
   218  	sessionPolicy = s.policyManager.ForLevel(request.User.Level)
   219  	ctx, cancel := context.WithCancel(ctx)
   220  	timer := signal.CancelAfterInactivity(ctx, cancel, sessionPolicy.Timeouts.ConnectionIdle)
   221  
   222  	ctx = policy.ContextWithBufferPolicy(ctx, sessionPolicy.Buffer)
   223  	link, err := dispatcher.Dispatch(ctx, dest)
   224  	if err != nil {
   225  		return err
   226  	}
   227  
   228  	responseDone := func() error {
   229  		defer timer.SetTimeout(sessionPolicy.Timeouts.UplinkOnly)
   230  
   231  		bufferedWriter := buf.NewBufferedWriter(buf.NewWriter(conn))
   232  		responseWriter, err := WriteTCPResponse(request, bufferedWriter)
   233  		if err != nil {
   234  			return newError("failed to write response").Base(err)
   235  		}
   236  
   237  		{
   238  			payload, err := link.Reader.ReadMultiBuffer()
   239  			if err != nil {
   240  				return err
   241  			}
   242  			if err := responseWriter.WriteMultiBuffer(payload); err != nil {
   243  				return err
   244  			}
   245  		}
   246  
   247  		if err := bufferedWriter.SetBuffered(false); err != nil {
   248  			return err
   249  		}
   250  
   251  		if err := buf.Copy(link.Reader, responseWriter, buf.UpdateActivity(timer)); err != nil {
   252  			return newError("failed to transport all TCP response").Base(err)
   253  		}
   254  
   255  		return nil
   256  	}
   257  
   258  	requestDone := func() error {
   259  		defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly)
   260  
   261  		if err := buf.Copy(bodyReader, link.Writer, buf.UpdateActivity(timer)); err != nil {
   262  			return newError("failed to transport all TCP request").Base(err)
   263  		}
   264  
   265  		return nil
   266  	}
   267  
   268  	requestDoneAndCloseWriter := task.OnSuccess(requestDone, task.Close(link.Writer))
   269  	if err := task.Run(ctx, requestDoneAndCloseWriter, responseDone); err != nil {
   270  		common.Interrupt(link.Reader)
   271  		common.Interrupt(link.Writer)
   272  		return newError("connection ends").Base(err)
   273  	}
   274  
   275  	return nil
   276  }
   277  
   278  func init() {
   279  	common.Must(common.RegisterConfig((*ServerConfig)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
   280  		return NewServer(ctx, config.(*ServerConfig))
   281  	}))
   282  }