github.com/cyverse/go-irodsclient@v0.13.2/irods/connection/connection.go (about) 1 package connection 2 3 import ( 4 "bytes" 5 "context" 6 "crypto/rand" 7 "crypto/tls" 8 "crypto/x509" 9 "encoding/binary" 10 "encoding/hex" 11 "fmt" 12 "io" 13 "net" 14 "sync" 15 "time" 16 17 "github.com/cyverse/go-irodsclient/irods/auth" 18 "github.com/cyverse/go-irodsclient/irods/common" 19 "github.com/cyverse/go-irodsclient/irods/message" 20 "github.com/cyverse/go-irodsclient/irods/metrics" 21 "github.com/cyverse/go-irodsclient/irods/types" 22 "github.com/cyverse/go-irodsclient/irods/util" 23 "golang.org/x/xerrors" 24 25 log "github.com/sirupsen/logrus" 26 ) 27 28 const ( 29 TCPBufferSizeDefault int = 4 * 1024 * 1024 30 ) 31 32 // IRODSConnection connects to iRODS 33 type IRODSConnection struct { 34 account *types.IRODSAccount 35 requestTimeout time.Duration 36 tcpBufferSize int 37 applicationName string 38 39 connected bool 40 socket net.Conn 41 serverVersion *types.IRODSVersion 42 generatedPasswordForPAM string // used for PAM auth 43 creationTime time.Time 44 lastSuccessfulAccess time.Time 45 clientSignature string 46 dirtyTransaction bool 47 mutex sync.Mutex 48 locked bool // true if mutex is locked 49 50 metrics *metrics.IRODSMetrics 51 } 52 53 // NewIRODSConnection create a IRODSConnection 54 func NewIRODSConnection(account *types.IRODSAccount, requestTimeout time.Duration, applicationName string) *IRODSConnection { 55 return &IRODSConnection{ 56 account: account, 57 requestTimeout: requestTimeout, 58 tcpBufferSize: TCPBufferSizeDefault, 59 applicationName: applicationName, 60 61 creationTime: time.Now(), 62 clientSignature: "", 63 dirtyTransaction: false, 64 mutex: sync.Mutex{}, 65 66 metrics: &metrics.IRODSMetrics{}, 67 } 68 } 69 70 // NewIRODSConnectionWithMetrics create a IRODSConnection 71 func NewIRODSConnectionWithMetrics(account *types.IRODSAccount, requestTimeout time.Duration, applicationName string, metrics *metrics.IRODSMetrics) *IRODSConnection { 72 return &IRODSConnection{ 73 account: account, 74 requestTimeout: requestTimeout, 75 tcpBufferSize: TCPBufferSizeDefault, 76 applicationName: applicationName, 77 78 creationTime: time.Now(), 79 clientSignature: "", 80 dirtyTransaction: false, 81 mutex: sync.Mutex{}, 82 83 metrics: metrics, 84 } 85 } 86 87 // Lock locks connection 88 func (conn *IRODSConnection) Lock() { 89 conn.mutex.Lock() 90 conn.locked = true 91 } 92 93 // Unlock unlocks connection 94 func (conn *IRODSConnection) Unlock() { 95 conn.locked = false 96 conn.mutex.Unlock() 97 } 98 99 // GetAccount returns iRODSAccount 100 func (conn *IRODSConnection) GetAccount() *types.IRODSAccount { 101 return conn.account 102 } 103 104 // GetVersion returns iRODS version 105 func (conn *IRODSConnection) GetVersion() *types.IRODSVersion { 106 return conn.serverVersion 107 } 108 109 // SetTCPBufferSize sets TCP Buffer Size 110 func (conn *IRODSConnection) SetTCPBufferSize(bufferSize int) { 111 conn.tcpBufferSize = bufferSize 112 } 113 114 // SupportParallelUpload checks if the server supports parallel upload 115 // available from 4.2.9 116 func (conn *IRODSConnection) SupportParallelUpload() bool { 117 return conn.serverVersion.HasHigherVersionThan(4, 2, 9) 118 } 119 120 func (conn *IRODSConnection) requiresCSNegotiation() bool { 121 return conn.account.ClientServerNegotiation 122 } 123 124 // GetGeneratedPasswordForPAMAuth returns generated Password For PAM Auth 125 func (conn *IRODSConnection) GetGeneratedPasswordForPAMAuth() string { 126 return conn.generatedPasswordForPAM 127 } 128 129 // IsConnected returns if the connection is live 130 func (conn *IRODSConnection) IsConnected() bool { 131 return conn.connected 132 } 133 134 // GetCreationTime returns creation time 135 func (conn *IRODSConnection) GetCreationTime() time.Time { 136 return conn.creationTime 137 } 138 139 // GetLastSuccessfulAccess returns last successful access time 140 func (conn *IRODSConnection) GetLastSuccessfulAccess() time.Time { 141 return conn.lastSuccessfulAccess 142 } 143 144 // GetClientSignature returns client signature to be used in password obfuscation 145 func (conn *IRODSConnection) GetClientSignature() string { 146 return conn.clientSignature 147 } 148 149 // SetTransactionDirty sets if transaction is dirty 150 func (conn *IRODSConnection) SetTransactionDirty(dirtyTransaction bool) { 151 conn.dirtyTransaction = dirtyTransaction 152 } 153 154 // IsTransactionDirty returns true if transaction is dirty 155 func (conn *IRODSConnection) IsTransactionDirty() bool { 156 return conn.dirtyTransaction 157 } 158 159 // Connect connects to iRODS 160 func (conn *IRODSConnection) Connect() error { 161 logger := log.WithFields(log.Fields{ 162 "package": "connection", 163 "struct": "IRODSConnection", 164 "function": "Connect", 165 }) 166 167 conn.connected = false 168 169 conn.account.FixAuthConfiguration() 170 171 err := conn.account.Validate() 172 if err != nil { 173 return xerrors.Errorf("invalid account (%s): %w", err.Error(), types.NewConnectionConfigError(conn.account)) 174 } 175 176 // lock the connection 177 conn.Lock() 178 defer conn.Unlock() 179 180 server := fmt.Sprintf("%s:%d", conn.account.Host, conn.account.Port) 181 logger.Debugf("Connecting to %s", server) 182 183 // must connect to the server in 10 sec 184 var dialer net.Dialer 185 ctx, cancelFunc := context.WithTimeout(context.Background(), 10*time.Second) 186 defer cancelFunc() 187 188 socket, err := dialer.DialContext(ctx, "tcp", server) 189 if err != nil { 190 connErr := xerrors.Errorf("failed to connect to specified host %s and port %d (%s): %w", conn.account.Host, conn.account.Port, err.Error(), types.NewConnectionError()) 191 logger.Errorf("%+v", connErr) 192 193 if conn.metrics != nil { 194 conn.metrics.IncreaseCounterForConnectionFailures(1) 195 } 196 return connErr 197 } 198 199 if tcpSocket, ok := socket.(*net.TCPConn); ok { 200 sockErr := tcpSocket.SetReadBuffer(conn.tcpBufferSize) 201 if sockErr != nil { 202 sockBuffErr := xerrors.Errorf("failed to set tcp read buffer size %d: %w", conn.tcpBufferSize, sockErr) 203 logger.Errorf("%+v", sockBuffErr) 204 } 205 206 sockErr = tcpSocket.SetWriteBuffer(conn.tcpBufferSize) 207 if sockErr != nil { 208 sockBuffErr := xerrors.Errorf("failed to set tcp write buffer size %d: %w", conn.tcpBufferSize, sockErr) 209 logger.Errorf("%+v", sockBuffErr) 210 } 211 } 212 213 if conn.metrics != nil { 214 conn.metrics.IncreaseConnectionsOpened(1) 215 } 216 217 conn.socket = socket 218 var irodsVersion *types.IRODSVersion 219 220 if conn.requiresCSNegotiation() { 221 // client-server negotiation 222 irodsVersion, err = conn.connectWithCSNegotiation() 223 } else { 224 // No client-server negotiation 225 irodsVersion, err = conn.connectWithoutCSNegotiation() 226 } 227 228 if err != nil { 229 connErr := xerrors.Errorf("failed to startup an iRODS connection to server %s and port %d (%s): %w", conn.account.Host, conn.account.Port, err.Error(), types.NewConnectionError()) 230 logger.Errorf("%+v", connErr) 231 _ = conn.disconnectNow() 232 if conn.metrics != nil { 233 conn.metrics.IncreaseCounterForConnectionFailures(1) 234 } 235 return connErr 236 } 237 238 conn.serverVersion = irodsVersion 239 240 switch conn.account.AuthenticationScheme { 241 case types.AuthSchemeNative: 242 err = conn.loginNative(conn.account.Password) 243 case types.AuthSchemeGSI: 244 err = conn.loginGSI() 245 case types.AuthSchemePAM: 246 err = conn.loginPAM() 247 default: 248 logger.Errorf("unknown Authentication Scheme - %s", conn.account.AuthenticationScheme) 249 return xerrors.Errorf("unknown Authentication Scheme - %s: %w", conn.account.AuthenticationScheme, types.NewConnectionConfigError(conn.account)) 250 } 251 252 if err != nil { 253 connErr := xerrors.Errorf("failed to login to irods: %w", err) 254 logger.Errorf("%+v", connErr) 255 _ = conn.disconnectNow() 256 return connErr 257 } 258 259 if conn.account.UseTicket() { 260 req := message.NewIRODSMessageTicketAdminRequest("session", conn.account.Ticket) 261 err := conn.RequestAndCheck(req, &message.IRODSMessageAdminResponse{}, nil) 262 if err != nil { 263 return xerrors.Errorf("received supply ticket error (%s): %w", err.Error(), types.NewAuthError(conn.account)) 264 } 265 } 266 267 conn.connected = true 268 conn.lastSuccessfulAccess = time.Now() 269 270 return nil 271 } 272 273 func (conn *IRODSConnection) connectWithCSNegotiation() (*types.IRODSVersion, error) { 274 logger := log.WithFields(log.Fields{ 275 "package": "connection", 276 "struct": "IRODSConnection", 277 "function": "connectWithCSNegotiation", 278 }) 279 280 // Get client negotiation policy 281 clientPolicy := types.CSNegotiationRequireTCP 282 if len(conn.account.CSNegotiationPolicy) > 0 { 283 clientPolicy = conn.account.CSNegotiationPolicy 284 } 285 286 // Send a startup message 287 logger.Debug("Start up a connection with CS Negotiation") 288 289 startup := message.NewIRODSMessageStartupPack(conn.account, conn.applicationName, true) 290 err := conn.RequestWithoutResponse(startup) 291 if err != nil { 292 return nil, xerrors.Errorf("failed to send startup (%s): %w", err.Error(), types.NewConnectionError()) 293 } 294 295 // Server responds with negotiation response 296 negotiationMessage, err := conn.ReadMessage(nil) 297 if err != nil { 298 return nil, xerrors.Errorf("failed to receive negotiation message (%s): %w", err.Error(), types.NewConnectionError()) 299 } 300 301 if negotiationMessage.Body == nil { 302 return nil, xerrors.Errorf("failed to receive negotiation message body: %w", types.NewConnectionError()) 303 } 304 305 if negotiationMessage.Body.Type == message.RODS_MESSAGE_VERSION_TYPE { 306 // this happens when an error occur 307 // Server responds with version 308 version := message.IRODSMessageVersion{} 309 err = version.FromMessage(negotiationMessage) 310 if err != nil { 311 return nil, xerrors.Errorf("failed to receive negotiation message (%s): %w", err.Error(), types.NewConnectionError()) 312 } 313 314 return version.GetVersion(), nil 315 } else if negotiationMessage.Body.Type == message.RODS_MESSAGE_CS_NEG_TYPE { 316 // Server responds with its own negotiation policy 317 logger.Debug("Start up CS Negotiation") 318 319 negotiation := message.IRODSMessageCSNegotiation{} 320 err = negotiation.FromMessage(negotiationMessage) 321 if err != nil { 322 return nil, xerrors.Errorf("failed to receive negotiation message (%s): %w", err.Error(), types.NewConnectionError()) 323 } 324 325 serverPolicy, err := types.GetCSNegotiationRequire(negotiation.Result) 326 if err != nil { 327 return nil, xerrors.Errorf("failed to parse server policy (%s): %w", err.Error(), types.NewConnectionError()) 328 } 329 330 logger.Debugf("Client policy - %s, server policy - %s", clientPolicy, serverPolicy) 331 332 // Perform the negotiation 333 policyResult := types.PerformCSNegotiation(clientPolicy, serverPolicy) 334 335 // If negotiation failed we're done 336 if policyResult == types.CSNegotiationFailure { 337 return nil, xerrors.Errorf("client-server negotiation failed - %s, %s: %w", string(clientPolicy), string(serverPolicy), types.NewConnectionError()) 338 } 339 340 // Send negotiation result to server 341 negotiationResult := message.NewIRODSMessageCSNegotiation(policyResult) 342 version := message.IRODSMessageVersion{} 343 err = conn.Request(negotiationResult, &version, nil) 344 if err != nil { 345 return nil, xerrors.Errorf("failed to receive version message (%s): %w", err.Error(), types.NewConnectionError()) 346 } 347 348 if policyResult == types.CSNegotiationUseSSL { 349 err := conn.sslStartup() 350 if err != nil { 351 return nil, xerrors.Errorf("failed to start up SSL: %w", err) 352 } 353 } 354 355 return version.GetVersion(), nil 356 } 357 358 return nil, xerrors.Errorf("unknown response message '%s': %w", negotiationMessage.Body.Type, types.NewConnectionError()) 359 } 360 361 func (conn *IRODSConnection) connectWithoutCSNegotiation() (*types.IRODSVersion, error) { 362 logger := log.WithFields(log.Fields{ 363 "package": "connection", 364 "struct": "IRODSConnection", 365 "function": "connectWithoutCSNegotiation", 366 }) 367 368 // No client-server negotiation 369 // Send a startup message 370 logger.Debug("Start up connection without CS Negotiation") 371 372 startup := message.NewIRODSMessageStartupPack(conn.account, conn.applicationName, false) 373 version := message.IRODSMessageVersion{} 374 err := conn.Request(startup, &version, nil) 375 if err != nil { 376 return nil, xerrors.Errorf("failed to receive version message (%s): %w", err.Error(), types.NewConnectionError()) 377 } 378 379 return version.GetVersion(), nil 380 } 381 382 func (conn *IRODSConnection) sslStartup() error { 383 logger := log.WithFields(log.Fields{ 384 "package": "connection", 385 "struct": "IRODSConnection", 386 "function": "sslStartup", 387 }) 388 389 logger.Debug("Start up SSL") 390 391 irodsSSLConfig := conn.account.SSLConfiguration 392 if irodsSSLConfig == nil { 393 return xerrors.Errorf("SSL Configuration is not set: %w", types.NewConnectionConfigError(conn.account)) 394 } 395 396 caCertPool := x509.NewCertPool() 397 caCert, err := irodsSSLConfig.ReadCACert() 398 if err != nil { 399 logger.WithError(err).Warn("failed to read CA cert, ignoring...") 400 } else { 401 caCertPool.AppendCertsFromPEM(caCert) 402 } 403 404 sslConf := &tls.Config{ 405 RootCAs: caCertPool, 406 ServerName: conn.account.Host, 407 InsecureSkipVerify: true, 408 } 409 410 // Create a side connection using the existing socket 411 sslSocket := tls.Client(conn.socket, sslConf) 412 413 err = sslSocket.Handshake() 414 if err != nil { 415 return xerrors.Errorf("SSL Handshake error (%s): %w", err.Error(), types.NewConnectionError()) 416 } 417 418 // from now on use ssl socket 419 conn.socket = sslSocket 420 421 // Generate a key (shared secret) 422 encryptionKey := make([]byte, irodsSSLConfig.EncryptionKeySize) 423 _, err = rand.Read(encryptionKey) 424 if err != nil { 425 return xerrors.Errorf("failed to generate shared secret (%s): %w", err.Error(), types.NewConnectionError()) 426 } 427 428 // Send a ssl setting 429 sslSetting := message.NewIRODSMessageSSLSettings(irodsSSLConfig.EncryptionAlgorithm, irodsSSLConfig.EncryptionKeySize, irodsSSLConfig.SaltSize, irodsSSLConfig.HashRounds) 430 err = conn.RequestWithoutResponse(sslSetting) 431 if err != nil { 432 return xerrors.Errorf("failed to send ssl setting message (%s): %w", err.Error(), types.NewConnectionError()) 433 } 434 435 // Send a shared secret 436 sslSharedSecret := message.NewIRODSMessageSSLSharedSecret(encryptionKey) 437 err = conn.RequestWithoutResponseNoXML(sslSharedSecret) 438 if err != nil { 439 return xerrors.Errorf("failed to send ssl shared secret message (%s): %w", err.Error(), types.NewConnectionError()) 440 } 441 442 return nil 443 } 444 445 func (conn *IRODSConnection) login(password string) error { 446 // authenticate 447 authRequest := message.NewIRODSMessageAuthRequest() 448 authChallenge := message.IRODSMessageAuthChallengeResponse{} 449 err := conn.Request(authRequest, &authChallenge, nil) 450 if err != nil { 451 return xerrors.Errorf("failed to receive authentication challenge message body (%s): %w", err.Error(), types.NewAuthError(conn.account)) 452 } 453 454 challengeBytes, err := authChallenge.GetChallenge() 455 if err != nil { 456 return xerrors.Errorf("failed to get authentication challenge (%s): %w", err.Error(), types.NewAuthError(conn.account)) 457 } 458 459 // save client signature 460 conn.clientSignature = conn.createClientSignature(challengeBytes) 461 462 encodedPassword := auth.GenerateAuthResponse(challengeBytes, password) 463 464 authResponse := message.NewIRODSMessageAuthResponse(encodedPassword, conn.account.ProxyUser) 465 authResult := message.IRODSMessageAuthResult{} 466 err = conn.RequestAndCheck(authResponse, &authResult, nil) 467 if err != nil { 468 return xerrors.Errorf("received irods authentication error (%s): %w", err.Error(), types.NewAuthError(conn.account)) 469 } 470 return nil 471 } 472 473 func (conn *IRODSConnection) loginNative(password string) error { 474 logger := log.WithFields(log.Fields{ 475 "package": "connection", 476 "struct": "IRODSConnection", 477 "function": "loginNative", 478 }) 479 480 logger.Debug("Logging in using native authentication method") 481 return conn.login(password) 482 } 483 484 func (conn *IRODSConnection) loginGSI() error { 485 return xerrors.Errorf("GSI login is not yet implemented: %w", types.NewAuthError(conn.account)) 486 } 487 488 func (conn *IRODSConnection) loginPAM() error { 489 logger := log.WithFields(log.Fields{ 490 "package": "connection", 491 "struct": "IRODSConnection", 492 "function": "loginPAM", 493 }) 494 495 logger.Debug("Logging in using pam authentication method") 496 497 // Check whether ssl has already started, if not, start ssl. 498 if _, ok := conn.socket.(*tls.Conn); !ok { 499 return xerrors.Errorf("connection should be using SSL: %w", types.NewConnectionError()) 500 } 501 502 ttl := conn.account.PamTTL 503 if ttl <= 0 { 504 ttl = 1 505 } 506 507 // authenticate 508 pamAuthRequest := message.NewIRODSMessagePamAuthRequest(conn.account.ClientUser, conn.account.Password, ttl) 509 pamAuthResponse := message.IRODSMessagePamAuthResponse{} 510 err := conn.Request(pamAuthRequest, &pamAuthResponse, nil) 511 if err != nil { 512 return xerrors.Errorf("failed to receive an authentication challenge message (%s): %w", err.Error(), types.NewAuthError(conn.account)) 513 } 514 515 // save irods generated password for possible future use 516 conn.generatedPasswordForPAM = pamAuthResponse.GeneratedPassword 517 518 // retry native auth with generated password 519 return conn.login(conn.generatedPasswordForPAM) 520 } 521 522 // Disconnect disconnects 523 func (conn *IRODSConnection) disconnectNow() error { 524 conn.connected = false 525 var err error 526 if conn.socket != nil { 527 err = conn.socket.Close() 528 conn.socket = nil 529 } 530 531 if conn.metrics != nil { 532 conn.metrics.DecreaseConnectionsOpened(1) 533 } 534 535 if err == nil { 536 return nil 537 } 538 539 return xerrors.Errorf("failed to close socket: %w", err) 540 } 541 542 // Disconnect disconnects 543 func (conn *IRODSConnection) Disconnect() error { 544 logger := log.WithFields(log.Fields{ 545 "package": "connection", 546 "struct": "IRODSConnection", 547 "function": "Disconnect", 548 }) 549 550 logger.Debug("Disconnecting the connection") 551 552 // lock the connection 553 conn.Lock() 554 defer conn.Unlock() 555 556 disconnect := message.NewIRODSMessageDisconnect() 557 err := conn.RequestWithoutResponse(disconnect) 558 559 conn.lastSuccessfulAccess = time.Now() 560 561 err2 := conn.disconnectNow() 562 if err2 != nil { 563 return err2 564 } 565 566 if err != nil { 567 return err 568 } 569 570 return nil 571 } 572 573 func (conn *IRODSConnection) socketFail() { 574 if conn.metrics != nil { 575 conn.metrics.IncreaseCounterForConnectionFailures(1) 576 } 577 578 conn.disconnectNow() 579 } 580 581 // Send sends data 582 func (conn *IRODSConnection) Send(buffer []byte, size int) error { 583 return conn.SendWithTrackerCallBack(buffer, size, nil) 584 } 585 586 // SendWithTrackerCallBack sends data 587 func (conn *IRODSConnection) SendWithTrackerCallBack(buffer []byte, size int, callback common.TrackerCallBack) error { 588 if conn.socket == nil { 589 return xerrors.Errorf("failed to send data - socket closed") 590 } 591 592 if !conn.locked { 593 return xerrors.Errorf("connection must be locked before use") 594 } 595 596 // use sslSocket 597 if conn.requestTimeout > 0 { 598 conn.socket.SetWriteDeadline(time.Now().Add(conn.requestTimeout)) 599 } 600 601 err := util.WriteBytesWithTrackerCallBack(conn.socket, buffer, size, callback) 602 if err != nil { 603 conn.socketFail() 604 return xerrors.Errorf("failed to send data: %w", err) 605 } 606 607 if size > 0 { 608 if conn.metrics != nil { 609 conn.metrics.IncreaseBytesSent(uint64(size)) 610 } 611 } 612 613 conn.lastSuccessfulAccess = time.Now() 614 615 return nil 616 } 617 618 // SendFromReader sends data from Reader 619 func (conn *IRODSConnection) SendFromReader(src io.Reader, size int) error { 620 if conn.socket == nil { 621 return xerrors.Errorf("failed to send data - socket closed") 622 } 623 624 if !conn.locked { 625 return xerrors.Errorf("connection must be locked before use") 626 } 627 628 // use sslSocket 629 if conn.requestTimeout > 0 { 630 conn.socket.SetWriteDeadline(time.Now().Add(conn.requestTimeout)) 631 } 632 633 copyLen, err := io.CopyN(conn.socket, src, int64(size)) 634 if err != nil { 635 conn.socketFail() 636 return xerrors.Errorf("failed to send data: %w", err) 637 } 638 639 if copyLen != int64(size) { 640 conn.socketFail() 641 return xerrors.Errorf("failed to send data. failed to send data fully (requested %d vs sent %d)", size, copyLen) 642 } 643 644 if copyLen > 0 { 645 if conn.metrics != nil { 646 conn.metrics.IncreaseBytesSent(uint64(copyLen)) 647 } 648 } 649 650 conn.lastSuccessfulAccess = time.Now() 651 652 return nil 653 } 654 655 // Recv receives a message 656 func (conn *IRODSConnection) Recv(buffer []byte, size int) (int, error) { 657 return conn.RecvWithTrackerCallBack(buffer, size, nil) 658 } 659 660 // Recv receives a message 661 func (conn *IRODSConnection) RecvWithTrackerCallBack(buffer []byte, size int, callback common.TrackerCallBack) (int, error) { 662 if conn.socket == nil { 663 return 0, xerrors.Errorf("failed to receive data - socket closed") 664 } 665 666 if !conn.locked { 667 return 0, xerrors.Errorf("connection must be locked before use") 668 } 669 670 if conn.requestTimeout > 0 { 671 conn.socket.SetReadDeadline(time.Now().Add(conn.requestTimeout)) 672 } 673 674 readLen, err := util.ReadBytesWithTrackerCallBack(conn.socket, buffer, size, callback) 675 if err != nil { 676 conn.socketFail() 677 return readLen, xerrors.Errorf("failed to receive data: %w", err) 678 } 679 680 if readLen > 0 { 681 if conn.metrics != nil { 682 conn.metrics.IncreaseBytesReceived(uint64(readLen)) 683 } 684 } 685 686 conn.lastSuccessfulAccess = time.Now() 687 688 return readLen, nil 689 } 690 691 // RecvToWriter receives a message to Writer 692 func (conn *IRODSConnection) RecvToWriter(writer io.Writer, size int) (int, error) { 693 if conn.socket == nil { 694 return 0, xerrors.Errorf("failed to receive data - socket closed") 695 } 696 697 if !conn.locked { 698 return 0, xerrors.Errorf("connection must be locked before use") 699 } 700 701 if conn.requestTimeout > 0 { 702 conn.socket.SetReadDeadline(time.Now().Add(conn.requestTimeout)) 703 } 704 705 copyLen, err := io.CopyN(writer, conn.socket, int64(size)) 706 if err != nil { 707 conn.socketFail() 708 return int(copyLen), xerrors.Errorf("failed to receive data: %w", err) 709 } 710 711 if copyLen > 0 { 712 if conn.metrics != nil { 713 conn.metrics.IncreaseBytesReceived(uint64(copyLen)) 714 } 715 } 716 717 conn.lastSuccessfulAccess = time.Now() 718 719 return int(copyLen), nil 720 } 721 722 // SendMessage makes the message into bytes 723 func (conn *IRODSConnection) SendMessage(msg *message.IRODSMessage) error { 724 return conn.SendMessageWithTrackerCallBack(msg, nil) 725 } 726 727 // SendMessageWithTrackerCallBack makes the message into bytes 728 func (conn *IRODSConnection) SendMessageWithTrackerCallBack(msg *message.IRODSMessage, callback common.TrackerCallBack) error { 729 if !conn.locked { 730 return xerrors.Errorf("connection must be locked before use") 731 } 732 733 messageBuffer := new(bytes.Buffer) 734 735 if msg.Header == nil && msg.Body == nil { 736 return xerrors.Errorf("header and body cannot be nil") 737 } 738 739 var headerBytes []byte 740 var err error 741 742 messageLen := 0 743 errorLen := 0 744 bsLen := 0 745 746 if msg.Body != nil { 747 if msg.Body.Message != nil { 748 messageLen = len(msg.Body.Message) 749 } 750 751 if msg.Body.Error != nil { 752 errorLen = len(msg.Body.Error) 753 } 754 755 if msg.Body.Bs != nil { 756 bsLen = len(msg.Body.Bs) 757 } 758 759 if msg.Header == nil { 760 h := message.MakeIRODSMessageHeader(msg.Body.Type, uint32(messageLen), uint32(errorLen), uint32(bsLen), msg.Body.IntInfo) 761 headerBytes, err = h.GetBytes() 762 if err != nil { 763 return err 764 } 765 } 766 } 767 768 if msg.Header != nil { 769 headerBytes, err = msg.Header.GetBytes() 770 if err != nil { 771 return err 772 } 773 } 774 775 // pack length - Big Endian 776 headerLenBuffer := make([]byte, 4) 777 binary.BigEndian.PutUint32(headerLenBuffer, uint32(len(headerBytes))) 778 779 // header 780 messageBuffer.Write(headerLenBuffer) 781 messageBuffer.Write(headerBytes) 782 783 if msg.Body != nil { 784 bodyBytes, err := msg.Body.GetBytesWithoutBS() 785 if err != nil { 786 return err 787 } 788 789 // body 790 messageBuffer.Write(bodyBytes) 791 } 792 793 // send 794 bytes := messageBuffer.Bytes() 795 err = conn.Send(bytes, len(bytes)) 796 if err != nil { 797 return xerrors.Errorf("failed to send message: %w", err) 798 } 799 800 // send body-bs 801 if msg.Body != nil { 802 if msg.Body.Bs != nil { 803 conn.SendWithTrackerCallBack(msg.Body.Bs, len(msg.Body.Bs), callback) 804 } 805 } 806 return nil 807 } 808 809 // readMessageHeader reads data from the given connection and returns iRODSMessageHeader 810 func (conn *IRODSConnection) readMessageHeader() (*message.IRODSMessageHeader, error) { 811 // read header size 812 headerLenBuffer := make([]byte, 4) 813 readLen, err := conn.Recv(headerLenBuffer, 4) 814 if err != nil { 815 return nil, xerrors.Errorf("failed to read header size: %w", err) 816 } 817 818 if readLen != 4 { 819 return nil, xerrors.Errorf("failed to read header size, read %d", readLen) 820 } 821 822 headerSize := binary.BigEndian.Uint32(headerLenBuffer) 823 if headerSize <= 0 { 824 return nil, xerrors.Errorf("invalid header size returned - len = %d", headerSize) 825 } 826 827 // read header 828 headerBuffer := make([]byte, headerSize) 829 readLen, err = conn.Recv(headerBuffer, int(headerSize)) 830 if err != nil { 831 return nil, xerrors.Errorf("failed to read header: %w", err) 832 } 833 if readLen != int(headerSize) { 834 return nil, xerrors.Errorf("failed to read header fully - %d requested but %d read", headerSize, readLen) 835 } 836 837 header := message.IRODSMessageHeader{} 838 err = header.FromBytes(headerBuffer) 839 if err != nil { 840 return nil, err 841 } 842 843 return &header, nil 844 } 845 846 // ReadMessage reads data from the given socket and returns IRODSMessage 847 // if bsBuffer is given, bs data will be written directly to the bsBuffer 848 // if not given, a new buffer will be allocated. 849 func (conn *IRODSConnection) ReadMessage(bsBuffer []byte) (*message.IRODSMessage, error) { 850 return conn.ReadMessageWithTrackerCallBack(bsBuffer, nil) 851 } 852 853 func (conn *IRODSConnection) ReadMessageWithTrackerCallBack(bsBuffer []byte, callback common.TrackerCallBack) (*message.IRODSMessage, error) { 854 if !conn.locked { 855 return nil, xerrors.Errorf("connection must be locked before use") 856 } 857 858 header, err := conn.readMessageHeader() 859 if err != nil { 860 return nil, err 861 } 862 863 // read body 864 bodyLen := header.MessageLen + header.ErrorLen 865 bodyBuffer := make([]byte, bodyLen) 866 if bsBuffer == nil { 867 bsBuffer = make([]byte, int(header.BsLen)) 868 } else if len(bsBuffer) < int(header.BsLen) { 869 return nil, xerrors.Errorf("provided bs buffer is too short, %d size is given, but %d size is required", len(bsBuffer), int(header.BsLen)) 870 } 871 872 bodyReadLen, err := conn.Recv(bodyBuffer, int(bodyLen)) 873 if err != nil { 874 return nil, xerrors.Errorf("failed to read body: %w", err) 875 } 876 if bodyReadLen != int(bodyLen) { 877 return nil, xerrors.Errorf("failed to read body fully - %d requested but %d read", bodyLen, bodyReadLen) 878 } 879 880 bsReadLen, err := conn.RecvWithTrackerCallBack(bsBuffer, int(header.BsLen), callback) 881 if err != nil { 882 return nil, xerrors.Errorf("failed to read body (BS): %w", err) 883 } 884 if bsReadLen != int(header.BsLen) { 885 return nil, xerrors.Errorf("failed to read body (BS) fully - %d requested but %d read", int(header.BsLen), bsReadLen) 886 } 887 888 body := message.IRODSMessageBody{} 889 err = body.FromBytes(header, bodyBuffer, bsBuffer[:int(header.BsLen)]) 890 if err != nil { 891 return nil, err 892 } 893 894 body.Type = header.Type 895 body.IntInfo = header.IntInfo 896 897 return &message.IRODSMessage{ 898 Header: header, 899 Body: &body, 900 }, nil 901 } 902 903 // Commit a transaction. This is useful in combination with the NO_COMMIT_FLAG. 904 // Usage is limited to privileged accounts. 905 func (conn *IRODSConnection) Commit() error { 906 if !conn.locked { 907 return xerrors.Errorf("connection must be locked before use") 908 } 909 910 return conn.endTransaction(true) 911 } 912 913 // Rollback a transaction. This is useful in combination with the NO_COMMIT_FLAG. 914 // It can also be used to clear the current database transaction if there are no staged operations, 915 // just to refresh the view on the database for future queries. 916 // Usage is limited to privileged accounts. 917 func (conn *IRODSConnection) Rollback() error { 918 if !conn.locked { 919 return xerrors.Errorf("connection must be locked before use") 920 } 921 922 return conn.endTransaction(false) 923 } 924 925 // PoorMansRollback rolls back a transaction as a nonprivileged account, bypassing API limitations. 926 // A nonprivileged account cannot have staged operations, so rollback is always a no-op. 927 // The usage for this function, is that rolling back the current database transaction still will start 928 // a new one, so that future queries will see all changes that where made up to calling this function. 929 func (conn *IRODSConnection) PoorMansRollback() error { 930 if !conn.locked { 931 return xerrors.Errorf("connection must be locked before use") 932 } 933 934 dummyCol := fmt.Sprintf("/%s/home/%s", conn.account.ClientZone, conn.account.ClientUser) 935 936 return conn.poorMansEndTransaction(dummyCol, false) 937 } 938 939 func (conn *IRODSConnection) endTransaction(commit bool) error { 940 request := message.NewIRODSMessageEndTransactionRequest(commit) 941 response := message.IRODSMessageEndTransactionResponse{} 942 return conn.RequestAndCheck(request, &response, nil) 943 } 944 945 func (conn *IRODSConnection) poorMansEndTransaction(dummyCol string, commit bool) error { 946 request := message.NewIRODSMessageModifyCollectionRequest(dummyCol) 947 if commit { 948 request.AddKeyVal(common.COLLECTION_TYPE_KW, "NULL_SPECIAL_VALUE") 949 } 950 response := message.IRODSMessageModifyCollectionResponse{} 951 err := conn.Request(request, &response, nil) 952 if err != nil { 953 return xerrors.Errorf("failed to make a poor mans end transaction") 954 } 955 956 if !commit { 957 // We do expect an error on rollback because we didn't supply enough parameters 958 if common.ErrorCode(response.Result) == common.CAT_INVALID_ARGUMENT { 959 return nil 960 } 961 962 if response.Result == 0 { 963 return xerrors.Errorf("expected an error, but transaction completed successfully") 964 } 965 } 966 967 err = response.CheckError() 968 if err != nil { 969 return xerrors.Errorf("received irods error: %w", err) 970 } 971 return nil 972 } 973 974 // RawBind binds an IRODSConnection to a raw net.Conn socket - to be used for e.g. a proxy server setup 975 func (conn *IRODSConnection) RawBind(socket net.Conn) { 976 conn.connected = true 977 conn.socket = socket 978 } 979 980 // GetMetrics returns metrics 981 func (conn *IRODSConnection) GetMetrics() *metrics.IRODSMetrics { 982 return conn.metrics 983 } 984 985 // createClientSignature creates a client signature from auth challenge 986 func (conn *IRODSConnection) createClientSignature(challenge []byte) string { 987 if len(challenge) > 16 { 988 challenge = challenge[:16] 989 } 990 991 signature := hex.EncodeToString(challenge) 992 return signature 993 }