github.com/ipfans/trojan-go@v0.11.0/tunnel/socks/server.go (about) 1 package socks 2 3 import ( 4 "bytes" 5 "context" 6 "fmt" 7 "io" 8 "io/ioutil" 9 "net" 10 "sync" 11 "time" 12 13 "github.com/ipfans/trojan-go/common" 14 "github.com/ipfans/trojan-go/config" 15 "github.com/ipfans/trojan-go/log" 16 "github.com/ipfans/trojan-go/tunnel" 17 ) 18 19 const ( 20 Connect tunnel.Command = 1 21 Associate tunnel.Command = 3 22 ) 23 24 const ( 25 MaxPacketSize = 1024 * 8 26 ) 27 28 type Server struct { 29 connChan chan tunnel.Conn 30 packetChan chan tunnel.PacketConn 31 underlay tunnel.Server 32 localHost string 33 localPort int 34 timeout time.Duration 35 listenPacketConn tunnel.PacketConn 36 mapping map[string]*PacketConn 37 mappingLock sync.RWMutex 38 ctx context.Context 39 cancel context.CancelFunc 40 } 41 42 func (s *Server) AcceptConn(tunnel.Tunnel) (tunnel.Conn, error) { 43 select { 44 case conn := <-s.connChan: 45 return conn, nil 46 case <-s.ctx.Done(): 47 return nil, common.NewError("socks server closed") 48 } 49 } 50 51 func (s *Server) AcceptPacket(tunnel.Tunnel) (tunnel.PacketConn, error) { 52 select { 53 case conn := <-s.packetChan: 54 return conn, nil 55 case <-s.ctx.Done(): 56 return nil, common.NewError("socks server closed") 57 } 58 } 59 60 func (s *Server) Close() error { 61 s.cancel() 62 return s.underlay.Close() 63 } 64 65 func (s *Server) handshake(conn net.Conn) (*Conn, error) { 66 version := [1]byte{} 67 if _, err := conn.Read(version[:]); err != nil { 68 return nil, common.NewError("failed to read socks version").Base(err) 69 } 70 if version[0] != 5 { 71 return nil, common.NewError(fmt.Sprintf("invalid socks version %d", version[0])) 72 } 73 nmethods := [1]byte{} 74 if _, err := conn.Read(nmethods[:]); err != nil { 75 return nil, common.NewError("failed to read NMETHODS") 76 } 77 if _, err := io.CopyN(ioutil.Discard, conn, int64(nmethods[0])); err != nil { 78 return nil, common.NewError("socks failed to read methods").Base(err) 79 } 80 if _, err := conn.Write([]byte{0x5, 0x0}); err != nil { 81 return nil, common.NewError("failed to respond auth").Base(err) 82 } 83 84 buf := [3]byte{} 85 if _, err := conn.Read(buf[:]); err != nil { 86 return nil, common.NewError("failed to read command") 87 } 88 89 addr := new(tunnel.Address) 90 if err := addr.ReadFrom(conn); err != nil { 91 return nil, err 92 } 93 94 return &Conn{ 95 metadata: &tunnel.Metadata{ 96 Command: tunnel.Command(buf[1]), 97 Address: addr, 98 }, 99 Conn: conn, 100 }, nil 101 } 102 103 func (s *Server) connect(conn net.Conn) error { 104 _, err := conn.Write([]byte{0x05, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}) 105 return err 106 } 107 108 func (s *Server) associate(conn net.Conn, addr *tunnel.Address) error { 109 buf := bytes.NewBuffer([]byte{0x05, 0x00, 0x00}) 110 common.Must(addr.WriteTo(buf)) 111 _, err := conn.Write(buf.Bytes()) 112 return err 113 } 114 115 func (s *Server) packetDispatchLoop() { 116 for { 117 buf := make([]byte, MaxPacketSize) 118 n, src, err := s.listenPacketConn.ReadFrom(buf) 119 if err != nil { 120 select { 121 case <-s.ctx.Done(): 122 log.Debug("exiting") 123 return 124 default: 125 continue 126 } 127 } 128 log.Debug("socks recv udp packet from", src) 129 s.mappingLock.RLock() 130 conn, found := s.mapping[src.String()] 131 s.mappingLock.RUnlock() 132 if !found { 133 ctx, cancel := context.WithCancel(s.ctx) 134 conn = &PacketConn{ 135 input: make(chan *packetInfo, 128), 136 output: make(chan *packetInfo, 128), 137 ctx: ctx, 138 cancel: cancel, 139 PacketConn: s.listenPacketConn, 140 src: src, 141 } 142 go func(conn *PacketConn) { 143 defer conn.Close() 144 for { 145 select { 146 case info := <-conn.output: 147 buf := bytes.NewBuffer(make([]byte, 0, MaxPacketSize)) 148 buf.Write([]byte{0, 0, 0}) // RSV, FRAG 149 common.Must(info.metadata.Address.WriteTo(buf)) 150 buf.Write(info.payload) 151 _, err := s.listenPacketConn.WriteTo(buf.Bytes(), conn.src) 152 if err != nil { 153 log.Error("socks failed to respond packet to", src) 154 return 155 } 156 log.Debug("socks respond udp packet to", src, "metadata", info.metadata) 157 case <-time.After(time.Second * 5): 158 log.Info("socks udp session timeout, closed") 159 s.mappingLock.Lock() 160 delete(s.mapping, src.String()) 161 s.mappingLock.Unlock() 162 return 163 case <-conn.ctx.Done(): 164 log.Info("socks udp session closed") 165 return 166 } 167 } 168 }(conn) 169 170 s.mappingLock.Lock() 171 s.mapping[src.String()] = conn 172 s.mappingLock.Unlock() 173 174 s.packetChan <- conn 175 log.Info("socks new udp session from", src) 176 } 177 r := bytes.NewBuffer(buf[3:n]) 178 address := new(tunnel.Address) 179 if err := address.ReadFrom(r); err != nil { 180 log.Error(common.NewError("socks failed to parse incoming packet").Base(err)) 181 continue 182 } 183 payload := make([]byte, MaxPacketSize) 184 length, _ := r.Read(payload) 185 select { 186 case conn.input <- &packetInfo{ 187 metadata: &tunnel.Metadata{ 188 Address: address, 189 }, 190 payload: payload[:length], 191 }: 192 default: 193 log.Warn("socks udp queue full") 194 } 195 } 196 } 197 198 func (s *Server) acceptLoop() { 199 for { 200 conn, err := s.underlay.AcceptConn(&Tunnel{}) 201 if err != nil { 202 log.Error(common.NewError("socks accept err").Base(err)) 203 return 204 } 205 go func(conn net.Conn) { 206 newConn, err := s.handshake(conn) 207 if err != nil { 208 log.Error(common.NewError("socks failed to handshake with client").Base(err)) 209 return 210 } 211 log.Info("socks connection from", conn.RemoteAddr(), "metadata", newConn.metadata.String()) 212 switch newConn.metadata.Command { 213 case Connect: 214 if err := s.connect(newConn); err != nil { 215 log.Error(common.NewError("socks failed to respond CONNECT").Base(err)) 216 newConn.Close() 217 return 218 } 219 s.connChan <- newConn 220 return 221 case Associate: 222 defer newConn.Close() 223 associateAddr := tunnel.NewAddressFromHostPort("udp", s.localHost, s.localPort) 224 if err := s.associate(newConn, associateAddr); err != nil { 225 log.Error(common.NewError("socks failed to respond to associate request").Base(err)) 226 return 227 } 228 buf := [16]byte{} 229 newConn.Read(buf[:]) 230 log.Debug("socks udp session ends") 231 default: 232 log.Error(common.NewError(fmt.Sprintf("unknown socks command %d", newConn.metadata.Command))) 233 newConn.Close() 234 } 235 }(conn) 236 } 237 } 238 239 // NewServer create a socks server 240 func NewServer(ctx context.Context, underlay tunnel.Server) (tunnel.Server, error) { 241 cfg := config.FromContext(ctx, Name).(*Config) 242 listenPacketConn, err := underlay.AcceptPacket(&Tunnel{}) 243 if err != nil { 244 return nil, common.NewError("socks failed to listen packet from underlying server") 245 } 246 ctx, cancel := context.WithCancel(ctx) 247 server := &Server{ 248 underlay: underlay, 249 ctx: ctx, 250 cancel: cancel, 251 connChan: make(chan tunnel.Conn, 32), 252 packetChan: make(chan tunnel.PacketConn, 32), 253 localHost: cfg.LocalHost, 254 localPort: cfg.LocalPort, 255 timeout: time.Duration(cfg.UDPTimeout) * time.Second, 256 listenPacketConn: listenPacketConn, 257 mapping: make(map[string]*PacketConn), 258 } 259 go server.acceptLoop() 260 go server.packetDispatchLoop() 261 log.Debug("socks server created") 262 return server, nil 263 }