github.com/ipfans/trojan-go@v0.11.0/tunnel/tproxy/server.go (about) 1 //go:build linux 2 // +build linux 3 4 package tproxy 5 6 import ( 7 "context" 8 "io" 9 "net" 10 "sync" 11 "time" 12 13 "github.com/ipfans/trojan-go/common" 14 "github.com/ipfans/trojan-go/config" 15 "github.com/ipfans/trojan-go/log" 16 "github.com/ipfans/trojan-go/tunnel" 17 ) 18 19 const MaxPacketSize = 1024 * 8 20 21 type Server struct { 22 tcpListener net.Listener 23 udpListener *net.UDPConn 24 packetChan chan tunnel.PacketConn 25 timeout time.Duration 26 mappingLock sync.RWMutex 27 mapping map[string]*PacketConn 28 ctx context.Context 29 cancel context.CancelFunc 30 } 31 32 func (s *Server) Close() error { 33 s.cancel() 34 s.tcpListener.Close() 35 return s.udpListener.Close() 36 } 37 38 func (s *Server) AcceptConn(tunnel.Tunnel) (tunnel.Conn, error) { 39 conn, err := s.tcpListener.Accept() 40 if err != nil { 41 select { 42 case <-s.ctx.Done(): 43 default: 44 log.Fatal(common.NewError("tproxy failed to accept connection").Base(err)) 45 } 46 return nil, common.NewError("tproxy failed to accept conn") 47 } 48 dst, err := getOriginalTCPDest(conn.(*net.TCPConn)) 49 if err != nil { 50 return nil, common.NewError("tproxy failed to obtain original address of tcp socket").Base(err) 51 } 52 address, err := tunnel.NewAddressFromAddr("tcp", dst.String()) 53 common.Must(err) 54 log.Info("tproxy connection from", conn.RemoteAddr().String(), "metadata", dst.String()) 55 return &Conn{ 56 metadata: &tunnel.Metadata{ 57 Address: address, 58 }, 59 Conn: conn, 60 }, nil 61 } 62 63 func (s *Server) packetDispatchLoop() { 64 type tproxyPacketInfo struct { 65 src *net.UDPAddr 66 dst *net.UDPAddr 67 payload []byte 68 } 69 packetQueue := make(chan *tproxyPacketInfo, 1024) 70 71 go func() { 72 for { 73 buf := make([]byte, MaxPacketSize) 74 n, src, dst, err := ReadFromUDP(s.udpListener, buf) 75 if err != nil { 76 select { 77 case <-s.ctx.Done(): 78 default: 79 log.Fatal(common.NewError("tproxy failed to read from udp socket").Base(err)) 80 } 81 s.Close() 82 return 83 } 84 log.Debug("udp packet from", src, "metadata", dst, "size", n) 85 packetQueue <- &tproxyPacketInfo{ 86 src: src, 87 dst: dst, 88 payload: buf[:n], 89 } 90 } 91 }() 92 93 for { 94 var info *tproxyPacketInfo 95 select { 96 case info = <-packetQueue: 97 case <-s.ctx.Done(): 98 log.Debug("exiting") 99 return 100 } 101 102 s.mappingLock.RLock() 103 conn, found := s.mapping[info.src.String()] 104 s.mappingLock.RUnlock() 105 106 if !found { 107 ctx, cancel := context.WithCancel(s.ctx) 108 conn = &PacketConn{ 109 input: make(chan *packetInfo, 128), 110 output: make(chan *packetInfo, 128), 111 PacketConn: s.udpListener, 112 ctx: ctx, 113 cancel: cancel, 114 src: info.src, 115 } 116 117 s.mappingLock.Lock() 118 s.mapping[info.src.String()] = conn 119 s.mappingLock.Unlock() 120 121 log.Info("new tproxy udp session from", info.src.String(), "metadata", info.dst.String()) 122 s.packetChan <- conn 123 124 go func(conn *PacketConn) { 125 defer conn.Close() 126 log.Debug("udp packet daemon for", conn.src.String()) 127 for { 128 select { 129 case info := <-conn.output: 130 if info.metadata.AddressType != tunnel.IPv4 && 131 info.metadata.AddressType != tunnel.IPv6 { 132 log.Error("tproxy invalid response metadata address", info.metadata) 133 continue 134 } 135 back, err := DialUDP( 136 "udp", 137 &net.UDPAddr{ 138 IP: info.metadata.IP, 139 Port: info.metadata.Port, 140 }, 141 conn.src.(*net.UDPAddr), 142 ) 143 if err != nil { 144 log.Error(common.NewError("failed to dial tproxy udp").Base(err)) 145 return 146 } 147 n, err := back.Write(info.payload) 148 if err != nil { 149 log.Error(common.NewError("tproxy udp write error").Base(err)) 150 return 151 } 152 log.Debug("recv packet, send back to", conn.src, "payload", len(info.payload), "sent", n) 153 back.Close() 154 case <-s.ctx.Done(): 155 log.Debug("exiting") 156 return 157 case <-time.After(s.timeout): 158 s.mappingLock.Lock() 159 delete(s.mapping, conn.src.String()) 160 s.mappingLock.Unlock() 161 log.Debug("packet session ", conn.src.String(), "timeout") 162 return 163 } 164 } 165 }(conn) 166 } 167 168 newInfo := &packetInfo{ 169 metadata: &tunnel.Metadata{ 170 Address: tunnel.NewAddressFromHostPort("udp", info.dst.IP.String(), info.dst.Port), 171 }, 172 payload: info.payload, 173 } 174 175 select { 176 case conn.input <- newInfo: 177 log.Debug("tproxy packet sent with metadata", newInfo.metadata, "size", len(info.payload)) 178 default: 179 // if we got too many packets, simply drop it 180 log.Warn("tproxy udp relay queue full!") 181 } 182 } 183 } 184 185 func (s *Server) AcceptPacket(tunnel.Tunnel) (tunnel.PacketConn, error) { 186 select { 187 case conn := <-s.packetChan: 188 log.Info("tproxy packet conn accepted") 189 return conn, nil 190 case <-s.ctx.Done(): 191 return nil, io.EOF 192 } 193 } 194 195 func NewServer(ctx context.Context, _ tunnel.Server) (*Server, error) { 196 cfg := config.FromContext(ctx, Name).(*Config) 197 ctx, cancel := context.WithCancel(ctx) 198 listenAddr := tunnel.NewAddressFromHostPort("tcp", cfg.LocalHost, cfg.LocalPort) 199 ip, err := listenAddr.ResolveIP() 200 if err != nil { 201 cancel() 202 return nil, common.NewError("invalid tproxy local address").Base(err) 203 } 204 tcpListener, err := ListenTCP("tcp", &net.TCPAddr{ 205 IP: ip, 206 Port: cfg.LocalPort, 207 }) 208 if err != nil { 209 cancel() 210 return nil, common.NewError("tproxy failed to listen tcp").Base(err) 211 } 212 213 udpListener, err := ListenUDP("udp", &net.UDPAddr{ 214 IP: ip, 215 Port: cfg.LocalPort, 216 }) 217 if err != nil { 218 cancel() 219 return nil, common.NewError("tproxy failed to listen udp").Base(err) 220 } 221 222 server := &Server{ 223 tcpListener: tcpListener, 224 udpListener: udpListener, 225 ctx: ctx, 226 cancel: cancel, 227 timeout: time.Duration(cfg.UDPTimeout) * time.Second, 228 mapping: make(map[string]*PacketConn), 229 packetChan: make(chan tunnel.PacketConn, 32), 230 } 231 go server.packetDispatchLoop() 232 log.Info("tproxy server listening on", tcpListener.Addr(), "(tcp)", udpListener.LocalAddr(), "(udp)") 233 log.Debug("tproxy server created") 234 return server, nil 235 }