github.com/Asutorufa/yuhaiin@v0.3.6-0.20240502055049-7984da7023a0/pkg/net/proxy/websocket/client.go (about)

     1  package websocket
     2  
     3  import (
     4  	"context"
     5  	"encoding/base64"
     6  	"fmt"
     7  	"math/rand/v2"
     8  	"net"
     9  	"net/http"
    10  	"sync"
    11  	"time"
    12  
    13  	"github.com/Asutorufa/yuhaiin/pkg/net/netapi"
    14  	websocket "github.com/Asutorufa/yuhaiin/pkg/net/proxy/websocket/x"
    15  	"github.com/Asutorufa/yuhaiin/pkg/protos/node/point"
    16  	"github.com/Asutorufa/yuhaiin/pkg/protos/node/protocol"
    17  	ynet "github.com/Asutorufa/yuhaiin/pkg/utils/net"
    18  )
    19  
    20  type client struct {
    21  	wsConfig *websocket.Config
    22  	netapi.Proxy
    23  }
    24  
    25  func init() {
    26  	point.RegisterProtocol(NewClient)
    27  }
    28  
    29  func NewClient(cf *protocol.Protocol_Websocket) point.WrapProxy {
    30  	return func(dialer netapi.Proxy) (netapi.Proxy, error) {
    31  
    32  		return &client{
    33  			&websocket.Config{
    34  				Host: cf.Websocket.Host,
    35  				Path: getNormalizedPath(cf.Websocket.Path),
    36  			},
    37  			dialer,
    38  		}, nil
    39  	}
    40  }
    41  
    42  func (c *client) Conn(ctx context.Context, h netapi.Address) (net.Conn, error) {
    43  	conn, err := c.Proxy.Conn(ctx, h)
    44  	if err != nil {
    45  		return nil, fmt.Errorf("websocket dial failed: %w", err)
    46  	}
    47  
    48  	ctx, cancel := context.WithCancel(context.TODO())
    49  
    50  	return &earlyConn{config: c.wsConfig, Conn: conn, handshakeCtx: ctx, handshakeDone: cancel}, nil
    51  }
    52  
    53  func getNormalizedPath(path string) string {
    54  	if path == "" {
    55  		return "/"
    56  	}
    57  	if path[0] != '/' {
    58  		return "/" + path
    59  	}
    60  	return path
    61  }
    62  
    63  type earlyConn struct {
    64  	handclasp bool
    65  
    66  	net.Conn
    67  	config *websocket.Config
    68  
    69  	handshakeMu   sync.Mutex
    70  	handshakeCtx  context.Context
    71  	handshakeDone func()
    72  
    73  	deadline *time.Timer
    74  }
    75  
    76  func (e *earlyConn) Read(b []byte) (int, error) {
    77  	if !e.handclasp {
    78  		<-e.handshakeCtx.Done()
    79  	}
    80  
    81  	return e.Conn.Read(b)
    82  }
    83  
    84  func (e *earlyConn) Close() error {
    85  	e.handshakeDone()
    86  	return e.Conn.Close()
    87  }
    88  
    89  func (e *earlyConn) Write(b []byte) (int, error) {
    90  	if e.handclasp {
    91  		return e.Conn.Write(b)
    92  	}
    93  
    94  	return e.handshake(b)
    95  }
    96  
    97  func (e *earlyConn) handshake(b []byte) (int, error) {
    98  	e.handshakeMu.Lock()
    99  	defer e.handshakeMu.Unlock()
   100  
   101  	if e.handclasp {
   102  		return e.Conn.Write(b)
   103  	}
   104  
   105  	defer e.handshakeDone()
   106  
   107  	var SecWebSocketKey string
   108  	if len(b) != 0 && len(b) <= 2048 {
   109  		SecWebSocketKey = base64.RawStdEncoding.EncodeToString(b)
   110  	}
   111  
   112  	var earlyDataSupport bool
   113  	conn, err := e.config.NewClient(SecWebSocketKey, e.Conn,
   114  		func(r *http.Request) error {
   115  			r.Header.Set("User-Agent", ynet.UserAgents[rand.IntN(ynet.UserAgentLength)])
   116  			r.Header.Set("Sec-Fetch-Dest", "websocket")
   117  			r.Header.Set("Sec-Fetch-Mode", "websocket")
   118  			r.Header.Set("Pragma", "no-cache")
   119  			if SecWebSocketKey != "" {
   120  				r.Header.Set("early_data", "base64")
   121  			}
   122  			return nil
   123  		},
   124  		func(r *http.Response) error {
   125  			earlyDataSupport = r.Header.Get("early_data") == "true"
   126  			return nil
   127  		})
   128  	if err != nil {
   129  		return 0, fmt.Errorf("websocket handshake failed: %w", err)
   130  	}
   131  	e.Conn = conn
   132  
   133  	e.handclasp = true
   134  
   135  	if !earlyDataSupport {
   136  		return conn.Write(b)
   137  	}
   138  
   139  	return len(b), nil
   140  }
   141  
   142  func (c *earlyConn) SetDeadline(t time.Time) error {
   143  	c.setDeadline(t)
   144  	return c.Conn.SetDeadline(t)
   145  }
   146  
   147  func (c *earlyConn) setDeadline(t time.Time) {
   148  	if c.deadline == nil {
   149  		if !t.IsZero() {
   150  			c.deadline = time.AfterFunc(time.Until(t), func() { c.handshakeDone() })
   151  		}
   152  		return
   153  	}
   154  
   155  	if t.IsZero() {
   156  		c.deadline.Stop()
   157  	} else {
   158  		c.deadline.Reset(time.Until(t))
   159  	}
   160  }
   161  
   162  func (c *earlyConn) SetReadDeadline(t time.Time) error {
   163  	c.setDeadline(t)
   164  	return c.Conn.SetReadDeadline(t)
   165  }
   166  
   167  func (c *earlyConn) SetWriteDeadline(t time.Time) error {
   168  	c.setDeadline(t)
   169  	return c.Conn.SetWriteDeadline(t)
   170  }