github.com/coreos/mantle@v0.13.0/network/mockssh/mockssh_test.go (about)

     1  // Copyright 2017 CoreOS, Inc.
     2  // Copyright 2014 The Go Authors.
     3  //
     4  // Licensed under the Apache License, Version 2.0 (the "License");
     5  // you may not use this file except in compliance with the License.
     6  // You may obtain a copy of the License at
     7  //
     8  //     http://www.apache.org/licenses/LICENSE-2.0
     9  //
    10  // Unless required by applicable law or agreed to in writing, software
    11  // distributed under the License is distributed on an "AS IS" BASIS,
    12  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  // See the License for the specific language governing permissions and
    14  // limitations under the License.
    15  
    16  package mockssh
    17  
    18  import (
    19  	"bytes"
    20  	"fmt"
    21  	"io"
    22  	"os"
    23  	"reflect"
    24  	"runtime"
    25  	"strings"
    26  	"testing"
    27  	"time"
    28  
    29  	"golang.org/x/crypto/ssh"
    30  )
    31  
    32  func TestExitStatusZero(t *testing.T) {
    33  	client := NewMockClient(func(s *Session) {
    34  		s.Exit(0)
    35  	})
    36  	defer client.Close()
    37  
    38  	session, err := client.NewSession()
    39  	if err != nil {
    40  		t.Fatal(err)
    41  	}
    42  
    43  	if err := session.Run(""); err != nil {
    44  		t.Fatal(err)
    45  	}
    46  }
    47  
    48  func TestExitStatusNonzero(t *testing.T) {
    49  	client := NewMockClient(func(s *Session) {
    50  		s.Exit(42)
    51  	})
    52  	defer client.Close()
    53  
    54  	session, err := client.NewSession()
    55  	if err != nil {
    56  		t.Fatal(err)
    57  	}
    58  
    59  	err = session.Run("")
    60  	if ex, ok := err.(*ssh.ExitError); !ok {
    61  		t.Fatalf("unexpected error: %v", err)
    62  	} else if ex.ExitStatus() != 42 {
    63  		t.Fatalf("unexpected exit: %v", err)
    64  	}
    65  }
    66  
    67  func TestExitStatusMissing(t *testing.T) {
    68  	client := NewMockClient(func(s *Session) {
    69  		s.Close()
    70  	})
    71  	defer client.Close()
    72  
    73  	session, err := client.NewSession()
    74  	if err != nil {
    75  		t.Fatal(err)
    76  	}
    77  
    78  	err = session.Run("")
    79  	if _, ok := err.(*ssh.ExitMissingError); !ok {
    80  		t.Fatalf("unexpected error: %v", err)
    81  	}
    82  }
    83  
    84  func TestCat(t *testing.T) {
    85  	client := NewMockClient(func(s *Session) {
    86  		if _, err := io.Copy(s.Stdout, s.Stdin); err != nil {
    87  			t.Error(err)
    88  		}
    89  		s.Exit(0)
    90  	})
    91  	defer client.Close()
    92  
    93  	session, err := client.NewSession()
    94  	if err != nil {
    95  		t.Fatal(err)
    96  	}
    97  	const data = "hello world\n"
    98  	session.Stdin = strings.NewReader("hello world\n")
    99  
   100  	out, err := session.Output("")
   101  	if err != nil {
   102  		t.Fatal(err)
   103  	}
   104  	if string(out) != data {
   105  		t.Fatalf("unexpected output: %q; wanted %q", out, data)
   106  	}
   107  }
   108  
   109  func TestStderr(t *testing.T) {
   110  	client := NewMockClient(func(s *Session) {
   111  		io.WriteString(s.Stdout, "Stdout")
   112  		io.WriteString(s.Stderr, "Stderr")
   113  		s.Exit(0)
   114  	})
   115  	defer client.Close()
   116  
   117  	session, err := client.NewSession()
   118  	if err != nil {
   119  		t.Fatal(err)
   120  	}
   121  	var stdout, stderr bytes.Buffer
   122  	session.Stdout = &stdout
   123  	session.Stderr = &stderr
   124  
   125  	if err := session.Run(""); err != nil {
   126  		t.Fatal(err)
   127  	}
   128  
   129  	if stdout.String() != "Stdout" {
   130  		t.Errorf("got %q wanted %q", stdout.String(), "Stdout")
   131  	}
   132  	if stderr.String() != "Stderr" {
   133  		t.Errorf("got %q wanted %q", stderr.String(), "Stderr")
   134  	}
   135  }
   136  
   137  func TestExec(t *testing.T) {
   138  	const cmd = "test command"
   139  	client := NewMockClient(func(s *Session) {
   140  		if s.Exec != cmd {
   141  			t.Errorf("got %q wanted %q", s.Exec, cmd)
   142  		}
   143  		s.Exit(0)
   144  	})
   145  	defer client.Close()
   146  
   147  	session, err := client.NewSession()
   148  	if err != nil {
   149  		t.Fatal(err)
   150  	}
   151  
   152  	if err := session.Run(cmd); err != nil {
   153  		t.Fatal(err)
   154  	}
   155  }
   156  
   157  func TestEnv(t *testing.T) {
   158  	expect := []string{"VAR1=VALUE1", "VAR2=VALUE2"}
   159  	client := NewMockClient(func(s *Session) {
   160  		if !reflect.DeepEqual(s.Env, expect) {
   161  			t.Errorf("got %v wanted %v", s.Env, expect)
   162  		}
   163  		s.Exit(0)
   164  	})
   165  	defer client.Close()
   166  
   167  	session, err := client.NewSession()
   168  	if err != nil {
   169  		t.Fatal(err)
   170  	}
   171  
   172  	if err := session.Setenv("VAR1", "VALUE1"); err != nil {
   173  		t.Fatal(err)
   174  	}
   175  
   176  	if err := session.Setenv("VAR2", "VALUE2"); err != nil {
   177  		t.Fatal(err)
   178  	}
   179  
   180  	if err := session.Run(""); err != nil {
   181  		t.Fatal(err)
   182  	}
   183  }
   184  
   185  // shell is not implemented, confirm it fails.
   186  func TestShell(t *testing.T) {
   187  	client := NewMockClient(func(s *Session) {
   188  		t.Errorf("executed shell")
   189  		s.Exit(0)
   190  	})
   191  	defer client.Close()
   192  
   193  	session, err := client.NewSession()
   194  	if err != nil {
   195  		t.Fatal(err)
   196  	}
   197  
   198  	if err := session.Shell(); err == nil {
   199  		t.Fatal("shell succeeded")
   200  	}
   201  }
   202  
   203  // The server code spawns a lot of goroutines and all of them need to
   204  // gracefully terminate when the client is closed.
   205  func TestMain(m *testing.M) {
   206  	g0 := runtime.NumGoroutine()
   207  
   208  	code := m.Run()
   209  	if code != 0 {
   210  		os.Exit(code)
   211  	}
   212  
   213  	// Check that there are no goroutines left behind.
   214  	t0 := time.Now()
   215  	stacks := make([]byte, 1<<20)
   216  	for {
   217  		g1 := runtime.NumGoroutine()
   218  		if g1 == g0 {
   219  			return
   220  		}
   221  		stacks = stacks[:runtime.Stack(stacks, true)]
   222  		time.Sleep(50 * time.Millisecond)
   223  		if time.Since(t0) > 2*time.Second {
   224  			fmt.Fprintf(os.Stderr, "Unexpected leftover goroutines detected: %v -> %v\n%s\n", g0, g1, stacks)
   225  			os.Exit(1)
   226  		}
   227  	}
   228  }