github.com/xxf098/lite-proxy@v0.15.1-0.20230422081941-12c69f323218/tunnel/socks/server.go (about) 1 package socks 2 3 import ( 4 "bytes" 5 "context" 6 "errors" 7 "fmt" 8 "io" 9 "io/ioutil" 10 "net" 11 "sync" 12 "time" 13 14 "github.com/xxf098/lite-proxy/common" 15 "github.com/xxf098/lite-proxy/log" 16 "github.com/xxf098/lite-proxy/tunnel" 17 ) 18 19 const ( 20 Connect tunnel.Command = 1 21 Associate tunnel.Command = 3 22 MaxPacketSize = 1024 * 8 23 ) 24 25 type Server struct { 26 connChan chan tunnel.Conn 27 packetChan chan tunnel.PacketConn 28 underlay tunnel.Server 29 localHost string 30 localPort int 31 timeout time.Duration 32 listenPacketConn net.PacketConn 33 mapping map[string]*PacketConn 34 mappingLock sync.RWMutex 35 ctx context.Context 36 cancel context.CancelFunc 37 } 38 39 func (s *Server) AcceptConn(tunnel.Tunnel) (tunnel.Conn, error) { 40 select { 41 case conn := <-s.connChan: 42 return conn, nil 43 case <-s.ctx.Done(): 44 return nil, errors.New("socks server closed") 45 } 46 } 47 48 func (s *Server) AcceptPacket(tunnel.Tunnel) (tunnel.PacketConn, error) { 49 select { 50 case conn := <-s.packetChan: 51 return conn, nil 52 case <-s.ctx.Done(): 53 return nil, common.NewError("socks server closed") 54 } 55 } 56 57 func (s *Server) Close() error { 58 s.cancel() 59 return s.underlay.Close() 60 } 61 62 func (s *Server) handshake(conn net.Conn) (*Conn, error) { 63 version := [1]byte{} 64 if _, err := conn.Read(version[:]); err != nil { 65 return nil, common.NewError("failed to read socks version").Base(err) 66 } 67 if version[0] != 5 { 68 return nil, common.NewError(fmt.Sprintf("invalid socks version %d", version[0])) 69 } 70 nmethods := [1]byte{} 71 if _, err := conn.Read(nmethods[:]); err != nil { 72 return nil, common.NewError("failed to read NMETHODS") 73 } 74 if _, err := io.CopyN(ioutil.Discard, conn, int64(nmethods[0])); err != nil { 75 return nil, common.NewError("socks failed to read methods").Base(err) 76 } 77 if _, err := conn.Write([]byte{0x5, 0x0}); err != nil { 78 return nil, common.NewError("failed to respond auth").Base(err) 79 } 80 81 buf := [3]byte{} 82 if _, err := conn.Read(buf[:]); err != nil { 83 return nil, common.NewError("failed to read command") 84 } 85 86 addr := new(tunnel.Address) 87 if err := addr.ReadFrom(conn); err != nil { 88 return nil, err 89 } 90 91 return &Conn{ 92 metadata: &tunnel.Metadata{ 93 Command: tunnel.Command(buf[1]), 94 Address: addr, 95 }, 96 Conn: conn, 97 }, nil 98 } 99 100 func (s *Server) connect(conn net.Conn) error { 101 _, err := conn.Write([]byte{0x05, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}) 102 return err 103 } 104 105 func (s *Server) associate(conn net.Conn, addr *tunnel.Address) error { 106 buf := bytes.NewBuffer([]byte{0x05, 0x00, 0x00}) 107 common.Must(addr.WriteTo(buf)) 108 _, err := conn.Write(buf.Bytes()) 109 return err 110 } 111 112 func (s *Server) packetDispatchLoop() { 113 for { 114 buf := make([]byte, MaxPacketSize) 115 n, src, err := s.listenPacketConn.ReadFrom(buf) 116 if err != nil { 117 select { 118 case <-s.ctx.Done(): 119 log.D("exiting") 120 return 121 default: 122 continue 123 } 124 } 125 log.D("socks recv udp packet from", src) 126 s.mappingLock.RLock() 127 conn, found := s.mapping[src.String()] 128 s.mappingLock.RUnlock() 129 if !found { 130 ctx, cancel := context.WithCancel(s.ctx) 131 conn = &PacketConn{ 132 input: make(chan *packetInfo, 128), 133 output: make(chan *packetInfo, 128), 134 ctx: ctx, 135 cancel: cancel, 136 PacketConn: s.listenPacketConn, 137 src: src, 138 } 139 go func(conn *PacketConn) { 140 defer conn.Close() 141 for { 142 select { 143 case info := <-conn.output: 144 buf := bytes.NewBuffer(make([]byte, 0, MaxPacketSize)) 145 buf.Write([]byte{0, 0, 0}) //RSV, FRAG 146 common.Must(info.metadata.Address.WriteTo(buf)) 147 buf.Write(info.payload) 148 _, err := s.listenPacketConn.WriteTo(buf.Bytes(), conn.src) 149 if err != nil { 150 log.E("socks failed to respond packet to", src) 151 return 152 } 153 log.D("socks respond udp packet to", src, "metadata", info.metadata) 154 case <-time.After(time.Second * 5): 155 log.I("socks udp session timeout, closed") 156 s.mappingLock.Lock() 157 delete(s.mapping, src.String()) 158 s.mappingLock.Unlock() 159 return 160 case <-conn.ctx.Done(): 161 log.I("socks udp session closed") 162 return 163 } 164 } 165 }(conn) 166 167 s.mappingLock.Lock() 168 s.mapping[src.String()] = conn 169 s.mappingLock.Unlock() 170 171 s.packetChan <- conn 172 log.I("socks new udp session from", src) 173 } 174 r := bytes.NewBuffer(buf[3:n]) 175 address := new(tunnel.Address) 176 if err := address.ReadFrom(r); err != nil { 177 log.Error(common.NewError("socks failed to parse incoming packet").Base(err)) 178 continue 179 } 180 payload := make([]byte, MaxPacketSize) 181 length, err := r.Read(payload) 182 select { 183 case conn.input <- &packetInfo{ 184 metadata: &tunnel.Metadata{ 185 Address: address, 186 }, 187 payload: payload[:length], 188 }: 189 default: 190 log.W("socks udp queue full") 191 } 192 } 193 } 194 195 func (s *Server) acceptLoop() { 196 for { 197 conn, err := s.underlay.AcceptConn(&Tunnel{}) 198 if err != nil { 199 log.Error(common.NewError("socks accept err").Base(err)) 200 return 201 } 202 go func(conn net.Conn) { 203 newConn, err := s.handshake(conn) 204 if err != nil { 205 log.Error(common.NewError("socks failed to handshake with client").Base(err)) 206 return 207 } 208 log.I("socks connection from", conn.RemoteAddr(), "metadata", newConn.metadata.String()) 209 switch newConn.metadata.Command { 210 case Connect: 211 if err := s.connect(newConn); err != nil { 212 log.Error(common.NewError("socks failed to respond CONNECT").Base(err)) 213 newConn.Close() 214 return 215 } 216 s.connChan <- newConn 217 return 218 case Associate: 219 defer newConn.Close() 220 associateAddr := tunnel.NewAddressFromHostPort("udp", s.localHost, s.localPort) 221 if err := s.associate(newConn, associateAddr); err != nil { 222 log.Error(common.NewError("socks failed to respond to associate request").Base(err)) 223 return 224 } 225 buf := [16]byte{} 226 newConn.Read(buf[:]) 227 log.D("socks udp session ends") 228 default: 229 log.Error(common.NewError(fmt.Sprintf("unknown socks command %d", newConn.metadata.Command))) 230 newConn.Close() 231 } 232 }(conn) 233 } 234 } 235 236 func NewServer(ctx context.Context, underlay tunnel.Server) (*Server, error) { 237 listenPacketConn, err := underlay.AcceptPacket(&Tunnel{}) 238 if err != nil { 239 return nil, err 240 } 241 localHost := ctx.Value("LocalHost").(string) 242 localPort := ctx.Value("LocalPort").(int) 243 ctx, cancel := context.WithCancel(ctx) 244 server := &Server{ 245 underlay: underlay, 246 ctx: ctx, 247 cancel: cancel, 248 connChan: make(chan tunnel.Conn, 32), 249 packetChan: make(chan tunnel.PacketConn, 32), 250 localHost: localHost, 251 localPort: localPort, 252 timeout: 5 * time.Second, 253 listenPacketConn: listenPacketConn, 254 mapping: make(map[string]*PacketConn), 255 } 256 go server.acceptLoop() 257 go server.packetDispatchLoop() 258 log.D("socks server created", localPort) 259 return server, nil 260 }