github.com/imgk/caddy-trojan@v0.0.0-20221206043256-2631719e16c8/socks/addr.go (about)

     1  package socks
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"io"
     7  	"net"
     8  	"strconv"
     9  )
    10  
    11  // MaxAddrLen is the maximum length of socks.Addr
    12  const MaxAddrLen = 1 + 1 + 255 + 2
    13  
    14  var (
    15  	// ErrInvalidAddrType is ...
    16  	ErrInvalidAddrType = errors.New("invalid address type")
    17  	// ErrInvalidAddrLen is ...
    18  	ErrInvalidAddrLen = errors.New("invalid address length")
    19  )
    20  
    21  const (
    22  	// AddrTypeIPv4 is ...
    23  	AddrTypeIPv4 = 1
    24  	// AddrTypeDomain is ...
    25  	AddrTypeDomain = 3
    26  	// AddrTypeIPv6 is ...
    27  	AddrTypeIPv6 = 4
    28  )
    29  
    30  // Addr is ...
    31  type Addr struct {
    32  	data []byte
    33  }
    34  
    35  // Network is ...
    36  func (*Addr) Network() string {
    37  	return "socks"
    38  }
    39  
    40  // String is ...
    41  func (addr *Addr) String() string {
    42  	switch addr.data[0] {
    43  	case AddrTypeIPv4:
    44  		host := net.IP(addr.data[1 : 1+net.IPv4len]).String()
    45  		port := strconv.Itoa(int(addr.data[1+net.IPv4len])<<8 | int(addr.data[1+net.IPv4len+1]))
    46  		return net.JoinHostPort(host, port)
    47  	case AddrTypeDomain:
    48  		host := string(addr.data[2 : 2+addr.data[1]])
    49  		port := strconv.Itoa(int(addr.data[2+addr.data[1]])<<8 | int(addr.data[2+addr.data[1]+1]))
    50  		return net.JoinHostPort(host, port)
    51  	case AddrTypeIPv6:
    52  		host := net.IP(addr.data[1 : 1+net.IPv6len]).String()
    53  		port := strconv.Itoa(int(addr.data[1+net.IPv6len])<<8 | int(addr.data[1+net.IPv6len+1]))
    54  		return net.JoinHostPort(host, port)
    55  	default:
    56  		return ""
    57  	}
    58  }
    59  
    60  // Len is ...
    61  func (addr *Addr) Len() int {
    62  	return len(addr.data)
    63  }
    64  
    65  // AppendTo is ...
    66  func (addr *Addr) AppendTo(b []byte) []byte {
    67  	return append(b, addr.data...)
    68  }
    69  
    70  // Bytes is ...
    71  func (addr *Addr) Bytes() []byte {
    72  	return addr.data
    73  }
    74  
    75  // Append is ...
    76  func (addr *Addr) Append(b []byte) []byte {
    77  	return append(addr.data, b...)
    78  }
    79  
    80  // ReadAddr is ....
    81  func ReadAddr(conn io.Reader) (*Addr, error) {
    82  	return ReadAddrBuffer(conn, make([]byte, MaxAddrLen))
    83  }
    84  
    85  // ReadAddrBuffer is ...
    86  func ReadAddrBuffer(conn io.Reader, addr []byte) (*Addr, error) {
    87  	_, err := io.ReadFull(conn, addr[:2])
    88  	if err != nil {
    89  		return nil, err
    90  	}
    91  
    92  	switch addr[0] {
    93  	case AddrTypeIPv4:
    94  		n := 1 + net.IPv4len + 2
    95  		_, err := io.ReadFull(conn, addr[2:n])
    96  		if err != nil {
    97  			return nil, err
    98  		}
    99  
   100  		return &Addr{data: addr[:n]}, nil
   101  	case AddrTypeDomain:
   102  		n := 1 + 1 + int(addr[1]) + 2
   103  		_, err := io.ReadFull(conn, addr[2:n])
   104  		if err != nil {
   105  			return nil, err
   106  		}
   107  
   108  		return &Addr{data: addr[:n]}, nil
   109  	case AddrTypeIPv6:
   110  		n := 1 + net.IPv6len + 2
   111  		_, err := io.ReadFull(conn, addr[2:n])
   112  		if err != nil {
   113  			return nil, err
   114  		}
   115  
   116  		return &Addr{data: addr[:n]}, nil
   117  	default:
   118  		return nil, ErrInvalidAddrType
   119  	}
   120  }
   121  
   122  // ParseAddr is ...
   123  func ParseAddr(addr []byte) (*Addr, error) {
   124  	if len(addr) < 1+1+1+2 {
   125  		return nil, ErrInvalidAddrLen
   126  	}
   127  
   128  	switch addr[0] {
   129  	case AddrTypeIPv4:
   130  		n := 1 + net.IPv4len + 2
   131  		if len(addr) < n {
   132  			return nil, ErrInvalidAddrLen
   133  		}
   134  
   135  		return &Addr{data: addr[:n]}, nil
   136  	case AddrTypeDomain:
   137  		n := 1 + 1 + int(addr[1]) + 2
   138  		if len(addr) < n {
   139  			return nil, ErrInvalidAddrLen
   140  		}
   141  
   142  		return &Addr{data: addr[:n]}, nil
   143  	case AddrTypeIPv6:
   144  		n := 1 + net.IPv6len + 2
   145  		if len(addr) < n {
   146  			return nil, ErrInvalidAddrLen
   147  		}
   148  
   149  		return &Addr{data: addr[:n]}, nil
   150  	default:
   151  		return nil, ErrInvalidAddrType
   152  	}
   153  }
   154  
   155  // ResolveTCPAddr is ...
   156  func ResolveTCPAddr(addr *Addr) (*net.TCPAddr, error) {
   157  	switch addr.data[0] {
   158  	case AddrTypeIPv4:
   159  		host := net.IP(addr.data[1 : 1+net.IPv4len])
   160  		port := int(addr.data[1+net.IPv4len])<<8 | int(addr.data[1+net.IPv4len+1])
   161  		return &net.TCPAddr{IP: host, Port: port}, nil
   162  	case AddrTypeDomain:
   163  		return net.ResolveTCPAddr("tcp", addr.String())
   164  	case AddrTypeIPv6:
   165  		host := net.IP(addr.data[1 : 1+net.IPv6len])
   166  		port := int(addr.data[1+net.IPv6len])<<8 | int(addr.data[1+net.IPv6len+1])
   167  		return &net.TCPAddr{IP: host, Port: port}, nil
   168  	default:
   169  		return nil, fmt.Errorf("address type (%v) error", addr.data[0])
   170  	}
   171  }
   172  
   173  // ResolveUDPAddr is ...
   174  func ResolveUDPAddr(addr *Addr) (*net.UDPAddr, error) {
   175  	switch addr.data[0] {
   176  	case AddrTypeIPv4:
   177  		host := net.IP(addr.data[1 : 1+net.IPv4len])
   178  		port := int(addr.data[1+net.IPv4len])<<8 | int(addr.data[1+net.IPv4len+1])
   179  		return &net.UDPAddr{IP: host, Port: port}, nil
   180  	case AddrTypeDomain:
   181  		return net.ResolveUDPAddr("udp", addr.String())
   182  	case AddrTypeIPv6:
   183  		host := net.IP(addr.data[1 : 1+net.IPv6len])
   184  		port := int(addr.data[1+net.IPv6len])<<8 | int(addr.data[1+net.IPv6len+1])
   185  		return &net.UDPAddr{IP: host, Port: port}, nil
   186  	default:
   187  		return nil, fmt.Errorf("address type (%v) error", addr.data[0])
   188  	}
   189  }
   190  
   191  // ResolveAddr is ...
   192  func ResolveAddr(addr net.Addr) (*Addr, error) {
   193  	if a, ok := addr.(*Addr); ok {
   194  		return a, nil
   195  	}
   196  	return ResolveAddrBuffer(addr, make([]byte, MaxAddrLen))
   197  }
   198  
   199  // ResolveAddrBuffer is ...
   200  func ResolveAddrBuffer(addr net.Addr, b []byte) (*Addr, error) {
   201  	if nAddr, ok := addr.(*net.TCPAddr); ok {
   202  		if ipv4 := nAddr.IP.To4(); ipv4 != nil {
   203  			b[0] = AddrTypeIPv4
   204  			copy(b[1:], ipv4)
   205  			b[1+net.IPv4len] = byte(nAddr.Port >> 8)
   206  			b[1+net.IPv4len+1] = byte(nAddr.Port)
   207  
   208  			return &Addr{data: b[:1+net.IPv4len+2]}, nil
   209  		}
   210  		ipv6 := nAddr.IP.To16()
   211  
   212  		b[0] = AddrTypeIPv6
   213  		copy(b[1:], ipv6)
   214  		b[1+net.IPv6len] = byte(nAddr.Port >> 8)
   215  		b[1+net.IPv6len+1] = byte(nAddr.Port)
   216  
   217  		return &Addr{data: b[:1+net.IPv6len+2]}, nil
   218  	}
   219  
   220  	if nAddr, ok := addr.(*net.UDPAddr); ok {
   221  		if ipv4 := nAddr.IP.To4(); ipv4 != nil {
   222  			b[0] = AddrTypeIPv4
   223  			copy(b[1:], ipv4)
   224  			b[1+net.IPv4len] = byte(nAddr.Port >> 8)
   225  			b[1+net.IPv4len+1] = byte(nAddr.Port)
   226  
   227  			return &Addr{data: b[:1+net.IPv4len+2]}, nil
   228  		}
   229  		ipv6 := nAddr.IP.To16()
   230  
   231  		b[0] = AddrTypeIPv6
   232  		copy(b[1:], ipv6)
   233  		b[1+net.IPv6len] = byte(nAddr.Port >> 8)
   234  		b[1+net.IPv6len+1] = byte(nAddr.Port)
   235  
   236  		return &Addr{data: b[:1+net.IPv6len+2]}, nil
   237  	}
   238  
   239  	if nAddr, ok := addr.(*Addr); ok {
   240  		copy(b, nAddr.data)
   241  		return &Addr{data: b[:len(nAddr.data)]}, nil
   242  	}
   243  
   244  	return nil, ErrInvalidAddrType
   245  }