github.com/AntonOrnatskyi/goproxy@v0.0.0-20190205095733-4526a9fa18b4/services/tunnel/tunnel_bridge.go (about)

     1  package tunnel
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	logger "log"
     7  	"net"
     8  	"os"
     9  	"strconv"
    10  	"strings"
    11  	"time"
    12  
    13  	"github.com/AntonOrnatskyi/goproxy/core/cs/server"
    14  	"github.com/AntonOrnatskyi/goproxy/services"
    15  	"github.com/AntonOrnatskyi/goproxy/utils"
    16  	"github.com/AntonOrnatskyi/goproxy/utils/mapx"
    17  
    18  	//"github.com/xtaci/smux"
    19  	smux "github.com/hashicorp/yamux"
    20  )
    21  
    22  const (
    23  	CONN_CLIENT_CONTROL = uint8(1)
    24  	CONN_SERVER         = uint8(4)
    25  	CONN_CLIENT         = uint8(5)
    26  )
    27  
    28  type TunnelBridgeArgs struct {
    29  	Parent    *string
    30  	CertFile  *string
    31  	KeyFile   *string
    32  	CertBytes []byte
    33  	KeyBytes  []byte
    34  	Local     *string
    35  	Timeout   *int
    36  }
    37  type ServerConn struct {
    38  	//ClientLocalAddr string //tcp:2.2.22:333@ID
    39  	Conn *net.Conn
    40  }
    41  type TunnelBridge struct {
    42  	cfg                TunnelBridgeArgs
    43  	serverConns        mapx.ConcurrentMap
    44  	clientControlConns mapx.ConcurrentMap
    45  	isStop             bool
    46  	log                *logger.Logger
    47  }
    48  
    49  func NewTunnelBridge() services.Service {
    50  	return &TunnelBridge{
    51  		cfg:                TunnelBridgeArgs{},
    52  		serverConns:        mapx.NewConcurrentMap(),
    53  		clientControlConns: mapx.NewConcurrentMap(),
    54  		isStop:             false,
    55  	}
    56  }
    57  
    58  func (s *TunnelBridge) InitService() (err error) {
    59  	return
    60  }
    61  func (s *TunnelBridge) CheckArgs() (err error) {
    62  	if *s.cfg.CertFile == "" || *s.cfg.KeyFile == "" {
    63  		err = fmt.Errorf("cert and key file required")
    64  		return
    65  	}
    66  	s.cfg.CertBytes, s.cfg.KeyBytes, err = utils.TlsBytes(*s.cfg.CertFile, *s.cfg.KeyFile)
    67  	return
    68  }
    69  func (s *TunnelBridge) StopService() {
    70  	defer func() {
    71  		e := recover()
    72  		if e != nil {
    73  			s.log.Printf("stop tbridge service crashed,%s", e)
    74  		} else {
    75  			s.log.Printf("service tbridge stopped")
    76  		}
    77  		s.cfg = TunnelBridgeArgs{}
    78  		s.clientControlConns = nil
    79  		s.log = nil
    80  		s.serverConns = nil
    81  		s = nil
    82  	}()
    83  	s.isStop = true
    84  	for _, sess := range s.clientControlConns.Items() {
    85  		(*sess.(*net.Conn)).Close()
    86  	}
    87  	for _, sess := range s.serverConns.Items() {
    88  		(*sess.(ServerConn).Conn).Close()
    89  	}
    90  }
    91  func (s *TunnelBridge) Start(args interface{}, log *logger.Logger) (err error) {
    92  	s.log = log
    93  	s.cfg = args.(TunnelBridgeArgs)
    94  	if err = s.CheckArgs(); err != nil {
    95  		return
    96  	}
    97  	if err = s.InitService(); err != nil {
    98  		return
    99  	}
   100  	host, port, _ := net.SplitHostPort(*s.cfg.Local)
   101  	p, _ := strconv.Atoi(port)
   102  	sc := server.NewServerChannel(host, p, s.log)
   103  
   104  	err = sc.ListenTLS(s.cfg.CertBytes, s.cfg.KeyBytes, nil, s.callback)
   105  	if err != nil {
   106  		return
   107  	}
   108  	s.log.Printf("proxy on tunnel bridge mode %s", (*sc.Listener).Addr())
   109  	return
   110  }
   111  func (s *TunnelBridge) Clean() {
   112  	s.StopService()
   113  }
   114  func (s *TunnelBridge) callback(inConn net.Conn) {
   115  	var err error
   116  	//s.log.Printf("connection from %s ", inConn.RemoteAddr())
   117  	sess, err := smux.Server(inConn, &smux.Config{
   118  		AcceptBacklog:          256,
   119  		EnableKeepAlive:        true,
   120  		KeepAliveInterval:      9 * time.Second,
   121  		ConnectionWriteTimeout: 3 * time.Second,
   122  		MaxStreamWindowSize:    512 * 1024,
   123  		LogOutput:              os.Stderr,
   124  	})
   125  	if err != nil {
   126  		s.log.Printf("new mux server conn error,ERR:%s", err)
   127  		return
   128  	}
   129  	inConn, err = sess.AcceptStream()
   130  	if err != nil {
   131  		s.log.Printf("mux server conn accept error,ERR:%s", err)
   132  		return
   133  	}
   134  	go func() {
   135  		defer func() {
   136  			_ = recover()
   137  		}()
   138  		timer := time.NewTicker(time.Second * 3)
   139  		for {
   140  			<-timer.C
   141  			if sess.NumStreams() == 0 {
   142  				sess.Close()
   143  				timer.Stop()
   144  				return
   145  			}
   146  		}
   147  	}()
   148  	var buf = make([]byte, 1024)
   149  	n, _ := inConn.Read(buf)
   150  	reader := bytes.NewReader(buf[:n])
   151  
   152  	//reader := bufio.NewReader(inConn)
   153  
   154  	var connType uint8
   155  	err = utils.ReadPacket(reader, &connType)
   156  	if err != nil {
   157  		s.log.Printf("read error,ERR:%s", err)
   158  		return
   159  	}
   160  	switch connType {
   161  	case CONN_SERVER:
   162  		var key, ID, clientLocalAddr, serverID string
   163  		err = utils.ReadPacketData(reader, &key, &ID, &clientLocalAddr, &serverID)
   164  		if err != nil {
   165  			s.log.Printf("read error,ERR:%s", err)
   166  			return
   167  		}
   168  		packet := utils.BuildPacketData(ID, clientLocalAddr, serverID)
   169  		s.log.Printf("server connection, key: %s , id: %s %s %s", key, ID, clientLocalAddr, serverID)
   170  
   171  		//addr := clientLocalAddr + "@" + ID
   172  		s.serverConns.Set(ID, ServerConn{
   173  			Conn: &inConn,
   174  		})
   175  		for {
   176  			if s.isStop {
   177  				return
   178  			}
   179  			item, ok := s.clientControlConns.Get(key)
   180  			if !ok {
   181  				s.log.Printf("client %s control conn not exists", key)
   182  				time.Sleep(time.Second * 3)
   183  				continue
   184  			}
   185  			(*item.(*net.Conn)).SetWriteDeadline(time.Now().Add(time.Second * 3))
   186  			_, err := (*item.(*net.Conn)).Write(packet)
   187  			(*item.(*net.Conn)).SetWriteDeadline(time.Time{})
   188  			if err != nil && strings.Contains(err.Error(), "stream closed") {
   189  				s.log.Printf("%s client control conn write signal fail, err: %s, retrying...", key, err)
   190  				time.Sleep(time.Second * 3)
   191  				continue
   192  			} else {
   193  				// s.cmServer.Add(serverID, ID, &inConn)
   194  				break
   195  			}
   196  		}
   197  	case CONN_CLIENT:
   198  		var key, ID, serverID string
   199  		err = utils.ReadPacketData(reader, &key, &ID, &serverID)
   200  		if err != nil {
   201  			s.log.Printf("read error,ERR:%s", err)
   202  			return
   203  		}
   204  		s.log.Printf("client connection , key: %s , id: %s, server id:%s", key, ID, serverID)
   205  
   206  		serverConnItem, ok := s.serverConns.Get(ID)
   207  		if !ok {
   208  			inConn.Close()
   209  			s.log.Printf("server conn %s exists", ID)
   210  			return
   211  		}
   212  		serverConn := serverConnItem.(ServerConn).Conn
   213  		utils.IoBind(*serverConn, inConn, func(err interface{}) {
   214  			s.serverConns.Remove(ID)
   215  			// s.cmClient.RemoveOne(key, ID)
   216  			// s.cmServer.RemoveOne(serverID, ID)
   217  			s.log.Printf("conn %s released", ID)
   218  		}, s.log)
   219  		// s.cmClient.Add(key, ID, &inConn)
   220  		s.log.Printf("conn %s created", ID)
   221  
   222  	case CONN_CLIENT_CONTROL:
   223  		var key string
   224  		err = utils.ReadPacketData(reader, &key)
   225  		if err != nil {
   226  			s.log.Printf("read error,ERR:%s", err)
   227  			return
   228  		}
   229  		s.log.Printf("client control connection, key: %s", key)
   230  		if s.clientControlConns.Has(key) {
   231  			item, _ := s.clientControlConns.Get(key)
   232  			(*item.(*net.Conn)).Close()
   233  		}
   234  		s.clientControlConns.Set(key, &inConn)
   235  		s.log.Printf("set client %s control conn", key)
   236  	}
   237  }