github.com/cyverse/go-irodsclient@v0.13.2/irods/connection/connection.go (about)

     1  package connection
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"crypto/rand"
     7  	"crypto/tls"
     8  	"crypto/x509"
     9  	"encoding/binary"
    10  	"encoding/hex"
    11  	"fmt"
    12  	"io"
    13  	"net"
    14  	"sync"
    15  	"time"
    16  
    17  	"github.com/cyverse/go-irodsclient/irods/auth"
    18  	"github.com/cyverse/go-irodsclient/irods/common"
    19  	"github.com/cyverse/go-irodsclient/irods/message"
    20  	"github.com/cyverse/go-irodsclient/irods/metrics"
    21  	"github.com/cyverse/go-irodsclient/irods/types"
    22  	"github.com/cyverse/go-irodsclient/irods/util"
    23  	"golang.org/x/xerrors"
    24  
    25  	log "github.com/sirupsen/logrus"
    26  )
    27  
    28  const (
    29  	TCPBufferSizeDefault int = 4 * 1024 * 1024
    30  )
    31  
    32  // IRODSConnection connects to iRODS
    33  type IRODSConnection struct {
    34  	account         *types.IRODSAccount
    35  	requestTimeout  time.Duration
    36  	tcpBufferSize   int
    37  	applicationName string
    38  
    39  	connected               bool
    40  	socket                  net.Conn
    41  	serverVersion           *types.IRODSVersion
    42  	generatedPasswordForPAM string // used for PAM auth
    43  	creationTime            time.Time
    44  	lastSuccessfulAccess    time.Time
    45  	clientSignature         string
    46  	dirtyTransaction        bool
    47  	mutex                   sync.Mutex
    48  	locked                  bool // true if mutex is locked
    49  
    50  	metrics *metrics.IRODSMetrics
    51  }
    52  
    53  // NewIRODSConnection create a IRODSConnection
    54  func NewIRODSConnection(account *types.IRODSAccount, requestTimeout time.Duration, applicationName string) *IRODSConnection {
    55  	return &IRODSConnection{
    56  		account:         account,
    57  		requestTimeout:  requestTimeout,
    58  		tcpBufferSize:   TCPBufferSizeDefault,
    59  		applicationName: applicationName,
    60  
    61  		creationTime:     time.Now(),
    62  		clientSignature:  "",
    63  		dirtyTransaction: false,
    64  		mutex:            sync.Mutex{},
    65  
    66  		metrics: &metrics.IRODSMetrics{},
    67  	}
    68  }
    69  
    70  // NewIRODSConnectionWithMetrics create a IRODSConnection
    71  func NewIRODSConnectionWithMetrics(account *types.IRODSAccount, requestTimeout time.Duration, applicationName string, metrics *metrics.IRODSMetrics) *IRODSConnection {
    72  	return &IRODSConnection{
    73  		account:         account,
    74  		requestTimeout:  requestTimeout,
    75  		tcpBufferSize:   TCPBufferSizeDefault,
    76  		applicationName: applicationName,
    77  
    78  		creationTime:     time.Now(),
    79  		clientSignature:  "",
    80  		dirtyTransaction: false,
    81  		mutex:            sync.Mutex{},
    82  
    83  		metrics: metrics,
    84  	}
    85  }
    86  
    87  // Lock locks connection
    88  func (conn *IRODSConnection) Lock() {
    89  	conn.mutex.Lock()
    90  	conn.locked = true
    91  }
    92  
    93  // Unlock unlocks connection
    94  func (conn *IRODSConnection) Unlock() {
    95  	conn.locked = false
    96  	conn.mutex.Unlock()
    97  }
    98  
    99  // GetAccount returns iRODSAccount
   100  func (conn *IRODSConnection) GetAccount() *types.IRODSAccount {
   101  	return conn.account
   102  }
   103  
   104  // GetVersion returns iRODS version
   105  func (conn *IRODSConnection) GetVersion() *types.IRODSVersion {
   106  	return conn.serverVersion
   107  }
   108  
   109  // SetTCPBufferSize sets TCP Buffer Size
   110  func (conn *IRODSConnection) SetTCPBufferSize(bufferSize int) {
   111  	conn.tcpBufferSize = bufferSize
   112  }
   113  
   114  // SupportParallelUpload checks if the server supports parallel upload
   115  // available from 4.2.9
   116  func (conn *IRODSConnection) SupportParallelUpload() bool {
   117  	return conn.serverVersion.HasHigherVersionThan(4, 2, 9)
   118  }
   119  
   120  func (conn *IRODSConnection) requiresCSNegotiation() bool {
   121  	return conn.account.ClientServerNegotiation
   122  }
   123  
   124  // GetGeneratedPasswordForPAMAuth returns generated Password For PAM Auth
   125  func (conn *IRODSConnection) GetGeneratedPasswordForPAMAuth() string {
   126  	return conn.generatedPasswordForPAM
   127  }
   128  
   129  // IsConnected returns if the connection is live
   130  func (conn *IRODSConnection) IsConnected() bool {
   131  	return conn.connected
   132  }
   133  
   134  // GetCreationTime returns creation time
   135  func (conn *IRODSConnection) GetCreationTime() time.Time {
   136  	return conn.creationTime
   137  }
   138  
   139  // GetLastSuccessfulAccess returns last successful access time
   140  func (conn *IRODSConnection) GetLastSuccessfulAccess() time.Time {
   141  	return conn.lastSuccessfulAccess
   142  }
   143  
   144  // GetClientSignature returns client signature to be used in password obfuscation
   145  func (conn *IRODSConnection) GetClientSignature() string {
   146  	return conn.clientSignature
   147  }
   148  
   149  // SetTransactionDirty sets if transaction is dirty
   150  func (conn *IRODSConnection) SetTransactionDirty(dirtyTransaction bool) {
   151  	conn.dirtyTransaction = dirtyTransaction
   152  }
   153  
   154  // IsTransactionDirty returns true if transaction is dirty
   155  func (conn *IRODSConnection) IsTransactionDirty() bool {
   156  	return conn.dirtyTransaction
   157  }
   158  
   159  // Connect connects to iRODS
   160  func (conn *IRODSConnection) Connect() error {
   161  	logger := log.WithFields(log.Fields{
   162  		"package":  "connection",
   163  		"struct":   "IRODSConnection",
   164  		"function": "Connect",
   165  	})
   166  
   167  	conn.connected = false
   168  
   169  	conn.account.FixAuthConfiguration()
   170  
   171  	err := conn.account.Validate()
   172  	if err != nil {
   173  		return xerrors.Errorf("invalid account (%s): %w", err.Error(), types.NewConnectionConfigError(conn.account))
   174  	}
   175  
   176  	// lock the connection
   177  	conn.Lock()
   178  	defer conn.Unlock()
   179  
   180  	server := fmt.Sprintf("%s:%d", conn.account.Host, conn.account.Port)
   181  	logger.Debugf("Connecting to %s", server)
   182  
   183  	// must connect to the server in 10 sec
   184  	var dialer net.Dialer
   185  	ctx, cancelFunc := context.WithTimeout(context.Background(), 10*time.Second)
   186  	defer cancelFunc()
   187  
   188  	socket, err := dialer.DialContext(ctx, "tcp", server)
   189  	if err != nil {
   190  		connErr := xerrors.Errorf("failed to connect to specified host %s and port %d (%s): %w", conn.account.Host, conn.account.Port, err.Error(), types.NewConnectionError())
   191  		logger.Errorf("%+v", connErr)
   192  
   193  		if conn.metrics != nil {
   194  			conn.metrics.IncreaseCounterForConnectionFailures(1)
   195  		}
   196  		return connErr
   197  	}
   198  
   199  	if tcpSocket, ok := socket.(*net.TCPConn); ok {
   200  		sockErr := tcpSocket.SetReadBuffer(conn.tcpBufferSize)
   201  		if sockErr != nil {
   202  			sockBuffErr := xerrors.Errorf("failed to set tcp read buffer size %d: %w", conn.tcpBufferSize, sockErr)
   203  			logger.Errorf("%+v", sockBuffErr)
   204  		}
   205  
   206  		sockErr = tcpSocket.SetWriteBuffer(conn.tcpBufferSize)
   207  		if sockErr != nil {
   208  			sockBuffErr := xerrors.Errorf("failed to set tcp write buffer size %d: %w", conn.tcpBufferSize, sockErr)
   209  			logger.Errorf("%+v", sockBuffErr)
   210  		}
   211  	}
   212  
   213  	if conn.metrics != nil {
   214  		conn.metrics.IncreaseConnectionsOpened(1)
   215  	}
   216  
   217  	conn.socket = socket
   218  	var irodsVersion *types.IRODSVersion
   219  
   220  	if conn.requiresCSNegotiation() {
   221  		// client-server negotiation
   222  		irodsVersion, err = conn.connectWithCSNegotiation()
   223  	} else {
   224  		// No client-server negotiation
   225  		irodsVersion, err = conn.connectWithoutCSNegotiation()
   226  	}
   227  
   228  	if err != nil {
   229  		connErr := xerrors.Errorf("failed to startup an iRODS connection to server %s and port %d (%s): %w", conn.account.Host, conn.account.Port, err.Error(), types.NewConnectionError())
   230  		logger.Errorf("%+v", connErr)
   231  		_ = conn.disconnectNow()
   232  		if conn.metrics != nil {
   233  			conn.metrics.IncreaseCounterForConnectionFailures(1)
   234  		}
   235  		return connErr
   236  	}
   237  
   238  	conn.serverVersion = irodsVersion
   239  
   240  	switch conn.account.AuthenticationScheme {
   241  	case types.AuthSchemeNative:
   242  		err = conn.loginNative(conn.account.Password)
   243  	case types.AuthSchemeGSI:
   244  		err = conn.loginGSI()
   245  	case types.AuthSchemePAM:
   246  		err = conn.loginPAM()
   247  	default:
   248  		logger.Errorf("unknown Authentication Scheme - %s", conn.account.AuthenticationScheme)
   249  		return xerrors.Errorf("unknown Authentication Scheme - %s: %w", conn.account.AuthenticationScheme, types.NewConnectionConfigError(conn.account))
   250  	}
   251  
   252  	if err != nil {
   253  		connErr := xerrors.Errorf("failed to login to irods: %w", err)
   254  		logger.Errorf("%+v", connErr)
   255  		_ = conn.disconnectNow()
   256  		return connErr
   257  	}
   258  
   259  	if conn.account.UseTicket() {
   260  		req := message.NewIRODSMessageTicketAdminRequest("session", conn.account.Ticket)
   261  		err := conn.RequestAndCheck(req, &message.IRODSMessageAdminResponse{}, nil)
   262  		if err != nil {
   263  			return xerrors.Errorf("received supply ticket error (%s): %w", err.Error(), types.NewAuthError(conn.account))
   264  		}
   265  	}
   266  
   267  	conn.connected = true
   268  	conn.lastSuccessfulAccess = time.Now()
   269  
   270  	return nil
   271  }
   272  
   273  func (conn *IRODSConnection) connectWithCSNegotiation() (*types.IRODSVersion, error) {
   274  	logger := log.WithFields(log.Fields{
   275  		"package":  "connection",
   276  		"struct":   "IRODSConnection",
   277  		"function": "connectWithCSNegotiation",
   278  	})
   279  
   280  	// Get client negotiation policy
   281  	clientPolicy := types.CSNegotiationRequireTCP
   282  	if len(conn.account.CSNegotiationPolicy) > 0 {
   283  		clientPolicy = conn.account.CSNegotiationPolicy
   284  	}
   285  
   286  	// Send a startup message
   287  	logger.Debug("Start up a connection with CS Negotiation")
   288  
   289  	startup := message.NewIRODSMessageStartupPack(conn.account, conn.applicationName, true)
   290  	err := conn.RequestWithoutResponse(startup)
   291  	if err != nil {
   292  		return nil, xerrors.Errorf("failed to send startup (%s): %w", err.Error(), types.NewConnectionError())
   293  	}
   294  
   295  	// Server responds with negotiation response
   296  	negotiationMessage, err := conn.ReadMessage(nil)
   297  	if err != nil {
   298  		return nil, xerrors.Errorf("failed to receive negotiation message (%s): %w", err.Error(), types.NewConnectionError())
   299  	}
   300  
   301  	if negotiationMessage.Body == nil {
   302  		return nil, xerrors.Errorf("failed to receive negotiation message body: %w", types.NewConnectionError())
   303  	}
   304  
   305  	if negotiationMessage.Body.Type == message.RODS_MESSAGE_VERSION_TYPE {
   306  		// this happens when an error occur
   307  		// Server responds with version
   308  		version := message.IRODSMessageVersion{}
   309  		err = version.FromMessage(negotiationMessage)
   310  		if err != nil {
   311  			return nil, xerrors.Errorf("failed to receive negotiation message (%s): %w", err.Error(), types.NewConnectionError())
   312  		}
   313  
   314  		return version.GetVersion(), nil
   315  	} else if negotiationMessage.Body.Type == message.RODS_MESSAGE_CS_NEG_TYPE {
   316  		// Server responds with its own negotiation policy
   317  		logger.Debug("Start up CS Negotiation")
   318  
   319  		negotiation := message.IRODSMessageCSNegotiation{}
   320  		err = negotiation.FromMessage(negotiationMessage)
   321  		if err != nil {
   322  			return nil, xerrors.Errorf("failed to receive negotiation message (%s): %w", err.Error(), types.NewConnectionError())
   323  		}
   324  
   325  		serverPolicy, err := types.GetCSNegotiationRequire(negotiation.Result)
   326  		if err != nil {
   327  			return nil, xerrors.Errorf("failed to parse server policy (%s): %w", err.Error(), types.NewConnectionError())
   328  		}
   329  
   330  		logger.Debugf("Client policy - %s, server policy - %s", clientPolicy, serverPolicy)
   331  
   332  		// Perform the negotiation
   333  		policyResult := types.PerformCSNegotiation(clientPolicy, serverPolicy)
   334  
   335  		// If negotiation failed we're done
   336  		if policyResult == types.CSNegotiationFailure {
   337  			return nil, xerrors.Errorf("client-server negotiation failed - %s, %s: %w", string(clientPolicy), string(serverPolicy), types.NewConnectionError())
   338  		}
   339  
   340  		// Send negotiation result to server
   341  		negotiationResult := message.NewIRODSMessageCSNegotiation(policyResult)
   342  		version := message.IRODSMessageVersion{}
   343  		err = conn.Request(negotiationResult, &version, nil)
   344  		if err != nil {
   345  			return nil, xerrors.Errorf("failed to receive version message (%s): %w", err.Error(), types.NewConnectionError())
   346  		}
   347  
   348  		if policyResult == types.CSNegotiationUseSSL {
   349  			err := conn.sslStartup()
   350  			if err != nil {
   351  				return nil, xerrors.Errorf("failed to start up SSL: %w", err)
   352  			}
   353  		}
   354  
   355  		return version.GetVersion(), nil
   356  	}
   357  
   358  	return nil, xerrors.Errorf("unknown response message '%s': %w", negotiationMessage.Body.Type, types.NewConnectionError())
   359  }
   360  
   361  func (conn *IRODSConnection) connectWithoutCSNegotiation() (*types.IRODSVersion, error) {
   362  	logger := log.WithFields(log.Fields{
   363  		"package":  "connection",
   364  		"struct":   "IRODSConnection",
   365  		"function": "connectWithoutCSNegotiation",
   366  	})
   367  
   368  	// No client-server negotiation
   369  	// Send a startup message
   370  	logger.Debug("Start up connection without CS Negotiation")
   371  
   372  	startup := message.NewIRODSMessageStartupPack(conn.account, conn.applicationName, false)
   373  	version := message.IRODSMessageVersion{}
   374  	err := conn.Request(startup, &version, nil)
   375  	if err != nil {
   376  		return nil, xerrors.Errorf("failed to receive version message (%s): %w", err.Error(), types.NewConnectionError())
   377  	}
   378  
   379  	return version.GetVersion(), nil
   380  }
   381  
   382  func (conn *IRODSConnection) sslStartup() error {
   383  	logger := log.WithFields(log.Fields{
   384  		"package":  "connection",
   385  		"struct":   "IRODSConnection",
   386  		"function": "sslStartup",
   387  	})
   388  
   389  	logger.Debug("Start up SSL")
   390  
   391  	irodsSSLConfig := conn.account.SSLConfiguration
   392  	if irodsSSLConfig == nil {
   393  		return xerrors.Errorf("SSL Configuration is not set: %w", types.NewConnectionConfigError(conn.account))
   394  	}
   395  
   396  	caCertPool := x509.NewCertPool()
   397  	caCert, err := irodsSSLConfig.ReadCACert()
   398  	if err != nil {
   399  		logger.WithError(err).Warn("failed to read CA cert, ignoring...")
   400  	} else {
   401  		caCertPool.AppendCertsFromPEM(caCert)
   402  	}
   403  
   404  	sslConf := &tls.Config{
   405  		RootCAs:            caCertPool,
   406  		ServerName:         conn.account.Host,
   407  		InsecureSkipVerify: true,
   408  	}
   409  
   410  	// Create a side connection using the existing socket
   411  	sslSocket := tls.Client(conn.socket, sslConf)
   412  
   413  	err = sslSocket.Handshake()
   414  	if err != nil {
   415  		return xerrors.Errorf("SSL Handshake error (%s): %w", err.Error(), types.NewConnectionError())
   416  	}
   417  
   418  	// from now on use ssl socket
   419  	conn.socket = sslSocket
   420  
   421  	// Generate a key (shared secret)
   422  	encryptionKey := make([]byte, irodsSSLConfig.EncryptionKeySize)
   423  	_, err = rand.Read(encryptionKey)
   424  	if err != nil {
   425  		return xerrors.Errorf("failed to generate shared secret (%s): %w", err.Error(), types.NewConnectionError())
   426  	}
   427  
   428  	// Send a ssl setting
   429  	sslSetting := message.NewIRODSMessageSSLSettings(irodsSSLConfig.EncryptionAlgorithm, irodsSSLConfig.EncryptionKeySize, irodsSSLConfig.SaltSize, irodsSSLConfig.HashRounds)
   430  	err = conn.RequestWithoutResponse(sslSetting)
   431  	if err != nil {
   432  		return xerrors.Errorf("failed to send ssl setting message (%s): %w", err.Error(), types.NewConnectionError())
   433  	}
   434  
   435  	// Send a shared secret
   436  	sslSharedSecret := message.NewIRODSMessageSSLSharedSecret(encryptionKey)
   437  	err = conn.RequestWithoutResponseNoXML(sslSharedSecret)
   438  	if err != nil {
   439  		return xerrors.Errorf("failed to send ssl shared secret message (%s): %w", err.Error(), types.NewConnectionError())
   440  	}
   441  
   442  	return nil
   443  }
   444  
   445  func (conn *IRODSConnection) login(password string) error {
   446  	// authenticate
   447  	authRequest := message.NewIRODSMessageAuthRequest()
   448  	authChallenge := message.IRODSMessageAuthChallengeResponse{}
   449  	err := conn.Request(authRequest, &authChallenge, nil)
   450  	if err != nil {
   451  		return xerrors.Errorf("failed to receive authentication challenge message body (%s): %w", err.Error(), types.NewAuthError(conn.account))
   452  	}
   453  
   454  	challengeBytes, err := authChallenge.GetChallenge()
   455  	if err != nil {
   456  		return xerrors.Errorf("failed to get authentication challenge (%s): %w", err.Error(), types.NewAuthError(conn.account))
   457  	}
   458  
   459  	// save client signature
   460  	conn.clientSignature = conn.createClientSignature(challengeBytes)
   461  
   462  	encodedPassword := auth.GenerateAuthResponse(challengeBytes, password)
   463  
   464  	authResponse := message.NewIRODSMessageAuthResponse(encodedPassword, conn.account.ProxyUser)
   465  	authResult := message.IRODSMessageAuthResult{}
   466  	err = conn.RequestAndCheck(authResponse, &authResult, nil)
   467  	if err != nil {
   468  		return xerrors.Errorf("received irods authentication error (%s): %w", err.Error(), types.NewAuthError(conn.account))
   469  	}
   470  	return nil
   471  }
   472  
   473  func (conn *IRODSConnection) loginNative(password string) error {
   474  	logger := log.WithFields(log.Fields{
   475  		"package":  "connection",
   476  		"struct":   "IRODSConnection",
   477  		"function": "loginNative",
   478  	})
   479  
   480  	logger.Debug("Logging in using native authentication method")
   481  	return conn.login(password)
   482  }
   483  
   484  func (conn *IRODSConnection) loginGSI() error {
   485  	return xerrors.Errorf("GSI login is not yet implemented: %w", types.NewAuthError(conn.account))
   486  }
   487  
   488  func (conn *IRODSConnection) loginPAM() error {
   489  	logger := log.WithFields(log.Fields{
   490  		"package":  "connection",
   491  		"struct":   "IRODSConnection",
   492  		"function": "loginPAM",
   493  	})
   494  
   495  	logger.Debug("Logging in using pam authentication method")
   496  
   497  	// Check whether ssl has already started, if not, start ssl.
   498  	if _, ok := conn.socket.(*tls.Conn); !ok {
   499  		return xerrors.Errorf("connection should be using SSL: %w", types.NewConnectionError())
   500  	}
   501  
   502  	ttl := conn.account.PamTTL
   503  	if ttl <= 0 {
   504  		ttl = 1
   505  	}
   506  
   507  	// authenticate
   508  	pamAuthRequest := message.NewIRODSMessagePamAuthRequest(conn.account.ClientUser, conn.account.Password, ttl)
   509  	pamAuthResponse := message.IRODSMessagePamAuthResponse{}
   510  	err := conn.Request(pamAuthRequest, &pamAuthResponse, nil)
   511  	if err != nil {
   512  		return xerrors.Errorf("failed to receive an authentication challenge message (%s): %w", err.Error(), types.NewAuthError(conn.account))
   513  	}
   514  
   515  	// save irods generated password for possible future use
   516  	conn.generatedPasswordForPAM = pamAuthResponse.GeneratedPassword
   517  
   518  	// retry native auth with generated password
   519  	return conn.login(conn.generatedPasswordForPAM)
   520  }
   521  
   522  // Disconnect disconnects
   523  func (conn *IRODSConnection) disconnectNow() error {
   524  	conn.connected = false
   525  	var err error
   526  	if conn.socket != nil {
   527  		err = conn.socket.Close()
   528  		conn.socket = nil
   529  	}
   530  
   531  	if conn.metrics != nil {
   532  		conn.metrics.DecreaseConnectionsOpened(1)
   533  	}
   534  
   535  	if err == nil {
   536  		return nil
   537  	}
   538  
   539  	return xerrors.Errorf("failed to close socket: %w", err)
   540  }
   541  
   542  // Disconnect disconnects
   543  func (conn *IRODSConnection) Disconnect() error {
   544  	logger := log.WithFields(log.Fields{
   545  		"package":  "connection",
   546  		"struct":   "IRODSConnection",
   547  		"function": "Disconnect",
   548  	})
   549  
   550  	logger.Debug("Disconnecting the connection")
   551  
   552  	// lock the connection
   553  	conn.Lock()
   554  	defer conn.Unlock()
   555  
   556  	disconnect := message.NewIRODSMessageDisconnect()
   557  	err := conn.RequestWithoutResponse(disconnect)
   558  
   559  	conn.lastSuccessfulAccess = time.Now()
   560  
   561  	err2 := conn.disconnectNow()
   562  	if err2 != nil {
   563  		return err2
   564  	}
   565  
   566  	if err != nil {
   567  		return err
   568  	}
   569  
   570  	return nil
   571  }
   572  
   573  func (conn *IRODSConnection) socketFail() {
   574  	if conn.metrics != nil {
   575  		conn.metrics.IncreaseCounterForConnectionFailures(1)
   576  	}
   577  
   578  	conn.disconnectNow()
   579  }
   580  
   581  // Send sends data
   582  func (conn *IRODSConnection) Send(buffer []byte, size int) error {
   583  	return conn.SendWithTrackerCallBack(buffer, size, nil)
   584  }
   585  
   586  // SendWithTrackerCallBack sends data
   587  func (conn *IRODSConnection) SendWithTrackerCallBack(buffer []byte, size int, callback common.TrackerCallBack) error {
   588  	if conn.socket == nil {
   589  		return xerrors.Errorf("failed to send data - socket closed")
   590  	}
   591  
   592  	if !conn.locked {
   593  		return xerrors.Errorf("connection must be locked before use")
   594  	}
   595  
   596  	// use sslSocket
   597  	if conn.requestTimeout > 0 {
   598  		conn.socket.SetWriteDeadline(time.Now().Add(conn.requestTimeout))
   599  	}
   600  
   601  	err := util.WriteBytesWithTrackerCallBack(conn.socket, buffer, size, callback)
   602  	if err != nil {
   603  		conn.socketFail()
   604  		return xerrors.Errorf("failed to send data: %w", err)
   605  	}
   606  
   607  	if size > 0 {
   608  		if conn.metrics != nil {
   609  			conn.metrics.IncreaseBytesSent(uint64(size))
   610  		}
   611  	}
   612  
   613  	conn.lastSuccessfulAccess = time.Now()
   614  
   615  	return nil
   616  }
   617  
   618  // SendFromReader sends data from Reader
   619  func (conn *IRODSConnection) SendFromReader(src io.Reader, size int) error {
   620  	if conn.socket == nil {
   621  		return xerrors.Errorf("failed to send data - socket closed")
   622  	}
   623  
   624  	if !conn.locked {
   625  		return xerrors.Errorf("connection must be locked before use")
   626  	}
   627  
   628  	// use sslSocket
   629  	if conn.requestTimeout > 0 {
   630  		conn.socket.SetWriteDeadline(time.Now().Add(conn.requestTimeout))
   631  	}
   632  
   633  	copyLen, err := io.CopyN(conn.socket, src, int64(size))
   634  	if err != nil {
   635  		conn.socketFail()
   636  		return xerrors.Errorf("failed to send data: %w", err)
   637  	}
   638  
   639  	if copyLen != int64(size) {
   640  		conn.socketFail()
   641  		return xerrors.Errorf("failed to send data. failed to send data fully (requested %d vs sent %d)", size, copyLen)
   642  	}
   643  
   644  	if copyLen > 0 {
   645  		if conn.metrics != nil {
   646  			conn.metrics.IncreaseBytesSent(uint64(copyLen))
   647  		}
   648  	}
   649  
   650  	conn.lastSuccessfulAccess = time.Now()
   651  
   652  	return nil
   653  }
   654  
   655  // Recv receives a message
   656  func (conn *IRODSConnection) Recv(buffer []byte, size int) (int, error) {
   657  	return conn.RecvWithTrackerCallBack(buffer, size, nil)
   658  }
   659  
   660  // Recv receives a message
   661  func (conn *IRODSConnection) RecvWithTrackerCallBack(buffer []byte, size int, callback common.TrackerCallBack) (int, error) {
   662  	if conn.socket == nil {
   663  		return 0, xerrors.Errorf("failed to receive data - socket closed")
   664  	}
   665  
   666  	if !conn.locked {
   667  		return 0, xerrors.Errorf("connection must be locked before use")
   668  	}
   669  
   670  	if conn.requestTimeout > 0 {
   671  		conn.socket.SetReadDeadline(time.Now().Add(conn.requestTimeout))
   672  	}
   673  
   674  	readLen, err := util.ReadBytesWithTrackerCallBack(conn.socket, buffer, size, callback)
   675  	if err != nil {
   676  		conn.socketFail()
   677  		return readLen, xerrors.Errorf("failed to receive data: %w", err)
   678  	}
   679  
   680  	if readLen > 0 {
   681  		if conn.metrics != nil {
   682  			conn.metrics.IncreaseBytesReceived(uint64(readLen))
   683  		}
   684  	}
   685  
   686  	conn.lastSuccessfulAccess = time.Now()
   687  
   688  	return readLen, nil
   689  }
   690  
   691  // RecvToWriter receives a message to Writer
   692  func (conn *IRODSConnection) RecvToWriter(writer io.Writer, size int) (int, error) {
   693  	if conn.socket == nil {
   694  		return 0, xerrors.Errorf("failed to receive data - socket closed")
   695  	}
   696  
   697  	if !conn.locked {
   698  		return 0, xerrors.Errorf("connection must be locked before use")
   699  	}
   700  
   701  	if conn.requestTimeout > 0 {
   702  		conn.socket.SetReadDeadline(time.Now().Add(conn.requestTimeout))
   703  	}
   704  
   705  	copyLen, err := io.CopyN(writer, conn.socket, int64(size))
   706  	if err != nil {
   707  		conn.socketFail()
   708  		return int(copyLen), xerrors.Errorf("failed to receive data: %w", err)
   709  	}
   710  
   711  	if copyLen > 0 {
   712  		if conn.metrics != nil {
   713  			conn.metrics.IncreaseBytesReceived(uint64(copyLen))
   714  		}
   715  	}
   716  
   717  	conn.lastSuccessfulAccess = time.Now()
   718  
   719  	return int(copyLen), nil
   720  }
   721  
   722  // SendMessage makes the message into bytes
   723  func (conn *IRODSConnection) SendMessage(msg *message.IRODSMessage) error {
   724  	return conn.SendMessageWithTrackerCallBack(msg, nil)
   725  }
   726  
   727  // SendMessageWithTrackerCallBack makes the message into bytes
   728  func (conn *IRODSConnection) SendMessageWithTrackerCallBack(msg *message.IRODSMessage, callback common.TrackerCallBack) error {
   729  	if !conn.locked {
   730  		return xerrors.Errorf("connection must be locked before use")
   731  	}
   732  
   733  	messageBuffer := new(bytes.Buffer)
   734  
   735  	if msg.Header == nil && msg.Body == nil {
   736  		return xerrors.Errorf("header and body cannot be nil")
   737  	}
   738  
   739  	var headerBytes []byte
   740  	var err error
   741  
   742  	messageLen := 0
   743  	errorLen := 0
   744  	bsLen := 0
   745  
   746  	if msg.Body != nil {
   747  		if msg.Body.Message != nil {
   748  			messageLen = len(msg.Body.Message)
   749  		}
   750  
   751  		if msg.Body.Error != nil {
   752  			errorLen = len(msg.Body.Error)
   753  		}
   754  
   755  		if msg.Body.Bs != nil {
   756  			bsLen = len(msg.Body.Bs)
   757  		}
   758  
   759  		if msg.Header == nil {
   760  			h := message.MakeIRODSMessageHeader(msg.Body.Type, uint32(messageLen), uint32(errorLen), uint32(bsLen), msg.Body.IntInfo)
   761  			headerBytes, err = h.GetBytes()
   762  			if err != nil {
   763  				return err
   764  			}
   765  		}
   766  	}
   767  
   768  	if msg.Header != nil {
   769  		headerBytes, err = msg.Header.GetBytes()
   770  		if err != nil {
   771  			return err
   772  		}
   773  	}
   774  
   775  	// pack length - Big Endian
   776  	headerLenBuffer := make([]byte, 4)
   777  	binary.BigEndian.PutUint32(headerLenBuffer, uint32(len(headerBytes)))
   778  
   779  	// header
   780  	messageBuffer.Write(headerLenBuffer)
   781  	messageBuffer.Write(headerBytes)
   782  
   783  	if msg.Body != nil {
   784  		bodyBytes, err := msg.Body.GetBytesWithoutBS()
   785  		if err != nil {
   786  			return err
   787  		}
   788  
   789  		// body
   790  		messageBuffer.Write(bodyBytes)
   791  	}
   792  
   793  	// send
   794  	bytes := messageBuffer.Bytes()
   795  	err = conn.Send(bytes, len(bytes))
   796  	if err != nil {
   797  		return xerrors.Errorf("failed to send message: %w", err)
   798  	}
   799  
   800  	// send body-bs
   801  	if msg.Body != nil {
   802  		if msg.Body.Bs != nil {
   803  			conn.SendWithTrackerCallBack(msg.Body.Bs, len(msg.Body.Bs), callback)
   804  		}
   805  	}
   806  	return nil
   807  }
   808  
   809  // readMessageHeader reads data from the given connection and returns iRODSMessageHeader
   810  func (conn *IRODSConnection) readMessageHeader() (*message.IRODSMessageHeader, error) {
   811  	// read header size
   812  	headerLenBuffer := make([]byte, 4)
   813  	readLen, err := conn.Recv(headerLenBuffer, 4)
   814  	if err != nil {
   815  		return nil, xerrors.Errorf("failed to read header size: %w", err)
   816  	}
   817  
   818  	if readLen != 4 {
   819  		return nil, xerrors.Errorf("failed to read header size, read %d", readLen)
   820  	}
   821  
   822  	headerSize := binary.BigEndian.Uint32(headerLenBuffer)
   823  	if headerSize <= 0 {
   824  		return nil, xerrors.Errorf("invalid header size returned - len = %d", headerSize)
   825  	}
   826  
   827  	// read header
   828  	headerBuffer := make([]byte, headerSize)
   829  	readLen, err = conn.Recv(headerBuffer, int(headerSize))
   830  	if err != nil {
   831  		return nil, xerrors.Errorf("failed to read header: %w", err)
   832  	}
   833  	if readLen != int(headerSize) {
   834  		return nil, xerrors.Errorf("failed to read header fully - %d requested but %d read", headerSize, readLen)
   835  	}
   836  
   837  	header := message.IRODSMessageHeader{}
   838  	err = header.FromBytes(headerBuffer)
   839  	if err != nil {
   840  		return nil, err
   841  	}
   842  
   843  	return &header, nil
   844  }
   845  
   846  // ReadMessage reads data from the given socket and returns IRODSMessage
   847  // if bsBuffer is given, bs data will be written directly to the bsBuffer
   848  // if not given, a new buffer will be allocated.
   849  func (conn *IRODSConnection) ReadMessage(bsBuffer []byte) (*message.IRODSMessage, error) {
   850  	return conn.ReadMessageWithTrackerCallBack(bsBuffer, nil)
   851  }
   852  
   853  func (conn *IRODSConnection) ReadMessageWithTrackerCallBack(bsBuffer []byte, callback common.TrackerCallBack) (*message.IRODSMessage, error) {
   854  	if !conn.locked {
   855  		return nil, xerrors.Errorf("connection must be locked before use")
   856  	}
   857  
   858  	header, err := conn.readMessageHeader()
   859  	if err != nil {
   860  		return nil, err
   861  	}
   862  
   863  	// read body
   864  	bodyLen := header.MessageLen + header.ErrorLen
   865  	bodyBuffer := make([]byte, bodyLen)
   866  	if bsBuffer == nil {
   867  		bsBuffer = make([]byte, int(header.BsLen))
   868  	} else if len(bsBuffer) < int(header.BsLen) {
   869  		return nil, xerrors.Errorf("provided bs buffer is too short, %d size is given, but %d size is required", len(bsBuffer), int(header.BsLen))
   870  	}
   871  
   872  	bodyReadLen, err := conn.Recv(bodyBuffer, int(bodyLen))
   873  	if err != nil {
   874  		return nil, xerrors.Errorf("failed to read body: %w", err)
   875  	}
   876  	if bodyReadLen != int(bodyLen) {
   877  		return nil, xerrors.Errorf("failed to read body fully - %d requested but %d read", bodyLen, bodyReadLen)
   878  	}
   879  
   880  	bsReadLen, err := conn.RecvWithTrackerCallBack(bsBuffer, int(header.BsLen), callback)
   881  	if err != nil {
   882  		return nil, xerrors.Errorf("failed to read body (BS): %w", err)
   883  	}
   884  	if bsReadLen != int(header.BsLen) {
   885  		return nil, xerrors.Errorf("failed to read body (BS) fully - %d requested but %d read", int(header.BsLen), bsReadLen)
   886  	}
   887  
   888  	body := message.IRODSMessageBody{}
   889  	err = body.FromBytes(header, bodyBuffer, bsBuffer[:int(header.BsLen)])
   890  	if err != nil {
   891  		return nil, err
   892  	}
   893  
   894  	body.Type = header.Type
   895  	body.IntInfo = header.IntInfo
   896  
   897  	return &message.IRODSMessage{
   898  		Header: header,
   899  		Body:   &body,
   900  	}, nil
   901  }
   902  
   903  // Commit a transaction. This is useful in combination with the NO_COMMIT_FLAG.
   904  // Usage is limited to privileged accounts.
   905  func (conn *IRODSConnection) Commit() error {
   906  	if !conn.locked {
   907  		return xerrors.Errorf("connection must be locked before use")
   908  	}
   909  
   910  	return conn.endTransaction(true)
   911  }
   912  
   913  // Rollback a transaction. This is useful in combination with the NO_COMMIT_FLAG.
   914  // It can also be used to clear the current database transaction if there are no staged operations,
   915  // just to refresh the view on the database for future queries.
   916  // Usage is limited to privileged accounts.
   917  func (conn *IRODSConnection) Rollback() error {
   918  	if !conn.locked {
   919  		return xerrors.Errorf("connection must be locked before use")
   920  	}
   921  
   922  	return conn.endTransaction(false)
   923  }
   924  
   925  // PoorMansRollback rolls back a transaction as a nonprivileged account, bypassing API limitations.
   926  // A nonprivileged account cannot have staged operations, so rollback is always a no-op.
   927  // The usage for this function, is that rolling back the current database transaction still will start
   928  // a new one, so that future queries will see all changes that where made up to calling this function.
   929  func (conn *IRODSConnection) PoorMansRollback() error {
   930  	if !conn.locked {
   931  		return xerrors.Errorf("connection must be locked before use")
   932  	}
   933  
   934  	dummyCol := fmt.Sprintf("/%s/home/%s", conn.account.ClientZone, conn.account.ClientUser)
   935  
   936  	return conn.poorMansEndTransaction(dummyCol, false)
   937  }
   938  
   939  func (conn *IRODSConnection) endTransaction(commit bool) error {
   940  	request := message.NewIRODSMessageEndTransactionRequest(commit)
   941  	response := message.IRODSMessageEndTransactionResponse{}
   942  	return conn.RequestAndCheck(request, &response, nil)
   943  }
   944  
   945  func (conn *IRODSConnection) poorMansEndTransaction(dummyCol string, commit bool) error {
   946  	request := message.NewIRODSMessageModifyCollectionRequest(dummyCol)
   947  	if commit {
   948  		request.AddKeyVal(common.COLLECTION_TYPE_KW, "NULL_SPECIAL_VALUE")
   949  	}
   950  	response := message.IRODSMessageModifyCollectionResponse{}
   951  	err := conn.Request(request, &response, nil)
   952  	if err != nil {
   953  		return xerrors.Errorf("failed to make a poor mans end transaction")
   954  	}
   955  
   956  	if !commit {
   957  		// We do expect an error on rollback because we didn't supply enough parameters
   958  		if common.ErrorCode(response.Result) == common.CAT_INVALID_ARGUMENT {
   959  			return nil
   960  		}
   961  
   962  		if response.Result == 0 {
   963  			return xerrors.Errorf("expected an error, but transaction completed successfully")
   964  		}
   965  	}
   966  
   967  	err = response.CheckError()
   968  	if err != nil {
   969  		return xerrors.Errorf("received irods error: %w", err)
   970  	}
   971  	return nil
   972  }
   973  
   974  // RawBind binds an IRODSConnection to a raw net.Conn socket - to be used for e.g. a proxy server setup
   975  func (conn *IRODSConnection) RawBind(socket net.Conn) {
   976  	conn.connected = true
   977  	conn.socket = socket
   978  }
   979  
   980  // GetMetrics returns metrics
   981  func (conn *IRODSConnection) GetMetrics() *metrics.IRODSMetrics {
   982  	return conn.metrics
   983  }
   984  
   985  // createClientSignature creates a client signature from auth challenge
   986  func (conn *IRODSConnection) createClientSignature(challenge []byte) string {
   987  	if len(challenge) > 16 {
   988  		challenge = challenge[:16]
   989  	}
   990  
   991  	signature := hex.EncodeToString(challenge)
   992  	return signature
   993  }