github.com/vmware/govmomi@v0.37.2/toolbox/service_test.go (about)

     1  /*
     2  Copyright (c) 2017 VMware, Inc. All Rights Reserved.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package toolbox
    18  
    19  import (
    20  	"bytes"
    21  	"context"
    22  	"errors"
    23  	"flag"
    24  	"io"
    25  	"log"
    26  	"net"
    27  	"net/http"
    28  	"net/http/httptest"
    29  	"net/url"
    30  	"os"
    31  	"sync"
    32  	"testing"
    33  	"time"
    34  
    35  	"github.com/vmware/govmomi/toolbox/hgfs"
    36  	"github.com/vmware/govmomi/toolbox/process"
    37  	"github.com/vmware/govmomi/toolbox/vix"
    38  	"github.com/vmware/govmomi/vim25/types"
    39  )
    40  
    41  func TestDefaultIP(t *testing.T) {
    42  	ip := DefaultIP()
    43  	if ip == "" {
    44  		t.Error("failed to get a default IP address")
    45  	}
    46  }
    47  
    48  type testRPC struct {
    49  	cmd    string
    50  	expect string
    51  }
    52  
    53  type mockChannelIn struct {
    54  	t       *testing.T
    55  	service *Service
    56  	rpc     []*testRPC
    57  	wg      sync.WaitGroup
    58  	start   error
    59  	sendErr int
    60  	count   struct {
    61  		send  int
    62  		stop  int
    63  		start int
    64  	}
    65  }
    66  
    67  func (c *mockChannelIn) Start() error {
    68  	c.count.start++
    69  	return c.start
    70  }
    71  
    72  func (c *mockChannelIn) Stop() error {
    73  	c.count.stop++
    74  	return nil
    75  }
    76  
    77  func (c *mockChannelIn) Receive() ([]byte, error) {
    78  	if len(c.rpc) == 0 {
    79  		if c.rpc != nil {
    80  			// All test RPC requests have been consumed
    81  			c.wg.Done()
    82  			c.rpc = nil
    83  		}
    84  		return nil, io.EOF
    85  	}
    86  
    87  	return []byte(c.rpc[0].cmd), nil
    88  }
    89  
    90  func (c *mockChannelIn) Send(buf []byte) error {
    91  	if c.sendErr > 0 {
    92  		c.count.send++
    93  		if c.count.send%c.sendErr == 0 {
    94  			c.wg.Done()
    95  			return errors.New("rpci send error")
    96  		}
    97  	}
    98  
    99  	if buf == nil {
   100  		return nil
   101  	}
   102  
   103  	expect := c.rpc[0].expect
   104  	if string(buf) != expect {
   105  		c.t.Errorf("expected %q reply for request %q, got: %q", expect, c.rpc[0].cmd, buf)
   106  	}
   107  
   108  	c.rpc = c.rpc[1:]
   109  
   110  	return nil
   111  }
   112  
   113  // discard rpc out for now
   114  type mockChannelOut struct {
   115  	reply [][]byte
   116  	start error
   117  }
   118  
   119  func (c *mockChannelOut) Start() error {
   120  	return c.start
   121  }
   122  
   123  func (c *mockChannelOut) Stop() error {
   124  	return nil
   125  }
   126  
   127  func (c *mockChannelOut) Receive() ([]byte, error) {
   128  	if len(c.reply) == 0 {
   129  		return nil, io.EOF
   130  	}
   131  	reply := c.reply[0]
   132  	c.reply = c.reply[1:]
   133  	return reply, nil
   134  }
   135  
   136  func (c *mockChannelOut) Send(buf []byte) error {
   137  	if len(buf) == 0 {
   138  		return io.ErrShortBuffer
   139  	}
   140  	return nil
   141  }
   142  
   143  func TestServiceRun(t *testing.T) {
   144  	in := new(mockChannelIn)
   145  	out := new(mockChannelOut)
   146  
   147  	service := NewService(in, out)
   148  
   149  	in.rpc = []*testRPC{
   150  		{"reset", "OK ATR toolbox"},
   151  		{"ping", "OK "},
   152  		{"Set_Option synctime 0", "OK "},
   153  		{"Capabilities_Register", "OK "},
   154  		{"Set_Option broadcastIP 1", "OK "},
   155  	}
   156  
   157  	in.wg.Add(1)
   158  
   159  	// replies to register capabilities
   160  	for i := 0; i < len(capabilities); i++ {
   161  		out.reply = append(out.reply, rpciOK)
   162  	}
   163  
   164  	out.reply = append(out.reply,
   165  		rpciOK, // reply to SendGuestInfo call in Reset()
   166  		rpciOK, // reply to IP broadcast
   167  	)
   168  
   169  	in.service = service
   170  
   171  	in.t = t
   172  
   173  	err := service.Start()
   174  	if err != nil {
   175  		t.Fatal(err)
   176  	}
   177  
   178  	in.wg.Wait()
   179  
   180  	service.Stop()
   181  	service.Wait()
   182  
   183  	// verify we don't set delay > maxDelay
   184  	for i := 0; i <= maxDelay+1; i++ {
   185  		service.backoff()
   186  	}
   187  
   188  	if service.delay != maxDelay {
   189  		t.Errorf("delay=%d", service.delay)
   190  	}
   191  }
   192  
   193  func TestServiceErrors(t *testing.T) {
   194  	Trace = true
   195  	if !testing.Verbose() {
   196  		// cover TraceChannel but discard output
   197  		traceLog = io.Discard
   198  	}
   199  
   200  	netInterfaceAddrs = func() ([]net.Addr, error) {
   201  		return nil, io.EOF
   202  	}
   203  
   204  	in := new(mockChannelIn)
   205  	out := new(mockChannelOut)
   206  
   207  	service := NewService(in, out)
   208  
   209  	service.RegisterHandler("Sorry", func([]byte) ([]byte, error) {
   210  		return nil, errors.New("i am so sorry")
   211  	})
   212  
   213  	ip := ""
   214  	service.PrimaryIP = func() string {
   215  		if ip == "" {
   216  			ip = "127"
   217  		} else if ip == "127" {
   218  			ip = "127.0.0.1"
   219  		} else if ip == "127.0.0.1" {
   220  			ip = ""
   221  		}
   222  		return ip
   223  	}
   224  
   225  	in.rpc = []*testRPC{
   226  		{"Capabilities_Register", "OK "},
   227  		{"Set_Option broadcastIP 1", "ERR "},
   228  		{"Set_Option broadcastIP 1", "OK "},
   229  		{"Set_Option broadcastIP 1", "OK "},
   230  		{"NOPE", "Unknown Command"},
   231  		{"Sorry", "ERR "},
   232  	}
   233  
   234  	in.wg.Add(1)
   235  
   236  	// replies to register capabilities
   237  	for i := 0; i < len(capabilities); i++ {
   238  		out.reply = append(out.reply, rpciERR)
   239  	}
   240  
   241  	foo := []byte("foo")
   242  	out.reply = append(
   243  		out.reply,
   244  		rpciERR,
   245  		rpciOK,
   246  		rpciOK,
   247  		append(rpciOK, foo...),
   248  		rpciERR,
   249  	)
   250  
   251  	in.service = service
   252  
   253  	in.t = t
   254  
   255  	err := service.Start()
   256  	if err != nil {
   257  		t.Fatal(err)
   258  	}
   259  
   260  	in.wg.Wait()
   261  
   262  	// Done serving RPCs, test ChannelOut errors
   263  	reply, err := service.out.Request(rpciOK)
   264  	if err != nil {
   265  		t.Error(err)
   266  	}
   267  
   268  	if !bytes.Equal(reply, foo) {
   269  		t.Errorf("reply=%q", foo)
   270  	}
   271  
   272  	_, err = service.out.Request(rpciOK)
   273  	if err == nil {
   274  		t.Error("expected error")
   275  	}
   276  
   277  	_, err = service.out.Request(nil)
   278  	if err == nil {
   279  		t.Error("expected error")
   280  	}
   281  
   282  	service.Stop()
   283  	service.Wait()
   284  
   285  	// cover service start error paths
   286  	start := errors.New("fail")
   287  
   288  	in.start = start
   289  	err = service.Start()
   290  	if err != start {
   291  		t.Error("expected error")
   292  	}
   293  
   294  	in.start = nil
   295  	out.start = start
   296  	err = service.Start()
   297  	if err != start {
   298  		t.Error("expected error")
   299  	}
   300  }
   301  
   302  func TestServiceResetChannel(t *testing.T) {
   303  	in := new(mockChannelIn)
   304  	out := new(mockChannelOut)
   305  
   306  	service := NewService(in, out)
   307  
   308  	resetDelay = maxDelay
   309  
   310  	fails := 2
   311  	in.wg.Add(fails)
   312  	in.sendErr = 10
   313  
   314  	err := service.Start()
   315  	if err != nil {
   316  		t.Fatal(err)
   317  	}
   318  
   319  	in.wg.Wait()
   320  
   321  	service.Stop()
   322  	service.Wait()
   323  
   324  	expect := fails
   325  	if in.count.start != expect || in.count.stop != expect {
   326  		t.Errorf("count=%#v", in.count)
   327  	}
   328  }
   329  
   330  var (
   331  	testESX = flag.Bool("toolbox.testesx", false, "Test toolbox service against ESX (vmtoolsd must not be running)")
   332  	testPID = flag.Int64("toolbox.testpid", 0, "PID to return from toolbox start command")
   333  	testOn  = flag.String("toolbox.powerState", "", "Power state of VM prior to starting the test")
   334  )
   335  
   336  // echoHandler for testing hgfs.FileHandler
   337  type echoHandler struct{}
   338  
   339  func (e *echoHandler) Stat(u *url.URL) (os.FileInfo, error) {
   340  	if u.RawQuery == "" {
   341  		return nil, errors.New("no query")
   342  	}
   343  
   344  	if u.Query().Get("foo") != "bar" {
   345  		return nil, errors.New("invalid query")
   346  	}
   347  
   348  	return os.Stat(u.Path)
   349  }
   350  
   351  func (e *echoHandler) Open(u *url.URL, mode int32) (hgfs.File, error) {
   352  	_, err := e.Stat(u)
   353  	if err != nil {
   354  		return nil, err
   355  	}
   356  
   357  	return os.Open(u.Path)
   358  }
   359  
   360  func TestServiceRunESX(t *testing.T) {
   361  	if *testESX == false {
   362  		t.SkipNow()
   363  	}
   364  
   365  	Trace = testing.Verbose()
   366  
   367  	// A server that echos HTTP requests, for testing toolbox's http.RoundTripper
   368  	echo := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   369  		_ = r.Write(w)
   370  	}))
   371  	// Client side can use 'govc guest.getenv' to get the URL w/ random port
   372  	_ = os.Setenv("TOOLBOX_ECHO_SERVER", echo.URL)
   373  
   374  	var wg sync.WaitGroup
   375  
   376  	in := NewBackdoorChannelIn()
   377  	out := NewBackdoorChannelOut()
   378  
   379  	service := NewService(in, out)
   380  	service.Command.FileServer.RegisterFileHandler("echo", new(echoHandler))
   381  
   382  	ping := sync.NewCond(new(sync.Mutex))
   383  
   384  	service.RegisterHandler("ping", func(b []byte) ([]byte, error) {
   385  		ping.Broadcast()
   386  		return service.Ping(b)
   387  	})
   388  
   389  	// assert that reset, ping, Set_Option and Capabilities_Register are called at least once
   390  	for _, name := range []string{"reset", "ping", "Set_Option", "Capabilities_Register"} {
   391  		n := name
   392  		h := service.handlers[name]
   393  		wg.Add(1)
   394  
   395  		service.handlers[name] = func(b []byte) ([]byte, error) {
   396  			defer wg.Done()
   397  
   398  			service.handlers[n] = h // reset
   399  
   400  			return h(b)
   401  		}
   402  	}
   403  
   404  	if *testOn == string(types.VirtualMachinePowerStatePoweredOff) {
   405  		wg.Add(1)
   406  		service.Power.PowerOn.Handler = func() error {
   407  			defer wg.Done()
   408  			log.Print("power on event")
   409  			return nil
   410  		}
   411  	} else {
   412  		log.Print("skipping power on test")
   413  	}
   414  
   415  	if *testPID != 0 {
   416  		service.Command.ProcessStartCommand = func(m *process.Manager, r *vix.StartProgramRequest) (int64, error) {
   417  			wg.Add(1)
   418  			defer wg.Done()
   419  
   420  			switch r.ProgramPath {
   421  			case "/bin/date":
   422  				return *testPID, nil
   423  			case "sleep":
   424  				p := process.NewFunc(func(ctx context.Context, arg string) error {
   425  					d, err := time.ParseDuration(arg)
   426  					if err != nil {
   427  						return err
   428  					}
   429  
   430  					select {
   431  					case <-ctx.Done():
   432  						return &process.Error{Err: ctx.Err(), ExitCode: 42}
   433  					case <-time.After(d):
   434  					}
   435  
   436  					return nil
   437  				})
   438  				return m.Start(r, p)
   439  			default:
   440  				return DefaultStartCommand(m, r)
   441  			}
   442  		}
   443  	}
   444  
   445  	service.PrimaryIP = func() string {
   446  		log.Print("broadcasting IP")
   447  		return DefaultIP()
   448  	}
   449  
   450  	log.Print("starting toolbox service")
   451  	err := service.Start()
   452  	if err != nil {
   453  		log.Fatal(err)
   454  	}
   455  
   456  	wg.Wait()
   457  
   458  	// wait for 1 last ping to make sure the final response has reached the client before stopping
   459  	ping.L.Lock()
   460  	ping.Wait()
   461  	ping.L.Unlock()
   462  
   463  	service.Stop()
   464  	service.Wait()
   465  }