github.com/AntonOrnatskyi/goproxy@v0.0.0-20190205095733-4526a9fa18b4/services/mux/mux_bridge.go (about)

     1  package mux
     2  
     3  import (
     4  	"bufio"
     5  	"fmt"
     6  	"io"
     7  	logger "log"
     8  	"math/rand"
     9  	"net"
    10  	"runtime/debug"
    11  	"strings"
    12  	"sync"
    13  	"time"
    14  
    15  	srvtransport "github.com/AntonOrnatskyi/goproxy/core/cs/server"
    16  	"github.com/AntonOrnatskyi/goproxy/core/lib/kcpcfg"
    17  	"github.com/AntonOrnatskyi/goproxy/services"
    18  	"github.com/AntonOrnatskyi/goproxy/utils"
    19  	"github.com/AntonOrnatskyi/goproxy/utils/mapx"
    20  	//"github.com/xtaci/smux"
    21  	smux "github.com/hashicorp/yamux"
    22  )
    23  
    24  type MuxBridgeArgs struct {
    25  	CertFile     *string
    26  	KeyFile      *string
    27  	CertBytes    []byte
    28  	KeyBytes     []byte
    29  	Local        *string
    30  	LocalType    *string
    31  	Timeout      *int
    32  	IsCompress   *bool
    33  	KCP          kcpcfg.KCPConfigArgs
    34  	TCPSMethod   *string
    35  	TCPSPassword *string
    36  	TOUMethod    *string
    37  	TOUPassword  *string
    38  }
    39  type MuxBridge struct {
    40  	cfg                MuxBridgeArgs
    41  	clientControlConns mapx.ConcurrentMap
    42  	serverConns        mapx.ConcurrentMap
    43  	router             utils.ClientKeyRouter
    44  	l                  *sync.Mutex
    45  	isStop             bool
    46  	sc                 *srvtransport.ServerChannel
    47  	log                *logger.Logger
    48  }
    49  
    50  func NewMuxBridge() services.Service {
    51  	b := &MuxBridge{
    52  		cfg:                MuxBridgeArgs{},
    53  		clientControlConns: mapx.NewConcurrentMap(),
    54  		serverConns:        mapx.NewConcurrentMap(),
    55  		l:                  &sync.Mutex{},
    56  		isStop:             false,
    57  	}
    58  	b.router = utils.NewClientKeyRouter(&b.clientControlConns, 50000)
    59  	return b
    60  }
    61  
    62  func (s *MuxBridge) InitService() (err error) {
    63  	return
    64  }
    65  func (s *MuxBridge) CheckArgs() (err error) {
    66  	if *s.cfg.CertFile == "" || *s.cfg.KeyFile == "" {
    67  		err = fmt.Errorf("cert and key file required")
    68  		return
    69  	}
    70  	if *s.cfg.LocalType == "tls" {
    71  		s.cfg.CertBytes, s.cfg.KeyBytes, err = utils.TlsBytes(*s.cfg.CertFile, *s.cfg.KeyFile)
    72  		if err != nil {
    73  			return
    74  		}
    75  	}
    76  	return
    77  }
    78  func (s *MuxBridge) StopService() {
    79  	defer func() {
    80  		e := recover()
    81  		if e != nil {
    82  			s.log.Printf("stop bridge service crashed,%s", e)
    83  		} else {
    84  			s.log.Printf("service bridge stopped")
    85  		}
    86  		s.cfg = MuxBridgeArgs{}
    87  		s.clientControlConns = nil
    88  		s.l = nil
    89  		s.log = nil
    90  		s.router = utils.ClientKeyRouter{}
    91  		s.sc = nil
    92  		s.serverConns = nil
    93  		s = nil
    94  	}()
    95  	s.isStop = true
    96  	if s.sc != nil && (*s.sc).Listener != nil {
    97  		(*(*s.sc).Listener).Close()
    98  	}
    99  	for _, g := range s.clientControlConns.Items() {
   100  		for _, session := range g.(*mapx.ConcurrentMap).Items() {
   101  			(session.(*smux.Session)).Close()
   102  		}
   103  	}
   104  	for _, c := range s.serverConns.Items() {
   105  		(*c.(*net.Conn)).Close()
   106  	}
   107  }
   108  func (s *MuxBridge) Start(args interface{}, log *logger.Logger) (err error) {
   109  	s.log = log
   110  	s.cfg = args.(MuxBridgeArgs)
   111  	if err = s.CheckArgs(); err != nil {
   112  		return
   113  	}
   114  	if err = s.InitService(); err != nil {
   115  		return
   116  	}
   117  
   118  	sc := srvtransport.NewServerChannelHost(*s.cfg.Local, s.log)
   119  	if *s.cfg.LocalType == "tcp" {
   120  		err = sc.ListenTCP(s.handler)
   121  	} else if *s.cfg.LocalType == "tls" {
   122  		err = sc.ListenTLS(s.cfg.CertBytes, s.cfg.KeyBytes, nil, s.handler)
   123  	} else if *s.cfg.LocalType == "kcp" {
   124  		err = sc.ListenKCP(s.cfg.KCP, s.handler, s.log)
   125  	} else if *s.cfg.LocalType == "tcps" {
   126  		err = sc.ListenTCPS(*s.cfg.TCPSMethod, *s.cfg.TCPSPassword, false, s.handler)
   127  	} else if *s.cfg.LocalType == "tou" {
   128  		err = sc.ListenTOU(*s.cfg.TOUMethod, *s.cfg.TOUPassword, false, s.handler)
   129  	}
   130  	if err != nil {
   131  		return
   132  	}
   133  	s.sc = &sc
   134  	if *s.cfg.LocalType == "tou" {
   135  		s.log.Printf("%s bridge on %s", *s.cfg.LocalType, sc.UDPListener.LocalAddr())
   136  	} else {
   137  		s.log.Printf("%s bridge on %s", *s.cfg.LocalType, (*sc.Listener).Addr())
   138  	}
   139  	return
   140  }
   141  func (s *MuxBridge) Clean() {
   142  	s.StopService()
   143  }
   144  func (s *MuxBridge) handler(inConn net.Conn) {
   145  	reader := bufio.NewReader(inConn)
   146  
   147  	var err error
   148  	var connType uint8
   149  	var key string
   150  	inConn.SetDeadline(time.Now().Add(time.Millisecond * time.Duration(*s.cfg.Timeout)))
   151  	err = utils.ReadPacket(reader, &connType, &key)
   152  	inConn.SetDeadline(time.Time{})
   153  	if err != nil {
   154  		s.log.Printf("read error,ERR:%s", err)
   155  		return
   156  	}
   157  	switch connType {
   158  	case CONN_SERVER:
   159  		var serverID string
   160  		inAddr := inConn.RemoteAddr().String()
   161  		inConn.SetDeadline(time.Now().Add(time.Millisecond * time.Duration(*s.cfg.Timeout)))
   162  		err = utils.ReadPacketData(reader, &serverID)
   163  		inConn.SetDeadline(time.Time{})
   164  		if err != nil {
   165  			s.log.Printf("read error,ERR:%s", err)
   166  			return
   167  		}
   168  		s.log.Printf("server connection %s %s connected", serverID, key)
   169  		if c, ok := s.serverConns.Get(inAddr); ok {
   170  			(*c.(*net.Conn)).Close()
   171  		}
   172  		s.serverConns.Set(inAddr, &inConn)
   173  		session, err := smux.Server(inConn, nil)
   174  		if err != nil {
   175  			utils.CloseConn(&inConn)
   176  			s.log.Printf("server session error,ERR:%s", err)
   177  			return
   178  		}
   179  		for {
   180  			if s.isStop {
   181  				return
   182  			}
   183  			stream, err := session.AcceptStream()
   184  			if err != nil {
   185  				session.Close()
   186  				utils.CloseConn(&inConn)
   187  				s.serverConns.Remove(inAddr)
   188  				s.log.Printf("server connection %s %s released", serverID, key)
   189  				return
   190  			}
   191  			go func() {
   192  				defer func() {
   193  					if e := recover(); e != nil {
   194  						s.log.Printf("bridge callback crashed,err: %s", e)
   195  					}
   196  				}()
   197  				s.callback(stream, serverID, key)
   198  			}()
   199  		}
   200  	case CONN_CLIENT:
   201  		s.log.Printf("client connection %s connected", key)
   202  		session, err := smux.Client(inConn, nil)
   203  		if err != nil {
   204  			utils.CloseConn(&inConn)
   205  			s.log.Printf("client session error,ERR:%s", err)
   206  			return
   207  		}
   208  		keyInfo := strings.Split(key, "-")
   209  		if len(keyInfo) != 2 {
   210  			utils.CloseConn(&inConn)
   211  			s.log.Printf("client key format error,key:%s", key)
   212  			return
   213  		}
   214  		groupKey := keyInfo[0]
   215  		index := keyInfo[1]
   216  		s.l.Lock()
   217  		defer s.l.Unlock()
   218  		var group *mapx.ConcurrentMap
   219  		if !s.clientControlConns.Has(groupKey) {
   220  			_g := mapx.NewConcurrentMap()
   221  			group = &_g
   222  			s.clientControlConns.Set(groupKey, group)
   223  			//s.log.Printf("init client session group %s", groupKey)
   224  		} else {
   225  			_group, _ := s.clientControlConns.Get(groupKey)
   226  			group = _group.(*mapx.ConcurrentMap)
   227  		}
   228  		if v, ok := group.Get(index); ok {
   229  			v.(*smux.Session).Close()
   230  		}
   231  		group.Set(index, session)
   232  		//s.log.Printf("set client session %s to group %s,grouplen:%d", index, groupKey, group.Count())
   233  		go func() {
   234  			defer func() {
   235  				if e := recover(); e != nil {
   236  					fmt.Printf("crashed, err: %s\nstack:%s", e, string(debug.Stack()))
   237  				}
   238  			}()
   239  			for {
   240  				if s.isStop {
   241  					return
   242  				}
   243  				if session.IsClosed() {
   244  					s.l.Lock()
   245  					defer s.l.Unlock()
   246  					if sess, ok := group.Get(index); ok && sess.(*smux.Session).IsClosed() {
   247  						group.Remove(index)
   248  						//s.log.Printf("client session %s removed from group %s, grouplen:%d", key, groupKey, group.Count())
   249  						s.log.Printf("client connection %s released", key)
   250  					}
   251  					if group.IsEmpty() {
   252  						s.clientControlConns.Remove(groupKey)
   253  						//s.log.Printf("client session group %s removed", groupKey)
   254  					}
   255  					break
   256  				}
   257  				time.Sleep(time.Second * 5)
   258  			}
   259  		}()
   260  		//s.log.Printf("set client session,key: %s", key)
   261  	}
   262  
   263  }
   264  func (s *MuxBridge) callback(inConn net.Conn, serverID, key string) {
   265  	try := 20
   266  	for {
   267  		if s.isStop {
   268  			return
   269  		}
   270  		try--
   271  		if try == 0 {
   272  			break
   273  		}
   274  		if key == "*" {
   275  			key = s.router.GetKey()
   276  		}
   277  		//s.log.Printf("server get client session %s", key)
   278  		_group, ok := s.clientControlConns.Get(key)
   279  		if !ok {
   280  			s.log.Printf("client %s session not exists for server stream %s, retrying...", key, serverID)
   281  			time.Sleep(time.Second * 3)
   282  			continue
   283  		}
   284  		group := _group.(*mapx.ConcurrentMap)
   285  		keys := []string{}
   286  		group.IterCb(func(key string, v interface{}) {
   287  			keys = append(keys, key)
   288  		})
   289  		keysLen := len(keys)
   290  		//s.log.Printf("client session %s , len:%d , keysLen: %d", key, group.Count(), keysLen)
   291  		i := 0
   292  		if keysLen > 0 {
   293  			i = rand.Intn(keysLen)
   294  		} else {
   295  			s.log.Printf("client %s session empty for server stream %s, retrying...", key, serverID)
   296  			time.Sleep(time.Second * 3)
   297  			continue
   298  		}
   299  		index := keys[i]
   300  		s.log.Printf("select client : %s-%s", key, index)
   301  		session, _ := group.Get(index)
   302  		//session.(*smux.Session).SetDeadline(time.Now().Add(time.Millisecond * time.Duration(*s.cfg.Timeout)))
   303  		stream, err := session.(*smux.Session).OpenStream()
   304  		//session.(*smux.Session).SetDeadline(time.Time{})
   305  		if err != nil {
   306  			s.log.Printf("%s client session open stream %s fail, err: %s, retrying...", key, serverID, err)
   307  			time.Sleep(time.Second * 3)
   308  			continue
   309  		} else {
   310  			s.log.Printf("stream %s -> %s created", serverID, key)
   311  			die1 := make(chan bool, 1)
   312  			die2 := make(chan bool, 1)
   313  			go func() {
   314  				defer func() {
   315  					if e := recover(); e != nil {
   316  						fmt.Printf("crashed, err: %s\nstack:%s", e, string(debug.Stack()))
   317  					}
   318  				}()
   319  				io.Copy(stream, inConn)
   320  				die1 <- true
   321  			}()
   322  			go func() {
   323  				defer func() {
   324  					if e := recover(); e != nil {
   325  						fmt.Printf("crashed, err: %s\nstack:%s", e, string(debug.Stack()))
   326  					}
   327  				}()
   328  				io.Copy(inConn, stream)
   329  				die2 <- true
   330  			}()
   331  			select {
   332  			case <-die1:
   333  			case <-die2:
   334  			}
   335  			stream.Close()
   336  			inConn.Close()
   337  			s.log.Printf("%s server %s stream released", key, serverID)
   338  			break
   339  		}
   340  	}
   341  
   342  }