github.com/xraypb/xray-core@v1.6.6/proxy/vmess/encoding/commands.go (about)

     1  package encoding
     2  
     3  import (
     4  	"encoding/binary"
     5  	"io"
     6  
     7  	"github.com/xraypb/xray-core/common"
     8  	"github.com/xraypb/xray-core/common/buf"
     9  	"github.com/xraypb/xray-core/common/net"
    10  	"github.com/xraypb/xray-core/common/protocol"
    11  	"github.com/xraypb/xray-core/common/serial"
    12  	"github.com/xraypb/xray-core/common/uuid"
    13  )
    14  
    15  var (
    16  	ErrCommandTooLarge     = newError("Command too large.")
    17  	ErrCommandTypeMismatch = newError("Command type mismatch.")
    18  	ErrInvalidAuth         = newError("Invalid auth.")
    19  	ErrInsufficientLength  = newError("Insufficient length.")
    20  	ErrUnknownCommand      = newError("Unknown command.")
    21  )
    22  
    23  func MarshalCommand(command interface{}, writer io.Writer) error {
    24  	if command == nil {
    25  		return ErrUnknownCommand
    26  	}
    27  
    28  	var cmdID byte
    29  	var factory CommandFactory
    30  	switch command.(type) {
    31  	case *protocol.CommandSwitchAccount:
    32  		factory = new(CommandSwitchAccountFactory)
    33  		cmdID = 1
    34  	default:
    35  		return ErrUnknownCommand
    36  	}
    37  
    38  	buffer := buf.New()
    39  	defer buffer.Release()
    40  
    41  	err := factory.Marshal(command, buffer)
    42  	if err != nil {
    43  		return err
    44  	}
    45  
    46  	auth := Authenticate(buffer.Bytes())
    47  	length := buffer.Len() + 4
    48  	if length > 255 {
    49  		return ErrCommandTooLarge
    50  	}
    51  
    52  	common.Must2(writer.Write([]byte{cmdID, byte(length), byte(auth >> 24), byte(auth >> 16), byte(auth >> 8), byte(auth)}))
    53  	common.Must2(writer.Write(buffer.Bytes()))
    54  	return nil
    55  }
    56  
    57  func UnmarshalCommand(cmdID byte, data []byte) (protocol.ResponseCommand, error) {
    58  	if len(data) <= 4 {
    59  		return nil, ErrInsufficientLength
    60  	}
    61  	expectedAuth := Authenticate(data[4:])
    62  	actualAuth := binary.BigEndian.Uint32(data[:4])
    63  	if expectedAuth != actualAuth {
    64  		return nil, ErrInvalidAuth
    65  	}
    66  
    67  	var factory CommandFactory
    68  	switch cmdID {
    69  	case 1:
    70  		factory = new(CommandSwitchAccountFactory)
    71  	default:
    72  		return nil, ErrUnknownCommand
    73  	}
    74  	return factory.Unmarshal(data[4:])
    75  }
    76  
    77  type CommandFactory interface {
    78  	Marshal(command interface{}, writer io.Writer) error
    79  	Unmarshal(data []byte) (interface{}, error)
    80  }
    81  
    82  type CommandSwitchAccountFactory struct{}
    83  
    84  func (f *CommandSwitchAccountFactory) Marshal(command interface{}, writer io.Writer) error {
    85  	cmd, ok := command.(*protocol.CommandSwitchAccount)
    86  	if !ok {
    87  		return ErrCommandTypeMismatch
    88  	}
    89  
    90  	hostStr := ""
    91  	if cmd.Host != nil {
    92  		hostStr = cmd.Host.String()
    93  	}
    94  	common.Must2(writer.Write([]byte{byte(len(hostStr))}))
    95  
    96  	if len(hostStr) > 0 {
    97  		common.Must2(writer.Write([]byte(hostStr)))
    98  	}
    99  
   100  	common.Must2(serial.WriteUint16(writer, cmd.Port.Value()))
   101  
   102  	idBytes := cmd.ID.Bytes()
   103  	common.Must2(writer.Write(idBytes))
   104  	common.Must2(serial.WriteUint16(writer, cmd.AlterIds))
   105  	common.Must2(writer.Write([]byte{byte(cmd.Level)}))
   106  
   107  	common.Must2(writer.Write([]byte{cmd.ValidMin}))
   108  	return nil
   109  }
   110  
   111  func (f *CommandSwitchAccountFactory) Unmarshal(data []byte) (interface{}, error) {
   112  	cmd := new(protocol.CommandSwitchAccount)
   113  	if len(data) == 0 {
   114  		return nil, ErrInsufficientLength
   115  	}
   116  	lenHost := int(data[0])
   117  	if len(data) < lenHost+1 {
   118  		return nil, ErrInsufficientLength
   119  	}
   120  	if lenHost > 0 {
   121  		cmd.Host = net.ParseAddress(string(data[1 : 1+lenHost]))
   122  	}
   123  	portStart := 1 + lenHost
   124  	if len(data) < portStart+2 {
   125  		return nil, ErrInsufficientLength
   126  	}
   127  	cmd.Port = net.PortFromBytes(data[portStart : portStart+2])
   128  	idStart := portStart + 2
   129  	if len(data) < idStart+16 {
   130  		return nil, ErrInsufficientLength
   131  	}
   132  	cmd.ID, _ = uuid.ParseBytes(data[idStart : idStart+16])
   133  	alterIDStart := idStart + 16
   134  	if len(data) < alterIDStart+2 {
   135  		return nil, ErrInsufficientLength
   136  	}
   137  	cmd.AlterIds = binary.BigEndian.Uint16(data[alterIDStart : alterIDStart+2])
   138  	levelStart := alterIDStart + 2
   139  	if len(data) < levelStart+1 {
   140  		return nil, ErrInsufficientLength
   141  	}
   142  	cmd.Level = uint32(data[levelStart])
   143  	timeStart := levelStart + 1
   144  	if len(data) < timeStart+1 {
   145  		return nil, ErrInsufficientLength
   146  	}
   147  	cmd.ValidMin = data[timeStart]
   148  	return cmd, nil
   149  }