github.com/ipfans/trojan-go@v0.11.0/tunnel/metadata.go (about)

     1  package tunnel
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/binary"
     6  	"fmt"
     7  	"io"
     8  	"net"
     9  	"strconv"
    10  
    11  	"github.com/ipfans/trojan-go/common"
    12  )
    13  
    14  type Command byte
    15  
    16  type Metadata struct {
    17  	Command
    18  	*Address
    19  }
    20  
    21  func (r *Metadata) ReadFrom(rr io.Reader) error {
    22  	byteBuf := [1]byte{}
    23  	_, err := io.ReadFull(rr, byteBuf[:])
    24  	if err != nil {
    25  		return err
    26  	}
    27  	r.Command = Command(byteBuf[0])
    28  	r.Address = new(Address)
    29  	err = r.Address.ReadFrom(rr)
    30  	if err != nil {
    31  		return common.NewError("failed to marshal address").Base(err)
    32  	}
    33  	return nil
    34  }
    35  
    36  func (r *Metadata) WriteTo(w io.Writer) error {
    37  	buf := bytes.NewBuffer(make([]byte, 0, 64))
    38  	buf.WriteByte(byte(r.Command))
    39  	if err := r.Address.WriteTo(buf); err != nil {
    40  		return err
    41  	}
    42  	// use tcp by default
    43  	r.Address.NetworkType = "tcp"
    44  	_, err := w.Write(buf.Bytes())
    45  	return err
    46  }
    47  
    48  func (r *Metadata) Network() string {
    49  	return r.Address.Network()
    50  }
    51  
    52  func (r *Metadata) String() string {
    53  	return r.Address.String()
    54  }
    55  
    56  type AddressType byte
    57  
    58  const (
    59  	IPv4       AddressType = 1
    60  	DomainName AddressType = 3
    61  	IPv6       AddressType = 4
    62  )
    63  
    64  type Address struct {
    65  	DomainName  string
    66  	Port        int
    67  	NetworkType string
    68  	net.IP
    69  	AddressType
    70  }
    71  
    72  func (a *Address) String() string {
    73  	switch a.AddressType {
    74  	case IPv4:
    75  		return fmt.Sprintf("%s:%d", a.IP.String(), a.Port)
    76  	case IPv6:
    77  		return fmt.Sprintf("[%s]:%d", a.IP.String(), a.Port)
    78  	case DomainName:
    79  		return fmt.Sprintf("%s:%d", a.DomainName, a.Port)
    80  	default:
    81  		return "INVALID_ADDRESS_TYPE"
    82  	}
    83  }
    84  
    85  func (a *Address) Network() string {
    86  	return a.NetworkType
    87  }
    88  
    89  func (a *Address) ResolveIP() (net.IP, error) {
    90  	if a.AddressType == IPv4 || a.AddressType == IPv6 {
    91  		return a.IP, nil
    92  	}
    93  	if a.IP != nil {
    94  		return a.IP, nil
    95  	}
    96  	addr, err := net.ResolveIPAddr("ip", a.DomainName)
    97  	if err != nil {
    98  		return nil, err
    99  	}
   100  	a.IP = addr.IP
   101  	return addr.IP, nil
   102  }
   103  
   104  func NewAddressFromAddr(network string, addr string) (*Address, error) {
   105  	host, portStr, err := net.SplitHostPort(addr)
   106  	if err != nil {
   107  		return nil, err
   108  	}
   109  	port, err := strconv.ParseInt(portStr, 10, 32)
   110  	common.Must(err)
   111  	return NewAddressFromHostPort(network, host, int(port)), nil
   112  }
   113  
   114  func NewAddressFromHostPort(network string, host string, port int) *Address {
   115  	if ip := net.ParseIP(host); ip != nil {
   116  		if ip.To4() != nil {
   117  			return &Address{
   118  				IP:          ip,
   119  				Port:        port,
   120  				AddressType: IPv4,
   121  				NetworkType: network,
   122  			}
   123  		}
   124  		return &Address{
   125  			IP:          ip,
   126  			Port:        port,
   127  			AddressType: IPv6,
   128  			NetworkType: network,
   129  		}
   130  	}
   131  	return &Address{
   132  		DomainName:  host,
   133  		Port:        port,
   134  		AddressType: DomainName,
   135  		NetworkType: network,
   136  	}
   137  }
   138  
   139  func (a *Address) ReadFrom(r io.Reader) error {
   140  	byteBuf := [1]byte{}
   141  	_, err := io.ReadFull(r, byteBuf[:])
   142  	if err != nil {
   143  		return common.NewError("unable to read ATYP").Base(err)
   144  	}
   145  	a.AddressType = AddressType(byteBuf[0])
   146  	switch a.AddressType {
   147  	case IPv4:
   148  		var buf [6]byte
   149  		_, err := io.ReadFull(r, buf[:])
   150  		if err != nil {
   151  			return common.NewError("failed to read IPv4").Base(err)
   152  		}
   153  		a.IP = buf[0:4]
   154  		a.Port = int(binary.BigEndian.Uint16(buf[4:6]))
   155  	case IPv6:
   156  		var buf [18]byte
   157  		_, err := io.ReadFull(r, buf[:])
   158  		if err != nil {
   159  			return common.NewError("failed to read IPv6").Base(err)
   160  		}
   161  		a.IP = buf[0:16]
   162  		a.Port = int(binary.BigEndian.Uint16(buf[16:18]))
   163  	case DomainName:
   164  		_, err := io.ReadFull(r, byteBuf[:])
   165  		length := byteBuf[0]
   166  		if err != nil {
   167  			return common.NewError("failed to read domain name length")
   168  		}
   169  		buf := make([]byte, length+2)
   170  		_, err = io.ReadFull(r, buf)
   171  		if err != nil {
   172  			return common.NewError("failed to read domain name")
   173  		}
   174  		// the fucking browser uses IP as a domain name sometimes
   175  		host := buf[0:length]
   176  		if ip := net.ParseIP(string(host)); ip != nil {
   177  			a.IP = ip
   178  			if ip.To4() != nil {
   179  				a.AddressType = IPv4
   180  			} else {
   181  				a.AddressType = IPv6
   182  			}
   183  		} else {
   184  			a.DomainName = string(host)
   185  		}
   186  		a.Port = int(binary.BigEndian.Uint16(buf[length : length+2]))
   187  	default:
   188  		return common.NewError("invalid ATYP " + strconv.FormatInt(int64(a.AddressType), 10))
   189  	}
   190  	return nil
   191  }
   192  
   193  func (a *Address) WriteTo(w io.Writer) error {
   194  	_, err := w.Write([]byte{byte(a.AddressType)})
   195  	if err != nil {
   196  		return err
   197  	}
   198  	switch a.AddressType {
   199  	case DomainName:
   200  		w.Write([]byte{byte(len(a.DomainName))})
   201  		_, err = w.Write([]byte(a.DomainName))
   202  	case IPv4:
   203  		_, err = w.Write(a.IP.To4())
   204  	case IPv6:
   205  		_, err = w.Write(a.IP.To16())
   206  	default:
   207  		return common.NewError("invalid ATYP " + strconv.FormatInt(int64(a.AddressType), 10))
   208  	}
   209  	if err != nil {
   210  		return err
   211  	}
   212  	port := [2]byte{}
   213  	binary.BigEndian.PutUint16(port[:], uint16(a.Port))
   214  	_, err = w.Write(port[:])
   215  	return err
   216  }