github.com/AntonOrnatskyi/goproxy@v0.0.0-20190205095733-4526a9fa18b4/services/udp/udp.go (about)

     1  package udp
     2  
     3  import (
     4  	"crypto/tls"
     5  	"fmt"
     6  	"io"
     7  	logger "log"
     8  	"net"
     9  	"runtime/debug"
    10  	"strconv"
    11  	"strings"
    12  	"time"
    13  
    14  	"github.com/AntonOrnatskyi/goproxy/core/cs/server"
    15  	"github.com/AntonOrnatskyi/goproxy/services"
    16  	"github.com/AntonOrnatskyi/goproxy/utils"
    17  	"github.com/AntonOrnatskyi/goproxy/utils/mapx"
    18  )
    19  
    20  type UDPArgs struct {
    21  	Parent              *string
    22  	CertFile            *string
    23  	KeyFile             *string
    24  	CertBytes           []byte
    25  	KeyBytes            []byte
    26  	Local               *string
    27  	ParentType          *string
    28  	Timeout             *int
    29  	CheckParentInterval *int
    30  }
    31  type UDP struct {
    32  	p                mapx.ConcurrentMap
    33  	cfg              UDPArgs
    34  	sc               *server.ServerChannel
    35  	isStop           bool
    36  	log              *logger.Logger
    37  	outUDPConnCtxMap mapx.ConcurrentMap
    38  	udpConns         mapx.ConcurrentMap
    39  	dstAddr          *net.UDPAddr
    40  }
    41  type UDPConnItem struct {
    42  	conn      *net.Conn
    43  	touchtime int64
    44  	srcAddr   *net.UDPAddr
    45  	localAddr *net.UDPAddr
    46  	connid    string
    47  }
    48  type outUDPConnCtx struct {
    49  	localAddr *net.UDPAddr
    50  	srcAddr   *net.UDPAddr
    51  	udpconn   *net.UDPConn
    52  	touchtime int64
    53  }
    54  
    55  func NewUDP() services.Service {
    56  	return &UDP{
    57  		p:                mapx.NewConcurrentMap(),
    58  		isStop:           false,
    59  		outUDPConnCtxMap: mapx.NewConcurrentMap(),
    60  		udpConns:         mapx.NewConcurrentMap(),
    61  	}
    62  }
    63  func (s *UDP) CheckArgs() (err error) {
    64  	if len(*s.cfg.Parent) == 0 {
    65  		err = fmt.Errorf("parent required for udp %s", *s.cfg.Local)
    66  		return
    67  	}
    68  	if *s.cfg.ParentType == "" {
    69  		err = fmt.Errorf("parent type unkown,use -T <udp|tls|tcp>")
    70  		return
    71  	}
    72  	if *s.cfg.ParentType == "tls" {
    73  		s.cfg.CertBytes, s.cfg.KeyBytes, err = utils.TlsBytes(*s.cfg.CertFile, *s.cfg.KeyFile)
    74  		if err != nil {
    75  			return
    76  		}
    77  	}
    78  
    79  	s.dstAddr, err = net.ResolveUDPAddr("udp", *s.cfg.Parent)
    80  	if err != nil {
    81  		s.log.Printf("resolve udp addr %s fail  fail,ERR:%s", *s.cfg.Parent, err)
    82  		return
    83  	}
    84  	return
    85  }
    86  func (s *UDP) InitService() (err error) {
    87  	s.OutToUDPGCDeamon()
    88  	s.UDPGCDeamon()
    89  	return
    90  }
    91  func (s *UDP) StopService() {
    92  	defer func() {
    93  		e := recover()
    94  		if e != nil {
    95  			s.log.Printf("stop udp service crashed,%s", e)
    96  		} else {
    97  			s.log.Printf("service udp stopped")
    98  		}
    99  		s.cfg = UDPArgs{}
   100  		s.log = nil
   101  		s.p = nil
   102  		s.sc = nil
   103  		s = nil
   104  	}()
   105  	s.isStop = true
   106  	if s.sc.Listener != nil && *s.sc.Listener != nil {
   107  		(*s.sc.Listener).Close()
   108  	}
   109  	if s.sc.UDPListener != nil {
   110  		(*s.sc.UDPListener).Close()
   111  	}
   112  }
   113  func (s *UDP) Start(args interface{}, log *logger.Logger) (err error) {
   114  	s.log = log
   115  	s.cfg = args.(UDPArgs)
   116  	if err = s.CheckArgs(); err != nil {
   117  		return
   118  	}
   119  	s.log.Printf("use %s parent %s", *s.cfg.ParentType, *s.cfg.Parent)
   120  	if err = s.InitService(); err != nil {
   121  		return
   122  	}
   123  	host, port, _ := net.SplitHostPort(*s.cfg.Local)
   124  	p, _ := strconv.Atoi(port)
   125  	sc := server.NewServerChannel(host, p, s.log)
   126  	s.sc = &sc
   127  	err = sc.ListenUDP(s.callback)
   128  	if err != nil {
   129  		return
   130  	}
   131  	s.log.Printf("udp proxy on %s", (*sc.UDPListener).LocalAddr())
   132  	return
   133  }
   134  
   135  func (s *UDP) Clean() {
   136  	s.StopService()
   137  }
   138  func (s *UDP) callback(listener *net.UDPConn, packet []byte, localAddr, srcAddr *net.UDPAddr) {
   139  	defer func() {
   140  		if err := recover(); err != nil {
   141  			s.log.Printf("udp conn handler crashed with err : %s \nstack: %s", err, string(debug.Stack()))
   142  		}
   143  	}()
   144  	switch *s.cfg.ParentType {
   145  	case "tcp", "tls":
   146  		s.OutToTCP(packet, localAddr, srcAddr)
   147  	case "udp":
   148  		s.OutToUDP(packet, localAddr, srcAddr)
   149  	default:
   150  		s.log.Printf("unkown parent type %s", *s.cfg.ParentType)
   151  	}
   152  }
   153  func (s *UDP) GetConn(connKey string) (conn net.Conn, isNew bool, err error) {
   154  	isNew = !s.p.Has(connKey)
   155  	var _conn interface{}
   156  	if isNew {
   157  		_conn, err = s.GetParentConn()
   158  		if err != nil {
   159  			return nil, false, err
   160  		}
   161  		s.p.Set(connKey, _conn)
   162  	} else {
   163  		_conn, _ = s.p.Get(connKey)
   164  	}
   165  	conn = _conn.(net.Conn)
   166  	return
   167  }
   168  func (s *UDP) OutToTCP(data []byte, localAddr, srcAddr *net.UDPAddr) (err error) {
   169  	s.UDPSend(data, localAddr, srcAddr)
   170  	return
   171  }
   172  func (s *UDP) OutToUDPGCDeamon() {
   173  	gctime := int64(30)
   174  	go func() {
   175  		defer func() {
   176  			if e := recover(); e != nil {
   177  				fmt.Printf("crashed, err: %s\nstack:\n%s", e, string(debug.Stack()))
   178  			}
   179  		}()
   180  		if s.isStop {
   181  			return
   182  		}
   183  		timer := time.NewTicker(time.Second)
   184  		for {
   185  			<-timer.C
   186  			gcKeys := []string{}
   187  			s.outUDPConnCtxMap.IterCb(func(key string, v interface{}) {
   188  				if time.Now().Unix()-v.(*outUDPConnCtx).touchtime > gctime {
   189  					(*(v.(*outUDPConnCtx).udpconn)).Close()
   190  					gcKeys = append(gcKeys, key)
   191  					s.log.Printf("gc udp conn %s <--> %s", (*v.(*outUDPConnCtx)).srcAddr, (*v.(*outUDPConnCtx)).localAddr)
   192  				}
   193  			})
   194  			for _, k := range gcKeys {
   195  				s.outUDPConnCtxMap.Remove(k)
   196  			}
   197  			gcKeys = nil
   198  		}
   199  	}()
   200  }
   201  func (s *UDP) OutToUDP(packet []byte, localAddr, srcAddr *net.UDPAddr) {
   202  	var ouc *outUDPConnCtx
   203  	if v, ok := s.outUDPConnCtxMap.Get(srcAddr.String()); !ok {
   204  		clientSrcAddr := &net.UDPAddr{IP: net.IPv4zero, Port: 0}
   205  		conn, err := net.DialUDP("udp", clientSrcAddr, s.dstAddr)
   206  		if err != nil {
   207  			s.log.Printf("connect to udp %s fail,ERR:%s", s.dstAddr.String(), err)
   208  
   209  		}
   210  		ouc = &outUDPConnCtx{
   211  			localAddr: localAddr,
   212  			srcAddr:   srcAddr,
   213  			udpconn:   conn,
   214  		}
   215  		s.outUDPConnCtxMap.Set(srcAddr.String(), ouc)
   216  		go func() {
   217  			defer func() {
   218  				if e := recover(); e != nil {
   219  					fmt.Printf("crashed, err: %s\nstack:\n%s", e, string(debug.Stack()))
   220  				}
   221  			}()
   222  			s.log.Printf("udp conn %s <--> %s connected", srcAddr.String(), localAddr.String())
   223  			buf := utils.LeakyBuffer.Get()
   224  			defer func() {
   225  				utils.LeakyBuffer.Put(buf)
   226  				s.outUDPConnCtxMap.Remove(srcAddr.String())
   227  				s.log.Printf("udp conn %s <--> %s released", srcAddr.String(), localAddr.String())
   228  			}()
   229  			for {
   230  				n, err := ouc.udpconn.Read(buf)
   231  				if err != nil {
   232  					if !utils.IsNetClosedErr(err) {
   233  						s.log.Printf("udp conn read udp packet fail , err: %s ", err)
   234  					}
   235  					return
   236  				}
   237  				ouc.touchtime = time.Now().Unix()
   238  				go func() {
   239  					defer func() {
   240  						if e := recover(); e != nil {
   241  							fmt.Printf("crashed, err: %s\nstack:\n%s", e, string(debug.Stack()))
   242  						}
   243  					}()
   244  					(*(s.sc).UDPListener).SetWriteDeadline(time.Now().Add(time.Millisecond * time.Duration(*s.cfg.Timeout)))
   245  					_, err = (*(s.sc).UDPListener).WriteTo(buf[:n], srcAddr)
   246  					(*(s.sc).UDPListener).SetWriteDeadline(time.Time{})
   247  				}()
   248  			}
   249  		}()
   250  	} else {
   251  		ouc = v.(*outUDPConnCtx)
   252  	}
   253  	go func() {
   254  		ouc.touchtime = time.Now().Unix()
   255  		ouc.udpconn.SetWriteDeadline(time.Now().Add(time.Millisecond * time.Duration(*s.cfg.Timeout)))
   256  		ouc.udpconn.Write(packet)
   257  		ouc.udpconn.SetWriteDeadline(time.Time{})
   258  	}()
   259  	return
   260  }
   261  func (s *UDP) GetParentConn() (conn net.Conn, err error) {
   262  	if *s.cfg.ParentType == "tls" {
   263  		var _conn tls.Conn
   264  		_conn, err = utils.TlsConnectHost(*s.cfg.Parent, *s.cfg.Timeout, s.cfg.CertBytes, s.cfg.KeyBytes, nil)
   265  		if err == nil {
   266  			conn = net.Conn(&_conn)
   267  		}
   268  	} else {
   269  		conn, err = utils.ConnectHost(*s.cfg.Parent, *s.cfg.Timeout)
   270  	}
   271  	return
   272  }
   273  func (s *UDP) UDPGCDeamon() {
   274  	gctime := int64(30)
   275  	go func() {
   276  		defer func() {
   277  			if e := recover(); e != nil {
   278  				fmt.Printf("crashed, err: %s\nstack:\n%s", e, string(debug.Stack()))
   279  			}
   280  		}()
   281  		if s.isStop {
   282  			return
   283  		}
   284  		timer := time.NewTicker(time.Second)
   285  		for {
   286  			<-timer.C
   287  			gcKeys := []string{}
   288  			s.udpConns.IterCb(func(key string, v interface{}) {
   289  				if time.Now().Unix()-v.(*UDPConnItem).touchtime > gctime {
   290  					(*(v.(*UDPConnItem).conn)).Close()
   291  					gcKeys = append(gcKeys, key)
   292  					s.log.Printf("gc udp conn %s", v.(*UDPConnItem).connid)
   293  				}
   294  			})
   295  			for _, k := range gcKeys {
   296  				s.udpConns.Remove(k)
   297  			}
   298  			gcKeys = nil
   299  		}
   300  	}()
   301  }
   302  func (s *UDP) UDPSend(data []byte, localAddr, srcAddr *net.UDPAddr) {
   303  	var (
   304  		uc      *UDPConnItem
   305  		key     = srcAddr.String()
   306  		err     error
   307  		outconn net.Conn
   308  	)
   309  	v, ok := s.udpConns.Get(key)
   310  	if !ok {
   311  		for {
   312  			outconn, err = s.GetParentConn()
   313  			if err != nil && strings.Contains(err.Error(), "can not connect at same time") {
   314  				time.Sleep(time.Millisecond * 500)
   315  				continue
   316  			} else {
   317  				break
   318  			}
   319  		}
   320  		if err != nil {
   321  			s.log.Printf("connect to %s fail, err: %s", *s.cfg.Parent, err)
   322  			return
   323  		}
   324  		uc = &UDPConnItem{
   325  			conn:      &outconn,
   326  			srcAddr:   srcAddr,
   327  			localAddr: localAddr,
   328  		}
   329  		s.udpConns.Set(key, uc)
   330  		s.UDPRevecive(key)
   331  	} else {
   332  		uc = v.(*UDPConnItem)
   333  	}
   334  	go func() {
   335  		defer func() {
   336  			if e := recover(); e != nil {
   337  				(*uc.conn).Close()
   338  				s.udpConns.Remove(key)
   339  				s.log.Printf("udp sender crashed with error : %s", e)
   340  			}
   341  		}()
   342  		uc.touchtime = time.Now().Unix()
   343  		(*uc.conn).SetWriteDeadline(time.Now().Add(time.Millisecond * time.Duration(*s.cfg.Timeout)))
   344  		_, err = (*uc.conn).Write(utils.UDPPacket(fmt.Sprintf("%s", srcAddr.String()), data))
   345  		(*uc.conn).SetWriteDeadline(time.Time{})
   346  		if err != nil {
   347  			s.log.Printf("write udp packet to %s fail ,flush err:%s ", *s.cfg.Parent, err)
   348  		}
   349  	}()
   350  }
   351  func (s *UDP) UDPRevecive(key string) {
   352  	go func() {
   353  		defer func() {
   354  			if e := recover(); e != nil {
   355  				fmt.Printf("crashed, err: %s\nstack:\n%s", e, string(debug.Stack()))
   356  			}
   357  		}()
   358  		s.log.Printf("udp conn %s connected", key)
   359  		var uc *UDPConnItem
   360  		defer func() {
   361  			if uc != nil {
   362  				(*uc.conn).Close()
   363  			}
   364  			s.udpConns.Remove(key)
   365  			s.log.Printf("udp conn %s released", key)
   366  		}()
   367  		v, ok := s.udpConns.Get(key)
   368  		if !ok {
   369  			s.log.Printf("[warn] udp conn not exists for %s", key)
   370  			return
   371  		}
   372  		uc = v.(*UDPConnItem)
   373  		for {
   374  			_, body, err := utils.ReadUDPPacket(*uc.conn)
   375  			if err != nil {
   376  				if strings.Contains(err.Error(), "n != int(") {
   377  					continue
   378  				}
   379  				if err != io.EOF && !utils.IsNetClosedErr(err) {
   380  					s.log.Printf("udp conn read udp packet fail , err: %s ", err)
   381  				}
   382  				return
   383  			}
   384  			uc.touchtime = time.Now().Unix()
   385  			go func() {
   386  				defer func() {
   387  					if e := recover(); e != nil {
   388  						fmt.Printf("crashed, err: %s\nstack:\n%s", e, string(debug.Stack()))
   389  					}
   390  				}()
   391  				s.sc.UDPListener.WriteToUDP(body, uc.srcAddr)
   392  			}()
   393  		}
   394  	}()
   395  }