github.com/artpar/rclone@v1.67.3/backend/smb/connpool.go (about)

     1  package smb
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"net"
     7  	"time"
     8  
     9  	"github.com/artpar/rclone/fs"
    10  	"github.com/artpar/rclone/fs/accounting"
    11  	"github.com/artpar/rclone/fs/config/obscure"
    12  	"github.com/artpar/rclone/fs/fshttp"
    13  	smb2 "github.com/cloudsoda/go-smb2"
    14  )
    15  
    16  // dial starts a client connection to the given SMB server. It is a
    17  // convenience function that connects to the given network address,
    18  // initiates the SMB handshake, and then sets up a Client.
    19  func (f *Fs) dial(ctx context.Context, network, addr string) (*conn, error) {
    20  	dialer := fshttp.NewDialer(ctx)
    21  	tconn, err := dialer.Dial(network, addr)
    22  	if err != nil {
    23  		return nil, err
    24  	}
    25  
    26  	pass := ""
    27  	if f.opt.Pass != "" {
    28  		pass, err = obscure.Reveal(f.opt.Pass)
    29  		if err != nil {
    30  			return nil, err
    31  		}
    32  	}
    33  
    34  	d := &smb2.Dialer{
    35  		Initiator: &smb2.NTLMInitiator{
    36  			User:      f.opt.User,
    37  			Password:  pass,
    38  			Domain:    f.opt.Domain,
    39  			TargetSPN: f.opt.SPN,
    40  		},
    41  	}
    42  
    43  	session, err := d.DialConn(ctx, tconn, addr)
    44  	if err != nil {
    45  		return nil, err
    46  	}
    47  
    48  	return &conn{
    49  		smbSession: session,
    50  		conn:       &tconn,
    51  	}, nil
    52  }
    53  
    54  // conn encapsulates a SMB client and corresponding SMB client
    55  type conn struct {
    56  	conn       *net.Conn
    57  	smbSession *smb2.Session
    58  	smbShare   *smb2.Share
    59  	shareName  string
    60  }
    61  
    62  // Closes the connection
    63  func (c *conn) close() (err error) {
    64  	if c.smbShare != nil {
    65  		err = c.smbShare.Umount()
    66  	}
    67  	sessionLogoffErr := c.smbSession.Logoff()
    68  	if err != nil {
    69  		return err
    70  	}
    71  	return sessionLogoffErr
    72  }
    73  
    74  // True if it's closed
    75  func (c *conn) closed() bool {
    76  	var nopErr error
    77  	if c.smbShare != nil {
    78  		// stat the current directory
    79  		_, nopErr = c.smbShare.Stat(".")
    80  	} else {
    81  		// list the shares
    82  		_, nopErr = c.smbSession.ListSharenames()
    83  	}
    84  	return nopErr != nil
    85  }
    86  
    87  // Show that we are using a SMB session
    88  //
    89  // Call removeSession() when done
    90  func (f *Fs) addSession() {
    91  	f.sessions.Add(1)
    92  }
    93  
    94  // Show the SMB session is no longer in use
    95  func (f *Fs) removeSession() {
    96  	f.sessions.Add(-1)
    97  }
    98  
    99  // getSessions shows whether there are any sessions in use
   100  func (f *Fs) getSessions() int32 {
   101  	return f.sessions.Load()
   102  }
   103  
   104  // Open a new connection to the SMB server.
   105  func (f *Fs) newConnection(ctx context.Context, share string) (c *conn, err error) {
   106  	// As we are pooling these connections we need to decouple
   107  	// them from the current context
   108  	bgCtx := context.Background()
   109  
   110  	c, err = f.dial(bgCtx, "tcp", f.opt.Host+":"+f.opt.Port)
   111  	if err != nil {
   112  		return nil, fmt.Errorf("couldn't connect SMB: %w", err)
   113  	}
   114  	if share != "" {
   115  		// mount the specified share as well if user requested
   116  		c.smbShare, err = c.smbSession.Mount(share)
   117  		if err != nil {
   118  			_ = c.smbSession.Logoff()
   119  			return nil, fmt.Errorf("couldn't initialize SMB: %w", err)
   120  		}
   121  		c.smbShare = c.smbShare.WithContext(bgCtx)
   122  	}
   123  	return c, nil
   124  }
   125  
   126  // Ensure the specified share is mounted or the session is unmounted
   127  func (c *conn) mountShare(share string) (err error) {
   128  	if c.shareName == share {
   129  		return nil
   130  	}
   131  	if c.smbShare != nil {
   132  		err = c.smbShare.Umount()
   133  		c.smbShare = nil
   134  	}
   135  	if err != nil {
   136  		return
   137  	}
   138  	if share != "" {
   139  		c.smbShare, err = c.smbSession.Mount(share)
   140  		if err != nil {
   141  			return
   142  		}
   143  	}
   144  	c.shareName = share
   145  	return nil
   146  }
   147  
   148  // Get a SMB connection from the pool, or open a new one
   149  func (f *Fs) getConnection(ctx context.Context, share string) (c *conn, err error) {
   150  	accounting.LimitTPS(ctx)
   151  	f.poolMu.Lock()
   152  	for len(f.pool) > 0 {
   153  		c = f.pool[0]
   154  		f.pool = f.pool[1:]
   155  		err = c.mountShare(share)
   156  		if err == nil {
   157  			break
   158  		}
   159  		fs.Debugf(f, "Discarding unusable SMB connection: %v", err)
   160  		c = nil
   161  	}
   162  	f.poolMu.Unlock()
   163  	if c != nil {
   164  		return c, nil
   165  	}
   166  	err = f.pacer.Call(func() (bool, error) {
   167  		c, err = f.newConnection(ctx, share)
   168  		if err != nil {
   169  			return true, err
   170  		}
   171  		return false, nil
   172  	})
   173  	return c, err
   174  }
   175  
   176  // Return a SMB connection to the pool
   177  //
   178  // It nils the pointed to connection out so it can't be reused
   179  func (f *Fs) putConnection(pc **conn) {
   180  	c := *pc
   181  	*pc = nil
   182  
   183  	var nopErr error
   184  	if c.smbShare != nil {
   185  		// stat the current directory
   186  		_, nopErr = c.smbShare.Stat(".")
   187  	} else {
   188  		// list the shares
   189  		_, nopErr = c.smbSession.ListSharenames()
   190  	}
   191  	if nopErr != nil {
   192  		fs.Debugf(f, "Connection failed, closing: %v", nopErr)
   193  		_ = c.close()
   194  		return
   195  	}
   196  
   197  	f.poolMu.Lock()
   198  	f.pool = append(f.pool, c)
   199  	if f.opt.IdleTimeout > 0 {
   200  		f.drain.Reset(time.Duration(f.opt.IdleTimeout)) // nudge on the pool emptying timer
   201  	}
   202  	f.poolMu.Unlock()
   203  }
   204  
   205  // Drain the pool of any connections
   206  func (f *Fs) drainPool(ctx context.Context) (err error) {
   207  	f.poolMu.Lock()
   208  	defer f.poolMu.Unlock()
   209  	if sessions := f.getSessions(); sessions != 0 {
   210  		fs.Debugf(f, "Not closing %d unused connections as %d sessions active", len(f.pool), sessions)
   211  		if f.opt.IdleTimeout > 0 {
   212  			f.drain.Reset(time.Duration(f.opt.IdleTimeout)) // nudge on the pool emptying timer
   213  		}
   214  		return nil
   215  	}
   216  	if f.opt.IdleTimeout > 0 {
   217  		f.drain.Stop()
   218  	}
   219  	if len(f.pool) != 0 {
   220  		fs.Debugf(f, "Closing %d unused connections", len(f.pool))
   221  	}
   222  	for i, c := range f.pool {
   223  		if !c.closed() {
   224  			cErr := c.close()
   225  			if cErr != nil {
   226  				err = cErr
   227  			}
   228  		}
   229  		f.pool[i] = nil
   230  	}
   231  	f.pool = nil
   232  	return err
   233  }