github.com/AntonOrnatskyi/goproxy@v0.0.0-20190205095733-4526a9fa18b4/core/cs/server/server.go (about)

     1  package server
     2  
     3  import (
     4  	"crypto/tls"
     5  	"crypto/x509"
     6  	"errors"
     7  	"fmt"
     8  
     9  	logger "log"
    10  	"net"
    11  	"runtime/debug"
    12  	"strconv"
    13  
    14  	tou "github.com/AntonOrnatskyi/goproxy/core/dst"
    15  	compressconn "github.com/AntonOrnatskyi/goproxy/core/lib/transport"
    16  	transportc "github.com/AntonOrnatskyi/goproxy/core/lib/transport"
    17  	encryptconn "github.com/AntonOrnatskyi/goproxy/core/lib/transport/encrypt"
    18  
    19  	"github.com/AntonOrnatskyi/goproxy/core/lib/kcpcfg"
    20  
    21  	kcp "github.com/xtaci/kcp-go"
    22  )
    23  
    24  func init() {
    25  
    26  }
    27  
    28  type ServerChannel struct {
    29  	ip               string
    30  	port             int
    31  	Listener         *net.Listener
    32  	UDPListener      *net.UDPConn
    33  	errAcceptHandler func(err error)
    34  	log              *logger.Logger
    35  	TOUServer        *tou.Mux
    36  }
    37  
    38  func NewServerChannel(ip string, port int, log *logger.Logger) ServerChannel {
    39  	return ServerChannel{
    40  		ip:   ip,
    41  		port: port,
    42  		log:  log,
    43  		errAcceptHandler: func(err error) {
    44  			log.Printf("accept error , ERR:%s", err)
    45  		},
    46  	}
    47  }
    48  func NewServerChannelHost(host string, log *logger.Logger) ServerChannel {
    49  	h, port, _ := net.SplitHostPort(host)
    50  	p, _ := strconv.Atoi(port)
    51  	return ServerChannel{
    52  		ip:   h,
    53  		port: p,
    54  		log:  log,
    55  		errAcceptHandler: func(err error) {
    56  			log.Printf("accept error , ERR:%s", err)
    57  		},
    58  	}
    59  }
    60  func (s *ServerChannel) SetErrAcceptHandler(fn func(err error)) {
    61  	s.errAcceptHandler = fn
    62  }
    63  func (s *ServerChannel) ListenSingleTLS(certBytes, keyBytes, caCertBytes []byte, fn func(conn net.Conn)) (err error) {
    64  	return s._ListenTLS(certBytes, keyBytes, caCertBytes, fn, true)
    65  
    66  }
    67  func (s *ServerChannel) ListenTLS(certBytes, keyBytes, caCertBytes []byte, fn func(conn net.Conn)) (err error) {
    68  	return s._ListenTLS(certBytes, keyBytes, caCertBytes, fn, false)
    69  }
    70  func (s *ServerChannel) _ListenTLS(certBytes, keyBytes, caCertBytes []byte, fn func(conn net.Conn), single bool) (err error) {
    71  	s.Listener, err = s.listenTLS(s.ip, s.port, certBytes, keyBytes, caCertBytes, single)
    72  	if err == nil {
    73  		go func() {
    74  			defer func() {
    75  				if e := recover(); e != nil {
    76  					s.log.Printf("ListenTLS crashed , err : %s , \ntrace:%s", e, string(debug.Stack()))
    77  				}
    78  			}()
    79  			for {
    80  				var conn net.Conn
    81  				conn, err = (*s.Listener).Accept()
    82  				if err == nil {
    83  					go func() {
    84  						defer func() {
    85  							if e := recover(); e != nil {
    86  								s.log.Printf("tls connection handler crashed , err : %s , \ntrace:%s", e, string(debug.Stack()))
    87  							}
    88  						}()
    89  						fn(conn)
    90  					}()
    91  				} else {
    92  					s.errAcceptHandler(err)
    93  					(*s.Listener).Close()
    94  					break
    95  				}
    96  			}
    97  		}()
    98  	}
    99  	return
   100  }
   101  func (s *ServerChannel) listenTLS(ip string, port int, certBytes, keyBytes, caCertBytes []byte, single bool) (ln *net.Listener, err error) {
   102  	var cert tls.Certificate
   103  	cert, err = tls.X509KeyPair(certBytes, keyBytes)
   104  	if err != nil {
   105  		return
   106  	}
   107  	config := &tls.Config{
   108  		Certificates: []tls.Certificate{cert},
   109  	}
   110  	if !single {
   111  		clientCertPool := x509.NewCertPool()
   112  		caBytes := certBytes
   113  		if caCertBytes != nil {
   114  			caBytes = caCertBytes
   115  		}
   116  		ok := clientCertPool.AppendCertsFromPEM(caBytes)
   117  		if !ok {
   118  			err = errors.New("failed to parse root certificate")
   119  		}
   120  		config.ClientCAs = clientCertPool
   121  		config.ClientAuth = tls.RequireAndVerifyClientCert
   122  	}
   123  	_ln, err := tls.Listen("tcp", net.JoinHostPort(ip, fmt.Sprintf("%d", port)), config)
   124  	if err == nil {
   125  		ln = &_ln
   126  	}
   127  	return
   128  }
   129  func (s *ServerChannel) ListenTCPS(method, password string, compress bool, fn func(conn net.Conn)) (err error) {
   130  	_, err = encryptconn.NewCipher(method, password)
   131  	if err != nil {
   132  		return
   133  	}
   134  	return s.ListenTCP(func(c net.Conn) {
   135  		if compress {
   136  			c = transportc.NewCompConn(c)
   137  		}
   138  		c, _ = encryptconn.NewConn(c, method, password)
   139  		fn(c)
   140  	})
   141  }
   142  func (s *ServerChannel) ListenTCP(fn func(conn net.Conn)) (err error) {
   143  	var l net.Listener
   144  	l, err = net.Listen("tcp", net.JoinHostPort(s.ip, fmt.Sprintf("%d", s.port)))
   145  	if err == nil {
   146  		s.Listener = &l
   147  		go func() {
   148  			defer func() {
   149  				if e := recover(); e != nil {
   150  					s.log.Printf("ListenTCP crashed , err : %s , \ntrace:%s", e, string(debug.Stack()))
   151  				}
   152  			}()
   153  			for {
   154  				var conn net.Conn
   155  				conn, err = (*s.Listener).Accept()
   156  				if err == nil {
   157  					go func() {
   158  						defer func() {
   159  							if e := recover(); e != nil {
   160  								s.log.Printf("tcp connection handler crashed , err : %s , \ntrace:%s", e, string(debug.Stack()))
   161  							}
   162  						}()
   163  						fn(conn)
   164  					}()
   165  				} else {
   166  					s.errAcceptHandler(err)
   167  					(*s.Listener).Close()
   168  					break
   169  				}
   170  			}
   171  		}()
   172  	}
   173  	return
   174  }
   175  func (s *ServerChannel) ListenUDP(fn func(listener *net.UDPConn, packet []byte, localAddr, srcAddr *net.UDPAddr)) (err error) {
   176  	addr := &net.UDPAddr{IP: net.ParseIP(s.ip), Port: s.port}
   177  	l, err := net.ListenUDP("udp", addr)
   178  	if err == nil {
   179  		s.UDPListener = l
   180  		go func() {
   181  			defer func() {
   182  				if e := recover(); e != nil {
   183  					s.log.Printf("ListenUDP crashed , err : %s , \ntrace:%s", e, string(debug.Stack()))
   184  				}
   185  			}()
   186  			for {
   187  				var buf = make([]byte, 2048)
   188  				n, srcAddr, err := (*s.UDPListener).ReadFromUDP(buf)
   189  				if err == nil {
   190  					packet := buf[0:n]
   191  					go func() {
   192  						defer func() {
   193  							if e := recover(); e != nil {
   194  								s.log.Printf("udp data handler crashed , err : %s , \ntrace:%s", e, string(debug.Stack()))
   195  							}
   196  						}()
   197  						fn(s.UDPListener, packet, addr, srcAddr)
   198  					}()
   199  				} else {
   200  					s.errAcceptHandler(err)
   201  					(*s.UDPListener).Close()
   202  					break
   203  				}
   204  			}
   205  		}()
   206  	}
   207  	return
   208  }
   209  func (s *ServerChannel) ListenKCP(config kcpcfg.KCPConfigArgs, fn func(conn net.Conn), log *logger.Logger) (err error) {
   210  	lis, err := kcp.ListenWithOptions(net.JoinHostPort(s.ip, fmt.Sprintf("%d", s.port)), config.Block, *config.DataShard, *config.ParityShard)
   211  	if err == nil {
   212  		if err = lis.SetDSCP(*config.DSCP); err != nil {
   213  			log.Println("SetDSCP:", err)
   214  			return
   215  		}
   216  		if err = lis.SetReadBuffer(*config.SockBuf); err != nil {
   217  			log.Println("SetReadBuffer:", err)
   218  			return
   219  		}
   220  		if err = lis.SetWriteBuffer(*config.SockBuf); err != nil {
   221  			log.Println("SetWriteBuffer:", err)
   222  			return
   223  		}
   224  		s.Listener = new(net.Listener)
   225  		*s.Listener = lis
   226  		go func() {
   227  			defer func() {
   228  				if e := recover(); e != nil {
   229  					s.log.Printf("ListenKCP crashed , err : %s , \ntrace:%s", e, string(debug.Stack()))
   230  				}
   231  			}()
   232  			for {
   233  				conn, err := lis.AcceptKCP()
   234  				if err == nil {
   235  					go func() {
   236  						defer func() {
   237  							if e := recover(); e != nil {
   238  								s.log.Printf("kcp connection handler crashed , err : %s , \ntrace:%s", e, string(debug.Stack()))
   239  							}
   240  						}()
   241  						conn.SetStreamMode(true)
   242  						conn.SetWriteDelay(true)
   243  						conn.SetNoDelay(*config.NoDelay, *config.Interval, *config.Resend, *config.NoCongestion)
   244  						conn.SetMtu(*config.MTU)
   245  						conn.SetWindowSize(*config.SndWnd, *config.RcvWnd)
   246  						conn.SetACKNoDelay(*config.AckNodelay)
   247  						if *config.NoComp {
   248  							fn(conn)
   249  						} else {
   250  							cconn := transportc.NewCompStream(conn)
   251  							fn(cconn)
   252  						}
   253  					}()
   254  				} else {
   255  					s.errAcceptHandler(err)
   256  					(*s.Listener).Close()
   257  					break
   258  				}
   259  			}
   260  		}()
   261  	}
   262  	return
   263  }
   264  
   265  func (s *ServerChannel) ListenTOU(method, password string, compress bool, fn func(conn net.Conn)) (err error) {
   266  	addr := &net.UDPAddr{IP: net.ParseIP(s.ip), Port: s.port}
   267  	s.UDPListener, err = net.ListenUDP("udp", addr)
   268  	if err != nil {
   269  		s.log.Println(err)
   270  		return
   271  	}
   272  	s.TOUServer = tou.NewMux(s.UDPListener, 0)
   273  	go func() {
   274  		defer func() {
   275  			if e := recover(); e != nil {
   276  				s.log.Printf("ListenRUDP crashed , err : %s , \ntrace:%s", e, string(debug.Stack()))
   277  			}
   278  		}()
   279  		for {
   280  			var conn net.Conn
   281  			conn, err = (*s.TOUServer).Accept()
   282  			if err == nil {
   283  				go func() {
   284  					defer func() {
   285  						if e := recover(); e != nil {
   286  							s.log.Printf("tcp connection handler crashed , err : %s , \ntrace:%s", e, string(debug.Stack()))
   287  						}
   288  					}()
   289  					if compress {
   290  						conn = compressconn.NewCompConn(conn)
   291  					}
   292  					conn, err = encryptconn.NewConn(conn, method, password)
   293  					if err != nil {
   294  						conn.Close()
   295  						s.log.Println(err)
   296  						return
   297  					}
   298  					fn(conn)
   299  				}()
   300  			} else {
   301  				s.errAcceptHandler(err)
   302  				s.TOUServer.Close()
   303  				s.UDPListener.Close()
   304  				break
   305  			}
   306  		}
   307  	}()
   308  
   309  	return
   310  }
   311  func (s *ServerChannel) Close() {
   312  	defer func() {
   313  		if e := recover(); e != nil {
   314  			s.log.Printf("close crashed :\n%s\n%s", e, string(debug.Stack()))
   315  		}
   316  	}()
   317  	if s.Listener != nil && *s.Listener != nil {
   318  		(*s.Listener).Close()
   319  	}
   320  	if s.TOUServer != nil {
   321  		s.TOUServer.Close()
   322  	}
   323  	if s.UDPListener != nil {
   324  		s.UDPListener.Close()
   325  	}
   326  }
   327  func (s *ServerChannel) Addr() string {
   328  	defer func() {
   329  		if e := recover(); e != nil {
   330  			s.log.Printf("close crashed :\n%s\n%s", e, string(debug.Stack()))
   331  		}
   332  	}()
   333  	if s.Listener != nil && *s.Listener != nil {
   334  		return (*s.Listener).Addr().String()
   335  	}
   336  
   337  	if s.UDPListener != nil {
   338  		return s.UDPListener.LocalAddr().String()
   339  	}
   340  	return ""
   341  }