github.com/AntonOrnatskyi/goproxy@v0.0.0-20190205095733-4526a9fa18b4/services/mux/mux_client.go (about)

     1  package mux
     2  
     3  import (
     4  	"crypto/tls"
     5  	"fmt"
     6  	"io"
     7  	logger "log"
     8  	"net"
     9  	"runtime/debug"
    10  	"strings"
    11  	"sync"
    12  	"time"
    13  
    14  	"github.com/golang/snappy"
    15  	clienttransport "github.com/AntonOrnatskyi/goproxy/core/cs/client"
    16  	"github.com/AntonOrnatskyi/goproxy/core/lib/kcpcfg"
    17  	encryptconn "github.com/AntonOrnatskyi/goproxy/core/lib/transport/encrypt"
    18  	"github.com/AntonOrnatskyi/goproxy/services"
    19  	"github.com/AntonOrnatskyi/goproxy/utils"
    20  	"github.com/AntonOrnatskyi/goproxy/utils/jumper"
    21  	"github.com/AntonOrnatskyi/goproxy/utils/mapx"
    22  	//"github.com/xtaci/smux"
    23  	smux "github.com/hashicorp/yamux"
    24  )
    25  
    26  type MuxClientArgs struct {
    27  	Parent       *string
    28  	ParentType   *string
    29  	CertFile     *string
    30  	KeyFile      *string
    31  	CertBytes    []byte
    32  	KeyBytes     []byte
    33  	Key          *string
    34  	Timeout      *int
    35  	IsCompress   *bool
    36  	SessionCount *int
    37  	KCP          kcpcfg.KCPConfigArgs
    38  	Jumper       *string
    39  	TCPSMethod   *string
    40  	TCPSPassword *string
    41  	TOUMethod    *string
    42  	TOUPassword  *string
    43  }
    44  type ClientUDPConnItem struct {
    45  	conn      *smux.Stream
    46  	isActive  bool
    47  	touchtime int64
    48  	srcAddr   *net.UDPAddr
    49  	localAddr *net.UDPAddr
    50  	udpConn   *net.UDPConn
    51  	connid    string
    52  }
    53  type MuxClient struct {
    54  	cfg      MuxClientArgs
    55  	isStop   bool
    56  	sessions mapx.ConcurrentMap
    57  	log      *logger.Logger
    58  	jumper   *jumper.Jumper
    59  	udpConns mapx.ConcurrentMap
    60  }
    61  
    62  func NewMuxClient() services.Service {
    63  	return &MuxClient{
    64  		cfg:      MuxClientArgs{},
    65  		isStop:   false,
    66  		sessions: mapx.NewConcurrentMap(),
    67  		udpConns: mapx.NewConcurrentMap(),
    68  	}
    69  }
    70  
    71  func (s *MuxClient) InitService() (err error) {
    72  	s.UDPGCDeamon()
    73  	return
    74  }
    75  
    76  func (s *MuxClient) CheckArgs() (err error) {
    77  	if *s.cfg.Parent != "" {
    78  		s.log.Printf("use tls parent %s", *s.cfg.Parent)
    79  	} else {
    80  		err = fmt.Errorf("parent required")
    81  		return
    82  	}
    83  	if *s.cfg.CertFile == "" || *s.cfg.KeyFile == "" {
    84  		err = fmt.Errorf("cert and key file required")
    85  		return
    86  	}
    87  	if *s.cfg.ParentType == "tls" {
    88  		s.cfg.CertBytes, s.cfg.KeyBytes, err = utils.TlsBytes(*s.cfg.CertFile, *s.cfg.KeyFile)
    89  		if err != nil {
    90  			return
    91  		}
    92  	}
    93  	if *s.cfg.Jumper != "" {
    94  		if *s.cfg.ParentType != "tls" && *s.cfg.ParentType != "tcp" {
    95  			err = fmt.Errorf("jumper only worked of -T is tls or tcp")
    96  			return
    97  		}
    98  		var j jumper.Jumper
    99  		j, err = jumper.New(*s.cfg.Jumper, time.Millisecond*time.Duration(*s.cfg.Timeout))
   100  		if err != nil {
   101  			err = fmt.Errorf("parse jumper fail, err %s", err)
   102  			return
   103  		}
   104  		s.jumper = &j
   105  	}
   106  	return
   107  }
   108  func (s *MuxClient) StopService() {
   109  	defer func() {
   110  		e := recover()
   111  		if e != nil {
   112  			s.log.Printf("stop client service crashed,%s", e)
   113  		} else {
   114  			s.log.Printf("service client stopped")
   115  		}
   116  		s.cfg = MuxClientArgs{}
   117  		s.jumper = nil
   118  		s.log = nil
   119  		s.sessions = nil
   120  		s.udpConns = nil
   121  		s = nil
   122  	}()
   123  	s.isStop = true
   124  	for _, sess := range s.sessions.Items() {
   125  		sess.(*smux.Session).Close()
   126  	}
   127  }
   128  func (s *MuxClient) Start(args interface{}, log *logger.Logger) (err error) {
   129  	s.log = log
   130  	s.cfg = args.(MuxClientArgs)
   131  	if err = s.CheckArgs(); err != nil {
   132  		return
   133  	}
   134  	if err = s.InitService(); err != nil {
   135  		return
   136  	}
   137  	s.log.Printf("client started")
   138  	count := 1
   139  	if *s.cfg.SessionCount > 0 {
   140  		count = *s.cfg.SessionCount
   141  	}
   142  	for i := 1; i <= count; i++ {
   143  		key := fmt.Sprintf("worker[%d]", i)
   144  		s.log.Printf("session %s started", key)
   145  		go func(i int) {
   146  			defer func() {
   147  				e := recover()
   148  				if e != nil {
   149  					s.log.Printf("session worker crashed: %s\nstack:%s", e, string(debug.Stack()))
   150  				}
   151  			}()
   152  			for {
   153  				if s.isStop {
   154  					return
   155  				}
   156  				conn, err := s.getParentConn()
   157  				if err != nil {
   158  					s.log.Printf("connection err: %s, retrying...", err)
   159  					time.Sleep(time.Second * 3)
   160  					continue
   161  				}
   162  				conn.SetDeadline(time.Now().Add(time.Millisecond * time.Duration(*s.cfg.Timeout)))
   163  				g := sync.WaitGroup{}
   164  				g.Add(1)
   165  				go func() {
   166  					defer func() {
   167  						_ = recover()
   168  						g.Done()
   169  					}()
   170  					_, err = conn.Write(utils.BuildPacket(CONN_CLIENT, fmt.Sprintf("%s-%d", *s.cfg.Key, i)))
   171  				}()
   172  				g.Wait()
   173  				conn.SetDeadline(time.Time{})
   174  				if err != nil {
   175  					conn.Close()
   176  					s.log.Printf("connection err: %s, retrying...", err)
   177  					time.Sleep(time.Second * 3)
   178  					continue
   179  				}
   180  				session, err := smux.Server(conn, nil)
   181  				if err != nil {
   182  					s.log.Printf("session err: %s, retrying...", err)
   183  					conn.Close()
   184  					time.Sleep(time.Second * 3)
   185  					continue
   186  				}
   187  				if _sess, ok := s.sessions.Get(key); ok {
   188  					_sess.(*smux.Session).Close()
   189  				}
   190  				s.sessions.Set(key, session)
   191  				for {
   192  					if s.isStop {
   193  						return
   194  					}
   195  					stream, err := session.AcceptStream()
   196  					if err != nil {
   197  						s.log.Printf("accept stream err: %s, retrying...", err)
   198  						session.Close()
   199  						time.Sleep(time.Second * 3)
   200  						break
   201  					}
   202  					go func() {
   203  						defer func() {
   204  							e := recover()
   205  							if e != nil {
   206  								s.log.Printf("stream handler crashed: %s", e)
   207  							}
   208  						}()
   209  						var ID, clientLocalAddr, serverID string
   210  						stream.SetDeadline(time.Now().Add(time.Millisecond * time.Duration(*s.cfg.Timeout)))
   211  						err = utils.ReadPacketData(stream, &ID, &clientLocalAddr, &serverID)
   212  						stream.SetDeadline(time.Time{})
   213  						if err != nil {
   214  							s.log.Printf("read stream signal err: %s", err)
   215  							stream.Close()
   216  							return
   217  						}
   218  						//s.log.Printf("worker[%d] signal revecived,server %s stream %s %s", i, serverID, ID, clientLocalAddr)
   219  						protocol := clientLocalAddr[:3]
   220  						localAddr := clientLocalAddr[4:]
   221  						if protocol == "udp" {
   222  							s.ServeUDP(stream, localAddr, ID)
   223  						} else {
   224  							s.ServeConn(stream, localAddr, ID)
   225  						}
   226  					}()
   227  				}
   228  			}
   229  		}(i)
   230  	}
   231  	return
   232  }
   233  func (s *MuxClient) Clean() {
   234  	s.StopService()
   235  }
   236  func (s *MuxClient) getParentConn() (conn net.Conn, err error) {
   237  	if *s.cfg.ParentType == "tls" {
   238  		if s.jumper == nil {
   239  			var _conn tls.Conn
   240  			_conn, err = clienttransport.TlsConnectHost(*s.cfg.Parent, *s.cfg.Timeout, s.cfg.CertBytes, s.cfg.KeyBytes, nil)
   241  			if err == nil {
   242  				conn = net.Conn(&_conn)
   243  			}
   244  		} else {
   245  			conf, e := utils.TlsConfig(s.cfg.CertBytes, s.cfg.KeyBytes, nil)
   246  			if e != nil {
   247  				return nil, e
   248  			}
   249  			var _c net.Conn
   250  			_c, err = s.jumper.Dial(*s.cfg.Parent, time.Millisecond*time.Duration(*s.cfg.Timeout))
   251  			if err == nil {
   252  				conn = net.Conn(tls.Client(_c, conf))
   253  			}
   254  		}
   255  
   256  	} else if *s.cfg.ParentType == "kcp" {
   257  		conn, err = clienttransport.KCPConnectHost(*s.cfg.Parent, s.cfg.KCP)
   258  	} else if *s.cfg.ParentType == "tcps" {
   259  		if s.jumper == nil {
   260  			conn, err = clienttransport.TCPSConnectHost(*s.cfg.Parent, *s.cfg.TCPSMethod, *s.cfg.TCPSPassword, false, *s.cfg.Timeout)
   261  		} else {
   262  			conn, err = s.jumper.Dial(*s.cfg.Parent, time.Millisecond*time.Duration(*s.cfg.Timeout))
   263  			if err == nil {
   264  				conn, err = encryptconn.NewConn(conn, *s.cfg.TCPSMethod, *s.cfg.TCPSPassword)
   265  			}
   266  		}
   267  
   268  	} else if *s.cfg.ParentType == "tou" {
   269  		conn, err = clienttransport.TOUConnectHost(*s.cfg.Parent, *s.cfg.TCPSMethod, *s.cfg.TCPSPassword, false, *s.cfg.Timeout)
   270  	} else {
   271  		if s.jumper == nil {
   272  			conn, err = clienttransport.TCPConnectHost(*s.cfg.Parent, *s.cfg.Timeout)
   273  		} else {
   274  			conn, err = s.jumper.Dial(*s.cfg.Parent, time.Millisecond*time.Duration(*s.cfg.Timeout))
   275  		}
   276  	}
   277  	return
   278  }
   279  func (s *MuxClient) ServeUDP(inConn *smux.Stream, localAddr, ID string) {
   280  	var item *ClientUDPConnItem
   281  	var body []byte
   282  	var err error
   283  	srcAddr := ""
   284  	defer func() {
   285  		if item != nil {
   286  			(*item).conn.Close()
   287  			(*item).udpConn.Close()
   288  			s.udpConns.Remove(srcAddr)
   289  			inConn.Close()
   290  		}
   291  	}()
   292  	for {
   293  		if s.isStop {
   294  			return
   295  		}
   296  		srcAddr, body, err = utils.ReadUDPPacket(inConn)
   297  		if err != nil {
   298  			if strings.Contains(err.Error(), "n != int(") {
   299  				continue
   300  			}
   301  			if !utils.IsNetDeadlineErr(err) && err != io.EOF {
   302  				s.log.Printf("udp packet revecived from bridge fail, err: %s", err)
   303  			}
   304  			return
   305  		}
   306  		if v, ok := s.udpConns.Get(srcAddr); !ok {
   307  			_srcAddr, _ := net.ResolveUDPAddr("udp", srcAddr)
   308  			zeroAddr, _ := net.ResolveUDPAddr("udp", ":")
   309  			_localAddr, _ := net.ResolveUDPAddr("udp", localAddr)
   310  			c, err := net.DialUDP("udp", zeroAddr, _localAddr)
   311  			if err != nil {
   312  				s.log.Printf("create local udp conn fail, err : %s", err)
   313  				inConn.Close()
   314  				return
   315  			}
   316  			item = &ClientUDPConnItem{
   317  				conn:      inConn,
   318  				srcAddr:   _srcAddr,
   319  				localAddr: _localAddr,
   320  				udpConn:   c,
   321  				connid:    ID,
   322  			}
   323  			s.udpConns.Set(srcAddr, item)
   324  			s.UDPRevecive(srcAddr, ID)
   325  		} else {
   326  			item = v.(*ClientUDPConnItem)
   327  		}
   328  		(*item).touchtime = time.Now().Unix()
   329  		go func() {
   330  			defer func() { _ = recover() }()
   331  			(*item).udpConn.Write(body)
   332  		}()
   333  	}
   334  }
   335  func (s *MuxClient) UDPRevecive(key, ID string) {
   336  	go func() {
   337  		defer func() {
   338  			if e := recover(); e != nil {
   339  				fmt.Printf("crashed, err: %s\nstack:%s", e, string(debug.Stack()))
   340  			}
   341  		}()
   342  		s.log.Printf("udp conn %s connected", ID)
   343  		v, ok := s.udpConns.Get(key)
   344  		if !ok {
   345  			s.log.Printf("[warn] udp conn not exists for %s, connid : %s", key, ID)
   346  			return
   347  		}
   348  		cui := v.(*ClientUDPConnItem)
   349  		buf := utils.LeakyBuffer.Get()
   350  		defer func() {
   351  			utils.LeakyBuffer.Put(buf)
   352  			cui.conn.Close()
   353  			cui.udpConn.Close()
   354  			s.udpConns.Remove(key)
   355  			s.log.Printf("udp conn %s released", ID)
   356  		}()
   357  		for {
   358  			n, err := cui.udpConn.Read(buf)
   359  			if err != nil {
   360  				if !utils.IsNetClosedErr(err) {
   361  					s.log.Printf("udp conn read udp packet fail , err: %s ", err)
   362  				}
   363  				return
   364  			}
   365  			cui.touchtime = time.Now().Unix()
   366  			go func() {
   367  				defer func() {
   368  					if e := recover(); e != nil {
   369  						fmt.Printf("crashed, err: %s\nstack:%s", e, string(debug.Stack()))
   370  					}
   371  				}()
   372  				cui.conn.SetWriteDeadline(time.Now().Add(time.Millisecond * time.Duration(*s.cfg.Timeout)))
   373  				_, err = cui.conn.Write(utils.UDPPacket(cui.srcAddr.String(), buf[:n]))
   374  				cui.conn.SetWriteDeadline(time.Time{})
   375  				if err != nil {
   376  					cui.udpConn.Close()
   377  					return
   378  				}
   379  			}()
   380  		}
   381  	}()
   382  }
   383  func (s *MuxClient) UDPGCDeamon() {
   384  	gctime := int64(30)
   385  	go func() {
   386  		defer func() {
   387  			if e := recover(); e != nil {
   388  				fmt.Printf("crashed, err: %s\nstack:%s", e, string(debug.Stack()))
   389  			}
   390  		}()
   391  		if s.isStop {
   392  			return
   393  		}
   394  		timer := time.NewTicker(time.Second)
   395  		for {
   396  			<-timer.C
   397  			gcKeys := []string{}
   398  			s.udpConns.IterCb(func(key string, v interface{}) {
   399  				if time.Now().Unix()-v.(*ClientUDPConnItem).touchtime > gctime {
   400  					(*(v.(*ClientUDPConnItem).conn)).Close()
   401  					(v.(*ClientUDPConnItem).udpConn).Close()
   402  					gcKeys = append(gcKeys, key)
   403  					s.log.Printf("gc udp conn %s", v.(*ClientUDPConnItem).connid)
   404  				}
   405  			})
   406  			for _, k := range gcKeys {
   407  				s.udpConns.Remove(k)
   408  			}
   409  			gcKeys = nil
   410  		}
   411  	}()
   412  }
   413  func (s *MuxClient) ServeConn(inConn *smux.Stream, localAddr, ID string) {
   414  	var err error
   415  	var outConn net.Conn
   416  	i := 0
   417  	for {
   418  		if s.isStop {
   419  			return
   420  		}
   421  		i++
   422  		outConn, err = utils.ConnectHost(localAddr, *s.cfg.Timeout)
   423  		if err == nil || i == 3 {
   424  			break
   425  		} else {
   426  			if i == 3 {
   427  				s.log.Printf("connect to %s err: %s, retrying...", localAddr, err)
   428  				time.Sleep(2 * time.Second)
   429  				continue
   430  			}
   431  		}
   432  	}
   433  	if err != nil {
   434  		inConn.Close()
   435  		utils.CloseConn(&outConn)
   436  		s.log.Printf("build connection error, err: %s", err)
   437  		return
   438  	}
   439  
   440  	s.log.Printf("stream %s created", ID)
   441  	if *s.cfg.IsCompress {
   442  		die1 := make(chan bool, 1)
   443  		die2 := make(chan bool, 1)
   444  		go func() {
   445  			defer func() {
   446  				if e := recover(); e != nil {
   447  					fmt.Printf("crashed, err: %s\nstack:%s", e, string(debug.Stack()))
   448  				}
   449  			}()
   450  			io.Copy(outConn, snappy.NewReader(inConn))
   451  			die1 <- true
   452  		}()
   453  		go func() {
   454  			defer func() {
   455  				if e := recover(); e != nil {
   456  					fmt.Printf("crashed, err: %s\nstack:%s", e, string(debug.Stack()))
   457  				}
   458  			}()
   459  			io.Copy(snappy.NewWriter(inConn), outConn)
   460  			die2 <- true
   461  		}()
   462  		select {
   463  		case <-die1:
   464  		case <-die2:
   465  		}
   466  		outConn.Close()
   467  		inConn.Close()
   468  		s.log.Printf("%s stream %s released", *s.cfg.Key, ID)
   469  	} else {
   470  		utils.IoBind(inConn, outConn, func(err interface{}) {
   471  			s.log.Printf("stream %s released", ID)
   472  		}, s.log)
   473  	}
   474  }