github.com/3andne/restls-client-go@v0.1.6/restls_utils.go (about)

     1  // #Restls# Begin
     2  
     3  package tls
     4  
     5  import (
     6  	"fmt"
     7  	"hash"
     8  	"math/rand"
     9  	"strings"
    10  	"sync/atomic"
    11  
    12  	"lukechampine.com/blake3"
    13  )
    14  
    15  type RestlsPlugin struct {
    16  	isClient              bool
    17  	isInbound             bool
    18  	numCipherChange       int
    19  	backupCipher          any
    20  	writingClientFinished bool
    21  	clientFinished        []byte
    22  	ConnId                int64
    23  }
    24  
    25  func (r *RestlsPlugin) initAsClientInbound(id int64) {
    26  	r.isClient = true
    27  	r.isInbound = true
    28  	r.ConnId = id
    29  }
    30  
    31  func (r *RestlsPlugin) initAsClientOutbound(id int64) {
    32  	r.isClient = true
    33  	r.isInbound = false
    34  	r.ConnId = id
    35  }
    36  
    37  var IDCounter = atomic.Int64{}
    38  
    39  func initRestlsPlugin(inPlugin *RestlsPlugin, outPlugin *RestlsPlugin) {
    40  	id := IDCounter.Add(1)
    41  	inPlugin.initAsClientInbound(id)
    42  	outPlugin.initAsClientOutbound(id)
    43  }
    44  
    45  func (r *RestlsPlugin) setBackupCipher(backupCipher ...any) {
    46  	if r.isClient && r.isInbound {
    47  		if len(backupCipher) != 1 {
    48  			panic("must provide exact 1 backup cipher")
    49  		}
    50  		r.backupCipher = backupCipher[0]
    51  	}
    52  }
    53  
    54  func (r *RestlsPlugin) changeCipher() {
    55  	debugf(nil, "[%d]RestlsPlugin changeCipher\n", r.ConnId)
    56  	r.numCipherChange += 1
    57  }
    58  
    59  func (r *RestlsPlugin) expectServerAuth(rType recordType) any {
    60  	if rType != recordTypeChangeCipherSpec && r.isClient && r.isInbound && r.numCipherChange == 1 && r.backupCipher != nil {
    61  		cipher := r.backupCipher
    62  		r.backupCipher = nil
    63  		return cipher
    64  	} else {
    65  		return nil
    66  	}
    67  }
    68  
    69  func (r *RestlsPlugin) captureClientFinished(record []byte) {
    70  	if r.isClient && !r.isInbound && r.writingClientFinished {
    71  		debugf(nil, "[%d]ClientFinished captured %v", r.ConnId, record)
    72  		r.writingClientFinished = false
    73  		r.clientFinished = append([]byte(nil), record...)
    74  	}
    75  }
    76  
    77  func (r *RestlsPlugin) WritingClientFinished() {
    78  	if !(r.isClient && !r.isInbound) {
    79  		panic("invalid operation")
    80  	}
    81  	r.writingClientFinished = true
    82  }
    83  
    84  func (r *RestlsPlugin) takeClientFinished() []byte {
    85  	if len(r.clientFinished) > 0 {
    86  		ret := r.clientFinished
    87  		r.clientFinished = nil
    88  		return ret
    89  	}
    90  	return nil
    91  }
    92  
    93  func RestlsHmac(key []byte) hash.Hash {
    94  	return blake3.New(32, key)
    95  }
    96  
    97  type Line struct {
    98  	targetLen TargetLength
    99  	command   restlsCommand
   100  }
   101  
   102  type restlsCommand interface {
   103  	toBytes() [2]byte
   104  	needInterrupt() bool
   105  }
   106  
   107  type ActResponse int8
   108  
   109  func (a ActResponse) toBytes() [2]byte {
   110  	return [2]byte{0x01, byte(a)}
   111  }
   112  
   113  func (a ActResponse) needInterrupt() bool {
   114  	return true
   115  }
   116  
   117  type ActNoop struct{}
   118  
   119  func (a ActNoop) toBytes() [2]byte {
   120  	return [2]byte{0x00, 0}
   121  }
   122  
   123  func (a ActNoop) needInterrupt() bool {
   124  	return false
   125  }
   126  
   127  func parseCommand(buf []byte) (restlsCommand, error) {
   128  	if buf[0] == 0 {
   129  		return ActNoop{}, nil
   130  	} else if buf[0] == 1 {
   131  		return ActResponse(buf[1]), nil
   132  	} else {
   133  		return nil, fmt.Errorf("unsupported restls command")
   134  	}
   135  }
   136  
   137  type TargetLength [2]int16
   138  
   139  func (t TargetLength) Len() int {
   140  	if t[1] != 0 {
   141  		return int(t[0] + int16(rand.Intn(int(t[1]))))
   142  	}
   143  	return int(t[0])
   144  }
   145  
   146  func parseRecordScript(script string) []Line {
   147  	script_split := strings.Split(strings.ReplaceAll(script, " ", ""), ",")
   148  	lines := []Line{}
   149  	for _, line_raw := range script_split {
   150  		if len(line_raw) == 0 {
   151  			continue
   152  		}
   153  		line_bytes := []byte(line_raw)
   154  		targetLen := TargetLength{getInteger(&line_bytes)}
   155  		if len(line_bytes) == 0 {
   156  			lines = append(lines, Line{targetLen, ActNoop{}})
   157  			continue
   158  		} else if line_bytes[0] == '~' || line_bytes[0] == '?' {
   159  			t := line_bytes[0]
   160  			line_bytes = line_bytes[1:]
   161  			randomRange := getInteger(&line_bytes)
   162  			if int(randomRange)+int(targetLen[0]) > 32768 {
   163  				panic("random target len > 32768")
   164  			}
   165  			targetLen[1] = randomRange
   166  			if t == '?' {
   167  				targetLen[0] = int16(targetLen.Len())
   168  				targetLen[1] = 0
   169  			}
   170  		}
   171  
   172  		if len(line_bytes) == 0 {
   173  			lines = append(lines, Line{targetLen, ActNoop{}})
   174  			continue
   175  		} else if line_bytes[0] == '<' {
   176  			line_bytes = line_bytes[1:]
   177  			numResponse := getInteger(&line_bytes)
   178  			lines = append(lines, Line{targetLen, ActResponse(numResponse)})
   179  		} else {
   180  			panic(fmt.Sprintf("invalid script %s, %v", line_raw, line_bytes))
   181  		}
   182  	}
   183  	debugf(nil, "script: %v\n", lines)
   184  	return lines
   185  }
   186  
   187  func getInteger(script *[]byte) int16 {
   188  	res := 0
   189  	i := 0
   190  	for i = 0; i < len(*script); i++ {
   191  		b := (*script)[i]
   192  		if b <= '9' && b >= '0' {
   193  			res = res*10 + int(b-'0')
   194  		} else {
   195  			break
   196  		}
   197  		if res > 32768 {
   198  			panic("target len > 32768")
   199  		}
   200  	}
   201  	*script = (*script)[i:]
   202  	return int16(res)
   203  }
   204  
   205  var curveIDMap = map[CurveID]int{
   206  	X25519:    0,
   207  	CurveP256: 1,
   208  	CurveP384: 2,
   209  }
   210  var curveIDList = []CurveID{X25519, CurveP256, CurveP384}
   211  
   212  var versionMap = map[string]versionHint{
   213  	"tls12": TLS12Hint,
   214  	"tls13": TLS13Hint,
   215  }
   216  
   217  var clientIDMap = map[string]*ClientHelloID{
   218  	"chrome":  &HelloChrome_Auto,
   219  	"firefox": &HelloFirefox_Auto,
   220  	"safari":  &HelloSafari_Auto,
   221  	"ios":     &HelloIOS_Auto,
   222  }
   223  
   224  var tls12GCMCiphers = []uint16{0xc02f, 0xc02b, 0xc030, 0xc02c}
   225  
   226  var defaultRestlsScript = "250?100<1,350~100<1,600~100,300~200,300~100"
   227  
   228  const debugLog = false
   229  
   230  func debugf(conn *Conn, format string, a ...any) {
   231  	if debugLog {
   232  		if conn != nil {
   233  			fmt.Printf("[%d]"+format, append([]any{conn.in.restlsPlugin.ConnId}, a...)...)
   234  		} else {
   235  			fmt.Printf(format, a...)
   236  		}
   237  	}
   238  }
   239  
   240  func NewRestlsConfig(serverName string, password string, versionHintString string, restlsScript string, clientIDStr string) (*Config, error) {
   241  	key := make([]byte, 32)
   242  	blake3.DeriveKey(key, "restls-traffic-key", []byte(password))
   243  	versionHint, ok := versionMap[strings.ToLower(versionHintString)]
   244  	if !ok {
   245  		return nil, fmt.Errorf("invalid version hint: should be either tls12 or tls13")
   246  	}
   247  
   248  	sessionTicketsDisabled := true
   249  	if versionHint == TLS12Hint {
   250  		sessionTicketsDisabled = false
   251  	}
   252  	if len(restlsScript) == 0 {
   253  		restlsScript = defaultRestlsScript
   254  	}
   255  	clientIDPtr, ok := clientIDMap[clientIDStr]
   256  	if !ok {
   257  		clientIDPtr = &HelloChrome_Auto
   258  	}
   259  	clientID := atomic.Pointer[ClientHelloID]{}
   260  	clientID.Store(clientIDPtr)
   261  	return &Config{RestlsSecret: key, VersionHint: versionHint, ServerName: serverName, RestlsScript: parseRecordScript(restlsScript), ClientSessionCache: NewLRUClientSessionCache(100), ClientID: &clientID, SessionTicketsDisabled: sessionTicketsDisabled}, nil
   262  }
   263  
   264  func AnyTrue[T any](vals []T, predicate func(T) bool) bool {
   265  	for _, v := range vals {
   266  		if predicate(v) {
   267  			return true
   268  		}
   269  	}
   270  	return false
   271  }
   272  
   273  // #Restls# End