github.com/EagleQL/Xray-core@v1.4.3/proxy/shadowsocks/protocol.go (about)

     1  package shadowsocks
     2  
     3  import (
     4  	"crypto/cipher"
     5  	"crypto/hmac"
     6  	"crypto/rand"
     7  	"crypto/sha256"
     8  	"hash/crc32"
     9  	"io"
    10  	"io/ioutil"
    11  
    12  	"github.com/xtls/xray-core/common"
    13  	"github.com/xtls/xray-core/common/buf"
    14  	"github.com/xtls/xray-core/common/crypto"
    15  	"github.com/xtls/xray-core/common/dice"
    16  	"github.com/xtls/xray-core/common/net"
    17  	"github.com/xtls/xray-core/common/protocol"
    18  )
    19  
    20  const (
    21  	Version = 1
    22  )
    23  
    24  var addrParser = protocol.NewAddressParser(
    25  	protocol.AddressFamilyByte(0x01, net.AddressFamilyIPv4),
    26  	protocol.AddressFamilyByte(0x04, net.AddressFamilyIPv6),
    27  	protocol.AddressFamilyByte(0x03, net.AddressFamilyDomain),
    28  	protocol.WithAddressTypeParser(func(b byte) byte {
    29  		return b & 0x0F
    30  	}),
    31  )
    32  
    33  type FullReader struct {
    34  	reader io.Reader
    35  	buffer []byte
    36  }
    37  
    38  func (r *FullReader) Read(p []byte) (n int, err error) {
    39  	if r.buffer != nil {
    40  		n := copy(p, r.buffer)
    41  		if n == len(r.buffer) {
    42  			r.buffer = nil
    43  		} else {
    44  			r.buffer = r.buffer[n:]
    45  		}
    46  		if n == len(p) {
    47  			return n, nil
    48  		} else {
    49  			m, err := r.reader.Read(p[n:])
    50  			return n + m, err
    51  		}
    52  	}
    53  	return r.reader.Read(p)
    54  }
    55  
    56  // ReadTCPSession reads a Shadowsocks TCP session from the given reader, returns its header and remaining parts.
    57  func ReadTCPSession(validator *Validator, reader io.Reader) (*protocol.RequestHeader, buf.Reader, error) {
    58  
    59  	hashkdf := hmac.New(sha256.New, []byte("SSBSKDF"))
    60  
    61  	behaviorSeed := crc32.ChecksumIEEE(hashkdf.Sum(nil))
    62  
    63  	behaviorRand := dice.NewDeterministicDice(int64(behaviorSeed))
    64  	BaseDrainSize := behaviorRand.Roll(3266)
    65  	RandDrainMax := behaviorRand.Roll(64) + 1
    66  	RandDrainRolled := dice.Roll(RandDrainMax)
    67  	DrainSize := BaseDrainSize + 16 + 38 + RandDrainRolled
    68  	readSizeRemain := DrainSize
    69  
    70  	var r2 buf.Reader
    71  	buffer := buf.New()
    72  	defer buffer.Release()
    73  
    74  	var user *protocol.MemoryUser
    75  	var ivLen int32
    76  	var err error
    77  
    78  	count := validator.Count()
    79  	if count == 0 {
    80  		readSizeRemain -= int(buffer.Len())
    81  		DrainConnN(reader, readSizeRemain)
    82  		return nil, nil, newError("invalid user")
    83  	} else if count > 1 {
    84  		var aead cipher.AEAD
    85  
    86  		if _, err := buffer.ReadFullFrom(reader, 50); err != nil {
    87  			readSizeRemain -= int(buffer.Len())
    88  			DrainConnN(reader, readSizeRemain)
    89  			return nil, nil, newError("failed to read 50 bytes").Base(err)
    90  		}
    91  
    92  		bs := buffer.Bytes()
    93  		user, aead, _, ivLen, err = validator.Get(bs, protocol.RequestCommandTCP)
    94  
    95  		if user != nil {
    96  			reader = &FullReader{reader, bs[ivLen:]}
    97  			auth := &crypto.AEADAuthenticator{
    98  				AEAD:           aead,
    99  				NonceGenerator: crypto.GenerateInitialAEADNonce(),
   100  			}
   101  			r2 = crypto.NewAuthenticationReader(auth, &crypto.AEADChunkSizeParser{
   102  				Auth: auth,
   103  			}, reader, protocol.TransferTypeStream, nil)
   104  		} else {
   105  			readSizeRemain -= int(buffer.Len())
   106  			DrainConnN(reader, readSizeRemain)
   107  			return nil, nil, newError("failed to match an user").Base(err)
   108  		}
   109  	} else {
   110  		user, ivLen = validator.GetOnlyUser()
   111  		account := user.Account.(*MemoryAccount)
   112  		hashkdf.Write(account.Key)
   113  		var iv []byte
   114  		if ivLen > 0 {
   115  			if _, err := buffer.ReadFullFrom(reader, ivLen); err != nil {
   116  				readSizeRemain -= int(buffer.Len())
   117  				DrainConnN(reader, readSizeRemain)
   118  				return nil, nil, newError("failed to read IV").Base(err)
   119  			}
   120  			iv = append([]byte(nil), buffer.BytesTo(ivLen)...)
   121  		}
   122  
   123  		r, err := account.Cipher.NewDecryptionReader(account.Key, iv, reader)
   124  		if err != nil {
   125  			readSizeRemain -= int(buffer.Len())
   126  			DrainConnN(reader, readSizeRemain)
   127  			return nil, nil, newError("failed to initialize decoding stream").Base(err).AtError()
   128  		}
   129  		r2 = r
   130  	}
   131  
   132  	br := &buf.BufferedReader{Reader: r2}
   133  
   134  	request := &protocol.RequestHeader{
   135  		Version: Version,
   136  		User:    user,
   137  		Command: protocol.RequestCommandTCP,
   138  	}
   139  
   140  	readSizeRemain -= int(buffer.Len())
   141  	buffer.Clear()
   142  
   143  	addr, port, err := addrParser.ReadAddressPort(buffer, br)
   144  	if err != nil {
   145  		readSizeRemain -= int(buffer.Len())
   146  		DrainConnN(reader, readSizeRemain)
   147  		return nil, nil, newError("failed to read address").Base(err)
   148  	}
   149  
   150  	request.Address = addr
   151  	request.Port = port
   152  
   153  	if request.Address == nil {
   154  		readSizeRemain -= int(buffer.Len())
   155  		DrainConnN(reader, readSizeRemain)
   156  		return nil, nil, newError("invalid remote address.")
   157  	}
   158  
   159  	return request, br, nil
   160  }
   161  
   162  func DrainConnN(reader io.Reader, n int) error {
   163  	_, err := io.CopyN(ioutil.Discard, reader, int64(n))
   164  	return err
   165  }
   166  
   167  // WriteTCPRequest writes Shadowsocks request into the given writer, and returns a writer for body.
   168  func WriteTCPRequest(request *protocol.RequestHeader, writer io.Writer) (buf.Writer, error) {
   169  	user := request.User
   170  	account := user.Account.(*MemoryAccount)
   171  
   172  	var iv []byte
   173  	if account.Cipher.IVSize() > 0 {
   174  		iv = make([]byte, account.Cipher.IVSize())
   175  		common.Must2(rand.Read(iv))
   176  		if err := buf.WriteAllBytes(writer, iv); err != nil {
   177  			return nil, newError("failed to write IV")
   178  		}
   179  	}
   180  
   181  	w, err := account.Cipher.NewEncryptionWriter(account.Key, iv, writer)
   182  	if err != nil {
   183  		return nil, newError("failed to create encoding stream").Base(err).AtError()
   184  	}
   185  
   186  	header := buf.New()
   187  
   188  	if err := addrParser.WriteAddressPort(header, request.Address, request.Port); err != nil {
   189  		return nil, newError("failed to write address").Base(err)
   190  	}
   191  
   192  	if err := w.WriteMultiBuffer(buf.MultiBuffer{header}); err != nil {
   193  		return nil, newError("failed to write header").Base(err)
   194  	}
   195  
   196  	return w, nil
   197  }
   198  
   199  func ReadTCPResponse(user *protocol.MemoryUser, reader io.Reader) (buf.Reader, error) {
   200  	account := user.Account.(*MemoryAccount)
   201  
   202  	var iv []byte
   203  	if account.Cipher.IVSize() > 0 {
   204  		iv = make([]byte, account.Cipher.IVSize())
   205  		if _, err := io.ReadFull(reader, iv); err != nil {
   206  			return nil, newError("failed to read IV").Base(err)
   207  		}
   208  	}
   209  
   210  	return account.Cipher.NewDecryptionReader(account.Key, iv, reader)
   211  }
   212  
   213  func WriteTCPResponse(request *protocol.RequestHeader, writer io.Writer) (buf.Writer, error) {
   214  	user := request.User
   215  	account := user.Account.(*MemoryAccount)
   216  
   217  	var iv []byte
   218  	if account.Cipher.IVSize() > 0 {
   219  		iv = make([]byte, account.Cipher.IVSize())
   220  		common.Must2(rand.Read(iv))
   221  		if err := buf.WriteAllBytes(writer, iv); err != nil {
   222  			return nil, newError("failed to write IV.").Base(err)
   223  		}
   224  	}
   225  
   226  	return account.Cipher.NewEncryptionWriter(account.Key, iv, writer)
   227  }
   228  
   229  func EncodeUDPPacket(request *protocol.RequestHeader, payload []byte) (*buf.Buffer, error) {
   230  	user := request.User
   231  	account := user.Account.(*MemoryAccount)
   232  
   233  	buffer := buf.New()
   234  	ivLen := account.Cipher.IVSize()
   235  	if ivLen > 0 {
   236  		common.Must2(buffer.ReadFullFrom(rand.Reader, ivLen))
   237  	}
   238  
   239  	if err := addrParser.WriteAddressPort(buffer, request.Address, request.Port); err != nil {
   240  		return nil, newError("failed to write address").Base(err)
   241  	}
   242  
   243  	buffer.Write(payload)
   244  
   245  	if err := account.Cipher.EncodePacket(account.Key, buffer); err != nil {
   246  		return nil, newError("failed to encrypt UDP payload").Base(err)
   247  	}
   248  
   249  	return buffer, nil
   250  }
   251  
   252  func DecodeUDPPacket(validator *Validator, payload *buf.Buffer) (*protocol.RequestHeader, *buf.Buffer, error) {
   253  	bs := payload.Bytes()
   254  	if len(bs) <= 32 {
   255  		return nil, nil, newError("len(bs) <= 32")
   256  	}
   257  
   258  	var user *protocol.MemoryUser
   259  	var err error
   260  
   261  	count := validator.Count()
   262  	if count == 0 {
   263  		return nil, nil, newError("invalid user")
   264  	} else if count > 1 {
   265  		var d []byte
   266  		user, _, d, _, err = validator.Get(bs, protocol.RequestCommandUDP)
   267  
   268  		if user != nil {
   269  			payload.Clear()
   270  			payload.Write(d)
   271  		} else {
   272  			return nil, nil, newError("failed to decrypt UDP payload").Base(err)
   273  		}
   274  	} else {
   275  		user, _ = validator.GetOnlyUser()
   276  		account := user.Account.(*MemoryAccount)
   277  
   278  		var iv []byte
   279  		if !account.Cipher.IsAEAD() && account.Cipher.IVSize() > 0 {
   280  			// Keep track of IV as it gets removed from payload in DecodePacket.
   281  			iv = make([]byte, account.Cipher.IVSize())
   282  			copy(iv, payload.BytesTo(account.Cipher.IVSize()))
   283  		}
   284  		if err = account.Cipher.DecodePacket(account.Key, payload); err != nil {
   285  			return nil, nil, newError("failed to decrypt UDP payload").Base(err)
   286  		}
   287  	}
   288  
   289  	request := &protocol.RequestHeader{
   290  		Version: Version,
   291  		User:    user,
   292  		Command: protocol.RequestCommandUDP,
   293  	}
   294  
   295  	payload.SetByte(0, payload.Byte(0)&0x0F)
   296  
   297  	addr, port, err := addrParser.ReadAddressPort(nil, payload)
   298  	if err != nil {
   299  		return nil, nil, newError("failed to parse address").Base(err)
   300  	}
   301  
   302  	request.Address = addr
   303  	request.Port = port
   304  
   305  	return request, payload, nil
   306  }
   307  
   308  type UDPReader struct {
   309  	Reader io.Reader
   310  	User   *protocol.MemoryUser
   311  }
   312  
   313  func (v *UDPReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
   314  	buffer := buf.New()
   315  	_, err := buffer.ReadFrom(v.Reader)
   316  	if err != nil {
   317  		buffer.Release()
   318  		return nil, err
   319  	}
   320  	validator := new(Validator)
   321  	validator.Add(v.User)
   322  
   323  	u, payload, err := DecodeUDPPacket(validator, buffer)
   324  	if err != nil {
   325  		buffer.Release()
   326  		return nil, err
   327  	}
   328  	dest := u.Destination()
   329  	payload.UDP = &dest
   330  	return buf.MultiBuffer{payload}, nil
   331  }
   332  
   333  type UDPWriter struct {
   334  	Writer  io.Writer
   335  	Request *protocol.RequestHeader
   336  }
   337  
   338  func (w *UDPWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
   339  	for {
   340  		mb2, b := buf.SplitFirst(mb)
   341  		mb = mb2
   342  		if b == nil {
   343  			break
   344  		}
   345  		request := w.Request
   346  		if b.UDP != nil {
   347  			request = &protocol.RequestHeader{
   348  				User:    w.Request.User,
   349  				Address: b.UDP.Address,
   350  				Port:    b.UDP.Port,
   351  			}
   352  		}
   353  		packet, err := EncodeUDPPacket(request, b.Bytes())
   354  		b.Release()
   355  		if err != nil {
   356  			buf.ReleaseMulti(mb)
   357  			return err
   358  		}
   359  		_, err = w.Writer.Write(packet.Bytes())
   360  		packet.Release()
   361  		if err != nil {
   362  			buf.ReleaseMulti(mb)
   363  			return err
   364  		}
   365  	}
   366  	return nil
   367  }