github.com/moqsien/xraycore@v1.8.5/proxy/shadowsocks/server.go (about)

     1  package shadowsocks
     2  
     3  import (
     4  	"context"
     5  	"time"
     6  
     7  	"github.com/moqsien/xraycore/common"
     8  	"github.com/moqsien/xraycore/common/buf"
     9  	"github.com/moqsien/xraycore/common/log"
    10  	"github.com/moqsien/xraycore/common/net"
    11  	"github.com/moqsien/xraycore/common/protocol"
    12  	udp_proto "github.com/moqsien/xraycore/common/protocol/udp"
    13  	"github.com/moqsien/xraycore/common/session"
    14  	"github.com/moqsien/xraycore/common/signal"
    15  	"github.com/moqsien/xraycore/common/task"
    16  	"github.com/moqsien/xraycore/core"
    17  	"github.com/moqsien/xraycore/features/policy"
    18  	"github.com/moqsien/xraycore/features/routing"
    19  	"github.com/moqsien/xraycore/transport/internet/stat"
    20  	"github.com/moqsien/xraycore/transport/internet/udp"
    21  )
    22  
    23  type Server struct {
    24  	config        *ServerConfig
    25  	validator     *Validator
    26  	policyManager policy.Manager
    27  	cone          bool
    28  }
    29  
    30  // NewServer create a new Shadowsocks server.
    31  func NewServer(ctx context.Context, config *ServerConfig) (*Server, error) {
    32  	validator := new(Validator)
    33  	for _, user := range config.Users {
    34  		u, err := user.ToMemoryUser()
    35  		if err != nil {
    36  			return nil, newError("failed to get shadowsocks user").Base(err).AtError()
    37  		}
    38  
    39  		if err := validator.Add(u); err != nil {
    40  			return nil, newError("failed to add user").Base(err).AtError()
    41  		}
    42  	}
    43  
    44  	v := core.MustFromContext(ctx)
    45  	s := &Server{
    46  		config:        config,
    47  		validator:     validator,
    48  		policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager),
    49  		cone:          ctx.Value("cone").(bool),
    50  	}
    51  
    52  	return s, nil
    53  }
    54  
    55  // AddUser implements proxy.UserManager.AddUser().
    56  func (s *Server) AddUser(ctx context.Context, u *protocol.MemoryUser) error {
    57  	return s.validator.Add(u)
    58  }
    59  
    60  // RemoveUser implements proxy.UserManager.RemoveUser().
    61  func (s *Server) RemoveUser(ctx context.Context, e string) error {
    62  	return s.validator.Del(e)
    63  }
    64  
    65  func (s *Server) Network() []net.Network {
    66  	list := s.config.Network
    67  	if len(list) == 0 {
    68  		list = append(list, net.Network_TCP)
    69  	}
    70  	return list
    71  }
    72  
    73  func (s *Server) Process(ctx context.Context, network net.Network, conn stat.Connection, dispatcher routing.Dispatcher) error {
    74  	switch network {
    75  	case net.Network_TCP:
    76  		return s.handleConnection(ctx, conn, dispatcher)
    77  	case net.Network_UDP:
    78  		return s.handleUDPPayload(ctx, conn, dispatcher)
    79  	default:
    80  		return newError("unknown network: ", network)
    81  	}
    82  }
    83  
    84  func (s *Server) handleUDPPayload(ctx context.Context, conn stat.Connection, dispatcher routing.Dispatcher) error {
    85  	udpServer := udp.NewDispatcher(dispatcher, func(ctx context.Context, packet *udp_proto.Packet) {
    86  		request := protocol.RequestHeaderFromContext(ctx)
    87  		if request == nil {
    88  			return
    89  		}
    90  
    91  		payload := packet.Payload
    92  
    93  		if payload.UDP != nil {
    94  			request = &protocol.RequestHeader{
    95  				User:    request.User,
    96  				Address: payload.UDP.Address,
    97  				Port:    payload.UDP.Port,
    98  			}
    99  		}
   100  
   101  		data, err := EncodeUDPPacket(request, payload.Bytes())
   102  		payload.Release()
   103  		if err != nil {
   104  			newError("failed to encode UDP packet").Base(err).AtWarning().WriteToLog(session.ExportIDToError(ctx))
   105  			return
   106  		}
   107  		defer data.Release()
   108  
   109  		conn.Write(data.Bytes())
   110  	})
   111  
   112  	inbound := session.InboundFromContext(ctx)
   113  	if inbound == nil {
   114  		panic("no inbound metadata")
   115  	}
   116  	inbound.Name = "shadowsocks"
   117  
   118  	var dest *net.Destination
   119  
   120  	reader := buf.NewPacketReader(conn)
   121  	for {
   122  		mpayload, err := reader.ReadMultiBuffer()
   123  		if err != nil {
   124  			break
   125  		}
   126  
   127  		for _, payload := range mpayload {
   128  			var request *protocol.RequestHeader
   129  			var data *buf.Buffer
   130  			var err error
   131  
   132  			if inbound.User != nil {
   133  				validator := new(Validator)
   134  				validator.Add(inbound.User)
   135  				request, data, err = DecodeUDPPacket(validator, payload)
   136  			} else {
   137  				request, data, err = DecodeUDPPacket(s.validator, payload)
   138  				if err == nil {
   139  					inbound.User = request.User
   140  				}
   141  			}
   142  
   143  			if err != nil {
   144  				if inbound.Source.IsValid() {
   145  					newError("dropping invalid UDP packet from: ", inbound.Source).Base(err).WriteToLog(session.ExportIDToError(ctx))
   146  					log.Record(&log.AccessMessage{
   147  						From:   inbound.Source,
   148  						To:     "",
   149  						Status: log.AccessRejected,
   150  						Reason: err,
   151  					})
   152  				}
   153  				payload.Release()
   154  				continue
   155  			}
   156  
   157  			destination := request.Destination()
   158  
   159  			currentPacketCtx := ctx
   160  			if inbound.Source.IsValid() {
   161  				currentPacketCtx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{
   162  					From:   inbound.Source,
   163  					To:     destination,
   164  					Status: log.AccessAccepted,
   165  					Reason: "",
   166  					Email:  request.User.Email,
   167  				})
   168  			}
   169  			newError("tunnelling request to ", destination).WriteToLog(session.ExportIDToError(currentPacketCtx))
   170  
   171  			data.UDP = &destination
   172  
   173  			if !s.cone || dest == nil {
   174  				dest = &destination
   175  			}
   176  
   177  			currentPacketCtx = protocol.ContextWithRequestHeader(currentPacketCtx, request)
   178  			udpServer.Dispatch(currentPacketCtx, *dest, data)
   179  		}
   180  	}
   181  
   182  	return nil
   183  }
   184  
   185  func (s *Server) handleConnection(ctx context.Context, conn stat.Connection, dispatcher routing.Dispatcher) error {
   186  	sessionPolicy := s.policyManager.ForLevel(0)
   187  	if err := conn.SetReadDeadline(time.Now().Add(sessionPolicy.Timeouts.Handshake)); err != nil {
   188  		return newError("unable to set read deadline").Base(err).AtWarning()
   189  	}
   190  
   191  	bufferedReader := buf.BufferedReader{Reader: buf.NewReader(conn)}
   192  	request, bodyReader, err := ReadTCPSession(s.validator, &bufferedReader)
   193  	if err != nil {
   194  		log.Record(&log.AccessMessage{
   195  			From:   conn.RemoteAddr(),
   196  			To:     "",
   197  			Status: log.AccessRejected,
   198  			Reason: err,
   199  		})
   200  		return newError("failed to create request from: ", conn.RemoteAddr()).Base(err)
   201  	}
   202  	conn.SetReadDeadline(time.Time{})
   203  
   204  	inbound := session.InboundFromContext(ctx)
   205  	if inbound == nil {
   206  		panic("no inbound metadata")
   207  	}
   208  	inbound.User = request.User
   209  
   210  	dest := request.Destination()
   211  	ctx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{
   212  		From:   conn.RemoteAddr(),
   213  		To:     dest,
   214  		Status: log.AccessAccepted,
   215  		Reason: "",
   216  		Email:  request.User.Email,
   217  	})
   218  	newError("tunnelling request to ", dest).WriteToLog(session.ExportIDToError(ctx))
   219  
   220  	sessionPolicy = s.policyManager.ForLevel(request.User.Level)
   221  	ctx, cancel := context.WithCancel(ctx)
   222  	timer := signal.CancelAfterInactivity(ctx, cancel, sessionPolicy.Timeouts.ConnectionIdle)
   223  
   224  	ctx = policy.ContextWithBufferPolicy(ctx, sessionPolicy.Buffer)
   225  	link, err := dispatcher.Dispatch(ctx, dest)
   226  	if err != nil {
   227  		return err
   228  	}
   229  
   230  	responseDone := func() error {
   231  		defer timer.SetTimeout(sessionPolicy.Timeouts.UplinkOnly)
   232  
   233  		bufferedWriter := buf.NewBufferedWriter(buf.NewWriter(conn))
   234  		responseWriter, err := WriteTCPResponse(request, bufferedWriter)
   235  		if err != nil {
   236  			return newError("failed to write response").Base(err)
   237  		}
   238  
   239  		{
   240  			payload, err := link.Reader.ReadMultiBuffer()
   241  			if err != nil {
   242  				return err
   243  			}
   244  			if err := responseWriter.WriteMultiBuffer(payload); err != nil {
   245  				return err
   246  			}
   247  		}
   248  
   249  		if err := bufferedWriter.SetBuffered(false); err != nil {
   250  			return err
   251  		}
   252  
   253  		if err := buf.Copy(link.Reader, responseWriter, buf.UpdateActivity(timer)); err != nil {
   254  			return newError("failed to transport all TCP response").Base(err)
   255  		}
   256  
   257  		return nil
   258  	}
   259  
   260  	requestDone := func() error {
   261  		defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly)
   262  
   263  		if err := buf.Copy(bodyReader, link.Writer, buf.UpdateActivity(timer)); err != nil {
   264  			return newError("failed to transport all TCP request").Base(err)
   265  		}
   266  
   267  		return nil
   268  	}
   269  
   270  	requestDoneAndCloseWriter := task.OnSuccess(requestDone, task.Close(link.Writer))
   271  	if err := task.Run(ctx, requestDoneAndCloseWriter, responseDone); err != nil {
   272  		common.Interrupt(link.Reader)
   273  		common.Interrupt(link.Writer)
   274  		return newError("connection ends").Base(err)
   275  	}
   276  
   277  	return nil
   278  }
   279  
   280  func init() {
   281  	common.Must(common.RegisterConfig((*ServerConfig)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
   282  		return NewServer(ctx, config.(*ServerConfig))
   283  	}))
   284  }