github.com/AntonOrnatskyi/goproxy@v0.0.0-20190205095733-4526a9fa18b4/services/tunnel/tunnel_client.go (about)

     1  package tunnel
     2  
     3  import (
     4  	"crypto/tls"
     5  	"fmt"
     6  	"io"
     7  	logger "log"
     8  	"net"
     9  	"os"
    10  	"runtime/debug"
    11  	"strings"
    12  	"time"
    13  
    14  	"github.com/AntonOrnatskyi/goproxy/services"
    15  	"github.com/AntonOrnatskyi/goproxy/utils"
    16  	"github.com/AntonOrnatskyi/goproxy/utils/jumper"
    17  	"github.com/AntonOrnatskyi/goproxy/utils/mapx"
    18  
    19  	//"github.com/xtaci/smux"
    20  	smux "github.com/hashicorp/yamux"
    21  )
    22  
    23  type TunnelClientArgs struct {
    24  	Parent    *string
    25  	CertFile  *string
    26  	KeyFile   *string
    27  	CertBytes []byte
    28  	KeyBytes  []byte
    29  	Key       *string
    30  	Timeout   *int
    31  	Jumper    *string
    32  }
    33  type ClientUDPConnItem struct {
    34  	conn      *net.Conn
    35  	isActive  bool
    36  	touchtime int64
    37  	srcAddr   *net.UDPAddr
    38  	localAddr *net.UDPAddr
    39  	udpConn   *net.UDPConn
    40  	connid    string
    41  }
    42  type TunnelClient struct {
    43  	cfg       TunnelClientArgs
    44  	ctrlConn  *net.Conn
    45  	isStop    bool
    46  	userConns mapx.ConcurrentMap
    47  	log       *logger.Logger
    48  	jumper    *jumper.Jumper
    49  	udpConns  mapx.ConcurrentMap
    50  }
    51  
    52  func NewTunnelClient() services.Service {
    53  	return &TunnelClient{
    54  		cfg:       TunnelClientArgs{},
    55  		userConns: mapx.NewConcurrentMap(),
    56  		isStop:    false,
    57  		udpConns:  mapx.NewConcurrentMap(),
    58  	}
    59  }
    60  
    61  func (s *TunnelClient) InitService() (err error) {
    62  	s.UDPGCDeamon()
    63  	return
    64  }
    65  
    66  func (s *TunnelClient) CheckArgs() (err error) {
    67  	if *s.cfg.Parent != "" {
    68  		s.log.Printf("use tls parent %s", *s.cfg.Parent)
    69  	} else {
    70  		err = fmt.Errorf("parent required")
    71  		return
    72  	}
    73  	if *s.cfg.CertFile == "" || *s.cfg.KeyFile == "" {
    74  		err = fmt.Errorf("cert and key file required")
    75  		return
    76  	}
    77  	s.cfg.CertBytes, s.cfg.KeyBytes, err = utils.TlsBytes(*s.cfg.CertFile, *s.cfg.KeyFile)
    78  	if err != nil {
    79  		return
    80  	}
    81  	if *s.cfg.Jumper != "" {
    82  		var j jumper.Jumper
    83  		j, err = jumper.New(*s.cfg.Jumper, time.Millisecond*time.Duration(*s.cfg.Timeout))
    84  		if err != nil {
    85  			err = fmt.Errorf("parse jumper fail, err %s", err)
    86  			return
    87  		}
    88  		s.jumper = &j
    89  	}
    90  	return
    91  }
    92  func (s *TunnelClient) StopService() {
    93  	defer func() {
    94  		e := recover()
    95  		if e != nil {
    96  			s.log.Printf("stop tclient service crashed,%s", e)
    97  		} else {
    98  			s.log.Printf("service tclient stopped")
    99  		}
   100  		s.cfg = TunnelClientArgs{}
   101  		s.ctrlConn = nil
   102  		s.jumper = nil
   103  		s.log = nil
   104  		s.udpConns = nil
   105  		s.userConns = nil
   106  		s = nil
   107  	}()
   108  	s.isStop = true
   109  	if s.ctrlConn != nil {
   110  		(*s.ctrlConn).Close()
   111  	}
   112  	for _, c := range s.userConns.Items() {
   113  		(*c.(*net.Conn)).Close()
   114  	}
   115  }
   116  func (s *TunnelClient) Start(args interface{}, log *logger.Logger) (err error) {
   117  	s.log = log
   118  	s.cfg = args.(TunnelClientArgs)
   119  	if err = s.CheckArgs(); err != nil {
   120  		return
   121  	}
   122  	if err = s.InitService(); err != nil {
   123  		return
   124  	}
   125  	s.log.Printf("proxy on tunnel client mode")
   126  
   127  	for {
   128  		if s.isStop {
   129  			return
   130  		}
   131  		if s.ctrlConn != nil {
   132  			(*s.ctrlConn).Close()
   133  		}
   134  		var c net.Conn
   135  		c, err = s.GetInConn(CONN_CLIENT_CONTROL, *s.cfg.Key)
   136  		if err != nil {
   137  			s.log.Printf("control connection err: %s, retrying...", err)
   138  			time.Sleep(time.Second * 3)
   139  			continue
   140  		}
   141  		s.ctrlConn = &c
   142  		for {
   143  			if s.isStop {
   144  				return
   145  			}
   146  			var ID, clientLocalAddr, serverID string
   147  			err = utils.ReadPacketData(*s.ctrlConn, &ID, &clientLocalAddr, &serverID)
   148  			if err != nil {
   149  				if s.ctrlConn != nil {
   150  					(*s.ctrlConn).Close()
   151  				}
   152  				s.log.Printf("read connection signal err: %s, retrying...", err)
   153  				break
   154  			}
   155  			//s.log.Printf("signal revecived:%s %s %s", serverID, ID, clientLocalAddr)
   156  			protocol := clientLocalAddr[:3]
   157  			localAddr := clientLocalAddr[4:]
   158  			if protocol == "udp" {
   159  				go func() {
   160  					defer func() {
   161  						if e := recover(); e != nil {
   162  							fmt.Printf("crashed, err: %s\nstack:\n%s",e, string(debug.Stack()))
   163  						}
   164  					}()
   165  					s.ServeUDP(localAddr, ID, serverID)
   166  				}()
   167  			} else {
   168  				go func() {
   169  					defer func() {
   170  						if e := recover(); e != nil {
   171  							fmt.Printf("crashed, err: %s\nstack:\n%s",e, string(debug.Stack()))
   172  						}
   173  					}()
   174  					s.ServeConn(localAddr, ID, serverID)
   175  				}()
   176  			}
   177  		}
   178  	}
   179  }
   180  func (s *TunnelClient) Clean() {
   181  	s.StopService()
   182  }
   183  func (s *TunnelClient) GetInConn(typ uint8, data ...string) (outConn net.Conn, err error) {
   184  	outConn, err = s.GetConn()
   185  	if err != nil {
   186  		err = fmt.Errorf("connection err: %s", err)
   187  		return
   188  	}
   189  	_, err = outConn.Write(utils.BuildPacket(typ, data...))
   190  	if err != nil {
   191  		err = fmt.Errorf("write connection data err: %s ,retrying...", err)
   192  		utils.CloseConn(&outConn)
   193  		return
   194  	}
   195  	return
   196  }
   197  func (s *TunnelClient) GetConn() (conn net.Conn, err error) {
   198  	if s.jumper == nil {
   199  		var _conn tls.Conn
   200  		_conn, err = utils.TlsConnectHost(*s.cfg.Parent, *s.cfg.Timeout, s.cfg.CertBytes, s.cfg.KeyBytes, nil)
   201  		if err == nil {
   202  			conn = net.Conn(&_conn)
   203  		}
   204  	} else {
   205  		conf, err := utils.TlsConfig(s.cfg.CertBytes, s.cfg.KeyBytes, nil)
   206  		if err != nil {
   207  			return nil, err
   208  		}
   209  		var _c net.Conn
   210  		_c, err = s.jumper.Dial(*s.cfg.Parent, time.Millisecond*time.Duration(*s.cfg.Timeout))
   211  		if err == nil {
   212  			conn = net.Conn(tls.Client(_c, conf))
   213  		}
   214  	}
   215  	if err == nil {
   216  		sess, e := smux.Client(conn, &smux.Config{
   217  			AcceptBacklog:          256,
   218  			EnableKeepAlive:        true,
   219  			KeepAliveInterval:      9 * time.Second,
   220  			ConnectionWriteTimeout: 3 * time.Second,
   221  			MaxStreamWindowSize:    512 * 1024,
   222  			LogOutput:              os.Stderr,
   223  		})
   224  		if e != nil {
   225  			s.log.Printf("new mux client conn error,ERR:%s", e)
   226  			err = e
   227  			return
   228  		}
   229  		conn, e = sess.OpenStream()
   230  		if e != nil {
   231  			s.log.Printf("mux client conn open stream error,ERR:%s", e)
   232  			err = e
   233  			return
   234  		}
   235  		go func() {
   236  			defer func() {
   237  				_ = recover()
   238  			}()
   239  			timer := time.NewTicker(time.Second * 3)
   240  			for {
   241  				<-timer.C
   242  				if sess.NumStreams() == 0 {
   243  					sess.Close()
   244  					timer.Stop()
   245  					return
   246  				}
   247  			}
   248  		}()
   249  	}
   250  	return
   251  }
   252  func (s *TunnelClient) ServeUDP(localAddr, ID, serverID string) {
   253  	var inConn net.Conn
   254  	var err error
   255  	// for {
   256  	for {
   257  		if s.isStop {
   258  			if inConn != nil {
   259  				inConn.Close()
   260  			}
   261  			return
   262  		}
   263  		// s.cm.RemoveOne(*s.cfg.Key, ID)
   264  		inConn, err = s.GetInConn(CONN_CLIENT, *s.cfg.Key, ID, serverID)
   265  		if err != nil {
   266  			utils.CloseConn(&inConn)
   267  			s.log.Printf("connection err: %s, retrying...", err)
   268  			time.Sleep(time.Second * 3)
   269  			continue
   270  		} else {
   271  			break
   272  		}
   273  	}
   274  	// s.cm.Add(*s.cfg.Key, ID, &inConn)
   275  	s.log.Printf("conn %s created", ID)
   276  	var item *ClientUDPConnItem
   277  	var body []byte
   278  	srcAddr := ""
   279  	defer func() {
   280  		if item != nil {
   281  			(*(*item).conn).Close()
   282  			(*item).udpConn.Close()
   283  			s.udpConns.Remove(srcAddr)
   284  			inConn.Close()
   285  		}
   286  	}()
   287  	for {
   288  		if s.isStop {
   289  			return
   290  		}
   291  		srcAddr, body, err = utils.ReadUDPPacket(inConn)
   292  		if err != nil {
   293  			if strings.Contains(err.Error(), "n != int(") {
   294  				continue
   295  			}
   296  			if !utils.IsNetDeadlineErr(err) && err != io.EOF {
   297  				s.log.Printf("udp packet revecived from bridge fail, err: %s", err)
   298  			}
   299  			return
   300  		}
   301  		if v, ok := s.udpConns.Get(srcAddr); !ok {
   302  			_srcAddr, _ := net.ResolveUDPAddr("udp", srcAddr)
   303  			zeroAddr, _ := net.ResolveUDPAddr("udp", ":")
   304  			_localAddr, _ := net.ResolveUDPAddr("udp", localAddr)
   305  			c, err := net.DialUDP("udp", zeroAddr, _localAddr)
   306  			if err != nil {
   307  				s.log.Printf("create local udp conn fail, err : %s", err)
   308  				inConn.Close()
   309  				return
   310  			}
   311  			item = &ClientUDPConnItem{
   312  				conn:      &inConn,
   313  				srcAddr:   _srcAddr,
   314  				localAddr: _localAddr,
   315  				udpConn:   c,
   316  				connid:    ID,
   317  			}
   318  			s.udpConns.Set(srcAddr, item)
   319  			s.UDPRevecive(srcAddr, ID)
   320  		} else {
   321  			item = v.(*ClientUDPConnItem)
   322  		}
   323  		(*item).touchtime = time.Now().Unix()
   324  		go func() {
   325  			defer func() {
   326  				if e := recover(); e != nil {
   327  					fmt.Printf("crashed, err: %s\nstack:\n%s",e, string(debug.Stack()))
   328  				}
   329  			}()
   330  			(*item).udpConn.Write(body)
   331  		}()
   332  	}
   333  }
   334  func (s *TunnelClient) UDPRevecive(key, ID string) {
   335  	go func() {
   336  		defer func() {
   337  			if e := recover(); e != nil {
   338  				fmt.Printf("crashed, err: %s\nstack:\n%s",e, string(debug.Stack()))
   339  			}
   340  		}()
   341  		s.log.Printf("udp conn %s connected", ID)
   342  		v, ok := s.udpConns.Get(key)
   343  		if !ok {
   344  			s.log.Printf("[warn] udp conn not exists for %s, connid : %s", key, ID)
   345  			return
   346  		}
   347  		cui := v.(*ClientUDPConnItem)
   348  		buf := utils.LeakyBuffer.Get()
   349  		defer func() {
   350  			utils.LeakyBuffer.Put(buf)
   351  			(*cui.conn).Close()
   352  			cui.udpConn.Close()
   353  			s.udpConns.Remove(key)
   354  			s.log.Printf("udp conn %s released", ID)
   355  		}()
   356  		for {
   357  			n, err := cui.udpConn.Read(buf)
   358  			if err != nil {
   359  				if !utils.IsNetClosedErr(err) {
   360  					s.log.Printf("udp conn read udp packet fail , err: %s ", err)
   361  				}
   362  				return
   363  			}
   364  			cui.touchtime = time.Now().Unix()
   365  			go func() {
   366  				defer func() {
   367  					if e := recover(); e != nil {
   368  						fmt.Printf("crashed, err: %s\nstack:\n%s",e, string(debug.Stack()))
   369  					}
   370  				}()
   371  				(*cui.conn).SetWriteDeadline(time.Now().Add(time.Millisecond * time.Duration(*s.cfg.Timeout)))
   372  				_, err = (*cui.conn).Write(utils.UDPPacket(cui.srcAddr.String(), buf[:n]))
   373  				(*cui.conn).SetWriteDeadline(time.Time{})
   374  				if err != nil {
   375  					cui.udpConn.Close()
   376  					return
   377  				}
   378  			}()
   379  		}
   380  	}()
   381  }
   382  func (s *TunnelClient) UDPGCDeamon() {
   383  	gctime := int64(30)
   384  	go func() {
   385  		defer func() {
   386  			if e := recover(); e != nil {
   387  				fmt.Printf("crashed, err: %s\nstack:\n%s",e, string(debug.Stack()))
   388  			}
   389  		}()
   390  		if s.isStop {
   391  			return
   392  		}
   393  		timer := time.NewTicker(time.Second)
   394  		for {
   395  			<-timer.C
   396  			gcKeys := []string{}
   397  			s.udpConns.IterCb(func(key string, v interface{}) {
   398  				if time.Now().Unix()-v.(*ClientUDPConnItem).touchtime > gctime {
   399  					(*(v.(*ClientUDPConnItem).conn)).Close()
   400  					(v.(*ClientUDPConnItem).udpConn).Close()
   401  					gcKeys = append(gcKeys, key)
   402  					s.log.Printf("gc udp conn %s", v.(*ClientUDPConnItem).connid)
   403  				}
   404  			})
   405  			for _, k := range gcKeys {
   406  				s.udpConns.Remove(k)
   407  			}
   408  			gcKeys = nil
   409  		}
   410  	}()
   411  }
   412  func (s *TunnelClient) ServeConn(localAddr, ID, serverID string) {
   413  	var inConn, outConn net.Conn
   414  	var err error
   415  	for {
   416  		if s.isStop {
   417  			return
   418  		}
   419  		inConn, err = s.GetInConn(CONN_CLIENT, *s.cfg.Key, ID, serverID)
   420  		if err != nil {
   421  			utils.CloseConn(&inConn)
   422  			s.log.Printf("connection err: %s, retrying...", err)
   423  			time.Sleep(time.Second * 3)
   424  			continue
   425  		} else {
   426  			break
   427  		}
   428  	}
   429  
   430  	i := 0
   431  	for {
   432  		if s.isStop {
   433  			return
   434  		}
   435  		i++
   436  		outConn, err = utils.ConnectHost(localAddr, *s.cfg.Timeout)
   437  		if err == nil || i == 3 {
   438  			break
   439  		} else {
   440  			if i == 3 {
   441  				s.log.Printf("connect to %s err: %s, retrying...", localAddr, err)
   442  				time.Sleep(2 * time.Second)
   443  				continue
   444  			}
   445  		}
   446  	}
   447  	if err != nil {
   448  		utils.CloseConn(&inConn)
   449  		utils.CloseConn(&outConn)
   450  		s.log.Printf("build connection error, err: %s", err)
   451  		return
   452  	}
   453  	inAddr := inConn.RemoteAddr().String()
   454  	utils.IoBind(inConn, outConn, func(err interface{}) {
   455  		s.log.Printf("conn %s released", ID)
   456  		s.userConns.Remove(inAddr)
   457  	}, s.log)
   458  	if c, ok := s.userConns.Get(inAddr); ok {
   459  		(*c.(*net.Conn)).Close()
   460  	}
   461  	s.userConns.Set(inAddr, &inConn)
   462  	s.log.Printf("conn %s created", ID)
   463  }