github.com/eagleql/xray-core@v1.4.4/proxy/vmess/encoding/commands.go (about)

     1  package encoding
     2  
     3  import (
     4  	"encoding/binary"
     5  	"io"
     6  
     7  	"github.com/eagleql/xray-core/common"
     8  	"github.com/eagleql/xray-core/common/buf"
     9  	"github.com/eagleql/xray-core/common/net"
    10  	"github.com/eagleql/xray-core/common/protocol"
    11  	"github.com/eagleql/xray-core/common/serial"
    12  	"github.com/eagleql/xray-core/common/uuid"
    13  )
    14  
    15  var (
    16  	ErrCommandTypeMismatch = newError("Command type mismatch.")
    17  	ErrUnknownCommand      = newError("Unknown command.")
    18  	ErrCommandTooLarge     = newError("Command too large.")
    19  )
    20  
    21  func MarshalCommand(command interface{}, writer io.Writer) error {
    22  	if command == nil {
    23  		return ErrUnknownCommand
    24  	}
    25  
    26  	var cmdID byte
    27  	var factory CommandFactory
    28  	switch command.(type) {
    29  	case *protocol.CommandSwitchAccount:
    30  		factory = new(CommandSwitchAccountFactory)
    31  		cmdID = 1
    32  	default:
    33  		return ErrUnknownCommand
    34  	}
    35  
    36  	buffer := buf.New()
    37  	defer buffer.Release()
    38  
    39  	err := factory.Marshal(command, buffer)
    40  	if err != nil {
    41  		return err
    42  	}
    43  
    44  	auth := Authenticate(buffer.Bytes())
    45  	length := buffer.Len() + 4
    46  	if length > 255 {
    47  		return ErrCommandTooLarge
    48  	}
    49  
    50  	common.Must2(writer.Write([]byte{cmdID, byte(length), byte(auth >> 24), byte(auth >> 16), byte(auth >> 8), byte(auth)}))
    51  	common.Must2(writer.Write(buffer.Bytes()))
    52  	return nil
    53  }
    54  
    55  func UnmarshalCommand(cmdID byte, data []byte) (protocol.ResponseCommand, error) {
    56  	if len(data) <= 4 {
    57  		return nil, newError("insufficient length")
    58  	}
    59  	expectedAuth := Authenticate(data[4:])
    60  	actualAuth := binary.BigEndian.Uint32(data[:4])
    61  	if expectedAuth != actualAuth {
    62  		return nil, newError("invalid auth")
    63  	}
    64  
    65  	var factory CommandFactory
    66  	switch cmdID {
    67  	case 1:
    68  		factory = new(CommandSwitchAccountFactory)
    69  	default:
    70  		return nil, ErrUnknownCommand
    71  	}
    72  	return factory.Unmarshal(data[4:])
    73  }
    74  
    75  type CommandFactory interface {
    76  	Marshal(command interface{}, writer io.Writer) error
    77  	Unmarshal(data []byte) (interface{}, error)
    78  }
    79  
    80  type CommandSwitchAccountFactory struct {
    81  }
    82  
    83  func (f *CommandSwitchAccountFactory) Marshal(command interface{}, writer io.Writer) error {
    84  	cmd, ok := command.(*protocol.CommandSwitchAccount)
    85  	if !ok {
    86  		return ErrCommandTypeMismatch
    87  	}
    88  
    89  	hostStr := ""
    90  	if cmd.Host != nil {
    91  		hostStr = cmd.Host.String()
    92  	}
    93  	common.Must2(writer.Write([]byte{byte(len(hostStr))}))
    94  
    95  	if len(hostStr) > 0 {
    96  		common.Must2(writer.Write([]byte(hostStr)))
    97  	}
    98  
    99  	common.Must2(serial.WriteUint16(writer, cmd.Port.Value()))
   100  
   101  	idBytes := cmd.ID.Bytes()
   102  	common.Must2(writer.Write(idBytes))
   103  	common.Must2(serial.WriteUint16(writer, cmd.AlterIds))
   104  	common.Must2(writer.Write([]byte{byte(cmd.Level)}))
   105  
   106  	common.Must2(writer.Write([]byte{cmd.ValidMin}))
   107  	return nil
   108  }
   109  
   110  func (f *CommandSwitchAccountFactory) Unmarshal(data []byte) (interface{}, error) {
   111  	cmd := new(protocol.CommandSwitchAccount)
   112  	if len(data) == 0 {
   113  		return nil, newError("insufficient length.")
   114  	}
   115  	lenHost := int(data[0])
   116  	if len(data) < lenHost+1 {
   117  		return nil, newError("insufficient length.")
   118  	}
   119  	if lenHost > 0 {
   120  		cmd.Host = net.ParseAddress(string(data[1 : 1+lenHost]))
   121  	}
   122  	portStart := 1 + lenHost
   123  	if len(data) < portStart+2 {
   124  		return nil, newError("insufficient length.")
   125  	}
   126  	cmd.Port = net.PortFromBytes(data[portStart : portStart+2])
   127  	idStart := portStart + 2
   128  	if len(data) < idStart+16 {
   129  		return nil, newError("insufficient length.")
   130  	}
   131  	cmd.ID, _ = uuid.ParseBytes(data[idStart : idStart+16])
   132  	alterIDStart := idStart + 16
   133  	if len(data) < alterIDStart+2 {
   134  		return nil, newError("insufficient length.")
   135  	}
   136  	cmd.AlterIds = binary.BigEndian.Uint16(data[alterIDStart : alterIDStart+2])
   137  	levelStart := alterIDStart + 2
   138  	if len(data) < levelStart+1 {
   139  		return nil, newError("insufficient length.")
   140  	}
   141  	cmd.Level = uint32(data[levelStart])
   142  	timeStart := levelStart + 1
   143  	if len(data) < timeStart {
   144  		return nil, newError("insufficient length.")
   145  	}
   146  	cmd.ValidMin = data[timeStart]
   147  	return cmd, nil
   148  }