github.com/AntonOrnatskyi/goproxy@v0.0.0-20190205095733-4526a9fa18b4/services/tunnel/tunnel_server.go (about)

     1  package tunnel
     2  
     3  import (
     4  	"crypto/tls"
     5  	"fmt"
     6  	"io"
     7  	logger "log"
     8  	"net"
     9  	"os"
    10  	"runtime/debug"
    11  	"strconv"
    12  	"strings"
    13  	"time"
    14  
    15  	"github.com/AntonOrnatskyi/goproxy/core/cs/server"
    16  	"github.com/AntonOrnatskyi/goproxy/services"
    17  	"github.com/AntonOrnatskyi/goproxy/utils"
    18  	"github.com/AntonOrnatskyi/goproxy/utils/jumper"
    19  	"github.com/AntonOrnatskyi/goproxy/utils/mapx"
    20  
    21  	//"github.com/xtaci/smux"
    22  	smux "github.com/hashicorp/yamux"
    23  )
    24  
    25  type TunnelServerArgs struct {
    26  	Parent    *string
    27  	CertFile  *string
    28  	KeyFile   *string
    29  	CertBytes []byte
    30  	KeyBytes  []byte
    31  	Local     *string
    32  	IsUDP     *bool
    33  	Key       *string
    34  	Remote    *string
    35  	Timeout   *int
    36  	Route     *[]string
    37  	Mgr       *TunnelServerManager
    38  	Jumper    *string
    39  }
    40  type TunnelServer struct {
    41  	cfg       TunnelServerArgs
    42  	sc        server.ServerChannel
    43  	isStop    bool
    44  	udpConn   *net.Conn
    45  	userConns mapx.ConcurrentMap
    46  	log       *logger.Logger
    47  	jumper    *jumper.Jumper
    48  	udpConns  mapx.ConcurrentMap
    49  }
    50  
    51  type TunnelServerManager struct {
    52  	cfg      TunnelServerArgs
    53  	serverID string
    54  	servers  []*services.Service
    55  	log      *logger.Logger
    56  }
    57  
    58  func NewTunnelServerManager() services.Service {
    59  	return &TunnelServerManager{
    60  		cfg:      TunnelServerArgs{},
    61  		serverID: utils.Uniqueid(),
    62  		servers:  []*services.Service{},
    63  	}
    64  }
    65  func (s *TunnelServerManager) Start(args interface{}, log *logger.Logger) (err error) {
    66  	s.log = log
    67  	s.cfg = args.(TunnelServerArgs)
    68  	if err = s.CheckArgs(); err != nil {
    69  		return
    70  	}
    71  	if *s.cfg.Parent != "" {
    72  		s.log.Printf("use tls parent %s", *s.cfg.Parent)
    73  	} else {
    74  		err = fmt.Errorf("parent required")
    75  		return
    76  	}
    77  
    78  	if err = s.InitService(); err != nil {
    79  		return
    80  	}
    81  
    82  	s.log.Printf("server id: %s", s.serverID)
    83  	//s.log.Printf("route:%v", *s.cfg.Route)
    84  	for _, _info := range *s.cfg.Route {
    85  		IsUDP := *s.cfg.IsUDP
    86  		if strings.HasPrefix(_info, "udp://") {
    87  			IsUDP = true
    88  		}
    89  		info := strings.TrimPrefix(_info, "udp://")
    90  		info = strings.TrimPrefix(info, "tcp://")
    91  		_routeInfo := strings.Split(info, "@")
    92  		server := NewTunnelServer()
    93  		local := _routeInfo[0]
    94  		remote := _routeInfo[1]
    95  		KEY := *s.cfg.Key
    96  		if strings.HasPrefix(remote, "[") {
    97  			KEY = remote[1:strings.LastIndex(remote, "]")]
    98  			remote = remote[strings.LastIndex(remote, "]")+1:]
    99  		}
   100  		if strings.HasPrefix(remote, ":") {
   101  			remote = fmt.Sprintf("127.0.0.1%s", remote)
   102  		}
   103  		err = server.Start(TunnelServerArgs{
   104  			CertBytes: s.cfg.CertBytes,
   105  			KeyBytes:  s.cfg.KeyBytes,
   106  			Parent:    s.cfg.Parent,
   107  			CertFile:  s.cfg.CertFile,
   108  			KeyFile:   s.cfg.KeyFile,
   109  			Local:     &local,
   110  			IsUDP:     &IsUDP,
   111  			Remote:    &remote,
   112  			Key:       &KEY,
   113  			Timeout:   s.cfg.Timeout,
   114  			Mgr:       s,
   115  			Jumper:    s.cfg.Jumper,
   116  		}, log)
   117  
   118  		if err != nil {
   119  			return
   120  		}
   121  		s.servers = append(s.servers, &server)
   122  	}
   123  	return
   124  }
   125  func (s *TunnelServerManager) Clean() {
   126  	s.StopService()
   127  }
   128  func (s *TunnelServerManager) StopService() {
   129  	for _, server := range s.servers {
   130  		(*server).Clean()
   131  	}
   132  	s.cfg = TunnelServerArgs{}
   133  	s.log = nil
   134  	s.serverID = ""
   135  	s.servers = nil
   136  	s = nil
   137  }
   138  func (s *TunnelServerManager) CheckArgs() (err error) {
   139  	if *s.cfg.CertFile == "" || *s.cfg.KeyFile == "" {
   140  		err = fmt.Errorf("cert and key file required")
   141  		return
   142  	}
   143  	s.cfg.CertBytes, s.cfg.KeyBytes, err = utils.TlsBytes(*s.cfg.CertFile, *s.cfg.KeyFile)
   144  	return
   145  }
   146  func (s *TunnelServerManager) InitService() (err error) {
   147  	return
   148  }
   149  
   150  func NewTunnelServer() services.Service {
   151  	return &TunnelServer{
   152  		cfg:       TunnelServerArgs{},
   153  		isStop:    false,
   154  		userConns: mapx.NewConcurrentMap(),
   155  		udpConns:  mapx.NewConcurrentMap(),
   156  	}
   157  }
   158  
   159  type TunnelUDPPacketItem struct {
   160  	packet    *[]byte
   161  	localAddr *net.UDPAddr
   162  	srcAddr   *net.UDPAddr
   163  }
   164  type TunnelUDPConnItem struct {
   165  	conn      *net.Conn
   166  	isActive  bool
   167  	touchtime int64
   168  	srcAddr   *net.UDPAddr
   169  	localAddr *net.UDPAddr
   170  	connid    string
   171  }
   172  
   173  func (s *TunnelServer) StopService() {
   174  	defer func() {
   175  		e := recover()
   176  		if e != nil {
   177  			s.log.Printf("stop server service crashed,%s", e)
   178  		} else {
   179  			s.log.Printf("service server stopped")
   180  		}
   181  		s.cfg = TunnelServerArgs{}
   182  		s.jumper = nil
   183  		s.log = nil
   184  		s.sc = server.ServerChannel{}
   185  		s.udpConn = nil
   186  		s.udpConns = nil
   187  		s.userConns = nil
   188  		s = nil
   189  	}()
   190  	s.isStop = true
   191  
   192  	if s.sc.Listener != nil {
   193  		(*s.sc.Listener).Close()
   194  	}
   195  	if s.sc.UDPListener != nil {
   196  		(*s.sc.UDPListener).Close()
   197  	}
   198  	if s.udpConn != nil {
   199  		(*s.udpConn).Close()
   200  	}
   201  	for _, c := range s.userConns.Items() {
   202  		(*c.(*net.Conn)).Close()
   203  	}
   204  }
   205  func (s *TunnelServer) InitService() (err error) {
   206  	s.UDPGCDeamon()
   207  	return
   208  }
   209  func (s *TunnelServer) CheckArgs() (err error) {
   210  	if *s.cfg.Remote == "" {
   211  		err = fmt.Errorf("remote required")
   212  		return
   213  	}
   214  	if *s.cfg.Jumper != "" {
   215  		var j jumper.Jumper
   216  		j, err = jumper.New(*s.cfg.Jumper, time.Millisecond*time.Duration(*s.cfg.Timeout))
   217  		if err != nil {
   218  			err = fmt.Errorf("parse jumper fail, err %s", err)
   219  			return
   220  		}
   221  		s.jumper = &j
   222  	}
   223  	return
   224  }
   225  
   226  func (s *TunnelServer) Start(args interface{}, log *logger.Logger) (err error) {
   227  	s.log = log
   228  	s.cfg = args.(TunnelServerArgs)
   229  	if err = s.CheckArgs(); err != nil {
   230  		return
   231  	}
   232  	if err = s.InitService(); err != nil {
   233  		return
   234  	}
   235  	host, port, _ := net.SplitHostPort(*s.cfg.Local)
   236  	p, _ := strconv.Atoi(port)
   237  	s.sc = server.NewServerChannel(host, p, s.log)
   238  	if *s.cfg.IsUDP {
   239  		err = s.sc.ListenUDP(func(listener *net.UDPConn, packet []byte, localAddr, srcAddr *net.UDPAddr) {
   240  			s.UDPSend(packet, localAddr, srcAddr)
   241  		})
   242  		if err != nil {
   243  			return
   244  		}
   245  		s.log.Printf("proxy on udp tunnel server mode %s", (*s.sc.UDPListener).LocalAddr())
   246  	} else {
   247  		err = s.sc.ListenTCP(func(inConn net.Conn) {
   248  			defer func() {
   249  				if err := recover(); err != nil {
   250  					s.log.Printf("tserver conn handler crashed with err : %s \nstack: %s", err, string(debug.Stack()))
   251  				}
   252  			}()
   253  			var outConn net.Conn
   254  			var ID string
   255  			for {
   256  				if s.isStop {
   257  					return
   258  				}
   259  				outConn, ID, err = s.GetOutConn(CONN_SERVER)
   260  				if err != nil {
   261  					utils.CloseConn(&outConn)
   262  					s.log.Printf("connect to %s fail, err: %s, retrying...", *s.cfg.Parent, err)
   263  					time.Sleep(time.Second * 3)
   264  					continue
   265  				} else {
   266  					break
   267  				}
   268  			}
   269  			inAddr := inConn.RemoteAddr().String()
   270  			utils.IoBind(inConn, outConn, func(err interface{}) {
   271  				s.userConns.Remove(inAddr)
   272  				s.log.Printf("%s conn %s released", *s.cfg.Key, ID)
   273  			}, s.log)
   274  			if c, ok := s.userConns.Get(inAddr); ok {
   275  				(*c.(*net.Conn)).Close()
   276  			}
   277  			s.userConns.Set(inAddr, &inConn)
   278  			s.log.Printf("%s conn %s created", *s.cfg.Key, ID)
   279  		})
   280  		if err != nil {
   281  			return
   282  		}
   283  		s.log.Printf("proxy on tunnel server mode %s", (*s.sc.Listener).Addr())
   284  	}
   285  	return
   286  }
   287  func (s *TunnelServer) Clean() {
   288  
   289  }
   290  func (s *TunnelServer) GetOutConn(typ uint8) (outConn net.Conn, ID string, err error) {
   291  	outConn, err = s.GetConn()
   292  	if err != nil {
   293  		s.log.Printf("connection err: %s", err)
   294  		return
   295  	}
   296  	remoteAddr := "tcp:" + *s.cfg.Remote
   297  	if *s.cfg.IsUDP {
   298  		remoteAddr = "udp:" + *s.cfg.Remote
   299  	}
   300  	ID = utils.Uniqueid()
   301  	_, err = outConn.Write(utils.BuildPacket(typ, *s.cfg.Key, ID, remoteAddr, s.cfg.Mgr.serverID))
   302  	if err != nil {
   303  		s.log.Printf("write connection data err: %s ,retrying...", err)
   304  		utils.CloseConn(&outConn)
   305  		return
   306  	}
   307  	return
   308  }
   309  func (s *TunnelServer) GetConn() (conn net.Conn, err error) {
   310  	var dconn net.Conn
   311  	if s.jumper == nil {
   312  		var _conn tls.Conn
   313  		_conn, err = utils.TlsConnectHost(*s.cfg.Parent, *s.cfg.Timeout, s.cfg.CertBytes, s.cfg.KeyBytes, nil)
   314  		if err == nil {
   315  			dconn = net.Conn(&_conn)
   316  		}
   317  	} else {
   318  		conf, err := utils.TlsConfig(s.cfg.CertBytes, s.cfg.KeyBytes, nil)
   319  		if err != nil {
   320  			return nil, err
   321  		}
   322  		var _c net.Conn
   323  		_c, err = s.jumper.Dial(*s.cfg.Parent, time.Millisecond*time.Duration(*s.cfg.Timeout))
   324  		if err == nil {
   325  			dconn = net.Conn(tls.Client(_c, conf))
   326  		}
   327  	}
   328  	if err == nil {
   329  		sess, e := smux.Client(dconn, &smux.Config{
   330  			AcceptBacklog:          256,
   331  			EnableKeepAlive:        true,
   332  			KeepAliveInterval:      9 * time.Second,
   333  			ConnectionWriteTimeout: 3 * time.Second,
   334  			MaxStreamWindowSize:    512 * 1024,
   335  			LogOutput:              os.Stderr,
   336  		})
   337  		if e != nil {
   338  			s.log.Printf("new mux client conn error,ERR:%s", e)
   339  			err = e
   340  			dconn.Close()
   341  			return
   342  		}
   343  		conn, e = sess.OpenStream()
   344  		if e != nil {
   345  			s.log.Printf("mux client conn open stream error,ERR:%s", e)
   346  			err = e
   347  			dconn.Close()
   348  			return
   349  		}
   350  		go func() {
   351  			defer func() {
   352  				_ = recover()
   353  			}()
   354  			timer := time.NewTicker(time.Second * 3)
   355  			for {
   356  				<-timer.C
   357  				if sess.NumStreams() == 0 {
   358  					sess.Close()
   359  					timer.Stop()
   360  					return
   361  				}
   362  			}
   363  		}()
   364  	}
   365  	return
   366  }
   367  func (s *TunnelServer) UDPGCDeamon() {
   368  	gctime := int64(30)
   369  	go func() {
   370  		defer func() {
   371  			if e := recover(); e != nil {
   372  				fmt.Printf("crashed, err: %s\nstack:\n%s", e, string(debug.Stack()))
   373  			}
   374  		}()
   375  		if s.isStop {
   376  			return
   377  		}
   378  		timer := time.NewTicker(time.Second)
   379  		for {
   380  			<-timer.C
   381  			gcKeys := []string{}
   382  			s.udpConns.IterCb(func(key string, v interface{}) {
   383  				if time.Now().Unix()-v.(*TunnelUDPConnItem).touchtime > gctime {
   384  					(*(v.(*TunnelUDPConnItem).conn)).Close()
   385  					gcKeys = append(gcKeys, key)
   386  					s.log.Printf("gc udp conn %s", v.(*TunnelUDPConnItem).connid)
   387  				}
   388  			})
   389  			for _, k := range gcKeys {
   390  				s.udpConns.Remove(k)
   391  			}
   392  			gcKeys = nil
   393  		}
   394  	}()
   395  }
   396  func (s *TunnelServer) UDPSend(data []byte, localAddr, srcAddr *net.UDPAddr) {
   397  	var (
   398  		uc      *TunnelUDPConnItem
   399  		key     = srcAddr.String()
   400  		ID      string
   401  		err     error
   402  		outconn net.Conn
   403  	)
   404  	v, ok := s.udpConns.Get(key)
   405  	if !ok {
   406  		outconn, ID, err = s.GetOutConn(CONN_SERVER)
   407  		if err != nil {
   408  			s.log.Printf("connect to %s fail, err: %s", *s.cfg.Parent, err)
   409  			return
   410  		}
   411  		uc = &TunnelUDPConnItem{
   412  			conn:      &outconn,
   413  			srcAddr:   srcAddr,
   414  			localAddr: localAddr,
   415  			connid:    ID,
   416  		}
   417  		s.udpConns.Set(key, uc)
   418  		s.UDPRevecive(key, ID)
   419  	} else {
   420  		uc = v.(*TunnelUDPConnItem)
   421  	}
   422  	go func() {
   423  		defer func() {
   424  			if e := recover(); e != nil {
   425  				(*uc.conn).Close()
   426  				s.udpConns.Remove(key)
   427  				s.log.Printf("udp sender crashed with error : %s", e)
   428  			}
   429  		}()
   430  		uc.touchtime = time.Now().Unix()
   431  		(*uc.conn).SetWriteDeadline(time.Now().Add(time.Millisecond * time.Duration(*s.cfg.Timeout)))
   432  		_, err = (*uc.conn).Write(utils.UDPPacket(srcAddr.String(), data))
   433  		(*uc.conn).SetWriteDeadline(time.Time{})
   434  		if err != nil {
   435  			s.log.Printf("write udp packet to %s fail ,flush err:%s ", *s.cfg.Parent, err)
   436  		}
   437  	}()
   438  }
   439  func (s *TunnelServer) UDPRevecive(key, ID string) {
   440  	go func() {
   441  		defer func() {
   442  			if e := recover(); e != nil {
   443  				fmt.Printf("crashed, err: %s\nstack:\n%s", e, string(debug.Stack()))
   444  			}
   445  		}()
   446  		s.log.Printf("udp conn %s connected", ID)
   447  		var uc *TunnelUDPConnItem
   448  		defer func() {
   449  			if uc != nil {
   450  				(*uc.conn).Close()
   451  			}
   452  			s.udpConns.Remove(key)
   453  			s.log.Printf("udp conn %s released", ID)
   454  		}()
   455  		v, ok := s.udpConns.Get(key)
   456  		if !ok {
   457  			s.log.Printf("[warn] udp conn not exists for %s, connid : %s", key, ID)
   458  			return
   459  		}
   460  		uc = v.(*TunnelUDPConnItem)
   461  		for {
   462  			_, body, err := utils.ReadUDPPacket(*uc.conn)
   463  			if err != nil {
   464  				if strings.Contains(err.Error(), "n != int(") {
   465  					continue
   466  				}
   467  				if err != io.EOF {
   468  					s.log.Printf("udp conn read udp packet fail , err: %s ", err)
   469  				}
   470  				return
   471  			}
   472  			uc.touchtime = time.Now().Unix()
   473  			go func() {
   474  				defer func() {
   475  					if e := recover(); e != nil {
   476  						fmt.Printf("crashed, err: %s\nstack:\n%s", e, string(debug.Stack()))
   477  					}
   478  				}()
   479  				s.sc.UDPListener.WriteToUDP(body, uc.srcAddr)
   480  			}()
   481  		}
   482  	}()
   483  }