github.com/Asutorufa/yuhaiin@v0.3.6-0.20240502055049-7984da7023a0/pkg/net/proxy/tls/tls.go (about)

     1  package tls
     2  
     3  import (
     4  	"context"
     5  	"crypto/tls"
     6  	"math/rand/v2"
     7  	"net"
     8  
     9  	"github.com/Asutorufa/yuhaiin/pkg/net/netapi"
    10  	"github.com/Asutorufa/yuhaiin/pkg/protos/config/listener"
    11  	"github.com/Asutorufa/yuhaiin/pkg/protos/node/point"
    12  	"github.com/Asutorufa/yuhaiin/pkg/protos/node/protocol"
    13  )
    14  
    15  type Tls struct {
    16  	netapi.EmptyDispatch
    17  
    18  	tlsConfig []*tls.Config
    19  	dialer    netapi.Proxy
    20  }
    21  
    22  func init() {
    23  	point.RegisterProtocol(NewClient)
    24  }
    25  
    26  func NewClient(c *protocol.Protocol_Tls) point.WrapProxy {
    27  	return func(p netapi.Proxy) (netapi.Proxy, error) {
    28  		var tlsConfigs []*tls.Config
    29  		tls := point.ParseTLSConfig(c.Tls)
    30  		if tls != nil {
    31  			// if !tls.InsecureSkipVerify && tls.ServerName == "" {
    32  			// 	tls.ServerName = c.Simple.GetHost()
    33  			// }
    34  
    35  			tlsConfigs = append(tlsConfigs, tls)
    36  
    37  			if len(c.Tls.ServerNames) > 1 {
    38  				for _, v := range c.Tls.ServerNames[1:] {
    39  					tc := tls.Clone()
    40  					tc.ServerName = v
    41  
    42  					tlsConfigs = append(tlsConfigs, tc)
    43  				}
    44  			}
    45  		}
    46  
    47  		return &Tls{
    48  			tlsConfig: tlsConfigs,
    49  			dialer:    p,
    50  		}, nil
    51  	}
    52  }
    53  
    54  func (t *Tls) Conn(ctx context.Context, addr netapi.Address) (net.Conn, error) {
    55  	c, err := t.dialer.Conn(ctx, addr)
    56  	if err != nil {
    57  		return nil, err
    58  	}
    59  
    60  	length := len(t.tlsConfig)
    61  	if length == 0 {
    62  		return c, nil
    63  	}
    64  
    65  	return tls.Client(c, t.tlsConfig[rand.IntN(length)]), nil
    66  }
    67  
    68  func (t *Tls) PacketConn(ctx context.Context, addr netapi.Address) (net.PacketConn, error) {
    69  	return t.dialer.PacketConn(ctx, addr)
    70  }
    71  
    72  func init() {
    73  	listener.RegisterTransport(NewServer)
    74  }
    75  
    76  func NewServer(c *listener.Transport_Tls) func(netapi.Listener) (netapi.Listener, error) {
    77  	config, err := listener.ParseTLS(c.Tls.Tls)
    78  	if err != nil {
    79  		return listener.ErrorTransportFunc(err)
    80  	}
    81  
    82  	return func(ii netapi.Listener) (netapi.Listener, error) {
    83  		lis, err := ii.Stream(context.TODO())
    84  		if err != nil {
    85  			return nil, err
    86  		}
    87  		return netapi.PatchStream(tls.NewListener(lis, config), ii), nil
    88  	}
    89  }