github.com/cloudfoundry-attic/ltc@v0.0.0-20151123212628-098adc7919fc/ssh/ssh.go (about)

     1  package ssh
     2  
     3  import (
     4  	"errors"
     5  	"io"
     6  	"os"
     7  	"os/signal"
     8  	"runtime"
     9  	"syscall"
    10  	"time"
    11  
    12  	config_package "github.com/cloudfoundry-incubator/ltc/config"
    13  	"github.com/cloudfoundry-incubator/ltc/exit_handler"
    14  	"github.com/cloudfoundry-incubator/ltc/ssh/sigwinch"
    15  	"github.com/cloudfoundry-incubator/ltc/terminal"
    16  	"github.com/docker/docker/pkg/term"
    17  )
    18  
    19  //go:generate counterfeiter -o mocks/fake_listener.go . Listener
    20  type Listener interface {
    21  	Listen(network, laddr string) (<-chan io.ReadWriteCloser, <-chan error)
    22  }
    23  
    24  //go:generate counterfeiter -o mocks/fake_client_dialer.go . ClientDialer
    25  type ClientDialer interface {
    26  	Dial(appName string, instanceIndex int, config *config_package.Config) (Client, error)
    27  }
    28  
    29  //go:generate counterfeiter -o mocks/fake_term.go . Term
    30  type Term interface {
    31  	SetRawTerminal(fd uintptr) (*term.State, error)
    32  	RestoreTerminal(fd uintptr, state *term.State) error
    33  	GetWinsize(fd uintptr) (width int, height int)
    34  	IsTTY(fd uintptr) bool
    35  }
    36  
    37  //go:generate counterfeiter -o mocks/fake_session_factory.go . SessionFactory
    38  type SessionFactory interface {
    39  	New(client Client, width, height int, desirePTY bool) (Session, error)
    40  }
    41  
    42  type SSH struct {
    43  	Listener        Listener
    44  	ClientDialer    ClientDialer
    45  	Term            Term
    46  	SessionFactory  SessionFactory
    47  	SigWinchChannel chan os.Signal
    48  	ExitHandler     exit_handler.ExitHandler
    49  	client          Client
    50  }
    51  
    52  func New(exitHandler exit_handler.ExitHandler) *SSH {
    53  	return &SSH{
    54  		Listener:        &ChannelListener{},
    55  		ClientDialer:    &AppDialer{},
    56  		Term:            &terminal.DockerTerm{},
    57  		SessionFactory:  &SSHAPISessionFactory{},
    58  		SigWinchChannel: make(chan os.Signal),
    59  		ExitHandler:     exitHandler,
    60  	}
    61  }
    62  
    63  func (s *SSH) Connect(appName string, instanceIndex int, config *config_package.Config) error {
    64  	if s.client != nil {
    65  		return errors.New("already connected")
    66  	}
    67  	var err error
    68  	s.client, err = s.ClientDialer.Dial(appName, instanceIndex, config)
    69  	if err != nil {
    70  		return err
    71  	}
    72  	return nil
    73  }
    74  
    75  func (s *SSH) Forward(localAddress, remoteAddress string) error {
    76  	acceptChan, errorChan := s.Listener.Listen("tcp", localAddress)
    77  
    78  	for {
    79  		select {
    80  		case conn, ok := <-acceptChan:
    81  			if !ok {
    82  				return nil
    83  			}
    84  
    85  			if err := s.client.Forward(conn, remoteAddress); err != nil {
    86  				return err
    87  			}
    88  		case err, ok := <-errorChan:
    89  			if !ok {
    90  				return nil
    91  			}
    92  
    93  			return err
    94  		}
    95  	}
    96  }
    97  
    98  func (s *SSH) Shell(command string, desirePTY bool) error {
    99  	if desirePTY {
   100  		desirePTY = s.Term.IsTTY(os.Stdin.Fd())
   101  	}
   102  
   103  	width, height := s.Term.GetWinsize(os.Stdout.Fd())
   104  	session, err := s.SessionFactory.New(s.client, width, height, desirePTY)
   105  	if err != nil {
   106  		return err
   107  	}
   108  	defer session.Close()
   109  
   110  	if desirePTY {
   111  		if state, err := s.Term.SetRawTerminal(os.Stdin.Fd()); err == nil {
   112  			defer s.Term.RestoreTerminal(os.Stdin.Fd(), state)
   113  
   114  			s.ExitHandler.OnExit(func() {
   115  				s.Term.RestoreTerminal(os.Stdin.Fd(), state)
   116  			})
   117  		}
   118  	}
   119  
   120  	if runtime.GOOS == "windows" {
   121  		ticker := time.NewTicker(250 * time.Millisecond)
   122  		defer ticker.Stop()
   123  
   124  		go func() {
   125  			for _ = range ticker.C {
   126  				s.SigWinchChannel <- syscall.Signal(-1)
   127  			}
   128  			close(s.SigWinchChannel)
   129  		}()
   130  	} else {
   131  		signal.Notify(s.SigWinchChannel, sigwinch.SIGWINCH())
   132  		defer func() {
   133  			signal.Stop(s.SigWinchChannel)
   134  			close(s.SigWinchChannel)
   135  		}()
   136  	}
   137  
   138  	go s.resize(session, os.Stdout.Fd(), width, height)
   139  
   140  	defer close(session.KeepAlive())
   141  
   142  	if command == "" {
   143  		session.Shell()
   144  		session.Wait()
   145  	} else {
   146  		session.Run(command)
   147  	}
   148  
   149  	return nil
   150  }
   151  
   152  func (s *SSH) resize(session Session, terminalFd uintptr, initialWidth, initialHeight int) {
   153  	previousWidth := initialWidth
   154  	previousHeight := initialHeight
   155  
   156  	for range s.SigWinchChannel {
   157  		width, height := s.Term.GetWinsize(terminalFd)
   158  
   159  		if width == previousWidth && height == previousHeight {
   160  			continue
   161  		}
   162  
   163  		session.Resize(width, height)
   164  
   165  		previousWidth = width
   166  		previousHeight = height
   167  	}
   168  }