github.com/v2fly/v2ray-core/v4@v4.45.2/proxy/vmess/encoding/commands.go (about)

     1  package encoding
     2  
     3  import (
     4  	"encoding/binary"
     5  	"io"
     6  
     7  	"github.com/v2fly/v2ray-core/v4/common"
     8  	"github.com/v2fly/v2ray-core/v4/common/buf"
     9  	"github.com/v2fly/v2ray-core/v4/common/net"
    10  	"github.com/v2fly/v2ray-core/v4/common/protocol"
    11  	"github.com/v2fly/v2ray-core/v4/common/serial"
    12  	"github.com/v2fly/v2ray-core/v4/common/uuid"
    13  )
    14  
    15  var (
    16  	ErrCommandTypeMismatch = newError("Command type mismatch.")
    17  	ErrUnknownCommand      = newError("Unknown command.")
    18  	ErrCommandTooLarge     = newError("Command too large.")
    19  )
    20  
    21  func MarshalCommand(command interface{}, writer io.Writer) error {
    22  	if command == nil {
    23  		return ErrUnknownCommand
    24  	}
    25  
    26  	var cmdID byte
    27  	var factory CommandFactory
    28  	switch command.(type) {
    29  	case *protocol.CommandSwitchAccount:
    30  		factory = new(CommandSwitchAccountFactory)
    31  		cmdID = 1
    32  	default:
    33  		return ErrUnknownCommand
    34  	}
    35  
    36  	buffer := buf.New()
    37  	defer buffer.Release()
    38  
    39  	err := factory.Marshal(command, buffer)
    40  	if err != nil {
    41  		return err
    42  	}
    43  
    44  	auth := Authenticate(buffer.Bytes())
    45  	length := buffer.Len() + 4
    46  	if length > 255 {
    47  		return ErrCommandTooLarge
    48  	}
    49  
    50  	common.Must2(writer.Write([]byte{cmdID, byte(length), byte(auth >> 24), byte(auth >> 16), byte(auth >> 8), byte(auth)}))
    51  	common.Must2(writer.Write(buffer.Bytes()))
    52  	return nil
    53  }
    54  
    55  func UnmarshalCommand(cmdID byte, data []byte) (protocol.ResponseCommand, error) {
    56  	if len(data) <= 4 {
    57  		return nil, newError("insufficient length")
    58  	}
    59  	expectedAuth := Authenticate(data[4:])
    60  	actualAuth := binary.BigEndian.Uint32(data[:4])
    61  	if expectedAuth != actualAuth {
    62  		return nil, newError("invalid auth")
    63  	}
    64  
    65  	var factory CommandFactory
    66  	switch cmdID {
    67  	case 1:
    68  		factory = new(CommandSwitchAccountFactory)
    69  	default:
    70  		return nil, ErrUnknownCommand
    71  	}
    72  	return factory.Unmarshal(data[4:])
    73  }
    74  
    75  type CommandFactory interface {
    76  	Marshal(command interface{}, writer io.Writer) error
    77  	Unmarshal(data []byte) (interface{}, error)
    78  }
    79  
    80  type CommandSwitchAccountFactory struct{}
    81  
    82  func (f *CommandSwitchAccountFactory) Marshal(command interface{}, writer io.Writer) error {
    83  	cmd, ok := command.(*protocol.CommandSwitchAccount)
    84  	if !ok {
    85  		return ErrCommandTypeMismatch
    86  	}
    87  
    88  	hostStr := ""
    89  	if cmd.Host != nil {
    90  		hostStr = cmd.Host.String()
    91  	}
    92  	common.Must2(writer.Write([]byte{byte(len(hostStr))}))
    93  
    94  	if len(hostStr) > 0 {
    95  		common.Must2(writer.Write([]byte(hostStr)))
    96  	}
    97  
    98  	common.Must2(serial.WriteUint16(writer, cmd.Port.Value()))
    99  
   100  	idBytes := cmd.ID.Bytes()
   101  	common.Must2(writer.Write(idBytes))
   102  	common.Must2(serial.WriteUint16(writer, cmd.AlterIds))
   103  	common.Must2(writer.Write([]byte{byte(cmd.Level)}))
   104  
   105  	common.Must2(writer.Write([]byte{cmd.ValidMin}))
   106  	return nil
   107  }
   108  
   109  func (f *CommandSwitchAccountFactory) Unmarshal(data []byte) (interface{}, error) {
   110  	cmd := new(protocol.CommandSwitchAccount)
   111  	if len(data) == 0 {
   112  		return nil, newError("insufficient length.")
   113  	}
   114  	lenHost := int(data[0])
   115  	if len(data) < lenHost+1 {
   116  		return nil, newError("insufficient length.")
   117  	}
   118  	if lenHost > 0 {
   119  		cmd.Host = net.ParseAddress(string(data[1 : 1+lenHost]))
   120  	}
   121  	portStart := 1 + lenHost
   122  	if len(data) < portStart+2 {
   123  		return nil, newError("insufficient length.")
   124  	}
   125  	cmd.Port = net.PortFromBytes(data[portStart : portStart+2])
   126  	idStart := portStart + 2
   127  	if len(data) < idStart+16 {
   128  		return nil, newError("insufficient length.")
   129  	}
   130  	cmd.ID, _ = uuid.ParseBytes(data[idStart : idStart+16])
   131  	alterIDStart := idStart + 16
   132  	if len(data) < alterIDStart+2 {
   133  		return nil, newError("insufficient length.")
   134  	}
   135  	cmd.AlterIds = binary.BigEndian.Uint16(data[alterIDStart : alterIDStart+2])
   136  	levelStart := alterIDStart + 2
   137  	if len(data) < levelStart+1 {
   138  		return nil, newError("insufficient length.")
   139  	}
   140  	cmd.Level = uint32(data[levelStart])
   141  	timeStart := levelStart + 1
   142  	if len(data) < timeStart+1 {
   143  		return nil, newError("insufficient length.")
   144  	}
   145  	cmd.ValidMin = data[timeStart]
   146  	return cmd, nil
   147  }