github.com/JimmyHuang454/JLS-go@v0.0.0-20230831150107-90d536585ba0/tls/jls.go (about)

     1  package tls
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"io"
     7  	"net"
     8  
     9  	r "github.com/JimmyHuang454/JLS-go/jls"
    10  )
    11  
    12  func JLSHandler(c *Conn, tlsError error) error {
    13  	if !c.config.UseJLS {
    14  		return tlsError
    15  	}
    16  
    17  	if c.isClient {
    18  		if tlsError == nil && !c.IsValidJLS {
    19  			// it is a valid TLS Client but Not JLS,
    20  			defer c.Close()
    21  			return JLSError("invalid.")
    22  			// so we must TODO: act like a normal http request at here
    23  		}
    24  	} else if tlsError != nil && !c.IsValidJLS && c.quic == nil {
    25  		// It is not JLS. Forward at here.
    26  		// TODO: if we using sing-box, we need to use its forward method, since it may take over traffic by Tun.
    27  		defer c.conn.Close()
    28  		if c.config.ServerName != "" {
    29  			server, forwardError := net.Dial("tcp", c.config.ServerName+":443")
    30  			fmt.Println(c.config.ServerName + ":443 forwarding...")
    31  			if forwardError == nil {
    32  				defer server.Close()
    33  				server.Write(c.ClientHelloRecord)
    34  				server.Write(c.ForwardClientHello)
    35  				c.ClientHelloRecord = nil // improve memory.
    36  				c.ForwardClientHello = nil
    37  				go io.Copy(server, c.conn)
    38  				io.Copy(c.conn, server) // block until forward finish.
    39  			}
    40  		}
    41  	}
    42  	return tlsError
    43  }
    44  
    45  func BuildJLSClientHello(c *Conn, hello *clientHelloMsg) {
    46  	if !c.config.UseJLS || c.IsBuildedFakeRandom {
    47  		return
    48  	}
    49  	zeroArray := BuildZeroArray()
    50  	hello.random = zeroArray
    51  	withoutBinder, _ := hello.marshalWithoutBinders()
    52  	hello.random, _ = BuildFakeRandom(c.config, withoutBinder)
    53  	copy(hello.raw[6:], hello.random)
    54  	c.IsBuildedFakeRandom = true
    55  }
    56  
    57  func BuildJLSServerHello(c *Conn, hello *serverHelloMsg) {
    58  	if !c.config.UseJLS {
    59  		return
    60  	}
    61  
    62  	hello.random = BuildZeroArray()
    63  	hello.marshal()
    64  
    65  	hello.random, _ = BuildFakeRandom(c.config, hello.raw)
    66  	copy(hello.raw[6:], hello.random)
    67  }
    68  
    69  func CheckJLSServerHello(c *Conn, serverHello *serverHelloMsg) {
    70  	c.IsValidJLS = false
    71  	if !c.config.UseJLS {
    72  		return
    73  	}
    74  	serverHello.marshal() // init
    75  	zeroArray := BuildZeroArray()
    76  	raw := make([]byte, len(serverHello.raw))
    77  	copy(raw, serverHello.raw)
    78  	copy(raw[6:], zeroArray)
    79  
    80  	c.IsValidJLS, _ = CheckFakeRandom(c.config, raw, serverHello.random)
    81  	c.config.InsecureSkipVerify = c.IsValidJLS
    82  }
    83  
    84  // return false means need to forward.
    85  func CheckJLSClientHello(c *Conn, clientHello *clientHelloMsg) (bool, error) {
    86  	c.IsValidJLS = false
    87  	if !c.config.UseJLS {
    88  		return true, JLSError("disable JLS.") // == TLS.
    89  	}
    90  	zeroArray := BuildZeroArray()
    91  	withoutBinder, err := clientHello.marshalWithoutBinders()
    92  	c.ForwardClientHello = clientHello.raw
    93  	if err != nil {
    94  		return false, JLSError("failed to get clientHello raw bytes.")
    95  	}
    96  	raw := make([]byte, len(withoutBinder))
    97  	copy(raw, withoutBinder)
    98  	copy(raw[6:], zeroArray)
    99  
   100  	c.IsValidJLS, err = CheckFakeRandom(c.config, raw, clientHello.random)
   101  	if err != nil {
   102  		return false, JLSError("failed to check fakeRandom.")
   103  	}
   104  	if !c.IsValidJLS || c.vers != VersionTLS13 {
   105  		return false, JLSError("wrong fakeRandom.")
   106  	}
   107  	if len(clientHello.keyShares) == 0 {
   108  		fmt.Println("JLS missing keyShare can be not safty.")
   109  	}
   110  	return true, nil // valid JLS.
   111  }
   112  
   113  func BuildZeroArray() []byte {
   114  	const byteLen = 32
   115  	zeroArray := make([]byte, byteLen)
   116  	for i := 0; i < byteLen; i++ {
   117  		zeroArray[i] = 0
   118  	}
   119  	return zeroArray
   120  }
   121  
   122  func BuildFakeRandom(config *Config, AuthData []byte) ([]byte, error) {
   123  	iv := append(config.JLSIV, AuthData...)
   124  	pwd := append(config.JLSPWD, AuthData...)
   125  	fakeRandom := r.NewFakeRandom(pwd, iv)
   126  
   127  	err := fakeRandom.Build()
   128  	return fakeRandom.Random, err
   129  }
   130  
   131  func CheckFakeRandom(config *Config, AuthData []byte, random []byte) (bool, error) {
   132  	iv := append(config.JLSIV, AuthData...)
   133  	pwd := append(config.JLSPWD, AuthData...)
   134  	fakeRandom := r.NewFakeRandom(pwd, iv)
   135  
   136  	IsValid, err := fakeRandom.Check(random)
   137  	return IsValid, err
   138  }
   139  
   140  func JLSError(text string) error {
   141  	return errors.New("[JLS] " + text)
   142  }