github.com/Asutorufa/yuhaiin@v0.3.6-0.20240502055049-7984da7023a0/pkg/net/proxy/shadowsocksr/cipher/conn.go (about)

     1  package cipher
     2  
     3  import (
     4  	"crypto/cipher"
     5  	"crypto/rand"
     6  	"fmt"
     7  	"io"
     8  	"net"
     9  
    10  	"github.com/Asutorufa/yuhaiin/pkg/net/nat"
    11  	"github.com/Asutorufa/yuhaiin/pkg/net/proxy/shadowsocks/core"
    12  	"github.com/Asutorufa/yuhaiin/pkg/utils/pool"
    13  )
    14  
    15  type Cipher struct {
    16  	ivSize int
    17  	key    []byte
    18  	core.Cipher
    19  }
    20  
    21  func NewCipher(method, password string) (*Cipher, error) {
    22  	if method == "none" || method == "dummy" {
    23  		return &Cipher{Cipher: dummy{}}, nil
    24  	}
    25  
    26  	if password == "" {
    27  		return nil, fmt.Errorf("password is empty")
    28  	}
    29  
    30  	if method == "" {
    31  		method = "rc4-md5"
    32  	}
    33  
    34  	ss, ok := StreamCipherMethod[method]
    35  	if !ok {
    36  		return nil, fmt.Errorf("unsupported encryption method: %v", method)
    37  	}
    38  	key := core.KDF(password, ss.KeySize)
    39  	mi := ss.Creator(key)
    40  	return &Cipher{mi.IVSize(), key, &cipherConn{mi}}, nil
    41  }
    42  func (c *Cipher) IVSize() int { return c.ivSize }
    43  func (c *Cipher) Key() []byte { return c.key }
    44  
    45  // dummy cipher does not encrypt
    46  type dummy struct{}
    47  
    48  func (dummy) StreamConn(c net.Conn) net.Conn             { return c }
    49  func (dummy) PacketConn(c net.PacketConn) net.PacketConn { return c }
    50  
    51  type cipherConn struct{ CipherFactory }
    52  
    53  func (c *cipherConn) StreamConn(conn net.Conn) net.Conn { return newStreamConn(conn, c.CipherFactory) }
    54  func (c *cipherConn) PacketConn(conn net.PacketConn) net.PacketConn {
    55  	return newPacketConn(conn, c.CipherFactory)
    56  }
    57  
    58  type packetConn struct {
    59  	net.PacketConn
    60  	CipherFactory
    61  }
    62  
    63  func newPacketConn(c net.PacketConn, cipherFactory CipherFactory) net.PacketConn {
    64  	return &packetConn{c, cipherFactory}
    65  }
    66  
    67  func (p *packetConn) WriteTo(b []byte, addr net.Addr) (int, error) {
    68  	if len(b)+p.IVSize() > nat.MaxSegmentSize {
    69  		return 0, fmt.Errorf("udp data size too large")
    70  	}
    71  
    72  	buf := pool.GetBytes(nat.MaxSegmentSize)
    73  	defer pool.PutBytes(buf)
    74  
    75  	_, err := rand.Read(buf[:p.IVSize()])
    76  	if err != nil {
    77  		return 0, err
    78  	}
    79  
    80  	s, err := p.EncryptStream(buf[:p.IVSize()])
    81  	if err != nil {
    82  		return 0, err
    83  	}
    84  
    85  	s.XORKeyStream(buf[p.IVSize():], b)
    86  
    87  	if _, err = p.PacketConn.WriteTo(buf[:p.IVSize()+len(b)], addr); err != nil {
    88  		return 0, err
    89  	}
    90  
    91  	return len(b), nil
    92  }
    93  
    94  func (p *packetConn) ReadFrom(b []byte) (int, net.Addr, error) {
    95  	n, addr, err := p.PacketConn.ReadFrom(b)
    96  	if err != nil {
    97  		return 0, nil, err
    98  	}
    99  
   100  	s, err := p.DecryptStream(b[:p.IVSize()])
   101  	if err != nil {
   102  		return 0, nil, err
   103  	}
   104  
   105  	s.XORKeyStream(b[p.IVSize():], b[p.IVSize():n])
   106  	n = copy(b, b[p.IVSize():n])
   107  
   108  	return n, addr, nil
   109  }
   110  
   111  type streamConn struct {
   112  	net.Conn
   113  	cipher CipherFactory
   114  
   115  	enc, dec        cipher.Stream
   116  	writeIV, readIV []byte
   117  }
   118  
   119  func newStreamConn(c net.Conn, cipher CipherFactory) net.Conn {
   120  	return &streamConn{Conn: c, cipher: cipher}
   121  }
   122  
   123  func (c *streamConn) WriteIV() ([]byte, error) {
   124  	if c.writeIV == nil {
   125  		c.writeIV = make([]byte, c.cipher.IVSize())
   126  		if _, err := rand.Read(c.writeIV); err != nil {
   127  			return nil, err
   128  		}
   129  	}
   130  	return c.writeIV, nil
   131  }
   132  
   133  func (c *streamConn) ReadIV() ([]byte, error) {
   134  	if c.readIV == nil {
   135  		c.readIV = make([]byte, c.cipher.IVSize())
   136  		if _, err := io.ReadFull(c.Conn, c.readIV); err != nil {
   137  			return nil, err
   138  		}
   139  	}
   140  	return c.readIV, nil
   141  }
   142  
   143  func (c *streamConn) Read(b []byte) (n int, err error) {
   144  	if c.dec == nil {
   145  		readIV, err := c.ReadIV()
   146  		if err != nil {
   147  			if e, ok := err.(net.Error); ok && e.Timeout() {
   148  				return 0, e
   149  			}
   150  			return 0, fmt.Errorf("get read iv failed: %w", err)
   151  		}
   152  		c.dec, err = c.cipher.DecryptStream(readIV)
   153  		if err != nil {
   154  			return 0, fmt.Errorf("create new decor failed: %w", err)
   155  		}
   156  	}
   157  
   158  	n, err = c.Conn.Read(b)
   159  	if err != nil {
   160  		return n, err
   161  	}
   162  	c.dec.XORKeyStream(b, b[:n])
   163  	return n, nil
   164  }
   165  
   166  func (c *streamConn) Write(b []byte) (_ int, err error) {
   167  	if c.enc == nil {
   168  		writeIV, err := c.WriteIV()
   169  		if err != nil {
   170  			return 0, fmt.Errorf("get write iv failed: %w", err)
   171  		}
   172  		c.enc, err = c.cipher.EncryptStream(writeIV)
   173  		if err != nil {
   174  			return 0, err
   175  		}
   176  
   177  		_, err = c.Conn.Write(writeIV)
   178  		if err != nil {
   179  			return 0, err
   180  		}
   181  	}
   182  
   183  	c.enc.XORKeyStream(b, b)
   184  
   185  	return c.Conn.Write(b)
   186  }