github.com/v2fly/v2ray-core/v5@v5.16.2-0.20240507031116-8191faa6e095/proxy/shadowsocks2022/client_session.go (about)

     1  package shadowsocks2022
     2  
     3  import (
     4  	"context"
     5  	"crypto/rand"
     6  	"io"
     7  	gonet "net"
     8  	"sync"
     9  	"time"
    10  
    11  	"github.com/v2fly/v2ray-core/v5/common/buf"
    12  	"github.com/v2fly/v2ray-core/v5/common/net"
    13  	"github.com/v2fly/v2ray-core/v5/transport/internet"
    14  
    15  	"github.com/pion/transport/v2/replaydetector"
    16  )
    17  
    18  func NewClientUDPSession(ctx context.Context, conn io.ReadWriteCloser, packetProcessor UDPClientPacketProcessor) *ClientUDPSession {
    19  	session := &ClientUDPSession{
    20  		locker:          &sync.RWMutex{},
    21  		conn:            conn,
    22  		packetProcessor: packetProcessor,
    23  		sessionMap:      make(map[string]*ClientUDPSessionConn),
    24  		sessionMapAlias: make(map[string]string),
    25  	}
    26  	session.ctx, session.finish = context.WithCancel(ctx)
    27  
    28  	go session.KeepReading()
    29  	return session
    30  }
    31  
    32  type ClientUDPSession struct {
    33  	locker *sync.RWMutex
    34  
    35  	conn            io.ReadWriteCloser
    36  	packetProcessor UDPClientPacketProcessor
    37  	sessionMap      map[string]*ClientUDPSessionConn
    38  
    39  	sessionMapAlias map[string]string
    40  
    41  	ctx    context.Context
    42  	finish func()
    43  }
    44  
    45  func (c *ClientUDPSession) GetCachedState(sessionID string) UDPClientPacketProcessorCachedState {
    46  	c.locker.RLock()
    47  	defer c.locker.RUnlock()
    48  
    49  	state, ok := c.sessionMap[sessionID]
    50  	if !ok {
    51  		return nil
    52  	}
    53  	return state.cachedProcessorState
    54  }
    55  
    56  func (c *ClientUDPSession) GetCachedServerState(serverSessionID string) UDPClientPacketProcessorCachedState {
    57  	c.locker.RLock()
    58  	defer c.locker.RUnlock()
    59  
    60  	clientSessionID := c.getCachedStateAlias(serverSessionID)
    61  	if clientSessionID == "" {
    62  		return nil
    63  	}
    64  	state, ok := c.sessionMap[clientSessionID]
    65  	if !ok {
    66  		return nil
    67  	}
    68  
    69  	if serverState, ok := state.trackedServerSessionID[serverSessionID]; !ok {
    70  		return nil
    71  	} else {
    72  		return serverState.cachedRecvProcessorState
    73  	}
    74  }
    75  
    76  func (c *ClientUDPSession) getCachedStateAlias(serverSessionID string) string {
    77  	state, ok := c.sessionMapAlias[serverSessionID]
    78  	if !ok {
    79  		return ""
    80  	}
    81  	return state
    82  }
    83  
    84  func (c *ClientUDPSession) PutCachedState(sessionID string, cache UDPClientPacketProcessorCachedState) {
    85  	c.locker.RLock()
    86  	defer c.locker.RUnlock()
    87  
    88  	state, ok := c.sessionMap[sessionID]
    89  	if !ok {
    90  		return
    91  	}
    92  	state.cachedProcessorState = cache
    93  }
    94  
    95  func (c *ClientUDPSession) PutCachedServerState(serverSessionID string, cache UDPClientPacketProcessorCachedState) {
    96  	c.locker.RLock()
    97  	defer c.locker.RUnlock()
    98  
    99  	clientSessionID := c.getCachedStateAlias(serverSessionID)
   100  	if clientSessionID == "" {
   101  		return
   102  	}
   103  	state, ok := c.sessionMap[clientSessionID]
   104  	if !ok {
   105  		return
   106  	}
   107  
   108  	if serverState, ok := state.trackedServerSessionID[serverSessionID]; ok {
   109  		serverState.cachedRecvProcessorState = cache
   110  		return
   111  	}
   112  }
   113  
   114  func (c *ClientUDPSession) Close() error {
   115  	c.finish()
   116  	return c.conn.Close()
   117  }
   118  
   119  func (c *ClientUDPSession) WriteUDPRequest(request *UDPRequest) error {
   120  	buffer := buf.New()
   121  	defer buffer.Release()
   122  	err := c.packetProcessor.EncodeUDPRequest(request, buffer, c)
   123  	if request.Payload != nil {
   124  		request.Payload.Release()
   125  	}
   126  	if err != nil {
   127  		return newError("unable to encode udp request").Base(err)
   128  	}
   129  	_, err = c.conn.Write(buffer.Bytes())
   130  	if err != nil {
   131  		return newError("unable to write to conn").Base(err)
   132  	}
   133  	return nil
   134  }
   135  
   136  func (c *ClientUDPSession) KeepReading() {
   137  	for c.ctx.Err() == nil {
   138  		udpResp := &UDPResponse{}
   139  		buffer := make([]byte, 1600)
   140  		n, err := c.conn.Read(buffer)
   141  		if err != nil {
   142  			newError("unable to read from conn").Base(err).WriteToLog()
   143  			return
   144  		}
   145  		if n != 0 {
   146  			err := c.packetProcessor.DecodeUDPResp(buffer[:n], udpResp, c)
   147  			if err != nil {
   148  				newError("unable to decode udp response").Base(err).WriteToLog()
   149  				continue
   150  			}
   151  
   152  			{
   153  				timeDifference := int64(udpResp.TimeStamp) - time.Now().Unix()
   154  				if timeDifference < -30 || timeDifference > 30 {
   155  					newError("udp packet timestamp difference too large, packet discarded, time diff = ", timeDifference).WriteToLog()
   156  					continue
   157  				}
   158  			}
   159  
   160  			c.locker.RLock()
   161  			session, ok := c.sessionMap[string(udpResp.ClientSessionID[:])]
   162  			c.locker.RUnlock()
   163  			if ok {
   164  				select {
   165  				case session.readChan <- udpResp:
   166  				default:
   167  				}
   168  			} else {
   169  				newError("misbehaving server: unknown client session ID").Base(err).WriteToLog()
   170  			}
   171  		}
   172  	}
   173  }
   174  
   175  func (c *ClientUDPSession) NewSessionConn() (internet.AbstractPacketConn, error) {
   176  	sessionID := make([]byte, 8)
   177  	_, err := rand.Read(sessionID)
   178  	if err != nil {
   179  		return nil, newError("unable to generate session id").Base(err)
   180  	}
   181  
   182  	connctx, connfinish := context.WithCancel(c.ctx)
   183  
   184  	sessionConn := &ClientUDPSessionConn{
   185  		sessionID:              string(sessionID),
   186  		readChan:               make(chan *UDPResponse, 128),
   187  		parent:                 c,
   188  		ctx:                    connctx,
   189  		finish:                 connfinish,
   190  		nextWritePacketID:      0,
   191  		trackedServerSessionID: make(map[string]*ClientUDPSessionServerTracker),
   192  	}
   193  	c.locker.Lock()
   194  	c.sessionMap[sessionConn.sessionID] = sessionConn
   195  	c.locker.Unlock()
   196  	return sessionConn, nil
   197  }
   198  
   199  type ClientUDPSessionServerTracker struct {
   200  	cachedRecvProcessorState UDPClientPacketProcessorCachedState
   201  	rxReplayDetector         replaydetector.ReplayDetector
   202  	lastSeen                 time.Time
   203  }
   204  
   205  type ClientUDPSessionConn struct {
   206  	sessionID string
   207  	readChan  chan *UDPResponse
   208  	parent    *ClientUDPSession
   209  
   210  	nextWritePacketID      uint64
   211  	trackedServerSessionID map[string]*ClientUDPSessionServerTracker
   212  
   213  	cachedProcessorState UDPClientPacketProcessorCachedState
   214  
   215  	ctx    context.Context
   216  	finish func()
   217  }
   218  
   219  func (c *ClientUDPSessionConn) Close() error {
   220  	c.parent.locker.Lock()
   221  	delete(c.parent.sessionMap, c.sessionID)
   222  	for k := range c.trackedServerSessionID {
   223  		delete(c.parent.sessionMapAlias, k)
   224  	}
   225  	c.parent.locker.Unlock()
   226  	c.finish()
   227  	return nil
   228  }
   229  
   230  func (c *ClientUDPSessionConn) WriteTo(p []byte, addr gonet.Addr) (n int, err error) {
   231  	thisPacketID := c.nextWritePacketID
   232  	c.nextWritePacketID += 1
   233  	req := &UDPRequest{
   234  		SessionID: [8]byte{},
   235  		PacketID:  thisPacketID,
   236  		TimeStamp: uint64(time.Now().Unix()),
   237  		Address:   net.IPAddress(addr.(*gonet.UDPAddr).IP),
   238  		Port:      addr.(*net.UDPAddr).Port,
   239  		Payload:   nil,
   240  	}
   241  	copy(req.SessionID[:], c.sessionID)
   242  	req.Payload = buf.New()
   243  	req.Payload.Write(p)
   244  	err = c.parent.WriteUDPRequest(req)
   245  	if err != nil {
   246  		return 0, newError("unable to write to parent session").Base(err)
   247  	}
   248  	return len(p), nil
   249  }
   250  
   251  func (c *ClientUDPSessionConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
   252  	for {
   253  		select {
   254  		case <-c.ctx.Done():
   255  			return 0, nil, io.EOF
   256  		case resp := <-c.readChan:
   257  			n = copy(p, resp.Payload.Bytes())
   258  			resp.Payload.Release()
   259  
   260  			var trackedState *ClientUDPSessionServerTracker
   261  			if trackedStateReceived, ok := c.trackedServerSessionID[string(resp.SessionID[:])]; !ok {
   262  				for key, value := range c.trackedServerSessionID {
   263  					if time.Since(value.lastSeen) > 65*time.Second {
   264  						delete(c.trackedServerSessionID, key)
   265  					}
   266  				}
   267  
   268  				state := &ClientUDPSessionServerTracker{
   269  					rxReplayDetector: replaydetector.New(1024, ^uint64(0)),
   270  				}
   271  				c.trackedServerSessionID[string(resp.SessionID[:])] = state
   272  				c.parent.locker.RLock()
   273  				c.parent.sessionMapAlias[string(resp.SessionID[:])] = string(resp.ClientSessionID[:])
   274  				c.parent.locker.RUnlock()
   275  				trackedState = state
   276  			} else {
   277  				trackedState = trackedStateReceived
   278  			}
   279  
   280  			if accept, ok := trackedState.rxReplayDetector.Check(resp.PacketID); ok {
   281  				accept()
   282  			} else {
   283  				newError("misbehaving server: replayed packet").Base(err).WriteToLog()
   284  				continue
   285  			}
   286  			trackedState.lastSeen = time.Now()
   287  
   288  			addr = &net.UDPAddr{IP: resp.Address.IP(), Port: resp.Port}
   289  		}
   290  		return n, addr, nil
   291  	}
   292  }