github.com/yaling888/clash@v1.53.0/transport/shadowsocks/shadowstream/stream.go (about)

     1  package shadowstream
     2  
     3  import (
     4  	"crypto/cipher"
     5  	"crypto/rand"
     6  	"io"
     7  	"net"
     8  )
     9  
    10  const bufSize = 2048
    11  
    12  type Writer struct {
    13  	io.Writer
    14  	cipher.Stream
    15  	buf [bufSize]byte
    16  }
    17  
    18  // NewWriter wraps an io.Writer with stream cipher encryption.
    19  func NewWriter(w io.Writer, s cipher.Stream) *Writer { return &Writer{Writer: w, Stream: s} }
    20  
    21  func (w *Writer) Write(p []byte) (n int, err error) {
    22  	buf := w.buf[:]
    23  	for nw := 0; n < len(p) && err == nil; n += nw {
    24  		end := n + len(buf)
    25  		if end > len(p) {
    26  			end = len(p)
    27  		}
    28  		w.XORKeyStream(buf, p[n:end])
    29  		nw, err = w.Writer.Write(buf[:end-n])
    30  	}
    31  	return
    32  }
    33  
    34  func (w *Writer) ReadFrom(r io.Reader) (n int64, err error) {
    35  	buf := w.buf[:]
    36  	for {
    37  		nr, er := r.Read(buf)
    38  		n += int64(nr)
    39  		b := buf[:nr]
    40  		w.XORKeyStream(b, b)
    41  		if _, err = w.Writer.Write(b); err != nil {
    42  			return
    43  		}
    44  		if er != nil {
    45  			if er != io.EOF { // ignore EOF as per io.ReaderFrom contract
    46  				err = er
    47  			}
    48  			return
    49  		}
    50  	}
    51  }
    52  
    53  type Reader struct {
    54  	io.Reader
    55  	cipher.Stream
    56  	buf [bufSize]byte
    57  }
    58  
    59  // NewReader wraps an io.Reader with stream cipher decryption.
    60  func NewReader(r io.Reader, s cipher.Stream) *Reader { return &Reader{Reader: r, Stream: s} }
    61  
    62  func (r *Reader) Read(p []byte) (n int, err error) {
    63  	n, err = r.Reader.Read(p)
    64  	if err != nil {
    65  		return 0, err
    66  	}
    67  	r.XORKeyStream(p, p[:n])
    68  	return
    69  }
    70  
    71  func (r *Reader) WriteTo(w io.Writer) (n int64, err error) {
    72  	buf := r.buf[:]
    73  	for {
    74  		nr, er := r.Reader.Read(buf)
    75  		if nr > 0 {
    76  			r.XORKeyStream(buf, buf[:nr])
    77  			nw, ew := w.Write(buf[:nr])
    78  			n += int64(nw)
    79  			if ew != nil {
    80  				err = ew
    81  				return
    82  			}
    83  		}
    84  		if er != nil {
    85  			if er != io.EOF { // ignore EOF as per io.Copy contract (using src.WriteTo shortcut)
    86  				err = er
    87  			}
    88  			return
    89  		}
    90  	}
    91  }
    92  
    93  // A Conn represents a Shadowsocks connection. It implements the net.Conn interface.
    94  type Conn struct {
    95  	net.Conn
    96  	Cipher
    97  	r       *Reader
    98  	w       *Writer
    99  	readIV  []byte
   100  	writeIV []byte
   101  }
   102  
   103  // NewConn wraps a stream-oriented net.Conn with stream cipher encryption/decryption.
   104  func NewConn(c net.Conn, ciph Cipher) *Conn { return &Conn{Conn: c, Cipher: ciph} }
   105  
   106  func (c *Conn) initReader() error {
   107  	if c.r == nil {
   108  		iv, err := c.ObtainReadIV()
   109  		if err != nil {
   110  			return err
   111  		}
   112  		c.r = NewReader(c.Conn, c.Decrypter(iv))
   113  	}
   114  	return nil
   115  }
   116  
   117  func (c *Conn) Read(b []byte) (int, error) {
   118  	if c.r == nil {
   119  		if err := c.initReader(); err != nil {
   120  			return 0, err
   121  		}
   122  	}
   123  	return c.r.Read(b)
   124  }
   125  
   126  func (c *Conn) WriteTo(w io.Writer) (int64, error) {
   127  	if c.r == nil {
   128  		if err := c.initReader(); err != nil {
   129  			return 0, err
   130  		}
   131  	}
   132  	return c.r.WriteTo(w)
   133  }
   134  
   135  func (c *Conn) initWriter() error {
   136  	if c.w == nil {
   137  		iv, err := c.ObtainWriteIV()
   138  		if err != nil {
   139  			return err
   140  		}
   141  		if _, err := c.Conn.Write(iv); err != nil {
   142  			return err
   143  		}
   144  		c.w = NewWriter(c.Conn, c.Encrypter(iv))
   145  	}
   146  	return nil
   147  }
   148  
   149  func (c *Conn) Write(b []byte) (int, error) {
   150  	if c.w == nil {
   151  		if err := c.initWriter(); err != nil {
   152  			return 0, err
   153  		}
   154  	}
   155  	return c.w.Write(b)
   156  }
   157  
   158  func (c *Conn) ReadFrom(r io.Reader) (int64, error) {
   159  	if c.w == nil {
   160  		if err := c.initWriter(); err != nil {
   161  			return 0, err
   162  		}
   163  	}
   164  	return c.w.ReadFrom(r)
   165  }
   166  
   167  func (c *Conn) ObtainWriteIV() ([]byte, error) {
   168  	if len(c.writeIV) == c.IVSize() {
   169  		return c.writeIV, nil
   170  	}
   171  
   172  	iv := make([]byte, c.IVSize())
   173  
   174  	if _, err := rand.Read(iv); err != nil {
   175  		return nil, err
   176  	}
   177  
   178  	c.writeIV = iv
   179  
   180  	return iv, nil
   181  }
   182  
   183  func (c *Conn) ObtainReadIV() ([]byte, error) {
   184  	if len(c.readIV) == c.IVSize() {
   185  		return c.readIV, nil
   186  	}
   187  
   188  	iv := make([]byte, c.IVSize())
   189  
   190  	if _, err := io.ReadFull(c.Conn, iv); err != nil {
   191  		return nil, err
   192  	}
   193  
   194  	c.readIV = iv
   195  
   196  	return iv, nil
   197  }