github.com/phobos182/packer@v0.2.3-0.20130819023704-c84d2aeffc68/packer/rpc/communicator.go (about)

     1  package rpc
     2  
     3  import (
     4  	"encoding/gob"
     5  	"errors"
     6  	"github.com/mitchellh/packer/packer"
     7  	"io"
     8  	"log"
     9  	"net"
    10  	"net/rpc"
    11  )
    12  
    13  // An implementation of packer.Communicator where the communicator is actually
    14  // executed over an RPC connection.
    15  type communicator struct {
    16  	client *rpc.Client
    17  }
    18  
    19  // CommunicatorServer wraps a packer.Communicator implementation and makes
    20  // it exportable as part of a Golang RPC server.
    21  type CommunicatorServer struct {
    22  	c packer.Communicator
    23  }
    24  
    25  type CommandFinished struct {
    26  	ExitStatus int
    27  }
    28  
    29  type CommunicatorStartArgs struct {
    30  	Command         string
    31  	StdinAddress    string
    32  	StdoutAddress   string
    33  	StderrAddress   string
    34  	ResponseAddress string
    35  }
    36  
    37  type CommunicatorDownloadArgs struct {
    38  	Path          string
    39  	WriterAddress string
    40  }
    41  
    42  type CommunicatorUploadArgs struct {
    43  	Path          string
    44  	ReaderAddress string
    45  }
    46  
    47  func Communicator(client *rpc.Client) *communicator {
    48  	return &communicator{client}
    49  }
    50  
    51  func (c *communicator) Start(cmd *packer.RemoteCmd) (err error) {
    52  	var args CommunicatorStartArgs
    53  	args.Command = cmd.Command
    54  
    55  	if cmd.Stdin != nil {
    56  		stdinL := netListenerInRange(portRangeMin, portRangeMax)
    57  		args.StdinAddress = stdinL.Addr().String()
    58  		go serveSingleCopy("stdin", stdinL, nil, cmd.Stdin)
    59  	}
    60  
    61  	if cmd.Stdout != nil {
    62  		stdoutL := netListenerInRange(portRangeMin, portRangeMax)
    63  		args.StdoutAddress = stdoutL.Addr().String()
    64  		go serveSingleCopy("stdout", stdoutL, cmd.Stdout, nil)
    65  	}
    66  
    67  	if cmd.Stderr != nil {
    68  		stderrL := netListenerInRange(portRangeMin, portRangeMax)
    69  		args.StderrAddress = stderrL.Addr().String()
    70  		go serveSingleCopy("stderr", stderrL, cmd.Stderr, nil)
    71  	}
    72  
    73  	responseL := netListenerInRange(portRangeMin, portRangeMax)
    74  	args.ResponseAddress = responseL.Addr().String()
    75  
    76  	go func() {
    77  		defer responseL.Close()
    78  
    79  		conn, err := responseL.Accept()
    80  		if err != nil {
    81  			log.Panic(err)
    82  		}
    83  
    84  		defer conn.Close()
    85  
    86  		decoder := gob.NewDecoder(conn)
    87  
    88  		var finished CommandFinished
    89  		if err := decoder.Decode(&finished); err != nil {
    90  			log.Panic(err)
    91  		}
    92  
    93  		cmd.SetExited(finished.ExitStatus)
    94  	}()
    95  
    96  	err = c.client.Call("Communicator.Start", &args, new(interface{}))
    97  	return
    98  }
    99  
   100  func (c *communicator) Upload(path string, r io.Reader) (err error) {
   101  	// We need to create a server that can proxy the reader data
   102  	// over because we can't simply gob encode an io.Reader
   103  	readerL := netListenerInRange(portRangeMin, portRangeMax)
   104  	if readerL == nil {
   105  		err = errors.New("couldn't allocate listener for upload reader")
   106  		return
   107  	}
   108  
   109  	// Make sure at the end of this call, we close the listener
   110  	defer readerL.Close()
   111  
   112  	// Pipe the reader through to the connection
   113  	go serveSingleCopy("uploadReader", readerL, nil, r)
   114  
   115  	args := CommunicatorUploadArgs{
   116  		path,
   117  		readerL.Addr().String(),
   118  	}
   119  
   120  	err = c.client.Call("Communicator.Upload", &args, new(interface{}))
   121  	return
   122  }
   123  
   124  func (c *communicator) Download(path string, w io.Writer) (err error) {
   125  	// We need to create a server that can proxy that data downloaded
   126  	// into the writer because we can't gob encode a writer directly.
   127  	writerL := netListenerInRange(portRangeMin, portRangeMax)
   128  	if writerL == nil {
   129  		err = errors.New("couldn't allocate listener for download writer")
   130  		return
   131  	}
   132  
   133  	// Make sure we close the listener once we're done because we'll be done
   134  	defer writerL.Close()
   135  
   136  	// Serve a single connection and a single copy
   137  	go serveSingleCopy("downloadWriter", writerL, w, nil)
   138  
   139  	args := CommunicatorDownloadArgs{
   140  		path,
   141  		writerL.Addr().String(),
   142  	}
   143  
   144  	err = c.client.Call("Communicator.Download", &args, new(interface{}))
   145  	return
   146  }
   147  
   148  func (c *CommunicatorServer) Start(args *CommunicatorStartArgs, reply *interface{}) (err error) {
   149  	// Build the RemoteCmd on this side so that it all pipes over
   150  	// to the remote side.
   151  	var cmd packer.RemoteCmd
   152  	cmd.Command = args.Command
   153  
   154  	toClose := make([]net.Conn, 0)
   155  	if args.StdinAddress != "" {
   156  		stdinC, err := net.Dial("tcp", args.StdinAddress)
   157  		if err != nil {
   158  			return err
   159  		}
   160  
   161  		toClose = append(toClose, stdinC)
   162  		cmd.Stdin = stdinC
   163  	}
   164  
   165  	if args.StdoutAddress != "" {
   166  		stdoutC, err := net.Dial("tcp", args.StdoutAddress)
   167  		if err != nil {
   168  			return err
   169  		}
   170  
   171  		toClose = append(toClose, stdoutC)
   172  		cmd.Stdout = stdoutC
   173  	}
   174  
   175  	if args.StderrAddress != "" {
   176  		stderrC, err := net.Dial("tcp", args.StderrAddress)
   177  		if err != nil {
   178  			return err
   179  		}
   180  
   181  		toClose = append(toClose, stderrC)
   182  		cmd.Stderr = stderrC
   183  	}
   184  
   185  	// Connect to the response address so we can write our result to it
   186  	// when ready.
   187  	responseC, err := net.Dial("tcp", args.ResponseAddress)
   188  	if err != nil {
   189  		return err
   190  	}
   191  
   192  	responseWriter := gob.NewEncoder(responseC)
   193  
   194  	// Start the actual command
   195  	err = c.c.Start(&cmd)
   196  
   197  	// Start a goroutine to spin and wait for the process to actual
   198  	// exit. When it does, report it back to caller...
   199  	go func() {
   200  		defer responseC.Close()
   201  		for _, conn := range toClose {
   202  			defer conn.Close()
   203  		}
   204  
   205  		cmd.Wait()
   206  		responseWriter.Encode(&CommandFinished{cmd.ExitStatus})
   207  	}()
   208  
   209  	return
   210  }
   211  
   212  func (c *CommunicatorServer) Upload(args *CommunicatorUploadArgs, reply *interface{}) (err error) {
   213  	readerC, err := net.Dial("tcp", args.ReaderAddress)
   214  	if err != nil {
   215  		return
   216  	}
   217  
   218  	defer readerC.Close()
   219  
   220  	err = c.c.Upload(args.Path, readerC)
   221  	return
   222  }
   223  
   224  func (c *CommunicatorServer) Download(args *CommunicatorDownloadArgs, reply *interface{}) (err error) {
   225  	writerC, err := net.Dial("tcp", args.WriterAddress)
   226  	if err != nil {
   227  		return
   228  	}
   229  
   230  	defer writerC.Close()
   231  
   232  	err = c.c.Download(args.Path, writerC)
   233  	return
   234  }
   235  
   236  func serveSingleCopy(name string, l net.Listener, dst io.Writer, src io.Reader) {
   237  	defer l.Close()
   238  
   239  	conn, err := l.Accept()
   240  	if err != nil {
   241  		log.Printf("'%s' accept error: %s", name, err)
   242  		return
   243  	}
   244  
   245  	// Be sure to close the connection after we're done copying so
   246  	// that an EOF will successfully be sent to the remote side
   247  	defer conn.Close()
   248  
   249  	// The connection is the destination/source that is nil
   250  	if dst == nil {
   251  		dst = conn
   252  	} else {
   253  		src = conn
   254  	}
   255  
   256  	written, err := io.Copy(dst, src)
   257  	log.Printf("%d bytes written for '%s'", written, name)
   258  	if err != nil {
   259  		log.Printf("'%s' copy error: %s", name, err)
   260  	}
   261  }