github.com/xtls/xray-core@v1.8.12-0.20240518155711-3168d27b0bdb/common/protocol/address.go (about)

     1  package protocol
     2  
     3  import (
     4  	"io"
     5  
     6  	"github.com/xtls/xray-core/common"
     7  	"github.com/xtls/xray-core/common/buf"
     8  	"github.com/xtls/xray-core/common/net"
     9  	"github.com/xtls/xray-core/common/serial"
    10  )
    11  
    12  type AddressOption func(*option)
    13  
    14  func PortThenAddress() AddressOption {
    15  	return func(p *option) {
    16  		p.portFirst = true
    17  	}
    18  }
    19  
    20  func AddressFamilyByte(b byte, f net.AddressFamily) AddressOption {
    21  	if b >= 16 {
    22  		panic("address family byte too big")
    23  	}
    24  	return func(p *option) {
    25  		p.addrTypeMap[b] = f
    26  		p.addrByteMap[f] = b
    27  	}
    28  }
    29  
    30  type AddressTypeParser func(byte) byte
    31  
    32  func WithAddressTypeParser(atp AddressTypeParser) AddressOption {
    33  	return func(p *option) {
    34  		p.typeParser = atp
    35  	}
    36  }
    37  
    38  type AddressSerializer interface {
    39  	ReadAddressPort(buffer *buf.Buffer, input io.Reader) (net.Address, net.Port, error)
    40  
    41  	WriteAddressPort(writer io.Writer, addr net.Address, port net.Port) error
    42  }
    43  
    44  const afInvalid = 255
    45  
    46  type option struct {
    47  	addrTypeMap [16]net.AddressFamily
    48  	addrByteMap [16]byte
    49  	portFirst   bool
    50  	typeParser  AddressTypeParser
    51  }
    52  
    53  // NewAddressParser creates a new AddressParser
    54  func NewAddressParser(options ...AddressOption) AddressSerializer {
    55  	var o option
    56  	for i := range o.addrByteMap {
    57  		o.addrByteMap[i] = afInvalid
    58  	}
    59  	for i := range o.addrTypeMap {
    60  		o.addrTypeMap[i] = net.AddressFamily(afInvalid)
    61  	}
    62  	for _, opt := range options {
    63  		opt(&o)
    64  	}
    65  
    66  	ap := &addressParser{
    67  		addrByteMap: o.addrByteMap,
    68  		addrTypeMap: o.addrTypeMap,
    69  	}
    70  
    71  	if o.typeParser != nil {
    72  		ap.typeParser = o.typeParser
    73  	}
    74  
    75  	if o.portFirst {
    76  		return portFirstAddressParser{ap: ap}
    77  	}
    78  
    79  	return portLastAddressParser{ap: ap}
    80  }
    81  
    82  type portFirstAddressParser struct {
    83  	ap *addressParser
    84  }
    85  
    86  func (p portFirstAddressParser) ReadAddressPort(buffer *buf.Buffer, input io.Reader) (net.Address, net.Port, error) {
    87  	if buffer == nil {
    88  		buffer = buf.New()
    89  		defer buffer.Release()
    90  	}
    91  
    92  	port, err := readPort(buffer, input)
    93  	if err != nil {
    94  		return nil, 0, err
    95  	}
    96  
    97  	addr, err := p.ap.readAddress(buffer, input)
    98  	if err != nil {
    99  		return nil, 0, err
   100  	}
   101  	return addr, port, nil
   102  }
   103  
   104  func (p portFirstAddressParser) WriteAddressPort(writer io.Writer, addr net.Address, port net.Port) error {
   105  	if err := writePort(writer, port); err != nil {
   106  		return err
   107  	}
   108  
   109  	return p.ap.writeAddress(writer, addr)
   110  }
   111  
   112  type portLastAddressParser struct {
   113  	ap *addressParser
   114  }
   115  
   116  func (p portLastAddressParser) ReadAddressPort(buffer *buf.Buffer, input io.Reader) (net.Address, net.Port, error) {
   117  	if buffer == nil {
   118  		buffer = buf.New()
   119  		defer buffer.Release()
   120  	}
   121  
   122  	addr, err := p.ap.readAddress(buffer, input)
   123  	if err != nil {
   124  		return nil, 0, err
   125  	}
   126  
   127  	port, err := readPort(buffer, input)
   128  	if err != nil {
   129  		return nil, 0, err
   130  	}
   131  
   132  	return addr, port, nil
   133  }
   134  
   135  func (p portLastAddressParser) WriteAddressPort(writer io.Writer, addr net.Address, port net.Port) error {
   136  	if err := p.ap.writeAddress(writer, addr); err != nil {
   137  		return err
   138  	}
   139  
   140  	return writePort(writer, port)
   141  }
   142  
   143  func readPort(b *buf.Buffer, reader io.Reader) (net.Port, error) {
   144  	if _, err := b.ReadFullFrom(reader, 2); err != nil {
   145  		return 0, err
   146  	}
   147  	return net.PortFromBytes(b.BytesFrom(-2)), nil
   148  }
   149  
   150  func writePort(writer io.Writer, port net.Port) error {
   151  	return common.Error2(serial.WriteUint16(writer, port.Value()))
   152  }
   153  
   154  func maybeIPPrefix(b byte) bool {
   155  	return b == '[' || (b >= '0' && b <= '9')
   156  }
   157  
   158  func isValidDomain(d string) bool {
   159  	for _, c := range d {
   160  		if !((c >= '0' && c <= '9') || (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || c == '-' || c == '.' || c == '_') {
   161  			return false
   162  		}
   163  	}
   164  	return true
   165  }
   166  
   167  type addressParser struct {
   168  	addrTypeMap [16]net.AddressFamily
   169  	addrByteMap [16]byte
   170  	typeParser  AddressTypeParser
   171  }
   172  
   173  func (p *addressParser) readAddress(b *buf.Buffer, reader io.Reader) (net.Address, error) {
   174  	if _, err := b.ReadFullFrom(reader, 1); err != nil {
   175  		return nil, err
   176  	}
   177  
   178  	addrType := b.Byte(b.Len() - 1)
   179  	if p.typeParser != nil {
   180  		addrType = p.typeParser(addrType)
   181  	}
   182  
   183  	if addrType >= 16 {
   184  		return nil, newError("unknown address type: ", addrType)
   185  	}
   186  
   187  	addrFamily := p.addrTypeMap[addrType]
   188  	if addrFamily == net.AddressFamily(afInvalid) {
   189  		return nil, newError("unknown address type: ", addrType)
   190  	}
   191  
   192  	switch addrFamily {
   193  	case net.AddressFamilyIPv4:
   194  		if _, err := b.ReadFullFrom(reader, 4); err != nil {
   195  			return nil, err
   196  		}
   197  		return net.IPAddress(b.BytesFrom(-4)), nil
   198  	case net.AddressFamilyIPv6:
   199  		if _, err := b.ReadFullFrom(reader, 16); err != nil {
   200  			return nil, err
   201  		}
   202  		return net.IPAddress(b.BytesFrom(-16)), nil
   203  	case net.AddressFamilyDomain:
   204  		if _, err := b.ReadFullFrom(reader, 1); err != nil {
   205  			return nil, err
   206  		}
   207  		domainLength := int32(b.Byte(b.Len() - 1))
   208  		if _, err := b.ReadFullFrom(reader, domainLength); err != nil {
   209  			return nil, err
   210  		}
   211  		domain := string(b.BytesFrom(-domainLength))
   212  		if maybeIPPrefix(domain[0]) {
   213  			addr := net.ParseAddress(domain)
   214  			if addr.Family().IsIP() {
   215  				return addr, nil
   216  			}
   217  		}
   218  		if !isValidDomain(domain) {
   219  			return nil, newError("invalid domain name: ", domain)
   220  		}
   221  		return net.DomainAddress(domain), nil
   222  	default:
   223  		panic("impossible case")
   224  	}
   225  }
   226  
   227  func (p *addressParser) writeAddress(writer io.Writer, address net.Address) error {
   228  	tb := p.addrByteMap[address.Family()]
   229  	if tb == afInvalid {
   230  		return newError("unknown address family", address.Family())
   231  	}
   232  
   233  	switch address.Family() {
   234  	case net.AddressFamilyIPv4, net.AddressFamilyIPv6:
   235  		if _, err := writer.Write([]byte{tb}); err != nil {
   236  			return err
   237  		}
   238  		if _, err := writer.Write(address.IP()); err != nil {
   239  			return err
   240  		}
   241  	case net.AddressFamilyDomain:
   242  		domain := address.Domain()
   243  		if isDomainTooLong(domain) {
   244  			return newError("Super long domain is not supported: ", domain)
   245  		}
   246  
   247  		if _, err := writer.Write([]byte{tb, byte(len(domain))}); err != nil {
   248  			return err
   249  		}
   250  		if _, err := writer.Write([]byte(domain)); err != nil {
   251  			return err
   252  		}
   253  	default:
   254  		panic("Unknown family type.")
   255  	}
   256  
   257  	return nil
   258  }