github.com/xraypb/xray-core@v1.6.6/proxy/shadowsocks/protocol.go (about)

     1  package shadowsocks
     2  
     3  import (
     4  	"crypto/hmac"
     5  	"crypto/rand"
     6  	"crypto/sha256"
     7  	"hash/crc32"
     8  	"io"
     9  
    10  	"github.com/xraypb/xray-core/common"
    11  	"github.com/xraypb/xray-core/common/buf"
    12  	"github.com/xraypb/xray-core/common/crypto"
    13  	"github.com/xraypb/xray-core/common/drain"
    14  	"github.com/xraypb/xray-core/common/net"
    15  	"github.com/xraypb/xray-core/common/protocol"
    16  )
    17  
    18  const (
    19  	Version = 1
    20  )
    21  
    22  var addrParser = protocol.NewAddressParser(
    23  	protocol.AddressFamilyByte(0x01, net.AddressFamilyIPv4),
    24  	protocol.AddressFamilyByte(0x04, net.AddressFamilyIPv6),
    25  	protocol.AddressFamilyByte(0x03, net.AddressFamilyDomain),
    26  	protocol.WithAddressTypeParser(func(b byte) byte {
    27  		return b & 0x0F
    28  	}),
    29  )
    30  
    31  type FullReader struct {
    32  	reader io.Reader
    33  	buffer []byte
    34  }
    35  
    36  func (r *FullReader) Read(p []byte) (n int, err error) {
    37  	if r.buffer != nil {
    38  		n := copy(p, r.buffer)
    39  		if n == len(r.buffer) {
    40  			r.buffer = nil
    41  		} else {
    42  			r.buffer = r.buffer[n:]
    43  		}
    44  		if n == len(p) {
    45  			return n, nil
    46  		} else {
    47  			m, err := r.reader.Read(p[n:])
    48  			return n + m, err
    49  		}
    50  	}
    51  	return r.reader.Read(p)
    52  }
    53  
    54  // ReadTCPSession reads a Shadowsocks TCP session from the given reader, returns its header and remaining parts.
    55  func ReadTCPSession(validator *Validator, reader io.Reader) (*protocol.RequestHeader, buf.Reader, error) {
    56  	behaviorSeed := validator.GetBehaviorSeed()
    57  	drainer, errDrain := drain.NewBehaviorSeedLimitedDrainer(int64(behaviorSeed), 16+38, 3266, 64)
    58  
    59  	if errDrain != nil {
    60  		return nil, nil, newError("failed to initialize drainer").Base(errDrain)
    61  	}
    62  
    63  	var r buf.Reader
    64  	buffer := buf.New()
    65  	defer buffer.Release()
    66  
    67  	if _, err := buffer.ReadFullFrom(reader, 50); err != nil {
    68  		drainer.AcknowledgeReceive(int(buffer.Len()))
    69  		return nil, nil, drain.WithError(drainer, reader, newError("failed to read 50 bytes").Base(err))
    70  	}
    71  
    72  	bs := buffer.Bytes()
    73  	user, aead, _, ivLen, err := validator.Get(bs, protocol.RequestCommandTCP)
    74  
    75  	switch err {
    76  	case ErrNotFound:
    77  		drainer.AcknowledgeReceive(int(buffer.Len()))
    78  		return nil, nil, drain.WithError(drainer, reader, newError("failed to match an user").Base(err))
    79  	case ErrIVNotUnique:
    80  		drainer.AcknowledgeReceive(int(buffer.Len()))
    81  		return nil, nil, drain.WithError(drainer, reader, newError("failed iv check").Base(err))
    82  	default:
    83  		reader = &FullReader{reader, bs[ivLen:]}
    84  		drainer.AcknowledgeReceive(int(ivLen))
    85  
    86  		if aead != nil {
    87  			auth := &crypto.AEADAuthenticator{
    88  				AEAD:           aead,
    89  				NonceGenerator: crypto.GenerateAEADNonceWithSize(aead.NonceSize()),
    90  			}
    91  			r = crypto.NewAuthenticationReader(auth, &crypto.AEADChunkSizeParser{
    92  				Auth: auth,
    93  			}, reader, protocol.TransferTypeStream, nil)
    94  		} else {
    95  			account := user.Account.(*MemoryAccount)
    96  			iv := append([]byte(nil), buffer.BytesTo(ivLen)...)
    97  			r, err = account.Cipher.NewDecryptionReader(account.Key, iv, reader)
    98  			if err != nil {
    99  				return nil, nil, drain.WithError(drainer, reader, newError("failed to initialize decoding stream").Base(err).AtError())
   100  			}
   101  		}
   102  	}
   103  
   104  	br := &buf.BufferedReader{Reader: r}
   105  
   106  	request := &protocol.RequestHeader{
   107  		Version: Version,
   108  		User:    user,
   109  		Command: protocol.RequestCommandTCP,
   110  	}
   111  
   112  	buffer.Clear()
   113  
   114  	addr, port, err := addrParser.ReadAddressPort(buffer, br)
   115  	if err != nil {
   116  		drainer.AcknowledgeReceive(int(buffer.Len()))
   117  		return nil, nil, drain.WithError(drainer, reader, newError("failed to read address").Base(err))
   118  	}
   119  
   120  	request.Address = addr
   121  	request.Port = port
   122  
   123  	if request.Address == nil {
   124  		drainer.AcknowledgeReceive(int(buffer.Len()))
   125  		return nil, nil, drain.WithError(drainer, reader, newError("invalid remote address."))
   126  	}
   127  
   128  	return request, br, nil
   129  }
   130  
   131  // WriteTCPRequest writes Shadowsocks request into the given writer, and returns a writer for body.
   132  func WriteTCPRequest(request *protocol.RequestHeader, writer io.Writer) (buf.Writer, error) {
   133  	user := request.User
   134  	account := user.Account.(*MemoryAccount)
   135  
   136  	var iv []byte
   137  	if account.Cipher.IVSize() > 0 {
   138  		iv = make([]byte, account.Cipher.IVSize())
   139  		common.Must2(rand.Read(iv))
   140  		if ivError := account.CheckIV(iv); ivError != nil {
   141  			return nil, newError("failed to mark outgoing iv").Base(ivError)
   142  		}
   143  		if err := buf.WriteAllBytes(writer, iv, nil); err != nil {
   144  			return nil, newError("failed to write IV")
   145  		}
   146  	}
   147  
   148  	w, err := account.Cipher.NewEncryptionWriter(account.Key, iv, writer)
   149  	if err != nil {
   150  		return nil, newError("failed to create encoding stream").Base(err).AtError()
   151  	}
   152  
   153  	header := buf.New()
   154  
   155  	if err := addrParser.WriteAddressPort(header, request.Address, request.Port); err != nil {
   156  		return nil, newError("failed to write address").Base(err)
   157  	}
   158  
   159  	if err := w.WriteMultiBuffer(buf.MultiBuffer{header}); err != nil {
   160  		return nil, newError("failed to write header").Base(err)
   161  	}
   162  
   163  	return w, nil
   164  }
   165  
   166  func ReadTCPResponse(user *protocol.MemoryUser, reader io.Reader) (buf.Reader, error) {
   167  	account := user.Account.(*MemoryAccount)
   168  
   169  	hashkdf := hmac.New(sha256.New, []byte("SSBSKDF"))
   170  	hashkdf.Write(account.Key)
   171  
   172  	behaviorSeed := crc32.ChecksumIEEE(hashkdf.Sum(nil))
   173  
   174  	drainer, err := drain.NewBehaviorSeedLimitedDrainer(int64(behaviorSeed), 16+38, 3266, 64)
   175  	if err != nil {
   176  		return nil, newError("failed to initialize drainer").Base(err)
   177  	}
   178  
   179  	var iv []byte
   180  	if account.Cipher.IVSize() > 0 {
   181  		iv = make([]byte, account.Cipher.IVSize())
   182  		if n, err := io.ReadFull(reader, iv); err != nil {
   183  			return nil, newError("failed to read IV").Base(err)
   184  		} else { // nolint: golint
   185  			drainer.AcknowledgeReceive(n)
   186  		}
   187  	}
   188  
   189  	if ivError := account.CheckIV(iv); ivError != nil {
   190  		return nil, drain.WithError(drainer, reader, newError("failed iv check").Base(ivError))
   191  	}
   192  
   193  	return account.Cipher.NewDecryptionReader(account.Key, iv, reader)
   194  }
   195  
   196  func WriteTCPResponse(request *protocol.RequestHeader, writer io.Writer) (buf.Writer, error) {
   197  	user := request.User
   198  	account := user.Account.(*MemoryAccount)
   199  
   200  	var iv []byte
   201  	if account.Cipher.IVSize() > 0 {
   202  		iv = make([]byte, account.Cipher.IVSize())
   203  		common.Must2(rand.Read(iv))
   204  		if ivError := account.CheckIV(iv); ivError != nil {
   205  			return nil, newError("failed to mark outgoing iv").Base(ivError)
   206  		}
   207  		if err := buf.WriteAllBytes(writer, iv, nil); err != nil {
   208  			return nil, newError("failed to write IV.").Base(err)
   209  		}
   210  	}
   211  
   212  	return account.Cipher.NewEncryptionWriter(account.Key, iv, writer)
   213  }
   214  
   215  func EncodeUDPPacket(request *protocol.RequestHeader, payload []byte) (*buf.Buffer, error) {
   216  	user := request.User
   217  	account := user.Account.(*MemoryAccount)
   218  
   219  	buffer := buf.New()
   220  	ivLen := account.Cipher.IVSize()
   221  	if ivLen > 0 {
   222  		common.Must2(buffer.ReadFullFrom(rand.Reader, ivLen))
   223  	}
   224  
   225  	if err := addrParser.WriteAddressPort(buffer, request.Address, request.Port); err != nil {
   226  		return nil, newError("failed to write address").Base(err)
   227  	}
   228  
   229  	buffer.Write(payload)
   230  
   231  	if err := account.Cipher.EncodePacket(account.Key, buffer); err != nil {
   232  		return nil, newError("failed to encrypt UDP payload").Base(err)
   233  	}
   234  
   235  	return buffer, nil
   236  }
   237  
   238  func DecodeUDPPacket(validator *Validator, payload *buf.Buffer) (*protocol.RequestHeader, *buf.Buffer, error) {
   239  	bs := payload.Bytes()
   240  	if len(bs) <= 32 {
   241  		return nil, nil, newError("len(bs) <= 32")
   242  	}
   243  
   244  	user, _, d, _, err := validator.Get(bs, protocol.RequestCommandUDP)
   245  	switch err {
   246  	case ErrIVNotUnique:
   247  		return nil, nil, newError("failed iv check").Base(err)
   248  	case ErrNotFound:
   249  		return nil, nil, newError("failed to match an user").Base(err)
   250  	default:
   251  		account := user.Account.(*MemoryAccount)
   252  		if account.Cipher.IsAEAD() {
   253  			payload.Clear()
   254  			payload.Write(d)
   255  		} else {
   256  			if account.Cipher.IVSize() > 0 {
   257  				iv := make([]byte, account.Cipher.IVSize())
   258  				copy(iv, payload.BytesTo(account.Cipher.IVSize()))
   259  			}
   260  			if err = account.Cipher.DecodePacket(account.Key, payload); err != nil {
   261  				return nil, nil, newError("failed to decrypt UDP payload").Base(err)
   262  			}
   263  		}
   264  	}
   265  
   266  	request := &protocol.RequestHeader{
   267  		Version: Version,
   268  		User:    user,
   269  		Command: protocol.RequestCommandUDP,
   270  	}
   271  
   272  	payload.SetByte(0, payload.Byte(0)&0x0F)
   273  
   274  	addr, port, err := addrParser.ReadAddressPort(nil, payload)
   275  	if err != nil {
   276  		return nil, nil, newError("failed to parse address").Base(err)
   277  	}
   278  
   279  	request.Address = addr
   280  	request.Port = port
   281  
   282  	return request, payload, nil
   283  }
   284  
   285  type UDPReader struct {
   286  	Reader io.Reader
   287  	User   *protocol.MemoryUser
   288  }
   289  
   290  func (v *UDPReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
   291  	buffer := buf.New()
   292  	_, err := buffer.ReadFrom(v.Reader)
   293  	if err != nil {
   294  		buffer.Release()
   295  		return nil, err
   296  	}
   297  	validator := new(Validator)
   298  	validator.Add(v.User)
   299  
   300  	u, payload, err := DecodeUDPPacket(validator, buffer)
   301  	if err != nil {
   302  		buffer.Release()
   303  		return nil, err
   304  	}
   305  	dest := u.Destination()
   306  	payload.UDP = &dest
   307  	return buf.MultiBuffer{payload}, nil
   308  }
   309  
   310  type UDPWriter struct {
   311  	Writer  io.Writer
   312  	Request *protocol.RequestHeader
   313  }
   314  
   315  func (w *UDPWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
   316  	for {
   317  		mb2, b := buf.SplitFirst(mb)
   318  		mb = mb2
   319  		if b == nil {
   320  			break
   321  		}
   322  		request := w.Request
   323  		if b.UDP != nil {
   324  			request = &protocol.RequestHeader{
   325  				User:    w.Request.User,
   326  				Address: b.UDP.Address,
   327  				Port:    b.UDP.Port,
   328  			}
   329  		}
   330  		packet, err := EncodeUDPPacket(request, b.Bytes())
   331  		b.Release()
   332  		if err != nil {
   333  			buf.ReleaseMulti(mb)
   334  			return err
   335  		}
   336  		_, err = w.Writer.Write(packet.Bytes())
   337  		packet.Release()
   338  		if err != nil {
   339  			buf.ReleaseMulti(mb)
   340  			return err
   341  		}
   342  	}
   343  	return nil
   344  }