github.com/sagernet/sing@v0.4.0-beta.19.0.20240518125136-f67a0988a636/common/metadata/serializer.go (about)

     1  package metadata
     2  
     3  import (
     4  	"encoding/binary"
     5  	"io"
     6  	"net/netip"
     7  
     8  	"github.com/sagernet/sing/common"
     9  	"github.com/sagernet/sing/common/buf"
    10  	E "github.com/sagernet/sing/common/exceptions"
    11  	"github.com/sagernet/sing/common/rw"
    12  )
    13  
    14  const (
    15  	MaxSocksaddrLength   = 2 + 255 + 2
    16  	MaxIPSocksaddrLength = 1 + 16 + 2
    17  )
    18  
    19  type SerializerOption func(*Serializer)
    20  
    21  func AddressFamilyByte(b byte, f Family) SerializerOption {
    22  	return func(s *Serializer) {
    23  		s.familyMap[b] = f
    24  		s.familyByteMap[f] = b
    25  	}
    26  }
    27  
    28  func PortThenAddress() SerializerOption {
    29  	return func(s *Serializer) {
    30  		s.portFirst = true
    31  	}
    32  }
    33  
    34  type Serializer struct {
    35  	familyMap     map[byte]Family
    36  	familyByteMap map[Family]byte
    37  	portFirst     bool
    38  }
    39  
    40  func NewSerializer(options ...SerializerOption) *Serializer {
    41  	s := &Serializer{
    42  		familyMap:     make(map[byte]Family),
    43  		familyByteMap: make(map[Family]byte),
    44  	}
    45  	for _, option := range options {
    46  		option(s)
    47  	}
    48  	return s
    49  }
    50  
    51  func (s *Serializer) WriteAddress(buffer *buf.Buffer, addr Socksaddr) error {
    52  	var af Family
    53  	if !addr.IsValid() {
    54  		af = AddressFamilyEmpty
    55  	} else if addr.IsIPv4() {
    56  		af = AddressFamilyIPv4
    57  	} else if addr.IsIPv6() {
    58  		af = AddressFamilyIPv6
    59  	} else {
    60  		af = AddressFamilyFqdn
    61  	}
    62  	afByte, loaded := s.familyByteMap[af]
    63  	if !loaded {
    64  		return E.New("unsupported address")
    65  	}
    66  	err := buffer.WriteByte(afByte)
    67  	if err != nil {
    68  		return err
    69  	}
    70  	switch af {
    71  	case AddressFamilyIPv4, AddressFamilyIPv6:
    72  		_, err = buffer.Write(addr.Addr.AsSlice())
    73  	case AddressFamilyFqdn:
    74  		err = WriteSocksString(buffer, addr.Fqdn)
    75  	}
    76  	return err
    77  }
    78  
    79  func (s *Serializer) AddressLen(addr Socksaddr) int {
    80  	if !addr.IsValid() {
    81  		return 1
    82  	} else if addr.IsIPv4() {
    83  		return 5
    84  	} else if addr.IsIPv6() {
    85  		return 17
    86  	} else {
    87  		return 2 + len(addr.Fqdn)
    88  	}
    89  }
    90  
    91  func (s *Serializer) WritePort(writer io.Writer, port uint16) error {
    92  	return binary.Write(writer, binary.BigEndian, port)
    93  }
    94  
    95  func (s *Serializer) WriteAddrPort(writer io.Writer, destination Socksaddr) error {
    96  	buffer, isBuffer := writer.(*buf.Buffer)
    97  	if !isBuffer {
    98  		buffer = buf.NewSize(s.AddrPortLen(destination))
    99  		defer buffer.Release()
   100  	}
   101  	var err error
   102  	if !s.portFirst {
   103  		err = s.WriteAddress(buffer, destination)
   104  	} else {
   105  		err = s.WritePort(buffer, destination.Port)
   106  	}
   107  	if err != nil {
   108  		return err
   109  	}
   110  	if s.portFirst {
   111  		err = s.WriteAddress(buffer, destination)
   112  	} else if destination.IsValid() {
   113  		err = s.WritePort(buffer, destination.Port)
   114  	}
   115  	if err != nil {
   116  		return err
   117  	}
   118  	if !isBuffer {
   119  		err = rw.WriteBytes(writer, buffer.Bytes())
   120  	}
   121  	return err
   122  }
   123  
   124  func (s *Serializer) AddrPortLen(destination Socksaddr) int {
   125  	if destination.IsValid() {
   126  		return s.AddressLen(destination) + 2
   127  	} else {
   128  		return s.AddressLen(destination)
   129  	}
   130  }
   131  
   132  func (s *Serializer) ReadAddress(reader io.Reader) (Socksaddr, error) {
   133  	af, err := rw.ReadByte(reader)
   134  	if err != nil {
   135  		return Socksaddr{}, err
   136  	}
   137  	family := s.familyMap[af]
   138  	switch family {
   139  	case AddressFamilyFqdn:
   140  		fqdn, err := ReadSockString(reader)
   141  		if err != nil {
   142  			return Socksaddr{}, E.Cause(err, "read fqdn")
   143  		}
   144  		return ParseSocksaddrHostPort(fqdn, 0), nil
   145  	case AddressFamilyIPv4:
   146  		var addr [4]byte
   147  		_, err = io.ReadFull(reader, addr[:])
   148  		if err != nil {
   149  			return Socksaddr{}, E.Cause(err, "read ipv4 address")
   150  		}
   151  		return Socksaddr{Addr: netip.AddrFrom4(addr)}, nil
   152  	case AddressFamilyIPv6:
   153  		var addr [16]byte
   154  		_, err = io.ReadFull(reader, addr[:])
   155  		if err != nil {
   156  			return Socksaddr{}, E.Cause(err, "read ipv6 address")
   157  		}
   158  		return Socksaddr{Addr: netip.AddrFrom16(addr)}.Unwrap(), nil
   159  	case AddressFamilyEmpty:
   160  		return Socksaddr{}, nil
   161  	default:
   162  		return Socksaddr{}, E.New("unknown address family: ", af)
   163  	}
   164  }
   165  
   166  func (s *Serializer) ReadPort(reader io.Reader) (uint16, error) {
   167  	port, err := rw.ReadBytes(reader, 2)
   168  	if err != nil {
   169  		return 0, E.Cause(err, "read port")
   170  	}
   171  	return binary.BigEndian.Uint16(port), nil
   172  }
   173  
   174  func (s *Serializer) ReadAddrPort(reader io.Reader) (destination Socksaddr, err error) {
   175  	var addr Socksaddr
   176  	var port uint16
   177  	if !s.portFirst {
   178  		addr, err = s.ReadAddress(reader)
   179  	} else {
   180  		port, err = s.ReadPort(reader)
   181  	}
   182  	if err != nil {
   183  		return
   184  	}
   185  	if s.portFirst {
   186  		addr, err = s.ReadAddress(reader)
   187  	} else if addr.IsValid() {
   188  		port, err = s.ReadPort(reader)
   189  	}
   190  	if err != nil {
   191  		return
   192  	}
   193  	addr.Port = port
   194  	return addr, nil
   195  }
   196  
   197  func ReadSockString(reader io.Reader) (string, error) {
   198  	strLen, err := rw.ReadByte(reader)
   199  	if err != nil {
   200  		return "", err
   201  	}
   202  	return rw.ReadString(reader, int(strLen))
   203  }
   204  
   205  func WriteSocksString(buffer *buf.Buffer, str string) error {
   206  	strLen := len(str)
   207  	if strLen > 255 {
   208  		return E.New("fqdn too long")
   209  	}
   210  	err := buffer.WriteByte(byte(strLen))
   211  	if err != nil {
   212  		return err
   213  	}
   214  	return common.Error(buffer.WriteString(str))
   215  }