github.com/imannamdari/v2ray-core/v5@v5.0.5/proxy/shadowsocks/server.go (about)

     1  package shadowsocks
     2  
     3  import (
     4  	"context"
     5  	"time"
     6  
     7  	core "github.com/imannamdari/v2ray-core/v5"
     8  	"github.com/imannamdari/v2ray-core/v5/common"
     9  	"github.com/imannamdari/v2ray-core/v5/common/buf"
    10  	"github.com/imannamdari/v2ray-core/v5/common/log"
    11  	"github.com/imannamdari/v2ray-core/v5/common/net"
    12  	"github.com/imannamdari/v2ray-core/v5/common/net/packetaddr"
    13  	"github.com/imannamdari/v2ray-core/v5/common/protocol"
    14  	udp_proto "github.com/imannamdari/v2ray-core/v5/common/protocol/udp"
    15  	"github.com/imannamdari/v2ray-core/v5/common/session"
    16  	"github.com/imannamdari/v2ray-core/v5/common/signal"
    17  	"github.com/imannamdari/v2ray-core/v5/common/task"
    18  	"github.com/imannamdari/v2ray-core/v5/features/policy"
    19  	"github.com/imannamdari/v2ray-core/v5/features/routing"
    20  	"github.com/imannamdari/v2ray-core/v5/transport/internet"
    21  	"github.com/imannamdari/v2ray-core/v5/transport/internet/udp"
    22  )
    23  
    24  type Server struct {
    25  	config        *ServerConfig
    26  	user          *protocol.MemoryUser
    27  	policyManager policy.Manager
    28  }
    29  
    30  // NewServer create a new Shadowsocks server.
    31  func NewServer(ctx context.Context, config *ServerConfig) (*Server, error) {
    32  	if config.GetUser() == nil {
    33  		return nil, newError("user is not specified")
    34  	}
    35  
    36  	mUser, err := config.User.ToMemoryUser()
    37  	if err != nil {
    38  		return nil, newError("failed to parse user account").Base(err)
    39  	}
    40  
    41  	v := core.MustFromContext(ctx)
    42  	s := &Server{
    43  		config:        config,
    44  		user:          mUser,
    45  		policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager),
    46  	}
    47  
    48  	return s, nil
    49  }
    50  
    51  func (s *Server) Network() []net.Network {
    52  	list := s.config.Network
    53  	if len(list) == 0 {
    54  		list = append(list, net.Network_TCP)
    55  	}
    56  	if s.config.UdpEnabled {
    57  		list = append(list, net.Network_UDP)
    58  	}
    59  	return list
    60  }
    61  
    62  func (s *Server) Process(ctx context.Context, network net.Network, conn internet.Connection, dispatcher routing.Dispatcher) error {
    63  	switch network {
    64  	case net.Network_TCP:
    65  		return s.handleConnection(ctx, conn, dispatcher)
    66  	case net.Network_UDP:
    67  		return s.handlerUDPPayload(ctx, conn, dispatcher)
    68  	default:
    69  		return newError("unknown network: ", network)
    70  	}
    71  }
    72  
    73  func (s *Server) handlerUDPPayload(ctx context.Context, conn internet.Connection, dispatcher routing.Dispatcher) error {
    74  	udpDispatcherConstructor := udp.NewSplitDispatcher
    75  	switch s.config.PacketEncoding {
    76  	case packetaddr.PacketAddrType_None:
    77  		break
    78  	case packetaddr.PacketAddrType_Packet:
    79  		packetAddrDispatcherFactory := udp.NewPacketAddrDispatcherCreator(ctx)
    80  		udpDispatcherConstructor = packetAddrDispatcherFactory.NewPacketAddrDispatcher
    81  	}
    82  
    83  	udpServer := udpDispatcherConstructor(dispatcher, func(ctx context.Context, packet *udp_proto.Packet) {
    84  		request := protocol.RequestHeaderFromContext(ctx)
    85  		if request == nil {
    86  			request = &protocol.RequestHeader{
    87  				Port:    packet.Source.Port,
    88  				Address: packet.Source.Address,
    89  				User:    s.user,
    90  			}
    91  		}
    92  
    93  		payload := packet.Payload
    94  		data, err := EncodeUDPPacket(request, payload.Bytes())
    95  		payload.Release()
    96  		if err != nil {
    97  			newError("failed to encode UDP packet").Base(err).AtWarning().WriteToLog(session.ExportIDToError(ctx))
    98  			return
    99  		}
   100  		defer data.Release()
   101  
   102  		conn.Write(data.Bytes())
   103  	})
   104  
   105  	inbound := session.InboundFromContext(ctx)
   106  	if inbound == nil {
   107  		panic("no inbound metadata")
   108  	}
   109  	inbound.User = s.user
   110  
   111  	reader := buf.NewPacketReader(conn)
   112  	for {
   113  		mpayload, err := reader.ReadMultiBuffer()
   114  		if err != nil {
   115  			break
   116  		}
   117  
   118  		for _, payload := range mpayload {
   119  			request, data, err := DecodeUDPPacket(s.user, payload)
   120  			if err != nil {
   121  				if inbound := session.InboundFromContext(ctx); inbound != nil && inbound.Source.IsValid() {
   122  					newError("dropping invalid UDP packet from: ", inbound.Source).Base(err).WriteToLog(session.ExportIDToError(ctx))
   123  					log.Record(&log.AccessMessage{
   124  						From:   inbound.Source,
   125  						To:     "",
   126  						Status: log.AccessRejected,
   127  						Reason: err,
   128  					})
   129  				}
   130  				payload.Release()
   131  				continue
   132  			}
   133  
   134  			currentPacketCtx := ctx
   135  			dest := request.Destination()
   136  			if inbound.Source.IsValid() {
   137  				currentPacketCtx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{
   138  					From:   inbound.Source,
   139  					To:     dest,
   140  					Status: log.AccessAccepted,
   141  					Reason: "",
   142  					Email:  request.User.Email,
   143  				})
   144  			}
   145  			newError("tunnelling request to ", dest).WriteToLog(session.ExportIDToError(currentPacketCtx))
   146  
   147  			currentPacketCtx = protocol.ContextWithRequestHeader(currentPacketCtx, request)
   148  			udpServer.Dispatch(currentPacketCtx, dest, data)
   149  		}
   150  	}
   151  
   152  	return nil
   153  }
   154  
   155  func (s *Server) handleConnection(ctx context.Context, conn internet.Connection, dispatcher routing.Dispatcher) error {
   156  	sessionPolicy := s.policyManager.ForLevel(s.user.Level)
   157  	conn.SetReadDeadline(time.Now().Add(sessionPolicy.Timeouts.Handshake))
   158  
   159  	bufferedReader := buf.BufferedReader{Reader: buf.NewReader(conn)}
   160  	request, bodyReader, err := ReadTCPSession(s.user, &bufferedReader)
   161  	if err != nil {
   162  		log.Record(&log.AccessMessage{
   163  			From:   conn.RemoteAddr(),
   164  			To:     "",
   165  			Status: log.AccessRejected,
   166  			Reason: err,
   167  		})
   168  		return newError("failed to create request from: ", conn.RemoteAddr()).Base(err)
   169  	}
   170  	conn.SetReadDeadline(time.Time{})
   171  
   172  	inbound := session.InboundFromContext(ctx)
   173  	if inbound == nil {
   174  		panic("no inbound metadata")
   175  	}
   176  	inbound.User = s.user
   177  
   178  	dest := request.Destination()
   179  	ctx = log.ContextWithAccessMessage(ctx, &log.AccessMessage{
   180  		From:   conn.RemoteAddr(),
   181  		To:     dest,
   182  		Status: log.AccessAccepted,
   183  		Reason: "",
   184  		Email:  request.User.Email,
   185  	})
   186  	newError("tunnelling request to ", dest).WriteToLog(session.ExportIDToError(ctx))
   187  
   188  	ctx, cancel := context.WithCancel(ctx)
   189  	timer := signal.CancelAfterInactivity(ctx, cancel, sessionPolicy.Timeouts.ConnectionIdle)
   190  
   191  	ctx = policy.ContextWithBufferPolicy(ctx, sessionPolicy.Buffer)
   192  	link, err := dispatcher.Dispatch(ctx, dest)
   193  	if err != nil {
   194  		return err
   195  	}
   196  
   197  	responseDone := func() error {
   198  		defer timer.SetTimeout(sessionPolicy.Timeouts.UplinkOnly)
   199  
   200  		bufferedWriter := buf.NewBufferedWriter(buf.NewWriter(conn))
   201  		responseWriter, err := WriteTCPResponse(request, bufferedWriter)
   202  		if err != nil {
   203  			return newError("failed to write response").Base(err)
   204  		}
   205  
   206  		{
   207  			payload, err := link.Reader.ReadMultiBuffer()
   208  			if err != nil {
   209  				return err
   210  			}
   211  			if err := responseWriter.WriteMultiBuffer(payload); err != nil {
   212  				return err
   213  			}
   214  		}
   215  
   216  		if err := bufferedWriter.SetBuffered(false); err != nil {
   217  			return err
   218  		}
   219  
   220  		if err := buf.Copy(link.Reader, responseWriter, buf.UpdateActivity(timer)); err != nil {
   221  			return newError("failed to transport all TCP response").Base(err)
   222  		}
   223  
   224  		return nil
   225  	}
   226  
   227  	requestDone := func() error {
   228  		defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly)
   229  
   230  		if err := buf.Copy(bodyReader, link.Writer, buf.UpdateActivity(timer)); err != nil {
   231  			return newError("failed to transport all TCP request").Base(err)
   232  		}
   233  
   234  		return nil
   235  	}
   236  
   237  	requestDoneAndCloseWriter := task.OnSuccess(requestDone, task.Close(link.Writer))
   238  	if err := task.Run(ctx, requestDoneAndCloseWriter, responseDone); err != nil {
   239  		common.Interrupt(link.Reader)
   240  		common.Interrupt(link.Writer)
   241  		return newError("connection ends").Base(err)
   242  	}
   243  
   244  	return nil
   245  }
   246  
   247  func init() {
   248  	common.Must(common.RegisterConfig((*ServerConfig)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
   249  		return NewServer(ctx, config.(*ServerConfig))
   250  	}))
   251  }