github.com/viant/toolbox@v0.34.5/ssh/tunnel.go (about)

     1  package ssh
     2  
     3  import (
     4  	"fmt"
     5  	"golang.org/x/crypto/ssh"
     6  	"io"
     7  	"log"
     8  	"net"
     9  	"sync"
    10  	"sync/atomic"
    11  )
    12  
    13  //Tunnel represents a SSH forwarding link
    14  type Tunnel struct {
    15  	RemoteAddress string
    16  	client        *ssh.Client
    17  	Local         net.Listener
    18  	Connections   []net.Conn
    19  	mutex         *sync.Mutex
    20  	closed        int32
    21  }
    22  
    23  func (f *Tunnel) tunnelTraffic(local, remote net.Conn) {
    24  	defer local.Close()
    25  	defer remote.Close()
    26  	completionChannel := make(chan bool)
    27  	go func() {
    28  		_, err := io.Copy(local, remote)
    29  		if err != nil {
    30  			log.Printf("failed to copy remote to local: %v", err)
    31  		}
    32  		completionChannel <- true
    33  	}()
    34  
    35  	go func() {
    36  		_, _ = io.Copy(remote, local)
    37  		//if err != nil {
    38  		//	log.Printf("failed to copy local to remote: %v", err)
    39  		//}
    40  		completionChannel <- true
    41  	}()
    42  	<-completionChannel
    43  }
    44  
    45  //Handle listen on local client to create tunnel with remote address.
    46  func (f *Tunnel) Handle() error {
    47  	for {
    48  		if atomic.LoadInt32(&f.closed) == 1 {
    49  			return nil
    50  		}
    51  		localclient, err := f.Local.Accept()
    52  		if err != nil {
    53  			return err
    54  		}
    55  		remote, err := f.client.Dial("tcp", f.RemoteAddress)
    56  		if err != nil {
    57  			return fmt.Errorf("failed to connect to remote: %v %v", f.RemoteAddress, err)
    58  		}
    59  		f.Connections = append(f.Connections, remote)
    60  		f.Connections = append(f.Connections, localclient)
    61  		go f.tunnelTraffic(localclient, remote)
    62  	}
    63  	return nil
    64  }
    65  
    66  //Close closes forwarding link
    67  func (f *Tunnel) Close() error {
    68  	atomic.StoreInt32(&f.closed, 1)
    69  	_ = f.Local.Close()
    70  	for _, remote := range f.Connections {
    71  		_ = remote.Close()
    72  	}
    73  	return nil
    74  }
    75  
    76  //NewForwarding creates a new ssh forwarding link
    77  func NewForwarding(client *ssh.Client, remoteAddress string, local net.Listener) *Tunnel {
    78  	return &Tunnel{
    79  		client:        client,
    80  		RemoteAddress: remoteAddress,
    81  		Connections:   make([]net.Conn, 0),
    82  		Local:         local,
    83  		mutex:         &sync.Mutex{},
    84  	}
    85  }