github.com/jlmucb/cloudproxy@v0.0.0-20170830161738-b5aa0b619bc4/go/apps/mixnet/queue.go (about)

     1  // Copyright (c) 2015, Google Inc. All rights reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package mixnet
    16  
    17  import (
    18  	"container/list"
    19  	"crypto/rand"
    20  	"encoding/binary"
    21  	"io"
    22  	"log"
    23  	"net"
    24  	"time"
    25  
    26  	"github.com/golang/glog"
    27  	"github.com/golang/protobuf/proto"
    28  	"github.com/jlmucb/cloudproxy/go/tao"
    29  )
    30  
    31  // The Queueable object is passed through a channel and mutates the state of
    32  // the Queue in some manner; for example, it can set the destination
    33  // adddress or connection of a sender, add a message or request for reply
    34  // to the queue, or destroy any resources associated with the connection.
    35  type Queueable struct {
    36  	id      uint64 // circuit id
    37  	msg     []byte
    38  	conn    net.Conn
    39  	errConn net.Conn
    40  	remove  bool
    41  	destroy bool
    42  }
    43  
    44  type sendQueueError struct {
    45  	id   uint64
    46  	conn net.Conn // where to send the error to
    47  	error
    48  }
    49  
    50  // The Queue structure maps a circuit identifier corresponding to a sender
    51  // (in the router context) to a destination. It also maintains a message buffer
    52  // for each sender. Once messages are ready on enough buffers, a batch of
    53  // messages are transmitted simultaneously.
    54  type Queue struct {
    55  	batchSize int // Number of messages to transmit in a round.
    56  	ct        int // Current number of buffers with messages ready.
    57  	// Tao to get the random bytes
    58  	// Might be okay to just use crypto/rand..
    59  	t tao.Tao
    60  
    61  	network string        // Network protocol, e.g. "tcp".
    62  	timeout time.Duration // Timeout on dial/read/write.
    63  
    64  	sendBuffer map[uint64]*list.List // Message buffer of sender.
    65  
    66  	queue chan *Queueable     // Channel for queueing messages/directives.
    67  	err   chan sendQueueError // Channel for handling errors.
    68  }
    69  
    70  // NewQueue creates a new Queue structure.
    71  func NewQueue(network string, t tao.Tao, batchSize int, timeout time.Duration) (sq *Queue) {
    72  	sq = new(Queue)
    73  	sq.batchSize = batchSize
    74  	sq.network = network
    75  	sq.t = t
    76  	sq.timeout = timeout
    77  
    78  	sq.sendBuffer = make(map[uint64]*list.List)
    79  
    80  	sq.queue = make(chan *Queueable)
    81  	sq.err = make(chan sendQueueError)
    82  	return sq
    83  }
    84  
    85  // Enqueue inserts a queueable object into the queue. Note that this is
    86  // generally unsafe to use concurrently because it doesn't make a copy of the
    87  // data.
    88  func (sq *Queue) Enqueue(q *Queueable) {
    89  	sq.queue <- q
    90  }
    91  
    92  // EnqueueMsg copies a byte slice into a queueable object and adds it to
    93  // the queue.
    94  func (sq *Queue) EnqueueMsg(id uint64, msg []byte, conn, errConn net.Conn) {
    95  	q := new(Queueable)
    96  	q.id = id
    97  	q.msg = make([]byte, len(msg))
    98  	copy(q.msg, msg)
    99  	q.conn = conn
   100  	q.errConn = errConn
   101  	sq.queue <- q
   102  }
   103  
   104  // Close creates a queueable object that sends the last msg in the circuit,
   105  // closes the connection and deletes all associated resources.
   106  func (sq *Queue) Close(id uint64, msg []byte, destroy bool, conn, errConn net.Conn) {
   107  	q := new(Queueable)
   108  	q.id = id
   109  	if msg != nil {
   110  		q.msg = make([]byte, len(msg))
   111  		copy(q.msg, msg)
   112  	}
   113  	q.remove = true
   114  	q.destroy = destroy
   115  	q.conn = conn
   116  	q.errConn = errConn
   117  	sq.queue <- q
   118  }
   119  
   120  func (sq *Queue) delete(q *Queueable) {
   121  	// Close the connection and delete all resources. Any subsequent
   122  	// messages or reply requests will cause an error.
   123  	if q.destroy {
   124  		// Wait for the client to kill the connection or timeout
   125  		if q.msg == nil {
   126  			q.conn.Close()
   127  		} else {
   128  			q.conn.SetDeadline(time.Now().Add(sq.timeout))
   129  			_, err := q.conn.Read([]byte{0})
   130  			if err != nil {
   131  				e, ok := err.(net.Error)
   132  				if err == io.EOF || (ok && e.Timeout()) {
   133  					// If it times out, and the connection
   134  					// is supposed to be closed,
   135  					// ignore it..
   136  					q.conn.Close()
   137  				}
   138  			}
   139  		}
   140  	}
   141  	if _, def := sq.sendBuffer[q.id]; def {
   142  		delete(sq.sendBuffer, q.id)
   143  	}
   144  }
   145  
   146  // DoQueue adds messages to a queue and transmits messages in batches. It also
   147  // provides an interface for receiving messages from a server. Typically a
   148  // message is a cell, but when the calling router is an exit point, the message
   149  // length is arbitrary. A batch is transmitted when there are messages on
   150  // batchSize distinct sender channels.
   151  func (sq *Queue) DoQueue(kill <-chan bool) {
   152  	for {
   153  		select {
   154  		case <-kill:
   155  			return
   156  		case q := <-sq.queue:
   157  			if q.msg != nil {
   158  				// Create a send buffer for the sender ID if it doesn't exist.
   159  				if _, def := sq.sendBuffer[q.id]; !def {
   160  					sq.sendBuffer[q.id] = list.New()
   161  				}
   162  				buf := sq.sendBuffer[q.id]
   163  
   164  				// The buffer was empty but now has a message ready; increment
   165  				// the counter.
   166  				if buf.Len() == 0 {
   167  					sq.ct++
   168  				}
   169  
   170  				// Add message to send buffer.
   171  				buf.PushBack(q)
   172  			} else if q.remove {
   173  				sq.delete(q)
   174  			}
   175  
   176  			// Transmit batches of messages.
   177  			for sq.ct >= sq.batchSize {
   178  				sq.dequeue()
   179  			}
   180  		}
   181  	}
   182  }
   183  
   184  // DoQueueErrorHandler handles errors produced by DoQueue by enqueing onto
   185  // queue a directive containing the error message.
   186  func (sq *Queue) DoQueueErrorHandler(queue *Queue, kill <-chan bool) {
   187  	for {
   188  		select {
   189  		case <-kill:
   190  			return
   191  		case err := <-sq.err:
   192  			if err.conn != nil {
   193  				var d Directive
   194  				d.Type = DirectiveType_ERROR.Enum()
   195  				d.Error = proto.String(err.Error())
   196  				cell, e := marshalDirective(err.id, &d)
   197  				if e != nil {
   198  					glog.Errorf("queue: %s\n", e)
   199  					return
   200  				}
   201  				queue.EnqueueMsg(err.id, cell, err.conn, nil)
   202  			} else {
   203  				glog.Errorf("client no. %d: %s\n", err.id, err)
   204  			}
   205  		}
   206  	}
   207  }
   208  
   209  // dequeue sends one message from each send buffer for each serial ID in a
   210  // random order. This is called by DoQueue and is not safe to call directly
   211  // elsewhere.
   212  func (sq *Queue) dequeue() {
   213  
   214  	// Shuffle the serial IDs.
   215  	pi := make([]int, sq.ct)
   216  	for i := 0; i < sq.ct; i++ { // Initialize a trivial permutation
   217  		pi[i] = i
   218  	}
   219  
   220  	for i := sq.ct - 1; i > 0; i-- { // Shuffle by random swaps
   221  		var b []byte
   222  		var err error = nil
   223  		if sq.t != nil {
   224  			b, err = sq.t.GetRandomBytes(8)
   225  			if err != nil {
   226  				glog.Error("Could not read random bytes from Tao")
   227  			}
   228  		}
   229  		if err != nil || sq.t == nil {
   230  			b = make([]byte, 8)
   231  			if _, err := rand.Read(b); err != nil {
   232  				// if we can't even get crypto/rand, fatal error
   233  				log.Fatal(err)
   234  			}
   235  		}
   236  		j := int(binary.LittleEndian.Uint64(b) % uint64(i+1))
   237  		if j != i {
   238  			tmp := pi[j]
   239  			pi[j] = pi[i]
   240  			pi[i] = tmp
   241  		}
   242  	}
   243  
   244  	ids := make([]uint64, sq.ct)
   245  	i := 0
   246  	for id, buf := range sq.sendBuffer {
   247  		if buf.Len() > 0 {
   248  			ids[pi[i]] = id
   249  			i++
   250  		}
   251  	}
   252  
   253  	// Issue a sendWorker thread for each message to be sent.
   254  	ch := make(chan senderResult)
   255  	for _, id := range ids[:sq.batchSize] {
   256  		q := sq.sendBuffer[id].Front().Value.(*Queueable)
   257  		go senderWorker(sq.network, q, ch, sq.err, sq.timeout)
   258  	}
   259  
   260  	// Wait for workers to finish.
   261  	for _ = range ids[:sq.batchSize] {
   262  		res := <-ch
   263  
   264  		// If this was close with a message, then remove q here
   265  		q := sq.sendBuffer[res.id].Front().Value.(*Queueable)
   266  		if q.remove {
   267  			sq.delete(q)
   268  		}
   269  
   270  		// Pop the message from the buffer and decrement the counter
   271  		// if the buffer is empty.
   272  		// Resource might be removed (circuit destroyed); check first
   273  		if buf, ok := sq.sendBuffer[res.id]; ok {
   274  			buf.Remove(buf.Front())
   275  			if buf.Len() == 0 {
   276  				sq.ct--
   277  			}
   278  		} else {
   279  			sq.ct--
   280  		}
   281  	}
   282  }
   283  
   284  type senderResult struct {
   285  	c  net.Conn
   286  	id uint64
   287  }
   288  
   289  func senderWorker(network string, q *Queueable,
   290  	res chan<- senderResult, err chan<- sendQueueError, timeout time.Duration) {
   291  	// Wait to connect until the queue is dequeued in order to prevent
   292  	// an observer from correlating an incoming cell with the handshake
   293  	// with the destination server.
   294  
   295  	q.conn.SetDeadline(time.Now().Add(timeout))
   296  	if q.msg != nil { // Send the message.
   297  		// if q.msg[TYPE] == 1 {
   298  		// 	var d Directive
   299  		// 	unmarshalDirective(q.msg, &d)
   300  		// 	fmt.Println("directive", q.id, *d.Type)
   301  		// 	if d.Error != nil {
   302  		// 		fmt.Println("error:", *d.Error)
   303  		// 	}
   304  		// }
   305  		if _, e := q.conn.Write(q.msg); e != nil {
   306  			err <- sendQueueError{q.id, q.errConn, e}
   307  			res <- senderResult{q.conn, q.id}
   308  			return
   309  		}
   310  	}
   311  
   312  	res <- senderResult{q.conn, q.id}
   313  }