github.com/AntonOrnatskyi/goproxy@v0.0.0-20190205095733-4526a9fa18b4/core/lib/udp/udp.go (about)

     1  package udputils
     2  
     3  import (
     4  	"fmt"
     5  	logger "log"
     6  	"net"
     7  	"runtime/debug"
     8  	"strings"
     9  	"time"
    10  
    11  	bufx "github.com/AntonOrnatskyi/goproxy/core/lib/buf"
    12  	mapx "github.com/AntonOrnatskyi/goproxy/core/lib/mapx"
    13  )
    14  
    15  type CreateOutUDPConnFn func(listener *net.UDPConn, srcAddr *net.UDPAddr, packet []byte) (outconn *net.UDPConn, err error)
    16  type CleanFn func(srcAddr string)
    17  type BeforeSendFn func(listener *net.UDPConn, srcAddr *net.UDPAddr, b []byte) (sendB []byte, err error)
    18  type BeforeReplyFn func(listener *net.UDPConn, srcAddr *net.UDPAddr, outconn *net.UDPConn, b []byte) (replyB []byte, err error)
    19  
    20  type IOBinder struct {
    21  	outConns           mapx.ConcurrentMap
    22  	listener           *net.UDPConn
    23  	createOutUDPConnFn CreateOutUDPConnFn
    24  	log                *logger.Logger
    25  	timeout            time.Duration
    26  	cleanFn            CleanFn
    27  	inTCPConn          *net.Conn
    28  	outTCPConn         *net.Conn
    29  	beforeSendFn       BeforeSendFn
    30  	beforeReplyFn      BeforeReplyFn
    31  }
    32  
    33  func NewIOBinder(listener *net.UDPConn, log *logger.Logger) *IOBinder {
    34  	return &IOBinder{
    35  		listener: listener,
    36  		outConns: mapx.NewConcurrentMap(),
    37  		log:      log,
    38  	}
    39  }
    40  func (s *IOBinder) Factory(fn CreateOutUDPConnFn) *IOBinder {
    41  	s.createOutUDPConnFn = fn
    42  	return s
    43  }
    44  func (s *IOBinder) AfterReadFromClient(fn BeforeSendFn) *IOBinder {
    45  	s.beforeSendFn = fn
    46  	return s
    47  }
    48  func (s *IOBinder) AfterReadFromServer(fn BeforeReplyFn) *IOBinder {
    49  	s.beforeReplyFn = fn
    50  	return s
    51  }
    52  func (s *IOBinder) Timeout(timeout time.Duration) *IOBinder {
    53  	s.timeout = timeout
    54  	return s
    55  }
    56  func (s *IOBinder) Clean(fn CleanFn) *IOBinder {
    57  	s.cleanFn = fn
    58  	return s
    59  }
    60  func (s *IOBinder) AliveWithServeConn(srcAddr string, inTCPConn *net.Conn) *IOBinder {
    61  	s.inTCPConn = inTCPConn
    62  	go func() {
    63  		defer func() {
    64  			if e := recover(); e != nil {
    65  				fmt.Printf("crashed, err: %s\nstack:\n%s",e, string(debug.Stack()))
    66  			}
    67  		}()
    68  		buf := make([]byte, 1)
    69  		(*inTCPConn).SetReadDeadline(time.Time{})
    70  		if _, err := (*inTCPConn).Read(buf); err != nil {
    71  			s.log.Printf("udp related tcp conn of client disconnected with read , %s", err.Error())
    72  			s.clean(srcAddr)
    73  		}
    74  	}()
    75  	go func() {
    76  		defer func() {
    77  			if e := recover(); e != nil {
    78  				fmt.Printf("crashed, err: %s\nstack:\n%s",e, string(debug.Stack()))
    79  			}
    80  		}()
    81  		for {
    82  			(*inTCPConn).SetWriteDeadline(time.Now().Add(time.Second * 5))
    83  			if _, err := (*inTCPConn).Write([]byte{0x00}); err != nil {
    84  				s.log.Printf("udp related tcp conn of client disconnected with write , %s", err.Error())
    85  				s.clean(srcAddr)
    86  				return
    87  			}
    88  			(*inTCPConn).SetWriteDeadline(time.Time{})
    89  			time.Sleep(time.Second * 5)
    90  		}
    91  	}()
    92  	return s
    93  }
    94  func (s *IOBinder) AliveWithClientConn(srcAddr string, outTCPConn *net.Conn) *IOBinder {
    95  	s.outTCPConn = outTCPConn
    96  	go func() {
    97  		defer func() {
    98  			if e := recover(); e != nil {
    99  				fmt.Printf("crashed, err: %s\nstack:\n%s",e, string(debug.Stack()))
   100  			}
   101  		}()
   102  		buf := make([]byte, 1)
   103  		(*outTCPConn).SetReadDeadline(time.Time{})
   104  		if _, err := (*outTCPConn).Read(buf); err != nil {
   105  			s.log.Printf("udp related tcp conn to parent disconnected with read , %s", err.Error())
   106  			s.clean(srcAddr)
   107  		}
   108  	}()
   109  	return s
   110  }
   111  func (s *IOBinder) Run() (err error) {
   112  	var (
   113  		isClosedErr = func(err error) bool {
   114  			return err != nil && strings.Contains(err.Error(), "use of closed network connection")
   115  		}
   116  		isTimeoutErr = func(err error) bool {
   117  			if err == nil {
   118  				return false
   119  			}
   120  			e, ok := err.(net.Error)
   121  			return ok && e.Timeout()
   122  		}
   123  		isRefusedErr = func(err error) bool {
   124  			return err != nil && strings.Contains(err.Error(), "connection refused")
   125  		}
   126  	)
   127  	for {
   128  		buf := bufx.Get()
   129  		defer bufx.Put(buf)
   130  		n, srcAddr, err := s.listener.ReadFromUDP(buf)
   131  		if err != nil {
   132  			s.log.Printf("read from client error %s", err)
   133  			if isClosedErr(err) {
   134  				return err
   135  			}
   136  			continue
   137  		}
   138  		var data []byte
   139  		if s.beforeSendFn != nil {
   140  			data, err = s.beforeSendFn(s.listener, srcAddr, buf[:n])
   141  			if err != nil {
   142  				s.log.Printf("beforeSend retured an error , %s", err)
   143  				continue
   144  			}
   145  		} else {
   146  			data = buf[:n]
   147  		}
   148  		inconnRemoteAddr := srcAddr.String()
   149  		var outconn *net.UDPConn
   150  		if v, ok := s.outConns.Get(inconnRemoteAddr); !ok {
   151  			outconn, err = s.createOutUDPConnFn(s.listener, srcAddr, data)
   152  			if err != nil {
   153  				s.log.Printf("connnect fail %s", err)
   154  				return err
   155  			}
   156  			go func() {
   157  				defer func() {
   158  					if e := recover(); e != nil {
   159  						fmt.Printf("crashed, err: %s\nstack:\n%s",e, string(debug.Stack()))
   160  					}
   161  				}()
   162  				defer func() {
   163  					s.clean(srcAddr.String())
   164  				}()
   165  				buf := bufx.Get()
   166  				defer bufx.Put(buf)
   167  				for {
   168  					if s.timeout > 0 {
   169  						outconn.SetReadDeadline(time.Now().Add(s.timeout))
   170  					}
   171  					n, srcAddr, err := outconn.ReadFromUDP(buf)
   172  					if err != nil {
   173  						s.log.Printf("read from remote error %s", err)
   174  						if isClosedErr(err) || isTimeoutErr(err) || isRefusedErr(err) {
   175  							return
   176  						}
   177  						continue
   178  					}
   179  					data := buf[:n]
   180  					if s.beforeReplyFn != nil {
   181  						data, err = s.beforeReplyFn(s.listener, srcAddr, outconn, buf[:n])
   182  						if err != nil {
   183  							s.log.Printf("beforeReply retured an error , %s", err)
   184  							continue
   185  						}
   186  					}
   187  					_, err = s.listener.WriteTo(data, srcAddr)
   188  					if err != nil {
   189  						s.log.Printf("write to remote error %s", err)
   190  						if isClosedErr(err) {
   191  							return
   192  						}
   193  						continue
   194  					}
   195  				}
   196  			}()
   197  		} else {
   198  			outconn = v.(*net.UDPConn)
   199  		}
   200  
   201  		s.log.Printf("use decrpyted data , %v", data)
   202  
   203  		_, err = outconn.Write(data)
   204  
   205  		if err != nil {
   206  			s.log.Printf("write to remote error %s", err)
   207  			if isClosedErr(err) {
   208  				return err
   209  			}
   210  		}
   211  	}
   212  }
   213  func (s *IOBinder) clean(srcAddr string) *IOBinder {
   214  	if v, ok := s.outConns.Get(srcAddr); ok {
   215  		(*v.(*net.UDPConn)).Close()
   216  		s.outConns.Remove(srcAddr)
   217  	}
   218  	if s.inTCPConn != nil {
   219  		(*s.inTCPConn).Close()
   220  	}
   221  	if s.outTCPConn != nil {
   222  		(*s.outTCPConn).Close()
   223  	}
   224  	if s.cleanFn != nil {
   225  		s.cleanFn(srcAddr)
   226  	}
   227  	return s
   228  }
   229  
   230  func (s *IOBinder) Close() {
   231  	for _, c := range s.outConns.Items() {
   232  		(*c.(*net.UDPConn)).Close()
   233  	}
   234  }