github.com/eagleql/xray-core@v1.4.4/proxy/mtproto/server.go (about)

     1  package mtproto
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"time"
     7  
     8  	"github.com/eagleql/xray-core/common"
     9  	"github.com/eagleql/xray-core/common/buf"
    10  	"github.com/eagleql/xray-core/common/crypto"
    11  	"github.com/eagleql/xray-core/common/net"
    12  	"github.com/eagleql/xray-core/common/protocol"
    13  	"github.com/eagleql/xray-core/common/session"
    14  	"github.com/eagleql/xray-core/common/signal"
    15  	"github.com/eagleql/xray-core/common/task"
    16  	"github.com/eagleql/xray-core/core"
    17  	"github.com/eagleql/xray-core/features/policy"
    18  	"github.com/eagleql/xray-core/features/routing"
    19  	"github.com/eagleql/xray-core/transport/internet"
    20  )
    21  
    22  var (
    23  	dcList = []net.Address{
    24  		net.ParseAddress("149.154.175.50"),
    25  		net.ParseAddress("149.154.167.51"),
    26  		net.ParseAddress("149.154.175.100"),
    27  		net.ParseAddress("149.154.167.91"),
    28  		net.ParseAddress("149.154.171.5"),
    29  	}
    30  )
    31  
    32  type Server struct {
    33  	user    *protocol.User
    34  	account *Account
    35  	policy  policy.Manager
    36  }
    37  
    38  func NewServer(ctx context.Context, config *ServerConfig) (*Server, error) {
    39  	if len(config.User) == 0 {
    40  		return nil, newError("no user configured.")
    41  	}
    42  
    43  	user := config.User[0]
    44  	rawAccount, err := config.User[0].GetTypedAccount()
    45  	if err != nil {
    46  		return nil, newError("invalid account").Base(err)
    47  	}
    48  	account, ok := rawAccount.(*Account)
    49  	if !ok {
    50  		return nil, newError("not a MTProto account")
    51  	}
    52  
    53  	v := core.MustFromContext(ctx)
    54  
    55  	return &Server{
    56  		user:    user,
    57  		account: account,
    58  		policy:  v.GetFeature(policy.ManagerType()).(policy.Manager),
    59  	}, nil
    60  }
    61  
    62  func (s *Server) Network() []net.Network {
    63  	return []net.Network{net.Network_TCP}
    64  }
    65  
    66  var ctype1 = []byte{0xef, 0xef, 0xef, 0xef}
    67  var ctype2 = []byte{0xee, 0xee, 0xee, 0xee}
    68  
    69  func isValidConnectionType(c [4]byte) bool {
    70  	if bytes.Equal(c[:], ctype1) {
    71  		return true
    72  	}
    73  	if bytes.Equal(c[:], ctype2) {
    74  		return true
    75  	}
    76  	return false
    77  }
    78  
    79  func (s *Server) Process(ctx context.Context, network net.Network, conn internet.Connection, dispatcher routing.Dispatcher) error {
    80  	sPolicy := s.policy.ForLevel(s.user.Level)
    81  
    82  	if err := conn.SetDeadline(time.Now().Add(sPolicy.Timeouts.Handshake)); err != nil {
    83  		newError("failed to set deadline").Base(err).WriteToLog(session.ExportIDToError(ctx))
    84  	}
    85  	auth, err := ReadAuthentication(conn)
    86  	if err != nil {
    87  		return newError("failed to read authentication header").Base(err)
    88  	}
    89  	defer putAuthenticationObject(auth)
    90  
    91  	if err := conn.SetDeadline(time.Time{}); err != nil {
    92  		newError("failed to clear deadline").Base(err).WriteToLog(session.ExportIDToError(ctx))
    93  	}
    94  
    95  	auth.ApplySecret(s.account.Secret)
    96  
    97  	decryptor := crypto.NewAesCTRStream(auth.DecodingKey[:], auth.DecodingNonce[:])
    98  	decryptor.XORKeyStream(auth.Header[:], auth.Header[:])
    99  
   100  	ct := auth.ConnectionType()
   101  	if !isValidConnectionType(ct) {
   102  		return newError("invalid connection type: ", ct)
   103  	}
   104  
   105  	dcID := auth.DataCenterID()
   106  	if dcID >= uint16(len(dcList)) {
   107  		return newError("invalid datacenter id: ", dcID)
   108  	}
   109  
   110  	dest := net.Destination{
   111  		Network: net.Network_TCP,
   112  		Address: dcList[dcID],
   113  		Port:    net.Port(443),
   114  	}
   115  
   116  	ctx, cancel := context.WithCancel(ctx)
   117  	timer := signal.CancelAfterInactivity(ctx, cancel, sPolicy.Timeouts.ConnectionIdle)
   118  	ctx = policy.ContextWithBufferPolicy(ctx, sPolicy.Buffer)
   119  
   120  	sc := SessionContext{
   121  		ConnectionType: ct,
   122  		DataCenterID:   dcID,
   123  	}
   124  	ctx = ContextWithSessionContext(ctx, sc)
   125  
   126  	link, err := dispatcher.Dispatch(ctx, dest)
   127  	if err != nil {
   128  		return newError("failed to dispatch request to: ", dest).Base(err)
   129  	}
   130  
   131  	request := func() error {
   132  		defer timer.SetTimeout(sPolicy.Timeouts.DownlinkOnly)
   133  
   134  		reader := buf.NewReader(crypto.NewCryptionReader(decryptor, conn))
   135  		return buf.Copy(reader, link.Writer, buf.UpdateActivity(timer))
   136  	}
   137  
   138  	response := func() error {
   139  		defer timer.SetTimeout(sPolicy.Timeouts.UplinkOnly)
   140  
   141  		encryptor := crypto.NewAesCTRStream(auth.EncodingKey[:], auth.EncodingNonce[:])
   142  		writer := buf.NewWriter(crypto.NewCryptionWriter(encryptor, conn))
   143  		return buf.Copy(link.Reader, writer, buf.UpdateActivity(timer))
   144  	}
   145  
   146  	var responseDoneAndCloseWriter = task.OnSuccess(response, task.Close(link.Writer))
   147  	if err := task.Run(ctx, request, responseDoneAndCloseWriter); err != nil {
   148  		common.Interrupt(link.Reader)
   149  		common.Interrupt(link.Writer)
   150  		return newError("connection ends").Base(err)
   151  	}
   152  
   153  	return nil
   154  }
   155  
   156  func init() {
   157  	common.Must(common.RegisterConfig((*ServerConfig)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
   158  		return NewServer(ctx, config.(*ServerConfig))
   159  	}))
   160  }