github.com/xmplusdev/xmcore@v1.8.11-0.20240412132628-5518b55526af/proxy/shadowsocks/client.go (about) 1 package shadowsocks 2 3 import ( 4 "context" 5 "time" 6 7 "github.com/xmplusdev/xmcore/common" 8 "github.com/xmplusdev/xmcore/common/buf" 9 "github.com/xmplusdev/xmcore/common/net" 10 "github.com/xmplusdev/xmcore/common/protocol" 11 "github.com/xmplusdev/xmcore/common/retry" 12 "github.com/xmplusdev/xmcore/common/session" 13 "github.com/xmplusdev/xmcore/common/signal" 14 "github.com/xmplusdev/xmcore/common/task" 15 "github.com/xmplusdev/xmcore/core" 16 "github.com/xmplusdev/xmcore/features/policy" 17 "github.com/xmplusdev/xmcore/transport" 18 "github.com/xmplusdev/xmcore/transport/internet" 19 "github.com/xmplusdev/xmcore/transport/internet/stat" 20 ) 21 22 // Client is a inbound handler for Shadowsocks protocol 23 type Client struct { 24 serverPicker protocol.ServerPicker 25 policyManager policy.Manager 26 } 27 28 // NewClient create a new Shadowsocks client. 29 func NewClient(ctx context.Context, config *ClientConfig) (*Client, error) { 30 serverList := protocol.NewServerList() 31 for _, rec := range config.Server { 32 s, err := protocol.NewServerSpecFromPB(rec) 33 if err != nil { 34 return nil, newError("failed to parse server spec").Base(err) 35 } 36 serverList.AddServer(s) 37 } 38 if serverList.Size() == 0 { 39 return nil, newError("0 server") 40 } 41 42 v := core.MustFromContext(ctx) 43 client := &Client{ 44 serverPicker: protocol.NewRoundRobinServerPicker(serverList), 45 policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager), 46 } 47 return client, nil 48 } 49 50 // Process implements OutboundHandler.Process(). 51 func (c *Client) Process(ctx context.Context, link *transport.Link, dialer internet.Dialer) error { 52 outbound := session.OutboundFromContext(ctx) 53 if outbound == nil || !outbound.Target.IsValid() { 54 return newError("target not specified") 55 } 56 outbound.Name = "shadowsocks" 57 inbound := session.InboundFromContext(ctx) 58 if inbound != nil { 59 inbound.SetCanSpliceCopy(3) 60 } 61 destination := outbound.Target 62 network := destination.Network 63 64 var server *protocol.ServerSpec 65 var conn stat.Connection 66 67 err := retry.ExponentialBackoff(5, 100).On(func() error { 68 server = c.serverPicker.PickServer() 69 dest := server.Destination() 70 dest.Network = network 71 rawConn, err := dialer.Dial(ctx, dest) 72 if err != nil { 73 return err 74 } 75 conn = rawConn 76 77 return nil 78 }) 79 if err != nil { 80 return newError("failed to find an available destination").AtWarning().Base(err) 81 } 82 newError("tunneling request to ", destination, " via ", network, ":", server.Destination().NetAddr()).WriteToLog(session.ExportIDToError(ctx)) 83 84 defer conn.Close() 85 86 request := &protocol.RequestHeader{ 87 Version: Version, 88 Address: destination.Address, 89 Port: destination.Port, 90 } 91 if destination.Network == net.Network_TCP { 92 request.Command = protocol.RequestCommandTCP 93 } else { 94 request.Command = protocol.RequestCommandUDP 95 } 96 97 user := server.PickUser() 98 _, ok := user.Account.(*MemoryAccount) 99 if !ok { 100 return newError("user account is not valid") 101 } 102 request.User = user 103 104 var newCtx context.Context 105 var newCancel context.CancelFunc 106 if session.TimeoutOnlyFromContext(ctx) { 107 newCtx, newCancel = context.WithCancel(context.Background()) 108 } 109 110 sessionPolicy := c.policyManager.ForLevel(user.Level) 111 ctx, cancel := context.WithCancel(ctx) 112 timer := signal.CancelAfterInactivity(ctx, func() { 113 cancel() 114 if newCancel != nil { 115 newCancel() 116 } 117 }, sessionPolicy.Timeouts.ConnectionIdle) 118 119 if newCtx != nil { 120 ctx = newCtx 121 } 122 123 if request.Command == protocol.RequestCommandTCP { 124 requestDone := func() error { 125 defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly) 126 bufferedWriter := buf.NewBufferedWriter(buf.NewWriter(conn)) 127 bodyWriter, err := WriteTCPRequest(request, bufferedWriter) 128 if err != nil { 129 return newError("failed to write request").Base(err) 130 } 131 132 if err = buf.CopyOnceTimeout(link.Reader, bodyWriter, time.Millisecond*100); err != nil && err != buf.ErrNotTimeoutReader && err != buf.ErrReadTimeout { 133 return newError("failed to write A request payload").Base(err).AtWarning() 134 } 135 136 if err := bufferedWriter.SetBuffered(false); err != nil { 137 return err 138 } 139 140 return buf.Copy(link.Reader, bodyWriter, buf.UpdateActivity(timer)) 141 } 142 143 responseDone := func() error { 144 defer timer.SetTimeout(sessionPolicy.Timeouts.UplinkOnly) 145 146 responseReader, err := ReadTCPResponse(user, conn) 147 if err != nil { 148 return err 149 } 150 151 return buf.Copy(responseReader, link.Writer, buf.UpdateActivity(timer)) 152 } 153 154 responseDoneAndCloseWriter := task.OnSuccess(responseDone, task.Close(link.Writer)) 155 if err := task.Run(ctx, requestDone, responseDoneAndCloseWriter); err != nil { 156 return newError("connection ends").Base(err) 157 } 158 159 return nil 160 } 161 162 if request.Command == protocol.RequestCommandUDP { 163 164 requestDone := func() error { 165 defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly) 166 167 writer := &UDPWriter{ 168 Writer: conn, 169 Request: request, 170 } 171 172 if err := buf.Copy(link.Reader, writer, buf.UpdateActivity(timer)); err != nil { 173 return newError("failed to transport all UDP request").Base(err) 174 } 175 return nil 176 } 177 178 responseDone := func() error { 179 defer timer.SetTimeout(sessionPolicy.Timeouts.UplinkOnly) 180 181 reader := &UDPReader{ 182 Reader: conn, 183 User: user, 184 } 185 186 if err := buf.Copy(reader, link.Writer, buf.UpdateActivity(timer)); err != nil { 187 return newError("failed to transport all UDP response").Base(err) 188 } 189 return nil 190 } 191 192 responseDoneAndCloseWriter := task.OnSuccess(responseDone, task.Close(link.Writer)) 193 if err := task.Run(ctx, requestDone, responseDoneAndCloseWriter); err != nil { 194 return newError("connection ends").Base(err) 195 } 196 197 return nil 198 } 199 200 return nil 201 } 202 203 func init() { 204 common.Must(common.RegisterConfig((*ClientConfig)(nil), func(ctx context.Context, config interface{}) (interface{}, error) { 205 return NewClient(ctx, config.(*ClientConfig)) 206 })) 207 }