github.com/AntonOrnatskyi/goproxy@v0.0.0-20190205095733-4526a9fa18b4/utils/functions.go (about)

     1  package utils
     2  
     3  import (
     4  	"bufio"
     5  	"bytes"
     6  	"context"
     7  	"crypto/sha1"
     8  	"crypto/tls"
     9  	"crypto/x509"
    10  	"encoding/base64"
    11  	"encoding/binary"
    12  	"encoding/hex"
    13  	"encoding/pem"
    14  	"errors"
    15  	"fmt"
    16  	"io"
    17  	"io/ioutil"
    18  	logger "log"
    19  	"math/rand"
    20  	"net"
    21  	"net/http"
    22  	"os"
    23  	"strings"
    24  
    25  	"github.com/AntonOrnatskyi/goproxy/core/lib/kcpcfg"
    26  	"github.com/AntonOrnatskyi/goproxy/utils/lb"
    27  
    28  	"golang.org/x/crypto/pbkdf2"
    29  
    30  	"strconv"
    31  
    32  	"time"
    33  
    34  	"github.com/AntonOrnatskyi/goproxy/utils/id"
    35  
    36  	kcp "github.com/xtaci/kcp-go"
    37  )
    38  
    39  func IoBind(dst io.ReadWriteCloser, src io.ReadWriteCloser, fn func(err interface{}), log *logger.Logger) {
    40  	ioBind(dst, src, fn, log, true)
    41  }
    42  func IoBindNoClose(dst io.ReadWriteCloser, src io.ReadWriteCloser, fn func(err interface{}), log *logger.Logger) {
    43  	ioBind(dst, src, fn, log, false)
    44  }
    45  func ioBind(dst io.ReadWriteCloser, src io.ReadWriteCloser, fn func(err interface{}), log *logger.Logger, close bool) {
    46  	go func() {
    47  		defer func() {
    48  			if err := recover(); err != nil {
    49  				log.Printf("bind crashed %s", err)
    50  			}
    51  		}()
    52  		e1 := make(chan interface{}, 1)
    53  		e2 := make(chan interface{}, 1)
    54  		go func() {
    55  			defer func() {
    56  				if err := recover(); err != nil {
    57  					log.Printf("bind crashed %s", err)
    58  				}
    59  			}()
    60  			//_, err := io.Copy(dst, src)
    61  			err := ioCopy(dst, src)
    62  			e1 <- err
    63  		}()
    64  		go func() {
    65  			defer func() {
    66  				if err := recover(); err != nil {
    67  					log.Printf("bind crashed %s", err)
    68  				}
    69  			}()
    70  			//_, err := io.Copy(src, dst)
    71  			err := ioCopy(src, dst)
    72  			e2 <- err
    73  		}()
    74  		var err interface{}
    75  		select {
    76  		case err = <-e1:
    77  			//log.Printf("e1")
    78  		case err = <-e2:
    79  			//log.Printf("e2")
    80  		}
    81  		func() {
    82  			defer func() {
    83  				_ = recover()
    84  			}()
    85  			if close {
    86  				src.Close()
    87  			}
    88  		}()
    89  		func() {
    90  			defer func() {
    91  				_ = recover()
    92  			}()
    93  			if close {
    94  				dst.Close()
    95  			}
    96  		}()
    97  		if fn != nil {
    98  			fn(err)
    99  		}
   100  	}()
   101  }
   102  func ioCopy(dst io.ReadWriter, src io.ReadWriter) (err error) {
   103  	defer func() {
   104  		if e := recover(); e != nil {
   105  		}
   106  	}()
   107  	buf := LeakyBuffer.Get()
   108  	defer LeakyBuffer.Put(buf)
   109  	n := 0
   110  	for {
   111  		n, err = src.Read(buf)
   112  		if n > 0 {
   113  			if n > len(buf) {
   114  				n = len(buf)
   115  			}
   116  			if _, e := dst.Write(buf[0:n]); e != nil {
   117  				return e
   118  			}
   119  		}
   120  		if err != nil {
   121  			return
   122  		}
   123  	}
   124  }
   125  func SingleTlsConnectHost(host string, timeout int, caCertBytes []byte) (conn tls.Conn, err error) {
   126  	h := strings.Split(host, ":")
   127  	port, _ := strconv.Atoi(h[1])
   128  	return SingleTlsConnect(h[0], port, timeout, caCertBytes)
   129  }
   130  func SingleTlsConnect(host string, port, timeout int, caCertBytes []byte) (conn tls.Conn, err error) {
   131  	conf, err := getRequestSingleTlsConfig(caCertBytes)
   132  	if err != nil {
   133  		return
   134  	}
   135  	_conn, err := net.DialTimeout("tcp", fmt.Sprintf("%s:%d", host, port), time.Duration(timeout)*time.Millisecond)
   136  	if err != nil {
   137  		return
   138  	}
   139  	return *tls.Client(_conn, conf), err
   140  }
   141  func SingleTlsConfig(caCertBytes []byte) (conf *tls.Config, err error) {
   142  	return getRequestSingleTlsConfig(caCertBytes)
   143  }
   144  func getRequestSingleTlsConfig(caCertBytes []byte) (conf *tls.Config, err error) {
   145  	conf = &tls.Config{InsecureSkipVerify: true}
   146  	serverCertPool := x509.NewCertPool()
   147  	if caCertBytes != nil {
   148  		ok := serverCertPool.AppendCertsFromPEM(caCertBytes)
   149  		if !ok {
   150  			err = errors.New("failed to parse root certificate")
   151  		}
   152  		conf.RootCAs = serverCertPool
   153  		conf.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
   154  			opts := x509.VerifyOptions{
   155  				Roots: serverCertPool,
   156  			}
   157  			for _, rawCert := range rawCerts {
   158  				cert, _ := x509.ParseCertificate(rawCert)
   159  				_, err := cert.Verify(opts)
   160  				if err != nil {
   161  					return err
   162  				}
   163  			}
   164  			return nil
   165  		}
   166  	}
   167  	return
   168  }
   169  func TlsConnectHost(host string, timeout int, certBytes, keyBytes, caCertBytes []byte) (conn tls.Conn, err error) {
   170  	h := strings.Split(host, ":")
   171  	port, _ := strconv.Atoi(h[1])
   172  	return TlsConnect(h[0], port, timeout, certBytes, keyBytes, caCertBytes)
   173  }
   174  func TlsConnect(host string, port, timeout int, certBytes, keyBytes, caCertBytes []byte) (conn tls.Conn, err error) {
   175  	conf, err := getRequestTlsConfig(certBytes, keyBytes, caCertBytes)
   176  	if err != nil {
   177  		return
   178  	}
   179  	_conn, err := net.DialTimeout("tcp", fmt.Sprintf("%s:%d", host, port), time.Duration(timeout)*time.Millisecond)
   180  	if err != nil {
   181  		return
   182  	}
   183  	return *tls.Client(_conn, conf), err
   184  }
   185  func TlsConfig(certBytes, keyBytes, caCertBytes []byte) (conf *tls.Config, err error) {
   186  	return getRequestTlsConfig(certBytes, keyBytes, caCertBytes)
   187  }
   188  func getRequestTlsConfig(certBytes, keyBytes, caCertBytes []byte) (conf *tls.Config, err error) {
   189  
   190  	var cert tls.Certificate
   191  	cert, err = tls.X509KeyPair(certBytes, keyBytes)
   192  	if err != nil {
   193  		return
   194  	}
   195  	serverCertPool := x509.NewCertPool()
   196  	caBytes := certBytes
   197  	if caCertBytes != nil {
   198  		caBytes = caCertBytes
   199  
   200  	}
   201  	ok := serverCertPool.AppendCertsFromPEM(caBytes)
   202  	if !ok {
   203  		err = errors.New("failed to parse root certificate")
   204  	}
   205  	block, _ := pem.Decode(caBytes)
   206  	if block == nil {
   207  		panic("failed to parse certificate PEM")
   208  	}
   209  	x509Cert, _ := x509.ParseCertificate(block.Bytes)
   210  	if x509Cert == nil {
   211  		panic("failed to parse block")
   212  	}
   213  	conf = &tls.Config{
   214  		RootCAs:            serverCertPool,
   215  		Certificates:       []tls.Certificate{cert},
   216  		InsecureSkipVerify: true,
   217  		ServerName:         x509Cert.Subject.CommonName,
   218  		VerifyPeerCertificate: func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
   219  			opts := x509.VerifyOptions{
   220  				Roots: serverCertPool,
   221  			}
   222  			for _, rawCert := range rawCerts {
   223  				cert, _ := x509.ParseCertificate(rawCert)
   224  				_, err := cert.Verify(opts)
   225  				if err != nil {
   226  					return err
   227  				}
   228  			}
   229  			return nil
   230  		},
   231  	}
   232  	return
   233  }
   234  
   235  func ConnectHost(hostAndPort string, timeout int) (conn net.Conn, err error) {
   236  	conn, err = net.DialTimeout("tcp", hostAndPort, time.Duration(timeout)*time.Millisecond)
   237  	return
   238  }
   239  func ConnectKCPHost(hostAndPort string, config kcpcfg.KCPConfigArgs) (conn net.Conn, err error) {
   240  	kcpconn, err := kcp.DialWithOptions(hostAndPort, config.Block, *config.DataShard, *config.ParityShard)
   241  	if err != nil {
   242  		return
   243  	}
   244  	kcpconn.SetStreamMode(true)
   245  	kcpconn.SetWriteDelay(true)
   246  	kcpconn.SetNoDelay(*config.NoDelay, *config.Interval, *config.Resend, *config.NoCongestion)
   247  	kcpconn.SetMtu(*config.MTU)
   248  	kcpconn.SetWindowSize(*config.SndWnd, *config.RcvWnd)
   249  	kcpconn.SetACKNoDelay(*config.AckNodelay)
   250  	if *config.NoComp {
   251  		return kcpconn, err
   252  	}
   253  	return NewCompStream(kcpconn), err
   254  }
   255  
   256  func PathExists(_path string) bool {
   257  	_, err := os.Stat(_path)
   258  	if err != nil && os.IsNotExist(err) {
   259  		return false
   260  	}
   261  	return true
   262  }
   263  func HTTPGet(URL string, timeout int) (err error) {
   264  	tr := &http.Transport{}
   265  	var resp *http.Response
   266  	var client *http.Client
   267  	defer func() {
   268  		if resp != nil && resp.Body != nil {
   269  			resp.Body.Close()
   270  		}
   271  		tr.CloseIdleConnections()
   272  	}()
   273  	client = &http.Client{Timeout: time.Millisecond * time.Duration(timeout), Transport: tr}
   274  	resp, err = client.Get(URL)
   275  	if err != nil {
   276  		return
   277  	}
   278  	return
   279  }
   280  
   281  func CloseConn(conn *net.Conn) {
   282  	defer func() {
   283  		_ = recover()
   284  	}()
   285  	if conn != nil && *conn != nil {
   286  		(*conn).SetDeadline(time.Now().Add(time.Millisecond))
   287  		(*conn).Close()
   288  	}
   289  }
   290  
   291  var allInterfaceAddrCache []net.IP
   292  
   293  func GetAllInterfaceAddr() ([]net.IP, error) {
   294  	if allInterfaceAddrCache != nil {
   295  		return allInterfaceAddrCache, nil
   296  	}
   297  	ifaces, err := net.Interfaces()
   298  	if err != nil {
   299  		return nil, err
   300  	}
   301  	addresses := []net.IP{}
   302  	for _, iface := range ifaces {
   303  
   304  		if iface.Flags&net.FlagUp == 0 {
   305  			continue // interface down
   306  		}
   307  		// if iface.Flags&net.FlagLoopback != 0 {
   308  		// 	continue // loopback interface
   309  		// }
   310  		addrs, err := iface.Addrs()
   311  		if err != nil {
   312  			continue
   313  		}
   314  
   315  		for _, addr := range addrs {
   316  			var ip net.IP
   317  			switch v := addr.(type) {
   318  			case *net.IPNet:
   319  				ip = v.IP
   320  			case *net.IPAddr:
   321  				ip = v.IP
   322  			}
   323  			// if ip == nil || ip.IsLoopback() {
   324  			// 	continue
   325  			// }
   326  			ip = ip.To4()
   327  			if ip == nil {
   328  				continue // not an ipv4 address
   329  			}
   330  			addresses = append(addresses, ip)
   331  		}
   332  	}
   333  	if len(addresses) == 0 {
   334  		return nil, fmt.Errorf("no address Found, net.InterfaceAddrs: %v", addresses)
   335  	}
   336  	//only need first
   337  	allInterfaceAddrCache = addresses
   338  	return addresses, nil
   339  }
   340  func UDPPacket(srcAddr string, packet []byte) []byte {
   341  	addrBytes := []byte(srcAddr)
   342  	addrLength := uint16(len(addrBytes))
   343  	bodyLength := uint16(len(packet))
   344  	//log.Printf("build packet : addr len %d, body len %d", addrLength, bodyLength)
   345  	pkg := new(bytes.Buffer)
   346  	binary.Write(pkg, binary.LittleEndian, addrLength)
   347  	binary.Write(pkg, binary.LittleEndian, addrBytes)
   348  	binary.Write(pkg, binary.LittleEndian, bodyLength)
   349  	binary.Write(pkg, binary.LittleEndian, packet)
   350  	return pkg.Bytes()
   351  }
   352  func ReadUDPPacket(_reader io.Reader) (srcAddr string, packet []byte, err error) {
   353  	reader := bufio.NewReader(_reader)
   354  	var addrLength uint16
   355  	var bodyLength uint16
   356  	err = binary.Read(reader, binary.LittleEndian, &addrLength)
   357  	if err != nil {
   358  		return
   359  	}
   360  	_srcAddr := make([]byte, addrLength)
   361  	n, err := reader.Read(_srcAddr)
   362  	if err != nil {
   363  		return
   364  	}
   365  	if n != int(addrLength) {
   366  		err = fmt.Errorf("n != int(addrLength), %d,%d", n, addrLength)
   367  		return
   368  	}
   369  	srcAddr = string(_srcAddr)
   370  
   371  	err = binary.Read(reader, binary.LittleEndian, &bodyLength)
   372  	if err != nil {
   373  
   374  		return
   375  	}
   376  	packet = make([]byte, bodyLength)
   377  	n, err = reader.Read(packet)
   378  	if err != nil {
   379  		return
   380  	}
   381  	if n != int(bodyLength) {
   382  		err = fmt.Errorf("n != int(bodyLength), %d,%d", n, bodyLength)
   383  		return
   384  	}
   385  	return
   386  }
   387  func Uniqueid() string {
   388  	str := fmt.Sprintf("%d%s", time.Now().UnixNano(), xid.New().String())
   389  	hash := sha1.New()
   390  	hash.Write([]byte(str))
   391  	return hex.EncodeToString(hash.Sum(nil))
   392  }
   393  func RandString(strlen int) string {
   394  	codes := "QWERTYUIOPLKJHGFDSAZXCVBNMabcdefghijklmnopqrstuvwxyz0123456789"
   395  	codeLen := len(codes)
   396  	data := make([]byte, strlen)
   397  	rand.Seed(time.Now().UnixNano() + rand.Int63() + rand.Int63() + rand.Int63() + rand.Int63())
   398  	for i := 0; i < strlen; i++ {
   399  		idx := rand.Intn(codeLen)
   400  		data[i] = byte(codes[idx])
   401  	}
   402  	return string(data)
   403  }
   404  func RandInt(strLen int) int64 {
   405  	codes := "123456789"
   406  	codeLen := len(codes)
   407  	data := make([]byte, strLen)
   408  	rand.Seed(time.Now().UnixNano() + rand.Int63() + rand.Int63() + rand.Int63() + rand.Int63())
   409  	for i := 0; i < strLen; i++ {
   410  		idx := rand.Intn(codeLen)
   411  		data[i] = byte(codes[idx])
   412  	}
   413  	i, _ := strconv.ParseInt(string(data), 10, 64)
   414  	return i
   415  }
   416  func ReadBytes(r io.Reader) (data []byte, err error) {
   417  	defer func() {
   418  		if e := recover(); e != nil {
   419  			err = fmt.Errorf("read bytes fail ,err : %s", e)
   420  		}
   421  	}()
   422  	var len uint64
   423  	err = binary.Read(r, binary.LittleEndian, &len)
   424  	if err != nil {
   425  		return
   426  	}
   427  	if len == 0 || len > ^uint64(0) {
   428  		err = fmt.Errorf("data len out of range, %d", len)
   429  		return
   430  	}
   431  	var n int
   432  	data = make([]byte, len)
   433  	n, err = r.Read(data)
   434  	if err != nil {
   435  		return
   436  	}
   437  	if n != int(len) {
   438  		err = fmt.Errorf("error data len")
   439  		return
   440  	}
   441  	return
   442  }
   443  func ReadData(r io.Reader) (data string, err error) {
   444  	_data, err := ReadBytes(r)
   445  	if err != nil {
   446  		return
   447  	}
   448  	data = string(_data)
   449  	return
   450  }
   451  
   452  //non typed packet with Bytes
   453  func ReadPacketBytes(r io.Reader, data ...*[]byte) (err error) {
   454  	for _, d := range data {
   455  		*d, err = ReadBytes(r)
   456  		if err != nil {
   457  			return
   458  		}
   459  	}
   460  	return
   461  }
   462  func BuildPacketBytes(data ...[]byte) []byte {
   463  	pkg := new(bytes.Buffer)
   464  	for _, d := range data {
   465  		binary.Write(pkg, binary.LittleEndian, uint64(len(d)))
   466  		binary.Write(pkg, binary.LittleEndian, d)
   467  	}
   468  	return pkg.Bytes()
   469  }
   470  
   471  //non typed packet with string
   472  func ReadPacketData(r io.Reader, data ...*string) (err error) {
   473  	for _, d := range data {
   474  		*d, err = ReadData(r)
   475  		if err != nil {
   476  			return
   477  		}
   478  	}
   479  	return
   480  }
   481  func BuildPacketData(data ...string) []byte {
   482  	pkg := new(bytes.Buffer)
   483  	for _, d := range data {
   484  		bytes := []byte(d)
   485  		binary.Write(pkg, binary.LittleEndian, uint64(len(bytes)))
   486  		binary.Write(pkg, binary.LittleEndian, bytes)
   487  	}
   488  	return pkg.Bytes()
   489  }
   490  
   491  //typed packet with bytes
   492  func ReadBytesPacket(r io.Reader, packetType *uint8, data ...*[]byte) (err error) {
   493  	var connType uint8
   494  	err = binary.Read(r, binary.LittleEndian, &connType)
   495  	if err != nil {
   496  		return
   497  	}
   498  	*packetType = connType
   499  	for _, d := range data {
   500  		*d, err = ReadBytes(r)
   501  		if err != nil {
   502  			return
   503  		}
   504  	}
   505  	return
   506  }
   507  func BuildBytesPacket(packetType uint8, data ...[]byte) []byte {
   508  	pkg := new(bytes.Buffer)
   509  	binary.Write(pkg, binary.LittleEndian, packetType)
   510  	for _, d := range data {
   511  		binary.Write(pkg, binary.LittleEndian, uint64(len(d)))
   512  		binary.Write(pkg, binary.LittleEndian, d)
   513  	}
   514  	return pkg.Bytes()
   515  }
   516  
   517  //typed packet with string
   518  func ReadPacket(r io.Reader, packetType *uint8, data ...*string) (err error) {
   519  	var connType uint8
   520  	err = binary.Read(r, binary.LittleEndian, &connType)
   521  	if err != nil {
   522  		return
   523  	}
   524  	*packetType = connType
   525  	for _, d := range data {
   526  		*d, err = ReadData(r)
   527  		if err != nil {
   528  			return
   529  		}
   530  	}
   531  	return
   532  }
   533  
   534  func BuildPacket(packetType uint8, data ...string) []byte {
   535  	pkg := new(bytes.Buffer)
   536  	binary.Write(pkg, binary.LittleEndian, packetType)
   537  	for _, d := range data {
   538  		bytes := []byte(d)
   539  		binary.Write(pkg, binary.LittleEndian, uint64(len(bytes)))
   540  		binary.Write(pkg, binary.LittleEndian, bytes)
   541  	}
   542  	return pkg.Bytes()
   543  }
   544  
   545  func SubStr(str string, start, end int) string {
   546  	if len(str) == 0 {
   547  		return ""
   548  	}
   549  	if end >= len(str) {
   550  		end = len(str) - 1
   551  	}
   552  	return str[start:end]
   553  }
   554  func SubBytes(bytes []byte, start, end int) []byte {
   555  	if len(bytes) == 0 {
   556  		return []byte{}
   557  	}
   558  	if end >= len(bytes) {
   559  		end = len(bytes) - 1
   560  	}
   561  	return bytes[start:end]
   562  }
   563  func TlsBytes(cert, key string) (certBytes, keyBytes []byte, err error) {
   564  	base64Prefix := "base64://"
   565  	if strings.HasPrefix(cert, base64Prefix) {
   566  		certBytes, err = base64.StdEncoding.DecodeString(cert[len(base64Prefix):])
   567  	} else {
   568  		certBytes, err = ioutil.ReadFile(cert)
   569  	}
   570  	if err != nil {
   571  		err = fmt.Errorf("err : %s", err)
   572  		return
   573  	}
   574  	if strings.HasPrefix(key, base64Prefix) {
   575  		keyBytes, err = base64.StdEncoding.DecodeString(key[len(base64Prefix):])
   576  	} else {
   577  		keyBytes, err = ioutil.ReadFile(key)
   578  	}
   579  	if err != nil {
   580  		err = fmt.Errorf("err : %s", err)
   581  		return
   582  	}
   583  	return
   584  }
   585  func GetKCPBlock(method, key string) (block kcp.BlockCrypt) {
   586  	pass := pbkdf2.Key([]byte(key), []byte(key), 4096, 32, sha1.New)
   587  	switch method {
   588  	case "sm4":
   589  		block, _ = kcp.NewSM4BlockCrypt(pass[:16])
   590  	case "tea":
   591  		block, _ = kcp.NewTEABlockCrypt(pass[:16])
   592  	case "xor":
   593  		block, _ = kcp.NewSimpleXORBlockCrypt(pass)
   594  	case "none":
   595  		block, _ = kcp.NewNoneBlockCrypt(pass)
   596  	case "aes-128":
   597  		block, _ = kcp.NewAESBlockCrypt(pass[:16])
   598  	case "aes-192":
   599  		block, _ = kcp.NewAESBlockCrypt(pass[:24])
   600  	case "blowfish":
   601  		block, _ = kcp.NewBlowfishBlockCrypt(pass)
   602  	case "twofish":
   603  		block, _ = kcp.NewTwofishBlockCrypt(pass)
   604  	case "cast5":
   605  		block, _ = kcp.NewCast5BlockCrypt(pass[:16])
   606  	case "3des":
   607  		block, _ = kcp.NewTripleDESBlockCrypt(pass[:24])
   608  	case "xtea":
   609  		block, _ = kcp.NewXTEABlockCrypt(pass[:16])
   610  	case "salsa20":
   611  		block, _ = kcp.NewSalsa20BlockCrypt(pass)
   612  	default:
   613  		block, _ = kcp.NewAESBlockCrypt(pass)
   614  	}
   615  	return
   616  }
   617  func HttpGet(URL string, timeout int, host ...string) (body []byte, code int, err error) {
   618  	var tr *http.Transport
   619  	var client *http.Client
   620  	conf := &tls.Config{
   621  		InsecureSkipVerify: true,
   622  	}
   623  	if strings.Contains(URL, "https://") {
   624  		tr = &http.Transport{TLSClientConfig: conf}
   625  		client = &http.Client{Timeout: time.Millisecond * time.Duration(timeout), Transport: tr}
   626  	} else {
   627  		tr = &http.Transport{}
   628  		client = &http.Client{Timeout: time.Millisecond * time.Duration(timeout), Transport: tr}
   629  	}
   630  	defer tr.CloseIdleConnections()
   631  
   632  	//resp, err := client.Get(URL)
   633  	req, err := http.NewRequest("GET", URL, nil)
   634  	if err != nil {
   635  		return
   636  	}
   637  	if len(host) == 1 && host[0] != "" {
   638  		req.Host = host[0]
   639  	}
   640  	resp, err := client.Do(req)
   641  	if err != nil {
   642  		return
   643  	}
   644  	defer resp.Body.Close()
   645  	code = resp.StatusCode
   646  	body, err = ioutil.ReadAll(resp.Body)
   647  	return
   648  }
   649  func IsInternalIP(domainOrIP string, always bool) bool {
   650  	var outIPs []net.IP
   651  	var err error
   652  	var isDomain bool
   653  	if net.ParseIP(domainOrIP) == nil {
   654  		isDomain = true
   655  	}
   656  	if always && isDomain {
   657  		return false
   658  	}
   659  
   660  	if isDomain {
   661  		outIPs, err = LookupIP(domainOrIP)
   662  	} else {
   663  		outIPs = []net.IP{net.ParseIP(domainOrIP)}
   664  	}
   665  
   666  	if err != nil {
   667  		return false
   668  	}
   669  
   670  	for _, ip := range outIPs {
   671  		if ip.IsLoopback() {
   672  			return true
   673  		}
   674  		if ip.To4().Mask(net.IPv4Mask(255, 0, 0, 0)).String() == "10.0.0.0" {
   675  			return true
   676  		}
   677  		if ip.To4().Mask(net.IPv4Mask(255, 255, 0, 0)).String() == "192.168.0.0" {
   678  			return true
   679  		}
   680  		if ip.To4().Mask(net.IPv4Mask(255, 0, 0, 0)).String() == "172.0.0.0" {
   681  			i, _ := strconv.Atoi(strings.Split(ip.To4().String(), ".")[1])
   682  			return i >= 16 && i <= 31
   683  		}
   684  	}
   685  	return false
   686  }
   687  func IsHTTP(head []byte) bool {
   688  	keys := []string{"GET", "HEAD", "POST", "PUT", "DELETE", "CONNECT", "OPTIONS", "TRACE", "PATCH"}
   689  	for _, key := range keys {
   690  		if bytes.HasPrefix(head, []byte(key)) || bytes.HasPrefix(head, []byte(strings.ToLower(key))) {
   691  			return true
   692  		}
   693  	}
   694  	return false
   695  }
   696  func IsSocks5(head []byte) bool {
   697  	if len(head) < 3 {
   698  		return false
   699  	}
   700  	if head[0] == uint8(0x05) && 0 < int(head[1]) && int(head[1]) < 255 {
   701  		if len(head) == 2+int(head[1]) {
   702  			return true
   703  		}
   704  	}
   705  	return false
   706  }
   707  func RemoveProxyHeaders(head []byte) []byte {
   708  	newLines := [][]byte{}
   709  	var keys = map[string]bool{}
   710  	lines := bytes.Split(head, []byte("\r\n"))
   711  	IsBody := false
   712  	i := -1
   713  	for _, line := range lines {
   714  		i++
   715  		if len(line) == 0 || IsBody {
   716  			newLines = append(newLines, line)
   717  			IsBody = true
   718  		} else {
   719  			hline := bytes.SplitN(line, []byte(":"), 2)
   720  			if i == 0 && IsHTTP(head) {
   721  				newLines = append(newLines, line)
   722  				continue
   723  			}
   724  			if len(hline) != 2 {
   725  				continue
   726  			}
   727  			k := strings.ToUpper(string(hline[0]))
   728  			if _, ok := keys[k]; ok || strings.HasPrefix(k, "PROXY-") {
   729  				continue
   730  			}
   731  			keys[k] = true
   732  			newLines = append(newLines, line)
   733  		}
   734  	}
   735  	return bytes.Join(newLines, []byte("\r\n"))
   736  }
   737  func InsertProxyHeaders(head []byte, headers string) []byte {
   738  	return bytes.Replace(head, []byte("\r\n"), []byte("\r\n"+headers), 1)
   739  }
   740  func LBMethod(key string) int {
   741  	typs := map[string]int{"weight": lb.SELECT_WEITHT, "leasttime": lb.SELECT_LEASTTIME, "leastconn": lb.SELECT_LEASTCONN, "hash": lb.SELECT_HASH, "roundrobin": lb.SELECT_ROUNDROBIN}
   742  	return typs[key]
   743  }
   744  func UDPCopy(dst, src *net.UDPConn, dstAddr net.Addr, readTimeout time.Duration, beforeWriteFn func(data []byte) []byte, deferFn func(e interface{})) {
   745  	go func() {
   746  		defer func() {
   747  			deferFn(recover())
   748  		}()
   749  		buf := LeakyBuffer.Get()
   750  		defer LeakyBuffer.Put(buf)
   751  		for {
   752  			if readTimeout > 0 {
   753  				src.SetReadDeadline(time.Now().Add(readTimeout))
   754  			}
   755  			n, err := src.Read(buf)
   756  			if readTimeout > 0 {
   757  				src.SetReadDeadline(time.Time{})
   758  			}
   759  			if err != nil {
   760  				if IsNetClosedErr(err) || IsNetTimeoutErr(err) || IsNetRefusedErr(err) {
   761  					return
   762  				}
   763  				continue
   764  			}
   765  			_, err = dst.WriteTo(beforeWriteFn(buf[:n]), dstAddr)
   766  			if err != nil {
   767  				if IsNetClosedErr(err) {
   768  					return
   769  				}
   770  				continue
   771  			}
   772  		}
   773  	}()
   774  }
   775  func IsNetClosedErr(err error) bool {
   776  	return err != nil && strings.Contains(err.Error(), "use of closed network connection")
   777  }
   778  func IsNetTimeoutErr(err error) bool {
   779  	if err == nil {
   780  		return false
   781  	}
   782  	e, ok := err.(net.Error)
   783  	return ok && e.Timeout()
   784  }
   785  func IsNetRefusedErr(err error) bool {
   786  	return err != nil && strings.Contains(err.Error(), "connection refused")
   787  }
   788  func IsNetDeadlineErr(err error) bool {
   789  	return err != nil && strings.Contains(err.Error(), "i/o deadline reached")
   790  }
   791  func IsNetSocketNotConnectedErr(err error) bool {
   792  	return err != nil && strings.Contains(err.Error(), "socket is not connected")
   793  }
   794  func NewDefaultLogger() *logger.Logger {
   795  	return logger.New(os.Stderr, "", logger.LstdFlags)
   796  }
   797  
   798  // type sockaddr struct {
   799  // 	family uint16
   800  // 	data   [14]byte
   801  // }
   802  
   803  // const SO_ORIGINAL_DST = 80
   804  
   805  // realServerAddress returns an intercepted connection's original destination.
   806  // func realServerAddress(conn *net.Conn) (string, error) {
   807  // 	tcpConn, ok := (*conn).(*net.TCPConn)
   808  // 	if !ok {
   809  // 		return "", errors.New("not a TCPConn")
   810  // 	}
   811  
   812  // 	file, err := tcpConn.File()
   813  // 	if err != nil {
   814  // 		return "", err
   815  // 	}
   816  
   817  // 	// To avoid potential problems from making the socket non-blocking.
   818  // 	tcpConn.Close()
   819  // 	*conn, err = net.FileConn(file)
   820  // 	if err != nil {
   821  // 		return "", err
   822  // 	}
   823  
   824  // 	defer file.Close()
   825  // 	fd := file.Fd()
   826  
   827  // 	var addr sockaddr
   828  // 	size := uint32(unsafe.Sizeof(addr))
   829  // 	err = getsockopt(int(fd), syscall.SOL_IP, SO_ORIGINAL_DST, uintptr(unsafe.Pointer(&addr)), &size)
   830  // 	if err != nil {
   831  // 		return "", err
   832  // 	}
   833  
   834  // 	var ip net.IP
   835  // 	switch addr.family {
   836  // 	case syscall.AF_INET:
   837  // 		ip = addr.data[2:6]
   838  // 	default:
   839  // 		return "", errors.New("unrecognized address family")
   840  // 	}
   841  
   842  // 	port := int(addr.data[0])<<8 + int(addr.data[1])
   843  
   844  // 	return net.JoinHostPort(ip.String(), strconv.Itoa(port)), nil
   845  // }
   846  
   847  // func getsockopt(s int, level int, name int, val uintptr, vallen *uint32) (err error) {
   848  // 	_, _, e1 := syscall.Syscall6(syscall.SYS_GETSOCKOPT, uintptr(s), uintptr(level), uintptr(name), uintptr(val), uintptr(unsafe.Pointer(vallen)), 0)
   849  // 	if e1 != 0 {
   850  // 		err = e1
   851  // 	}
   852  // 	return
   853  // }
   854  
   855  /*
   856  net.LookupIP may cause  deadlock in windows
   857  https://github.com/golang/go/issues/24178
   858  */
   859  
   860  func LookupIP(host string) ([]net.IP, error) {
   861  
   862  	ctx, cancel := context.WithTimeout(context.Background(), time.Second*time.Duration(3))
   863  	defer func() {
   864  		cancel()
   865  		//ctx.Done()
   866  	}()
   867  	addrs, err := net.DefaultResolver.LookupIPAddr(ctx, host)
   868  	if err != nil {
   869  		return nil, err
   870  	}
   871  	ips := make([]net.IP, len(addrs))
   872  	for i, ia := range addrs {
   873  		ips[i] = ia.IP
   874  	}
   875  	return ips, nil
   876  }