github.com/psiphon-labs/psiphon-tunnel-core@v2.0.28+incompatible/psiphon/sessionTicket_test.go (about)

     1  /*
     2   * Copyright (c) 2016, Psiphon Inc.
     3   * All rights reserved.
     4   *
     5   * This program is free software: you can redistribute it and/or modify
     6   * it under the terms of the GNU General Public License as published by
     7   * the Free Software Foundation, either version 3 of the License, or
     8   * (at your option) any later version.
     9   *
    10   * This program is distributed in the hope that it will be useful,
    11   * but WITHOUT ANY WARRANTY; without even the implied warranty of
    12   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    13   * GNU General Public License for more details.
    14   *
    15   * You should have received a copy of the GNU General Public License
    16   * along with this program.  If not, see <http://www.gnu.org/licenses/>.
    17   *
    18   */
    19  
    20  package psiphon
    21  
    22  import (
    23  	"crypto/rand"
    24  	"crypto/rsa"
    25  	"crypto/sha1"
    26  	"crypto/x509"
    27  	"crypto/x509/pkix"
    28  	"encoding/pem"
    29  	std_errors "errors"
    30  	"io"
    31  	"math/big"
    32  	"net"
    33  	"testing"
    34  	"time"
    35  
    36  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/parameters"
    37  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/protocol"
    38  	tris "github.com/Psiphon-Labs/tls-tris"
    39  	utls "github.com/Psiphon-Labs/utls"
    40  )
    41  
    42  func TestObfuscatedSessionTicket(t *testing.T) {
    43  
    44  	tlsProfiles := []string{
    45  		protocol.TLS_PROFILE_CHROME_58,
    46  		protocol.TLS_PROFILE_FIREFOX_55,
    47  		protocol.TLS_PROFILE_RANDOMIZED,
    48  	}
    49  
    50  	for _, tlsProfile := range tlsProfiles {
    51  		t.Run(tlsProfile, func(t *testing.T) {
    52  			runObfuscatedSessionTicket(t, tlsProfile)
    53  		})
    54  	}
    55  }
    56  
    57  func runObfuscatedSessionTicket(t *testing.T, tlsProfile string) {
    58  
    59  	params, err := parameters.NewParameters(nil)
    60  	if err != nil {
    61  		t.Fatalf("NewParameters failed: %s\n", err)
    62  	}
    63  
    64  	var standardSessionTicketKey [32]byte
    65  	rand.Read(standardSessionTicketKey[:])
    66  
    67  	var obfuscatedSessionTicketSharedSecret [32]byte
    68  	rand.Read(obfuscatedSessionTicketSharedSecret[:])
    69  
    70  	clientConfig := &utls.Config{
    71  		InsecureSkipVerify: true,
    72  	}
    73  
    74  	certificate, err := generateCertificate()
    75  	if err != nil {
    76  		t.Fatalf("generateCertificate failed: %s", err)
    77  	}
    78  
    79  	serverConfig := &tris.Config{
    80  		Certificates:                []tris.Certificate{*certificate},
    81  		NextProtos:                  []string{"http/1.1"},
    82  		MinVersion:                  utls.VersionTLS12,
    83  		UseExtendedMasterSecret:     true,
    84  		UseObfuscatedSessionTickets: true,
    85  	}
    86  
    87  	// Note: SessionTicketKey needs to be set, or else, it appears,
    88  	// tris.Config.serverInit() will clobber the value set by
    89  	// SetSessionTicketKeys.
    90  	serverConfig.SessionTicketKey = obfuscatedSessionTicketSharedSecret
    91  	serverConfig.SetSessionTicketKeys([][32]byte{
    92  		standardSessionTicketKey, obfuscatedSessionTicketSharedSecret})
    93  
    94  	testMessage := "test"
    95  
    96  	result := make(chan error, 1)
    97  
    98  	report := func(err error) {
    99  		select {
   100  		case result <- err:
   101  		default:
   102  		}
   103  	}
   104  
   105  	listening := make(chan string, 1)
   106  
   107  	go func() {
   108  
   109  		listener, err := tris.Listen("tcp", ":0", serverConfig)
   110  		if err != nil {
   111  			report(err)
   112  			return
   113  		}
   114  		defer listener.Close()
   115  
   116  		listening <- listener.Addr().String()
   117  
   118  		for i := 0; i < 2; i++ {
   119  			conn, err := listener.Accept()
   120  			if err != nil {
   121  				report(err)
   122  				return
   123  			}
   124  
   125  			recv := make([]byte, len(testMessage))
   126  			_, err = io.ReadFull(conn, recv)
   127  			if err == nil && string(recv) != testMessage {
   128  				err = std_errors.New("unexpected payload")
   129  			}
   130  			conn.Close()
   131  			if err != nil {
   132  				report(err)
   133  				return
   134  			}
   135  		}
   136  
   137  		// Sends nil on success
   138  		report(nil)
   139  	}()
   140  
   141  	go func() {
   142  
   143  		serverAddress := <-listening
   144  
   145  		clientSessionCache := utls.NewLRUClientSessionCache(0)
   146  
   147  		for i := 0; i < 2; i++ {
   148  
   149  			tcpConn, err := net.Dial("tcp", serverAddress)
   150  			if err != nil {
   151  				report(err)
   152  				return
   153  			}
   154  			defer tcpConn.Close()
   155  
   156  			utlsClientHelloID, _, err := getUTLSClientHelloID(
   157  				params.Get(), tlsProfile)
   158  			if err != nil {
   159  				report(err)
   160  				return
   161  			}
   162  
   163  			tlsConn := utls.UClient(tcpConn, clientConfig, utlsClientHelloID)
   164  
   165  			tlsConn.SetSessionCache(clientSessionCache)
   166  
   167  			// The first connection will use an obfuscated session ticket and the
   168  			// second connection will use a real session ticket issued by the server.
   169  			var clientSessionState *utls.ClientSessionState
   170  			if i == 0 {
   171  				obfuscatedSessionState, err := tris.NewObfuscatedClientSessionState(
   172  					obfuscatedSessionTicketSharedSecret)
   173  				if err != nil {
   174  					report(err)
   175  					return
   176  				}
   177  				clientSessionState = utls.MakeClientSessionState(
   178  					obfuscatedSessionState.SessionTicket,
   179  					obfuscatedSessionState.Vers,
   180  					obfuscatedSessionState.CipherSuite,
   181  					obfuscatedSessionState.MasterSecret,
   182  					nil,
   183  					nil)
   184  				tlsConn.SetSessionState(clientSessionState)
   185  			}
   186  
   187  			if protocol.TLSProfileIsRandomized(tlsProfile) {
   188  				for {
   189  					err = tlsConn.BuildHandshakeState()
   190  					if err != nil {
   191  						report(err)
   192  						return
   193  					}
   194  
   195  					isTLS13 := false
   196  					for _, v := range tlsConn.HandshakeState.Hello.SupportedVersions {
   197  						if v == utls.VersionTLS13 {
   198  							isTLS13 = true
   199  							break
   200  						}
   201  					}
   202  
   203  					if !isTLS13 && tris.ContainsObfuscatedSessionTicketCipherSuite(
   204  						tlsConn.HandshakeState.Hello.CipherSuites) {
   205  						break
   206  					}
   207  
   208  					utlsClientHelloID.Seed, _ = utls.NewPRNGSeed()
   209  					tlsConn = utls.UClient(tcpConn, clientConfig, utlsClientHelloID)
   210  					tlsConn.SetSessionCache(clientSessionCache)
   211  					if i == 0 {
   212  						tlsConn.SetSessionState(clientSessionState)
   213  					}
   214  				}
   215  			}
   216  
   217  			err = tlsConn.Handshake()
   218  			if err != nil {
   219  				report(err)
   220  				return
   221  			}
   222  
   223  			if len(tlsConn.ConnectionState().PeerCertificates) > 0 {
   224  				report(std_errors.New("unexpected certificate in handshake"))
   225  				return
   226  			}
   227  
   228  			_, err = tlsConn.Write([]byte(testMessage))
   229  			if err != nil {
   230  				report(err)
   231  				return
   232  			}
   233  		}
   234  	}()
   235  
   236  	err = <-result
   237  	if err != nil {
   238  		t.Fatalf("connect failed: %s", err)
   239  	}
   240  }
   241  
   242  func generateCertificate() (*tris.Certificate, error) {
   243  
   244  	rsaKey, err := rsa.GenerateKey(rand.Reader, 2048)
   245  	if err != nil {
   246  		return nil, err
   247  	}
   248  
   249  	publicKeyBytes, err := x509.MarshalPKIXPublicKey(rsaKey.Public())
   250  	if err != nil {
   251  		return nil, err
   252  	}
   253  	subjectKeyID := sha1.Sum(publicKeyBytes)
   254  
   255  	template := x509.Certificate{
   256  		SerialNumber:          big.NewInt(1),
   257  		Subject:               pkix.Name{CommonName: "www.example.org"},
   258  		NotBefore:             time.Now().Add(-1 * time.Hour).UTC(),
   259  		NotAfter:              time.Now().Add(time.Hour).UTC(),
   260  		KeyUsage:              x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
   261  		ExtKeyUsage:           []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
   262  		BasicConstraintsValid: true,
   263  		IsCA:                  true,
   264  		SubjectKeyId:          subjectKeyID[:],
   265  		MaxPathLen:            1,
   266  		Version:               2,
   267  	}
   268  
   269  	derCert, err := x509.CreateCertificate(
   270  		rand.Reader,
   271  		&template,
   272  		&template,
   273  		rsaKey.Public(),
   274  		rsaKey)
   275  	if err != nil {
   276  		return nil, err
   277  	}
   278  
   279  	certificate := pem.EncodeToMemory(
   280  		&pem.Block{
   281  			Type:  "CERTIFICATE",
   282  			Bytes: derCert,
   283  		},
   284  	)
   285  
   286  	privateKey := pem.EncodeToMemory(
   287  		&pem.Block{
   288  			Type:  "RSA PRIVATE KEY",
   289  			Bytes: x509.MarshalPKCS1PrivateKey(rsaKey),
   290  		},
   291  	)
   292  
   293  	keyPair, err := tris.X509KeyPair(certificate, privateKey)
   294  	if err != nil {
   295  		return nil, err
   296  	}
   297  
   298  	return &keyPair, nil
   299  }