github.com/AntonOrnatskyi/goproxy@v0.0.0-20190205095733-4526a9fa18b4/utils/ss/conn.go (about)

     1  package ss
     2  
     3  import (
     4  	"encoding/binary"
     5  	"fmt"
     6  	"io"
     7  	"net"
     8  	"strconv"
     9  )
    10  
    11  const (
    12  	OneTimeAuthMask byte = 0x10
    13  	AddrMask        byte = 0xf
    14  )
    15  
    16  type Conn struct {
    17  	net.Conn
    18  	*Cipher
    19  	readBuf  []byte
    20  	writeBuf []byte
    21  	chunkId  uint32
    22  }
    23  
    24  func NewConn(c net.Conn, cipher *Cipher) *Conn {
    25  	return &Conn{
    26  		Conn:     c,
    27  		Cipher:   cipher,
    28  		readBuf:  leakyBuf.Get(),
    29  		writeBuf: leakyBuf.Get()}
    30  }
    31  
    32  func (c *Conn) Close() error {
    33  	leakyBuf.Put(c.readBuf)
    34  	leakyBuf.Put(c.writeBuf)
    35  	return c.Conn.Close()
    36  }
    37  
    38  func RawAddr(addr string) (buf []byte, err error) {
    39  	host, portStr, err := net.SplitHostPort(addr)
    40  	if err != nil {
    41  		return nil, fmt.Errorf("ss: address error %s %v", addr, err)
    42  	}
    43  	port, err := strconv.Atoi(portStr)
    44  	if err != nil {
    45  		return nil, fmt.Errorf("ss: invalid port %s", addr)
    46  	}
    47  
    48  	hostLen := len(host)
    49  	l := 1 + 1 + hostLen + 2 // addrType + lenByte + address + port
    50  	buf = make([]byte, l)
    51  	buf[0] = 3             // 3 means the address is domain name
    52  	buf[1] = byte(hostLen) // host address length  followed by host address
    53  	copy(buf[2:], host)
    54  	binary.BigEndian.PutUint16(buf[2+hostLen:2+hostLen+2], uint16(port))
    55  	return
    56  }
    57  
    58  // This is intended for use by users implementing a local socks proxy.
    59  // rawaddr shoud contain part of the data in socks request, starting from the
    60  // ATYP field. (Refer to rfc1928 for more information.)
    61  func DialWithRawAddr(rawConn *net.Conn, rawaddr []byte, server string, cipher *Cipher) (c *Conn, err error) {
    62  	var conn net.Conn
    63  	if rawConn == nil {
    64  		conn, err = net.Dial("tcp", server)
    65  	}
    66  	if err != nil {
    67  		return
    68  	}
    69  	if rawConn != nil {
    70  		c = NewConn(*rawConn, cipher)
    71  	} else {
    72  		c = NewConn(conn, cipher)
    73  	}
    74  
    75  	if _, err = c.write(rawaddr); err != nil {
    76  		c.Close()
    77  		return nil, err
    78  	}
    79  	return
    80  }
    81  
    82  // addr should be in the form of host:port
    83  func Dial(addr, server string, cipher *Cipher) (c *Conn, err error) {
    84  	ra, err := RawAddr(addr)
    85  	if err != nil {
    86  		return
    87  	}
    88  	return DialWithRawAddr(nil, ra, server, cipher)
    89  }
    90  
    91  func (c *Conn) GetIv() (iv []byte) {
    92  	iv = make([]byte, len(c.iv))
    93  	copy(iv, c.iv)
    94  	return
    95  }
    96  
    97  func (c *Conn) GetKey() (key []byte) {
    98  	key = make([]byte, len(c.key))
    99  	copy(key, c.key)
   100  	return
   101  }
   102  
   103  func (c *Conn) IsOta() bool {
   104  	return c.ota
   105  }
   106  
   107  func (c *Conn) GetAndIncrChunkId() (chunkId uint32) {
   108  	chunkId = c.chunkId
   109  	c.chunkId += 1
   110  	return
   111  }
   112  
   113  func (c *Conn) Read(b []byte) (n int, err error) {
   114  	if c.dec == nil {
   115  		iv := make([]byte, c.info.ivLen)
   116  		if _, err = io.ReadFull(c.Conn, iv); err != nil {
   117  			return
   118  		}
   119  		if err = c.initDecrypt(iv); err != nil {
   120  			return
   121  		}
   122  		if len(c.iv) == 0 {
   123  			c.iv = iv
   124  		}
   125  	}
   126  
   127  	cipherData := c.readBuf
   128  	if len(b) > len(cipherData) {
   129  		cipherData = make([]byte, len(b))
   130  	} else {
   131  		cipherData = cipherData[:len(b)]
   132  	}
   133  
   134  	n, err = c.Conn.Read(cipherData)
   135  	if n > 0 {
   136  		c.decrypt(b[0:n], cipherData[0:n])
   137  	}
   138  	return
   139  }
   140  
   141  func (c *Conn) Write(b []byte) (n int, err error) {
   142  	nn := len(b)
   143  
   144  	headerLen := len(b) - nn
   145  
   146  	n, err = c.write(b)
   147  	// Make sure <= 0 <= len(b), where b is the slice passed in.
   148  	if n >= headerLen {
   149  		n -= headerLen
   150  	}
   151  	return
   152  }
   153  
   154  func (c *Conn) write(b []byte) (n int, err error) {
   155  	var iv []byte
   156  	if c.enc == nil {
   157  		iv, err = c.initEncrypt()
   158  		if err != nil {
   159  			return
   160  		}
   161  	}
   162  
   163  	cipherData := c.writeBuf
   164  	dataSize := len(b) + len(iv)
   165  	if dataSize > len(cipherData) {
   166  		cipherData = make([]byte, dataSize)
   167  	} else {
   168  		cipherData = cipherData[:dataSize]
   169  	}
   170  
   171  	if iv != nil {
   172  		// Put initialization vector in buffer, do a single write to send both
   173  		// iv and data.
   174  		copy(cipherData, iv)
   175  	}
   176  
   177  	c.encrypt(cipherData[len(iv):], b)
   178  	n, err = c.Conn.Write(cipherData)
   179  	return
   180  }