github.com/ipfans/trojan-go@v0.11.0/tunnel/mux/client.go (about)

     1  package mux
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"math/rand"
     7  	"sync"
     8  	"time"
     9  
    10  	"github.com/xtaci/smux"
    11  
    12  	"github.com/ipfans/trojan-go/common"
    13  	"github.com/ipfans/trojan-go/config"
    14  	"github.com/ipfans/trojan-go/log"
    15  	"github.com/ipfans/trojan-go/tunnel"
    16  )
    17  
    18  type muxID uint32
    19  
    20  func generateMuxID() muxID {
    21  	return muxID(rand.Uint32())
    22  }
    23  
    24  type smuxClientInfo struct {
    25  	id             muxID
    26  	client         *smux.Session
    27  	lastActiveTime time.Time
    28  	underlayConn   tunnel.Conn
    29  }
    30  
    31  // Client is a smux client
    32  type Client struct {
    33  	clientPoolLock sync.Mutex
    34  	clientPool     map[muxID]*smuxClientInfo
    35  	underlay       tunnel.Client
    36  	concurrency    int
    37  	timeout        time.Duration
    38  	ctx            context.Context
    39  	cancel         context.CancelFunc
    40  }
    41  
    42  func (c *Client) Close() error {
    43  	c.cancel()
    44  	c.clientPoolLock.Lock()
    45  	defer c.clientPoolLock.Unlock()
    46  	for id, info := range c.clientPool {
    47  		info.client.Close()
    48  		log.Debug("mux client", id, "closed")
    49  	}
    50  	return nil
    51  }
    52  
    53  func (c *Client) cleanLoop() {
    54  	var checkDuration time.Duration
    55  	if c.timeout <= 0 {
    56  		checkDuration = time.Second * 10
    57  		log.Warn("negative mux timeout")
    58  	} else {
    59  		checkDuration = c.timeout / 4
    60  	}
    61  	log.Debug("check duration:", checkDuration.Seconds(), "s")
    62  	for {
    63  		select {
    64  		case <-time.After(checkDuration):
    65  			c.clientPoolLock.Lock()
    66  			for id, info := range c.clientPool {
    67  				if info.client.IsClosed() {
    68  					info.client.Close()
    69  					info.underlayConn.Close()
    70  					delete(c.clientPool, id)
    71  					log.Info("mux client", id, "is dead")
    72  				} else if info.client.NumStreams() == 0 && time.Since(info.lastActiveTime) > c.timeout {
    73  					info.client.Close()
    74  					info.underlayConn.Close()
    75  					delete(c.clientPool, id)
    76  					log.Info("mux client", id, "is closed due to inactivity")
    77  				}
    78  			}
    79  			log.Debug("current mux clients: ", len(c.clientPool))
    80  			for id, info := range c.clientPool {
    81  				log.Debug(fmt.Sprintf("  - %x: %d/%d", id, info.client.NumStreams(), c.concurrency))
    82  			}
    83  			c.clientPoolLock.Unlock()
    84  		case <-c.ctx.Done():
    85  			log.Debug("shutting down mux cleaner..")
    86  			c.clientPoolLock.Lock()
    87  			for id, info := range c.clientPool {
    88  				info.client.Close()
    89  				info.underlayConn.Close()
    90  				delete(c.clientPool, id)
    91  				log.Debug("mux client", id, "closed")
    92  			}
    93  			c.clientPoolLock.Unlock()
    94  			return
    95  		}
    96  	}
    97  }
    98  
    99  func (c *Client) newMuxClient() (*smuxClientInfo, error) {
   100  	// The mutex should be locked when this function is called
   101  	id := generateMuxID()
   102  	if _, found := c.clientPool[id]; found {
   103  		return nil, common.NewError("duplicated id")
   104  	}
   105  
   106  	fakeAddr := &tunnel.Address{
   107  		DomainName:  "MUX_CONN",
   108  		AddressType: tunnel.DomainName,
   109  	}
   110  	conn, err := c.underlay.DialConn(fakeAddr, &Tunnel{})
   111  	if err != nil {
   112  		return nil, common.NewError("mux failed to dial").Base(err)
   113  	}
   114  	conn = newStickyConn(conn)
   115  
   116  	smuxConfig := smux.DefaultConfig()
   117  	// smuxConfig.KeepAliveDisabled = true
   118  	client, _ := smux.Client(conn, smuxConfig)
   119  	info := &smuxClientInfo{
   120  		client:         client,
   121  		underlayConn:   conn,
   122  		id:             id,
   123  		lastActiveTime: time.Now(),
   124  	}
   125  	c.clientPool[id] = info
   126  	return info, nil
   127  }
   128  
   129  func (c *Client) DialConn(*tunnel.Address, tunnel.Tunnel) (tunnel.Conn, error) {
   130  	createNewConn := func(info *smuxClientInfo) (tunnel.Conn, error) {
   131  		rwc, err := info.client.Open()
   132  		info.lastActiveTime = time.Now()
   133  		if err != nil {
   134  			info.underlayConn.Close()
   135  			info.client.Close()
   136  			delete(c.clientPool, info.id)
   137  			return nil, common.NewError("mux failed to open stream from client").Base(err)
   138  		}
   139  		return &Conn{
   140  			rwc:  rwc,
   141  			Conn: info.underlayConn,
   142  		}, nil
   143  	}
   144  
   145  	c.clientPoolLock.Lock()
   146  	defer c.clientPoolLock.Unlock()
   147  	for _, info := range c.clientPool {
   148  		if info.client.IsClosed() {
   149  			delete(c.clientPool, info.id)
   150  			log.Info(fmt.Sprintf("Mux client %x is closed", info.id))
   151  			continue
   152  		}
   153  		if info.client.NumStreams() < c.concurrency || c.concurrency <= 0 {
   154  			return createNewConn(info)
   155  		}
   156  	}
   157  
   158  	info, err := c.newMuxClient()
   159  	if err != nil {
   160  		return nil, common.NewError("no available mux client found").Base(err)
   161  	}
   162  	return createNewConn(info)
   163  }
   164  
   165  func (c *Client) DialPacket(tunnel.Tunnel) (tunnel.PacketConn, error) {
   166  	panic("not supported")
   167  }
   168  
   169  func NewClient(ctx context.Context, underlay tunnel.Client) (*Client, error) {
   170  	clientConfig := config.FromContext(ctx, Name).(*Config)
   171  	ctx, cancel := context.WithCancel(ctx)
   172  	client := &Client{
   173  		underlay:    underlay,
   174  		concurrency: clientConfig.Mux.Concurrency,
   175  		timeout:     time.Duration(clientConfig.Mux.IdleTimeout) * time.Second,
   176  		ctx:         ctx,
   177  		cancel:      cancel,
   178  		clientPool:  make(map[muxID]*smuxClientInfo),
   179  	}
   180  	go client.cleanLoop()
   181  	log.Debug("mux client created")
   182  	return client, nil
   183  }