github.com/xraypb/xray-core@v1.6.6/proxy/mtproto/auth.go (about)

     1  package mtproto
     2  
     3  import (
     4  	"context"
     5  	"crypto/rand"
     6  	"crypto/sha256"
     7  	"io"
     8  	"sync"
     9  
    10  	"github.com/xraypb/xray-core/common"
    11  )
    12  
    13  const (
    14  	HeaderSize = 64
    15  )
    16  
    17  type SessionContext struct {
    18  	ConnectionType [4]byte
    19  	DataCenterID   uint16
    20  }
    21  
    22  func DefaultSessionContext() SessionContext {
    23  	return SessionContext{
    24  		ConnectionType: [4]byte{0xef, 0xef, 0xef, 0xef},
    25  		DataCenterID:   0,
    26  	}
    27  }
    28  
    29  type contextKey int32
    30  
    31  const (
    32  	sessionContextKey contextKey = iota
    33  )
    34  
    35  func ContextWithSessionContext(ctx context.Context, c SessionContext) context.Context {
    36  	return context.WithValue(ctx, sessionContextKey, c)
    37  }
    38  
    39  func SessionContextFromContext(ctx context.Context) SessionContext {
    40  	if c := ctx.Value(sessionContextKey); c != nil {
    41  		return c.(SessionContext)
    42  	}
    43  	return DefaultSessionContext()
    44  }
    45  
    46  type Authentication struct {
    47  	Header        [HeaderSize]byte
    48  	DecodingKey   [32]byte
    49  	EncodingKey   [32]byte
    50  	DecodingNonce [16]byte
    51  	EncodingNonce [16]byte
    52  }
    53  
    54  func (a *Authentication) DataCenterID() uint16 {
    55  	x := ((int16(a.Header[61]) << 8) | int16(a.Header[60]))
    56  	if x < 0 {
    57  		x = -x
    58  	}
    59  	return uint16(x) - 1
    60  }
    61  
    62  func (a *Authentication) ConnectionType() [4]byte {
    63  	var x [4]byte
    64  	copy(x[:], a.Header[56:60])
    65  	return x
    66  }
    67  
    68  func (a *Authentication) ApplySecret(b []byte) {
    69  	a.DecodingKey = sha256.Sum256(append(a.DecodingKey[:], b...))
    70  	a.EncodingKey = sha256.Sum256(append(a.EncodingKey[:], b...))
    71  }
    72  
    73  func generateRandomBytes(random []byte, connType [4]byte) {
    74  	for {
    75  		common.Must2(rand.Read(random))
    76  
    77  		if random[0] == 0xef {
    78  			continue
    79  		}
    80  
    81  		val := (uint32(random[3]) << 24) | (uint32(random[2]) << 16) | (uint32(random[1]) << 8) | uint32(random[0])
    82  		if val == 0x44414548 || val == 0x54534f50 || val == 0x20544547 || val == 0x4954504f || val == 0xeeeeeeee {
    83  			continue
    84  		}
    85  
    86  		if (uint32(random[7])<<24)|(uint32(random[6])<<16)|(uint32(random[5])<<8)|uint32(random[4]) == 0x00000000 {
    87  			continue
    88  		}
    89  
    90  		copy(random[56:60], connType[:])
    91  
    92  		return
    93  	}
    94  }
    95  
    96  func NewAuthentication(sc SessionContext) *Authentication {
    97  	auth := getAuthenticationObject()
    98  	random := auth.Header[:]
    99  	generateRandomBytes(random, sc.ConnectionType)
   100  	copy(auth.EncodingKey[:], random[8:])
   101  	copy(auth.EncodingNonce[:], random[8+32:])
   102  	keyivInverse := Inverse(random[8 : 8+32+16])
   103  	copy(auth.DecodingKey[:], keyivInverse)
   104  	copy(auth.DecodingNonce[:], keyivInverse[32:])
   105  	return auth
   106  }
   107  
   108  func ReadAuthentication(reader io.Reader) (*Authentication, error) {
   109  	auth := getAuthenticationObject()
   110  
   111  	if _, err := io.ReadFull(reader, auth.Header[:]); err != nil {
   112  		putAuthenticationObject(auth)
   113  		return nil, err
   114  	}
   115  
   116  	copy(auth.DecodingKey[:], auth.Header[8:])
   117  	copy(auth.DecodingNonce[:], auth.Header[8+32:])
   118  	keyivInverse := Inverse(auth.Header[8 : 8+32+16])
   119  	copy(auth.EncodingKey[:], keyivInverse)
   120  	copy(auth.EncodingNonce[:], keyivInverse[32:])
   121  
   122  	return auth, nil
   123  }
   124  
   125  // Inverse returns a new byte array. It is a sequence of bytes when the input is read from end to beginning.Inverse
   126  // Visible for testing only.
   127  func Inverse(b []byte) []byte {
   128  	lenb := len(b)
   129  	b2 := make([]byte, lenb)
   130  	for i, v := range b {
   131  		b2[lenb-i-1] = v
   132  	}
   133  	return b2
   134  }
   135  
   136  var authPool = sync.Pool{
   137  	New: func() interface{} {
   138  		return new(Authentication)
   139  	},
   140  }
   141  
   142  func getAuthenticationObject() *Authentication {
   143  	return authPool.Get().(*Authentication)
   144  }
   145  
   146  func putAuthenticationObject(auth *Authentication) {
   147  	authPool.Put(auth)
   148  }