github.com/AntonOrnatskyi/goproxy@v0.0.0-20190205095733-4526a9fa18b4/utils/structs.go (about) 1 package utils 2 3 import ( 4 "bufio" 5 "bytes" 6 "encoding/base64" 7 "errors" 8 "fmt" 9 "io" 10 "io/ioutil" 11 logger "log" 12 "net" 13 "net/url" 14 "runtime/debug" 15 "strings" 16 "sync" 17 "time" 18 19 "github.com/AntonOrnatskyi/goproxy/utils/dnsx" 20 "github.com/AntonOrnatskyi/goproxy/utils/mapx" 21 "github.com/AntonOrnatskyi/goproxy/utils/sni" 22 23 "github.com/golang/snappy" 24 ) 25 26 type Checker struct { 27 data mapx.ConcurrentMap 28 blockedMap mapx.ConcurrentMap 29 directMap mapx.ConcurrentMap 30 interval int64 31 timeout int 32 isStop bool 33 intelligent string 34 log *logger.Logger 35 } 36 type CheckerItem struct { 37 Domain string 38 Address string 39 SuccessCount uint 40 FailCount uint 41 Lasttime int64 42 } 43 44 //NewChecker args: 45 //timeout : tcp timeout milliseconds ,connect to host 46 //interval: recheck domain interval seconds 47 func NewChecker(timeout int, interval int64, blockedFile, directFile string, log *logger.Logger, intelligent string) Checker { 48 ch := Checker{ 49 data: mapx.NewConcurrentMap(), 50 interval: interval, 51 timeout: timeout, 52 isStop: false, 53 intelligent: intelligent, 54 log: log, 55 } 56 ch.blockedMap = ch.loadMap(blockedFile) 57 ch.directMap = ch.loadMap(directFile) 58 if !ch.blockedMap.IsEmpty() { 59 log.Printf("blocked file loaded , domains : %d", ch.blockedMap.Count()) 60 } 61 if !ch.directMap.IsEmpty() { 62 log.Printf("direct file loaded , domains : %d", ch.directMap.Count()) 63 } 64 if interval > 0 { 65 ch.start() 66 } 67 68 return ch 69 } 70 71 func (c *Checker) loadMap(f string) (dataMap mapx.ConcurrentMap) { 72 dataMap = mapx.NewConcurrentMap() 73 if PathExists(f) { 74 _contents, err := ioutil.ReadFile(f) 75 if err != nil { 76 c.log.Printf("load file err:%s", err) 77 return 78 } 79 for _, line := range strings.Split(string(_contents), "\n") { 80 line = strings.Trim(line, "\r \t") 81 if line != "" { 82 dataMap.Set(line, true) 83 } 84 } 85 } 86 return 87 } 88 func (c *Checker) Stop() { 89 c.isStop = true 90 } 91 func (c *Checker) start() { 92 go func() { 93 defer func() { 94 if e := recover(); e != nil { 95 fmt.Printf("crashed, err: %s\nstack:%s", e, string(debug.Stack())) 96 } 97 }() 98 //log.Printf("checker started") 99 for { 100 //log.Printf("checker did") 101 for _, v := range c.data.Items() { 102 go func(item CheckerItem) { 103 defer func() { 104 if e := recover(); e != nil { 105 fmt.Printf("crashed, err: %s\nstack:%s", e, string(debug.Stack())) 106 } 107 }() 108 if c.isNeedCheck(item) { 109 //log.Printf("check %s", item.Host) 110 var conn net.Conn 111 var err error 112 var now = time.Now().Unix() 113 conn, err = ConnectHost(item.Address, c.timeout) 114 if err == nil { 115 conn.SetDeadline(time.Now().Add(time.Millisecond)) 116 conn.Close() 117 } 118 if now-item.Lasttime > 1800 { 119 item.FailCount = 0 120 item.SuccessCount = 0 121 } 122 if err != nil { 123 item.FailCount = item.FailCount + 1 124 } else { 125 item.SuccessCount = item.SuccessCount + 1 126 } 127 item.Lasttime = now 128 c.data.Set(item.Domain, item) 129 } 130 }(v.(CheckerItem)) 131 } 132 time.Sleep(time.Second * time.Duration(c.interval)) 133 if c.isStop { 134 return 135 } 136 } 137 }() 138 } 139 func (c *Checker) isNeedCheck(item CheckerItem) bool { 140 var minCount uint = 5 141 var now = time.Now().Unix() 142 if (item.SuccessCount >= minCount && item.SuccessCount > item.FailCount && now-item.Lasttime < 1800) || 143 (item.FailCount >= minCount && item.SuccessCount > item.FailCount && now-item.Lasttime < 1800) || 144 c.domainIsInMap(item.Domain, false) || 145 c.domainIsInMap(item.Domain, true) { 146 return false 147 } 148 return true 149 } 150 func (c *Checker) IsBlocked(domain string) (blocked, isInMap bool, failN, successN uint) { 151 h, _, _ := net.SplitHostPort(domain) 152 if h != "" { 153 domain = h 154 } 155 if c.domainIsInMap(domain, true) { 156 //log.Printf("%s in blocked ? true", address) 157 return true, true, 0, 0 158 } 159 if c.domainIsInMap(domain, false) { 160 //log.Printf("%s in direct ? true", address) 161 return false, true, 0, 0 162 } 163 164 _item, ok := c.data.Get(domain) 165 if !ok { 166 //log.Printf("%s not in map, blocked true", address) 167 return true, false, 0, 0 168 } 169 switch c.intelligent { 170 case "direct": 171 return false, true, 0, 0 172 case "parent": 173 return true, true, 0, 0 174 case "intelligent": 175 fallthrough 176 default: 177 item := _item.(CheckerItem) 178 return (item.FailCount >= item.SuccessCount) && (time.Now().Unix()-item.Lasttime < 1800), true, item.FailCount, item.SuccessCount 179 } 180 } 181 182 func (c *Checker) domainIsInMap(address string, blockedMap bool) bool { 183 u, err := url.Parse("http://" + address) 184 if err != nil { 185 c.log.Printf("blocked check , url parse err:%s", err) 186 return true 187 } 188 domainSlice := strings.Split(u.Hostname(), ".") 189 if len(domainSlice) > 1 { 190 checkDomain := "" 191 for i := len(domainSlice) - 1; i >= 0; i-- { 192 checkDomain = strings.Join(domainSlice[i:], ".") 193 if !blockedMap && c.directMap.Has(checkDomain) { 194 return true 195 } 196 if blockedMap && c.blockedMap.Has(checkDomain) { 197 return true 198 } 199 } 200 } 201 return false 202 } 203 func (c *Checker) Add(domain, address string) { 204 h, _, _ := net.SplitHostPort(domain) 205 if h != "" { 206 domain = h 207 } 208 if c.domainIsInMap(domain, false) || c.domainIsInMap(domain, true) { 209 return 210 } 211 var item CheckerItem 212 item = CheckerItem{ 213 Domain: domain, 214 Address: address, 215 } 216 c.data.SetIfAbsent(item.Domain, item) 217 } 218 219 type BasicAuth struct { 220 data mapx.ConcurrentMap 221 authURL string 222 authOkCode int 223 authTimeout int 224 authRetry int 225 dns *dnsx.DomainResolver 226 log *logger.Logger 227 } 228 229 func NewBasicAuth(dns *dnsx.DomainResolver, log *logger.Logger) BasicAuth { 230 return BasicAuth{ 231 data: mapx.NewConcurrentMap(), 232 dns: dns, 233 log: log, 234 } 235 } 236 func (ba *BasicAuth) SetAuthURL(URL string, code, timeout, retry int) { 237 ba.authURL = URL 238 ba.authOkCode = code 239 ba.authTimeout = timeout 240 ba.authRetry = retry 241 } 242 func (ba *BasicAuth) AddFromFile(file string) (n int, err error) { 243 _content, err := ioutil.ReadFile(file) 244 if err != nil { 245 return 246 } 247 userpassArr := strings.Split(strings.Replace(string(_content), "\r", "", -1), "\n") 248 for _, userpass := range userpassArr { 249 if strings.HasPrefix(userpass, "#") { 250 continue 251 } 252 u := strings.Split(strings.Trim(userpass, " "), ":") 253 if len(u) == 2 { 254 ba.data.Set(u[0], u[1]) 255 n++ 256 } 257 } 258 return 259 } 260 261 func (ba *BasicAuth) Add(userpassArr []string) (n int) { 262 for _, userpass := range userpassArr { 263 u := strings.Split(userpass, ":") 264 if len(u) == 2 { 265 ba.data.Set(u[0], u[1]) 266 n++ 267 } 268 } 269 return 270 } 271 func (ba *BasicAuth) Delete(userArr []string) { 272 for _, u := range userArr { 273 ba.data.Remove(u) 274 } 275 } 276 func (ba *BasicAuth) CheckUserPass(user, pass, userIP, localIP, target string) (ok bool) { 277 278 return ba.Check(user+":"+pass, userIP, localIP, target) 279 } 280 func (ba *BasicAuth) Check(userpass string, userIP, localIP, target string) (ok bool) { 281 u := strings.Split(strings.Trim(userpass, " "), ":") 282 if len(u) == 2 { 283 if p, _ok := ba.data.Get(u[0]); _ok { 284 return p.(string) == u[1] 285 } 286 if ba.authURL != "" { 287 err := ba.checkFromURL(userpass, userIP, localIP, target) 288 if err == nil { 289 return true 290 } 291 ba.log.Printf("%s", err) 292 } 293 return false 294 } 295 return 296 } 297 func (ba *BasicAuth) checkFromURL(userpass, userIP, localIP, target string) (err error) { 298 u := strings.Split(strings.Trim(userpass, " "), ":") 299 if len(u) != 2 { 300 return 301 } 302 303 URL := ba.authURL 304 if strings.Contains(URL, "?") { 305 URL += "&" 306 } else { 307 URL += "?" 308 } 309 URL += fmt.Sprintf("user=%s&pass=%s&ip=%s&local_ip=%s&target=%s", u[0], u[1], userIP, localIP, url.QueryEscape(target)) 310 getURL := URL 311 var domain string 312 if ba.dns != nil { 313 _url, _ := url.Parse(ba.authURL) 314 domain = _url.Host 315 domainIP := ba.dns.MustResolve(domain) 316 getURL = strings.Replace(URL, domain, domainIP, 1) 317 } 318 var code int 319 var tryCount = 0 320 var body []byte 321 for tryCount <= ba.authRetry { 322 body, code, err = HttpGet(getURL, ba.authTimeout, domain) 323 if err == nil && code == ba.authOkCode { 324 break 325 } else if err != nil { 326 err = fmt.Errorf("auth fail from url %s,resonse err:%s , %s -> %s", URL, err, userIP, localIP) 327 } else { 328 if len(body) > 0 { 329 err = fmt.Errorf(string(body[0:100])) 330 } else { 331 err = fmt.Errorf("token error") 332 } 333 b := string(body) 334 if len(b) > 50 { 335 b = b[:50] 336 } 337 err = fmt.Errorf("auth fail from url %s,resonse code: %d, except: %d , %s -> %s, %s", URL, code, ba.authOkCode, userIP, localIP, b) 338 } 339 if err != nil && tryCount < ba.authRetry { 340 ba.log.Print(err) 341 time.Sleep(time.Second * 2) 342 } 343 tryCount++ 344 } 345 if err != nil { 346 return 347 } 348 //log.Printf("auth success from auth url, %s", ip) 349 return 350 } 351 352 func (ba *BasicAuth) Total() (n int) { 353 n = ba.data.Count() 354 return 355 } 356 357 type HTTPRequest struct { 358 HeadBuf []byte 359 conn *net.Conn 360 Host string 361 Method string 362 URL string 363 hostOrURL string 364 isBasicAuth bool 365 basicAuth *BasicAuth 366 log *logger.Logger 367 IsSNI bool 368 } 369 370 func NewHTTPRequest(inConn *net.Conn, bufSize int, isBasicAuth bool, basicAuth *BasicAuth, log *logger.Logger, header ...[]byte) (req HTTPRequest, err error) { 371 buf := make([]byte, bufSize) 372 n := 0 373 req = HTTPRequest{ 374 conn: inConn, 375 log: log, 376 } 377 if header != nil && len(header) == 1 && len(header[0]) > 1 { 378 buf = header[0] 379 n = len(header[0]) 380 } else { 381 n, err = (*inConn).Read(buf[:]) 382 if err != nil { 383 if err != io.EOF { 384 err = fmt.Errorf("http decoder read err:%s", err) 385 } 386 CloseConn(inConn) 387 return 388 } 389 } 390 391 req.HeadBuf = buf[:n] 392 //fmt.Println(string(req.HeadBuf)) 393 //try sni 394 serverName, err0 := sni.ServerNameFromBytes(req.HeadBuf) 395 if err0 == nil { 396 //sni success 397 req.Method = "SNI" 398 req.hostOrURL = "https://" + serverName + ":443" 399 req.IsSNI = true 400 } else { 401 //sni fail , try http 402 index := bytes.IndexByte(req.HeadBuf, '\n') 403 if index == -1 { 404 err = fmt.Errorf("http decoder data line err:%s", SubStr(string(req.HeadBuf), 0, 50)) 405 CloseConn(inConn) 406 return 407 } 408 fmt.Sscanf(string(req.HeadBuf[:index]), "%s%s", &req.Method, &req.hostOrURL) 409 } 410 if req.Method == "" || req.hostOrURL == "" { 411 err = fmt.Errorf("http decoder data err:%s", SubStr(string(req.HeadBuf), 0, 50)) 412 CloseConn(inConn) 413 return 414 } 415 req.Method = strings.ToUpper(req.Method) 416 req.isBasicAuth = isBasicAuth 417 req.basicAuth = basicAuth 418 log.Printf("%s:%s", req.Method, req.hostOrURL) 419 420 if req.IsHTTPS() { 421 err = req.HTTPS() 422 } else { 423 err = req.HTTP() 424 } 425 return 426 } 427 func (req *HTTPRequest) HTTP() (err error) { 428 if req.isBasicAuth { 429 err = req.BasicAuth() 430 if err != nil { 431 return 432 } 433 } 434 req.URL = req.getHTTPURL() 435 var u *url.URL 436 u, err = url.Parse(req.URL) 437 if err != nil { 438 return 439 } 440 req.Host = u.Host 441 req.addPortIfNot() 442 return 443 } 444 func (req *HTTPRequest) HTTPS() (err error) { 445 if req.isBasicAuth { 446 err = req.BasicAuth() 447 if err != nil { 448 return 449 } 450 } 451 req.Host = req.hostOrURL 452 req.addPortIfNot() 453 return 454 } 455 func (req *HTTPRequest) HTTPSReply() (err error) { 456 _, err = fmt.Fprint(*req.conn, "HTTP/1.1 200 Connection established\r\n\r\n") 457 return 458 } 459 func (req *HTTPRequest) IsHTTPS() bool { 460 return req.Method == "CONNECT" 461 } 462 463 func (req *HTTPRequest) GetAuthDataStr() (basicInfo string, err error) { 464 // log.Printf("request :%s", string(req.HeadBuf)) 465 authorization := req.getHeader("Proxy-Authorization") 466 467 authorization = strings.Trim(authorization, " \r\n\t") 468 if authorization == "" { 469 fmt.Fprintf((*req.conn), "HTTP/1.1 %s Proxy Authentication Required\r\nProxy-Authenticate: Basic realm=\"\"\r\n\r\nProxy Authentication Required", "407") 470 CloseConn(req.conn) 471 err = errors.New("require auth header data") 472 return 473 } 474 //log.Printf("Authorization:%authorization = req.getHeader("Authorization") 475 basic := strings.Fields(authorization) 476 if len(basic) != 2 { 477 err = fmt.Errorf("authorization data error,ERR:%s", authorization) 478 CloseConn(req.conn) 479 return 480 } 481 user, err := base64.StdEncoding.DecodeString(basic[1]) 482 if err != nil { 483 err = fmt.Errorf("authorization data parse error,ERR:%s", err) 484 CloseConn(req.conn) 485 return 486 } 487 basicInfo = string(user) 488 return 489 } 490 func (req *HTTPRequest) BasicAuth() (err error) { 491 userIP := strings.Split((*req.conn).RemoteAddr().String(), ":") 492 localIP := strings.Split((*req.conn).LocalAddr().String(), ":") 493 URL := "" 494 if req.IsHTTPS() { 495 URL = "https://" + req.Host 496 } else { 497 URL = req.getHTTPURL() 498 } 499 user, err := req.GetAuthDataStr() 500 if err != nil { 501 return 502 } 503 authOk := (*req.basicAuth).Check(string(user), userIP[0], localIP[0], URL) 504 //log.Printf("auth %s,%v", string(user), authOk) 505 if !authOk { 506 fmt.Fprintf((*req.conn), "HTTP/1.1 %s Proxy Authentication Required\r\n\r\nProxy Authentication Required", "407") 507 CloseConn(req.conn) 508 err = fmt.Errorf("basic auth fail") 509 return 510 } 511 return 512 } 513 func (req *HTTPRequest) getHTTPURL() (URL string) { 514 if !strings.HasPrefix(req.hostOrURL, "/") { 515 return req.hostOrURL 516 } 517 _host := req.getHeader("host") 518 if _host == "" { 519 return 520 } 521 URL = fmt.Sprintf("http://%s%s", _host, req.hostOrURL) 522 return 523 } 524 func (req *HTTPRequest) getHeader(key string) (val string) { 525 key = strings.ToUpper(key) 526 lines := strings.Split(string(req.HeadBuf), "\r\n") 527 //log.Println(lines) 528 for _, line := range lines { 529 hline := strings.SplitN(strings.Trim(line, "\r\n "), ":", 2) 530 if len(hline) == 2 { 531 k := strings.ToUpper(strings.Trim(hline[0], " ")) 532 v := strings.Trim(hline[1], " ") 533 if key == k { 534 val = v 535 return 536 } 537 } 538 } 539 return 540 } 541 542 func (req *HTTPRequest) addPortIfNot() (newHost string) { 543 //newHost = req.Host 544 port := "80" 545 if req.IsHTTPS() { 546 port = "443" 547 } 548 if (!strings.HasPrefix(req.Host, "[") && strings.Index(req.Host, ":") == -1) || (strings.HasPrefix(req.Host, "[") && strings.HasSuffix(req.Host, "]")) { 549 //newHost = req.Host + ":" + port 550 //req.headBuf = []byte(strings.Replace(string(req.headBuf), req.Host, newHost, 1)) 551 req.Host = req.Host + ":" + port 552 } 553 return 554 } 555 556 type ConnManager struct { 557 pool mapx.ConcurrentMap 558 l *sync.Mutex 559 log *logger.Logger 560 } 561 562 func NewConnManager(log *logger.Logger) ConnManager { 563 cm := ConnManager{ 564 pool: mapx.NewConcurrentMap(), 565 l: &sync.Mutex{}, 566 log: log, 567 } 568 return cm 569 } 570 func (cm *ConnManager) Add(key, ID string, conn *net.Conn) { 571 cm.pool.Upsert(key, nil, func(exist bool, valueInMap interface{}, newValue interface{}) interface{} { 572 var conns mapx.ConcurrentMap 573 if !exist { 574 conns = mapx.NewConcurrentMap() 575 } else { 576 conns = valueInMap.(mapx.ConcurrentMap) 577 } 578 if conns.Has(ID) { 579 v, _ := conns.Get(ID) 580 (*v.(*net.Conn)).Close() 581 } 582 conns.Set(ID, conn) 583 cm.log.Printf("%s conn added", key) 584 return conns 585 }) 586 } 587 func (cm *ConnManager) Remove(key string) { 588 var conns mapx.ConcurrentMap 589 if v, ok := cm.pool.Get(key); ok { 590 conns = v.(mapx.ConcurrentMap) 591 conns.IterCb(func(key string, v interface{}) { 592 CloseConn(v.(*net.Conn)) 593 }) 594 cm.log.Printf("%s conns closed", key) 595 } 596 cm.pool.Remove(key) 597 } 598 func (cm *ConnManager) RemoveOne(key string, ID string) { 599 defer cm.l.Unlock() 600 cm.l.Lock() 601 var conns mapx.ConcurrentMap 602 if v, ok := cm.pool.Get(key); ok { 603 conns = v.(mapx.ConcurrentMap) 604 if conns.Has(ID) { 605 v, _ := conns.Get(ID) 606 (*v.(*net.Conn)).Close() 607 conns.Remove(ID) 608 cm.pool.Set(key, conns) 609 cm.log.Printf("%s %s conn closed", key, ID) 610 } 611 } 612 } 613 func (cm *ConnManager) RemoveAll() { 614 for _, k := range cm.pool.Keys() { 615 cm.Remove(k) 616 } 617 } 618 619 type ClientKeyRouter struct { 620 keyChan chan string 621 ctrl *mapx.ConcurrentMap 622 lock *sync.Mutex 623 } 624 625 func NewClientKeyRouter(ctrl *mapx.ConcurrentMap, size int) ClientKeyRouter { 626 return ClientKeyRouter{ 627 keyChan: make(chan string, size), 628 ctrl: ctrl, 629 lock: &sync.Mutex{}, 630 } 631 } 632 func (c *ClientKeyRouter) GetKey() string { 633 defer c.lock.Unlock() 634 c.lock.Lock() 635 if len(c.keyChan) == 0 { 636 EXIT: 637 for _, k := range c.ctrl.Keys() { 638 select { 639 case c.keyChan <- k: 640 default: 641 goto EXIT 642 } 643 } 644 } 645 for { 646 if len(c.keyChan) == 0 { 647 return "*" 648 } 649 select { 650 case key := <-c.keyChan: 651 if c.ctrl.Has(key) { 652 return key 653 } 654 default: 655 return "*" 656 } 657 } 658 659 } 660 661 func NewCompStream(conn net.Conn) *CompStream { 662 c := new(CompStream) 663 c.conn = conn 664 c.w = snappy.NewBufferedWriter(conn) 665 c.r = snappy.NewReader(conn) 666 return c 667 } 668 func NewCompConn(conn net.Conn) net.Conn { 669 c := CompStream{} 670 c.conn = conn 671 c.w = snappy.NewBufferedWriter(conn) 672 c.r = snappy.NewReader(conn) 673 return &c 674 } 675 676 type CompStream struct { 677 net.Conn 678 conn net.Conn 679 w *snappy.Writer 680 r *snappy.Reader 681 } 682 683 func (c *CompStream) Read(p []byte) (n int, err error) { 684 return c.r.Read(p) 685 } 686 687 func (c *CompStream) Write(p []byte) (n int, err error) { 688 n, err = c.w.Write(p) 689 err = c.w.Flush() 690 return n, err 691 } 692 693 func (c *CompStream) Close() error { 694 return c.conn.Close() 695 } 696 func (c *CompStream) LocalAddr() net.Addr { 697 return c.conn.LocalAddr() 698 } 699 func (c *CompStream) RemoteAddr() net.Addr { 700 return c.conn.RemoteAddr() 701 } 702 func (c *CompStream) SetDeadline(t time.Time) error { 703 return c.conn.SetDeadline(t) 704 } 705 func (c *CompStream) SetReadDeadline(t time.Time) error { 706 return c.conn.SetReadDeadline(t) 707 } 708 func (c *CompStream) SetWriteDeadline(t time.Time) error { 709 return c.conn.SetWriteDeadline(t) 710 } 711 712 type BufferedConn struct { 713 r *bufio.Reader 714 net.Conn // So that most methods are embedded 715 } 716 717 func NewBufferedConn(c net.Conn) BufferedConn { 718 return BufferedConn{bufio.NewReader(c), c} 719 } 720 721 func NewBufferedConnSize(c net.Conn, n int) BufferedConn { 722 return BufferedConn{bufio.NewReaderSize(c, n), c} 723 } 724 725 func (b BufferedConn) Peek(n int) ([]byte, error) { 726 return b.r.Peek(n) 727 } 728 729 func (b BufferedConn) Read(p []byte) (int, error) { 730 return b.r.Read(p) 731 } 732 func (b BufferedConn) ReadByte() (byte, error) { 733 return b.r.ReadByte() 734 } 735 func (b BufferedConn) UnreadByte() error { 736 return b.r.UnreadByte() 737 } 738 func (b BufferedConn) Buffered() int { 739 return b.r.Buffered() 740 }