github.com/AntonOrnatskyi/goproxy@v0.0.0-20190205095733-4526a9fa18b4/utils/structs.go (about)

     1  package utils
     2  
     3  import (
     4  	"bufio"
     5  	"bytes"
     6  	"encoding/base64"
     7  	"errors"
     8  	"fmt"
     9  	"io"
    10  	"io/ioutil"
    11  	logger "log"
    12  	"net"
    13  	"net/url"
    14  	"runtime/debug"
    15  	"strings"
    16  	"sync"
    17  	"time"
    18  
    19  	"github.com/AntonOrnatskyi/goproxy/utils/dnsx"
    20  	"github.com/AntonOrnatskyi/goproxy/utils/mapx"
    21  	"github.com/AntonOrnatskyi/goproxy/utils/sni"
    22  
    23  	"github.com/golang/snappy"
    24  )
    25  
    26  type Checker struct {
    27  	data        mapx.ConcurrentMap
    28  	blockedMap  mapx.ConcurrentMap
    29  	directMap   mapx.ConcurrentMap
    30  	interval    int64
    31  	timeout     int
    32  	isStop      bool
    33  	intelligent string
    34  	log         *logger.Logger
    35  }
    36  type CheckerItem struct {
    37  	Domain       string
    38  	Address      string
    39  	SuccessCount uint
    40  	FailCount    uint
    41  	Lasttime     int64
    42  }
    43  
    44  //NewChecker args:
    45  //timeout : tcp timeout milliseconds ,connect to host
    46  //interval: recheck domain interval seconds
    47  func NewChecker(timeout int, interval int64, blockedFile, directFile string, log *logger.Logger, intelligent string) Checker {
    48  	ch := Checker{
    49  		data:        mapx.NewConcurrentMap(),
    50  		interval:    interval,
    51  		timeout:     timeout,
    52  		isStop:      false,
    53  		intelligent: intelligent,
    54  		log:         log,
    55  	}
    56  	ch.blockedMap = ch.loadMap(blockedFile)
    57  	ch.directMap = ch.loadMap(directFile)
    58  	if !ch.blockedMap.IsEmpty() {
    59  		log.Printf("blocked file loaded , domains : %d", ch.blockedMap.Count())
    60  	}
    61  	if !ch.directMap.IsEmpty() {
    62  		log.Printf("direct file loaded , domains : %d", ch.directMap.Count())
    63  	}
    64  	if interval > 0 {
    65  		ch.start()
    66  	}
    67  
    68  	return ch
    69  }
    70  
    71  func (c *Checker) loadMap(f string) (dataMap mapx.ConcurrentMap) {
    72  	dataMap = mapx.NewConcurrentMap()
    73  	if PathExists(f) {
    74  		_contents, err := ioutil.ReadFile(f)
    75  		if err != nil {
    76  			c.log.Printf("load file err:%s", err)
    77  			return
    78  		}
    79  		for _, line := range strings.Split(string(_contents), "\n") {
    80  			line = strings.Trim(line, "\r \t")
    81  			if line != "" {
    82  				dataMap.Set(line, true)
    83  			}
    84  		}
    85  	}
    86  	return
    87  }
    88  func (c *Checker) Stop() {
    89  	c.isStop = true
    90  }
    91  func (c *Checker) start() {
    92  	go func() {
    93  		defer func() {
    94  			if e := recover(); e != nil {
    95  				fmt.Printf("crashed, err: %s\nstack:%s", e, string(debug.Stack()))
    96  			}
    97  		}()
    98  		//log.Printf("checker started")
    99  		for {
   100  			//log.Printf("checker did")
   101  			for _, v := range c.data.Items() {
   102  				go func(item CheckerItem) {
   103  					defer func() {
   104  						if e := recover(); e != nil {
   105  							fmt.Printf("crashed, err: %s\nstack:%s", e, string(debug.Stack()))
   106  						}
   107  					}()
   108  					if c.isNeedCheck(item) {
   109  						//log.Printf("check %s", item.Host)
   110  						var conn net.Conn
   111  						var err error
   112  						var now = time.Now().Unix()
   113  						conn, err = ConnectHost(item.Address, c.timeout)
   114  						if err == nil {
   115  							conn.SetDeadline(time.Now().Add(time.Millisecond))
   116  							conn.Close()
   117  						}
   118  						if now-item.Lasttime > 1800 {
   119  							item.FailCount = 0
   120  							item.SuccessCount = 0
   121  						}
   122  						if err != nil {
   123  							item.FailCount = item.FailCount + 1
   124  						} else {
   125  							item.SuccessCount = item.SuccessCount + 1
   126  						}
   127  						item.Lasttime = now
   128  						c.data.Set(item.Domain, item)
   129  					}
   130  				}(v.(CheckerItem))
   131  			}
   132  			time.Sleep(time.Second * time.Duration(c.interval))
   133  			if c.isStop {
   134  				return
   135  			}
   136  		}
   137  	}()
   138  }
   139  func (c *Checker) isNeedCheck(item CheckerItem) bool {
   140  	var minCount uint = 5
   141  	var now = time.Now().Unix()
   142  	if (item.SuccessCount >= minCount && item.SuccessCount > item.FailCount && now-item.Lasttime < 1800) ||
   143  		(item.FailCount >= minCount && item.SuccessCount > item.FailCount && now-item.Lasttime < 1800) ||
   144  		c.domainIsInMap(item.Domain, false) ||
   145  		c.domainIsInMap(item.Domain, true) {
   146  		return false
   147  	}
   148  	return true
   149  }
   150  func (c *Checker) IsBlocked(domain string) (blocked, isInMap bool, failN, successN uint) {
   151  	h, _, _ := net.SplitHostPort(domain)
   152  	if h != "" {
   153  		domain = h
   154  	}
   155  	if c.domainIsInMap(domain, true) {
   156  		//log.Printf("%s in blocked ? true", address)
   157  		return true, true, 0, 0
   158  	}
   159  	if c.domainIsInMap(domain, false) {
   160  		//log.Printf("%s in direct ? true", address)
   161  		return false, true, 0, 0
   162  	}
   163  
   164  	_item, ok := c.data.Get(domain)
   165  	if !ok {
   166  		//log.Printf("%s not in map, blocked true", address)
   167  		return true, false, 0, 0
   168  	}
   169  	switch c.intelligent {
   170  	case "direct":
   171  		return false, true, 0, 0
   172  	case "parent":
   173  		return true, true, 0, 0
   174  	case "intelligent":
   175  		fallthrough
   176  	default:
   177  		item := _item.(CheckerItem)
   178  		return (item.FailCount >= item.SuccessCount) && (time.Now().Unix()-item.Lasttime < 1800), true, item.FailCount, item.SuccessCount
   179  	}
   180  }
   181  
   182  func (c *Checker) domainIsInMap(address string, blockedMap bool) bool {
   183  	u, err := url.Parse("http://" + address)
   184  	if err != nil {
   185  		c.log.Printf("blocked check , url parse err:%s", err)
   186  		return true
   187  	}
   188  	domainSlice := strings.Split(u.Hostname(), ".")
   189  	if len(domainSlice) > 1 {
   190  		checkDomain := ""
   191  		for i := len(domainSlice) - 1; i >= 0; i-- {
   192  			checkDomain = strings.Join(domainSlice[i:], ".")
   193  			if !blockedMap && c.directMap.Has(checkDomain) {
   194  				return true
   195  			}
   196  			if blockedMap && c.blockedMap.Has(checkDomain) {
   197  				return true
   198  			}
   199  		}
   200  	}
   201  	return false
   202  }
   203  func (c *Checker) Add(domain, address string) {
   204  	h, _, _ := net.SplitHostPort(domain)
   205  	if h != "" {
   206  		domain = h
   207  	}
   208  	if c.domainIsInMap(domain, false) || c.domainIsInMap(domain, true) {
   209  		return
   210  	}
   211  	var item CheckerItem
   212  	item = CheckerItem{
   213  		Domain:  domain,
   214  		Address: address,
   215  	}
   216  	c.data.SetIfAbsent(item.Domain, item)
   217  }
   218  
   219  type BasicAuth struct {
   220  	data        mapx.ConcurrentMap
   221  	authURL     string
   222  	authOkCode  int
   223  	authTimeout int
   224  	authRetry   int
   225  	dns         *dnsx.DomainResolver
   226  	log         *logger.Logger
   227  }
   228  
   229  func NewBasicAuth(dns *dnsx.DomainResolver, log *logger.Logger) BasicAuth {
   230  	return BasicAuth{
   231  		data: mapx.NewConcurrentMap(),
   232  		dns:  dns,
   233  		log:  log,
   234  	}
   235  }
   236  func (ba *BasicAuth) SetAuthURL(URL string, code, timeout, retry int) {
   237  	ba.authURL = URL
   238  	ba.authOkCode = code
   239  	ba.authTimeout = timeout
   240  	ba.authRetry = retry
   241  }
   242  func (ba *BasicAuth) AddFromFile(file string) (n int, err error) {
   243  	_content, err := ioutil.ReadFile(file)
   244  	if err != nil {
   245  		return
   246  	}
   247  	userpassArr := strings.Split(strings.Replace(string(_content), "\r", "", -1), "\n")
   248  	for _, userpass := range userpassArr {
   249  		if strings.HasPrefix(userpass, "#") {
   250  			continue
   251  		}
   252  		u := strings.Split(strings.Trim(userpass, " "), ":")
   253  		if len(u) == 2 {
   254  			ba.data.Set(u[0], u[1])
   255  			n++
   256  		}
   257  	}
   258  	return
   259  }
   260  
   261  func (ba *BasicAuth) Add(userpassArr []string) (n int) {
   262  	for _, userpass := range userpassArr {
   263  		u := strings.Split(userpass, ":")
   264  		if len(u) == 2 {
   265  			ba.data.Set(u[0], u[1])
   266  			n++
   267  		}
   268  	}
   269  	return
   270  }
   271  func (ba *BasicAuth) Delete(userArr []string) {
   272  	for _, u := range userArr {
   273  		ba.data.Remove(u)
   274  	}
   275  }
   276  func (ba *BasicAuth) CheckUserPass(user, pass, userIP, localIP, target string) (ok bool) {
   277  
   278  	return ba.Check(user+":"+pass, userIP, localIP, target)
   279  }
   280  func (ba *BasicAuth) Check(userpass string, userIP, localIP, target string) (ok bool) {
   281  	u := strings.Split(strings.Trim(userpass, " "), ":")
   282  	if len(u) == 2 {
   283  		if p, _ok := ba.data.Get(u[0]); _ok {
   284  			return p.(string) == u[1]
   285  		}
   286  		if ba.authURL != "" {
   287  			err := ba.checkFromURL(userpass, userIP, localIP, target)
   288  			if err == nil {
   289  				return true
   290  			}
   291  			ba.log.Printf("%s", err)
   292  		}
   293  		return false
   294  	}
   295  	return
   296  }
   297  func (ba *BasicAuth) checkFromURL(userpass, userIP, localIP, target string) (err error) {
   298  	u := strings.Split(strings.Trim(userpass, " "), ":")
   299  	if len(u) != 2 {
   300  		return
   301  	}
   302  
   303  	URL := ba.authURL
   304  	if strings.Contains(URL, "?") {
   305  		URL += "&"
   306  	} else {
   307  		URL += "?"
   308  	}
   309  	URL += fmt.Sprintf("user=%s&pass=%s&ip=%s&local_ip=%s&target=%s", u[0], u[1], userIP, localIP, url.QueryEscape(target))
   310  	getURL := URL
   311  	var domain string
   312  	if ba.dns != nil {
   313  		_url, _ := url.Parse(ba.authURL)
   314  		domain = _url.Host
   315  		domainIP := ba.dns.MustResolve(domain)
   316  		getURL = strings.Replace(URL, domain, domainIP, 1)
   317  	}
   318  	var code int
   319  	var tryCount = 0
   320  	var body []byte
   321  	for tryCount <= ba.authRetry {
   322  		body, code, err = HttpGet(getURL, ba.authTimeout, domain)
   323  		if err == nil && code == ba.authOkCode {
   324  			break
   325  		} else if err != nil {
   326  			err = fmt.Errorf("auth fail from url %s,resonse err:%s , %s -> %s", URL, err, userIP, localIP)
   327  		} else {
   328  			if len(body) > 0 {
   329  				err = fmt.Errorf(string(body[0:100]))
   330  			} else {
   331  				err = fmt.Errorf("token error")
   332  			}
   333  			b := string(body)
   334  			if len(b) > 50 {
   335  				b = b[:50]
   336  			}
   337  			err = fmt.Errorf("auth fail from url %s,resonse code: %d, except: %d , %s -> %s, %s", URL, code, ba.authOkCode, userIP, localIP, b)
   338  		}
   339  		if err != nil && tryCount < ba.authRetry {
   340  			ba.log.Print(err)
   341  			time.Sleep(time.Second * 2)
   342  		}
   343  		tryCount++
   344  	}
   345  	if err != nil {
   346  		return
   347  	}
   348  	//log.Printf("auth success from auth url, %s", ip)
   349  	return
   350  }
   351  
   352  func (ba *BasicAuth) Total() (n int) {
   353  	n = ba.data.Count()
   354  	return
   355  }
   356  
   357  type HTTPRequest struct {
   358  	HeadBuf     []byte
   359  	conn        *net.Conn
   360  	Host        string
   361  	Method      string
   362  	URL         string
   363  	hostOrURL   string
   364  	isBasicAuth bool
   365  	basicAuth   *BasicAuth
   366  	log         *logger.Logger
   367  	IsSNI       bool
   368  }
   369  
   370  func NewHTTPRequest(inConn *net.Conn, bufSize int, isBasicAuth bool, basicAuth *BasicAuth, log *logger.Logger, header ...[]byte) (req HTTPRequest, err error) {
   371  	buf := make([]byte, bufSize)
   372  	n := 0
   373  	req = HTTPRequest{
   374  		conn: inConn,
   375  		log:  log,
   376  	}
   377  	if header != nil && len(header) == 1 && len(header[0]) > 1 {
   378  		buf = header[0]
   379  		n = len(header[0])
   380  	} else {
   381  		n, err = (*inConn).Read(buf[:])
   382  		if err != nil {
   383  			if err != io.EOF {
   384  				err = fmt.Errorf("http decoder read err:%s", err)
   385  			}
   386  			CloseConn(inConn)
   387  			return
   388  		}
   389  	}
   390  
   391  	req.HeadBuf = buf[:n]
   392  	//fmt.Println(string(req.HeadBuf))
   393  	//try sni
   394  	serverName, err0 := sni.ServerNameFromBytes(req.HeadBuf)
   395  	if err0 == nil {
   396  		//sni success
   397  		req.Method = "SNI"
   398  		req.hostOrURL = "https://" + serverName + ":443"
   399  		req.IsSNI = true
   400  	} else {
   401  		//sni fail , try http
   402  		index := bytes.IndexByte(req.HeadBuf, '\n')
   403  		if index == -1 {
   404  			err = fmt.Errorf("http decoder data line err:%s", SubStr(string(req.HeadBuf), 0, 50))
   405  			CloseConn(inConn)
   406  			return
   407  		}
   408  		fmt.Sscanf(string(req.HeadBuf[:index]), "%s%s", &req.Method, &req.hostOrURL)
   409  	}
   410  	if req.Method == "" || req.hostOrURL == "" {
   411  		err = fmt.Errorf("http decoder data err:%s", SubStr(string(req.HeadBuf), 0, 50))
   412  		CloseConn(inConn)
   413  		return
   414  	}
   415  	req.Method = strings.ToUpper(req.Method)
   416  	req.isBasicAuth = isBasicAuth
   417  	req.basicAuth = basicAuth
   418  	log.Printf("%s:%s", req.Method, req.hostOrURL)
   419  
   420  	if req.IsHTTPS() {
   421  		err = req.HTTPS()
   422  	} else {
   423  		err = req.HTTP()
   424  	}
   425  	return
   426  }
   427  func (req *HTTPRequest) HTTP() (err error) {
   428  	if req.isBasicAuth {
   429  		err = req.BasicAuth()
   430  		if err != nil {
   431  			return
   432  		}
   433  	}
   434  	req.URL = req.getHTTPURL()
   435  	var u *url.URL
   436  	u, err = url.Parse(req.URL)
   437  	if err != nil {
   438  		return
   439  	}
   440  	req.Host = u.Host
   441  	req.addPortIfNot()
   442  	return
   443  }
   444  func (req *HTTPRequest) HTTPS() (err error) {
   445  	if req.isBasicAuth {
   446  		err = req.BasicAuth()
   447  		if err != nil {
   448  			return
   449  		}
   450  	}
   451  	req.Host = req.hostOrURL
   452  	req.addPortIfNot()
   453  	return
   454  }
   455  func (req *HTTPRequest) HTTPSReply() (err error) {
   456  	_, err = fmt.Fprint(*req.conn, "HTTP/1.1 200 Connection established\r\n\r\n")
   457  	return
   458  }
   459  func (req *HTTPRequest) IsHTTPS() bool {
   460  	return req.Method == "CONNECT"
   461  }
   462  
   463  func (req *HTTPRequest) GetAuthDataStr() (basicInfo string, err error) {
   464  	// log.Printf("request :%s", string(req.HeadBuf))
   465  	authorization := req.getHeader("Proxy-Authorization")
   466  
   467  	authorization = strings.Trim(authorization, " \r\n\t")
   468  	if authorization == "" {
   469  		fmt.Fprintf((*req.conn), "HTTP/1.1 %s Proxy Authentication Required\r\nProxy-Authenticate: Basic realm=\"\"\r\n\r\nProxy Authentication Required", "407")
   470  		CloseConn(req.conn)
   471  		err = errors.New("require auth header data")
   472  		return
   473  	}
   474  	//log.Printf("Authorization:%authorization = req.getHeader("Authorization")
   475  	basic := strings.Fields(authorization)
   476  	if len(basic) != 2 {
   477  		err = fmt.Errorf("authorization data error,ERR:%s", authorization)
   478  		CloseConn(req.conn)
   479  		return
   480  	}
   481  	user, err := base64.StdEncoding.DecodeString(basic[1])
   482  	if err != nil {
   483  		err = fmt.Errorf("authorization data parse error,ERR:%s", err)
   484  		CloseConn(req.conn)
   485  		return
   486  	}
   487  	basicInfo = string(user)
   488  	return
   489  }
   490  func (req *HTTPRequest) BasicAuth() (err error) {
   491  	userIP := strings.Split((*req.conn).RemoteAddr().String(), ":")
   492  	localIP := strings.Split((*req.conn).LocalAddr().String(), ":")
   493  	URL := ""
   494  	if req.IsHTTPS() {
   495  		URL = "https://" + req.Host
   496  	} else {
   497  		URL = req.getHTTPURL()
   498  	}
   499  	user, err := req.GetAuthDataStr()
   500  	if err != nil {
   501  		return
   502  	}
   503  	authOk := (*req.basicAuth).Check(string(user), userIP[0], localIP[0], URL)
   504  	//log.Printf("auth %s,%v", string(user), authOk)
   505  	if !authOk {
   506  		fmt.Fprintf((*req.conn), "HTTP/1.1 %s Proxy Authentication Required\r\n\r\nProxy Authentication Required", "407")
   507  		CloseConn(req.conn)
   508  		err = fmt.Errorf("basic auth fail")
   509  		return
   510  	}
   511  	return
   512  }
   513  func (req *HTTPRequest) getHTTPURL() (URL string) {
   514  	if !strings.HasPrefix(req.hostOrURL, "/") {
   515  		return req.hostOrURL
   516  	}
   517  	_host := req.getHeader("host")
   518  	if _host == "" {
   519  		return
   520  	}
   521  	URL = fmt.Sprintf("http://%s%s", _host, req.hostOrURL)
   522  	return
   523  }
   524  func (req *HTTPRequest) getHeader(key string) (val string) {
   525  	key = strings.ToUpper(key)
   526  	lines := strings.Split(string(req.HeadBuf), "\r\n")
   527  	//log.Println(lines)
   528  	for _, line := range lines {
   529  		hline := strings.SplitN(strings.Trim(line, "\r\n "), ":", 2)
   530  		if len(hline) == 2 {
   531  			k := strings.ToUpper(strings.Trim(hline[0], " "))
   532  			v := strings.Trim(hline[1], " ")
   533  			if key == k {
   534  				val = v
   535  				return
   536  			}
   537  		}
   538  	}
   539  	return
   540  }
   541  
   542  func (req *HTTPRequest) addPortIfNot() (newHost string) {
   543  	//newHost = req.Host
   544  	port := "80"
   545  	if req.IsHTTPS() {
   546  		port = "443"
   547  	}
   548  	if (!strings.HasPrefix(req.Host, "[") && strings.Index(req.Host, ":") == -1) || (strings.HasPrefix(req.Host, "[") && strings.HasSuffix(req.Host, "]")) {
   549  		//newHost = req.Host + ":" + port
   550  		//req.headBuf = []byte(strings.Replace(string(req.headBuf), req.Host, newHost, 1))
   551  		req.Host = req.Host + ":" + port
   552  	}
   553  	return
   554  }
   555  
   556  type ConnManager struct {
   557  	pool mapx.ConcurrentMap
   558  	l    *sync.Mutex
   559  	log  *logger.Logger
   560  }
   561  
   562  func NewConnManager(log *logger.Logger) ConnManager {
   563  	cm := ConnManager{
   564  		pool: mapx.NewConcurrentMap(),
   565  		l:    &sync.Mutex{},
   566  		log:  log,
   567  	}
   568  	return cm
   569  }
   570  func (cm *ConnManager) Add(key, ID string, conn *net.Conn) {
   571  	cm.pool.Upsert(key, nil, func(exist bool, valueInMap interface{}, newValue interface{}) interface{} {
   572  		var conns mapx.ConcurrentMap
   573  		if !exist {
   574  			conns = mapx.NewConcurrentMap()
   575  		} else {
   576  			conns = valueInMap.(mapx.ConcurrentMap)
   577  		}
   578  		if conns.Has(ID) {
   579  			v, _ := conns.Get(ID)
   580  			(*v.(*net.Conn)).Close()
   581  		}
   582  		conns.Set(ID, conn)
   583  		cm.log.Printf("%s conn added", key)
   584  		return conns
   585  	})
   586  }
   587  func (cm *ConnManager) Remove(key string) {
   588  	var conns mapx.ConcurrentMap
   589  	if v, ok := cm.pool.Get(key); ok {
   590  		conns = v.(mapx.ConcurrentMap)
   591  		conns.IterCb(func(key string, v interface{}) {
   592  			CloseConn(v.(*net.Conn))
   593  		})
   594  		cm.log.Printf("%s conns closed", key)
   595  	}
   596  	cm.pool.Remove(key)
   597  }
   598  func (cm *ConnManager) RemoveOne(key string, ID string) {
   599  	defer cm.l.Unlock()
   600  	cm.l.Lock()
   601  	var conns mapx.ConcurrentMap
   602  	if v, ok := cm.pool.Get(key); ok {
   603  		conns = v.(mapx.ConcurrentMap)
   604  		if conns.Has(ID) {
   605  			v, _ := conns.Get(ID)
   606  			(*v.(*net.Conn)).Close()
   607  			conns.Remove(ID)
   608  			cm.pool.Set(key, conns)
   609  			cm.log.Printf("%s %s conn closed", key, ID)
   610  		}
   611  	}
   612  }
   613  func (cm *ConnManager) RemoveAll() {
   614  	for _, k := range cm.pool.Keys() {
   615  		cm.Remove(k)
   616  	}
   617  }
   618  
   619  type ClientKeyRouter struct {
   620  	keyChan chan string
   621  	ctrl    *mapx.ConcurrentMap
   622  	lock    *sync.Mutex
   623  }
   624  
   625  func NewClientKeyRouter(ctrl *mapx.ConcurrentMap, size int) ClientKeyRouter {
   626  	return ClientKeyRouter{
   627  		keyChan: make(chan string, size),
   628  		ctrl:    ctrl,
   629  		lock:    &sync.Mutex{},
   630  	}
   631  }
   632  func (c *ClientKeyRouter) GetKey() string {
   633  	defer c.lock.Unlock()
   634  	c.lock.Lock()
   635  	if len(c.keyChan) == 0 {
   636  	EXIT:
   637  		for _, k := range c.ctrl.Keys() {
   638  			select {
   639  			case c.keyChan <- k:
   640  			default:
   641  				goto EXIT
   642  			}
   643  		}
   644  	}
   645  	for {
   646  		if len(c.keyChan) == 0 {
   647  			return "*"
   648  		}
   649  		select {
   650  		case key := <-c.keyChan:
   651  			if c.ctrl.Has(key) {
   652  				return key
   653  			}
   654  		default:
   655  			return "*"
   656  		}
   657  	}
   658  
   659  }
   660  
   661  func NewCompStream(conn net.Conn) *CompStream {
   662  	c := new(CompStream)
   663  	c.conn = conn
   664  	c.w = snappy.NewBufferedWriter(conn)
   665  	c.r = snappy.NewReader(conn)
   666  	return c
   667  }
   668  func NewCompConn(conn net.Conn) net.Conn {
   669  	c := CompStream{}
   670  	c.conn = conn
   671  	c.w = snappy.NewBufferedWriter(conn)
   672  	c.r = snappy.NewReader(conn)
   673  	return &c
   674  }
   675  
   676  type CompStream struct {
   677  	net.Conn
   678  	conn net.Conn
   679  	w    *snappy.Writer
   680  	r    *snappy.Reader
   681  }
   682  
   683  func (c *CompStream) Read(p []byte) (n int, err error) {
   684  	return c.r.Read(p)
   685  }
   686  
   687  func (c *CompStream) Write(p []byte) (n int, err error) {
   688  	n, err = c.w.Write(p)
   689  	err = c.w.Flush()
   690  	return n, err
   691  }
   692  
   693  func (c *CompStream) Close() error {
   694  	return c.conn.Close()
   695  }
   696  func (c *CompStream) LocalAddr() net.Addr {
   697  	return c.conn.LocalAddr()
   698  }
   699  func (c *CompStream) RemoteAddr() net.Addr {
   700  	return c.conn.RemoteAddr()
   701  }
   702  func (c *CompStream) SetDeadline(t time.Time) error {
   703  	return c.conn.SetDeadline(t)
   704  }
   705  func (c *CompStream) SetReadDeadline(t time.Time) error {
   706  	return c.conn.SetReadDeadline(t)
   707  }
   708  func (c *CompStream) SetWriteDeadline(t time.Time) error {
   709  	return c.conn.SetWriteDeadline(t)
   710  }
   711  
   712  type BufferedConn struct {
   713  	r        *bufio.Reader
   714  	net.Conn // So that most methods are embedded
   715  }
   716  
   717  func NewBufferedConn(c net.Conn) BufferedConn {
   718  	return BufferedConn{bufio.NewReader(c), c}
   719  }
   720  
   721  func NewBufferedConnSize(c net.Conn, n int) BufferedConn {
   722  	return BufferedConn{bufio.NewReaderSize(c, n), c}
   723  }
   724  
   725  func (b BufferedConn) Peek(n int) ([]byte, error) {
   726  	return b.r.Peek(n)
   727  }
   728  
   729  func (b BufferedConn) Read(p []byte) (int, error) {
   730  	return b.r.Read(p)
   731  }
   732  func (b BufferedConn) ReadByte() (byte, error) {
   733  	return b.r.ReadByte()
   734  }
   735  func (b BufferedConn) UnreadByte() error {
   736  	return b.r.UnreadByte()
   737  }
   738  func (b BufferedConn) Buffered() int {
   739  	return b.r.Buffered()
   740  }