9fans.net/go@v0.0.5/plan9/client/conn.go (about)

     1  //go:build !plan9
     2  // +build !plan9
     3  
     4  package client // import "9fans.net/go/plan9/client"
     5  
     6  import (
     7  	"fmt"
     8  	"io"
     9  	"sync"
    10  	"sync/atomic"
    11  
    12  	"9fans.net/go/plan9"
    13  )
    14  
    15  type Error string
    16  
    17  func (e Error) Error() string { return string(e) }
    18  
    19  type Conn struct {
    20  	// We wrap the underlying conn type so that
    21  	// there's a clear distinction between Close,
    22  	// which forces a close of the underlying rwc,
    23  	// and Release, which lets the Fids take control
    24  	// of when the conn is actually closed.
    25  	mu       sync.Mutex
    26  	_c       *conn
    27  	released bool
    28  }
    29  
    30  var errClosed = fmt.Errorf("connection has been closed")
    31  
    32  // Close forces a close of the connection and all Fids derived
    33  // from it.
    34  func (c *Conn) Close() error {
    35  	c.mu.Lock()
    36  	defer c.mu.Unlock()
    37  	if c._c == nil {
    38  		if c.released {
    39  			return fmt.Errorf("cannot close connection after it's been released")
    40  		}
    41  		return nil
    42  	}
    43  	rwc := c._c.rwc
    44  	c._c = nil
    45  	// TODO perhaps we shouldn't hold the mutex while closing?
    46  	return rwc.Close()
    47  }
    48  
    49  func (c *Conn) conn() (*conn, error) {
    50  	c.mu.Lock()
    51  	defer c.mu.Unlock()
    52  	if c._c == nil {
    53  		return nil, errClosed
    54  	}
    55  	return c._c, nil
    56  }
    57  
    58  // Release marks the connection so that it will
    59  // close automatically when the last Fid derived
    60  // from it is closed.
    61  //
    62  // If there are no current Fids, it closes immediately.
    63  // After calling Release, c.Attach, c.Auth and c.Close will return
    64  // an error.
    65  func (c *Conn) Release() error {
    66  	c.mu.Lock()
    67  	defer c.mu.Unlock()
    68  	if c._c == nil {
    69  		return nil
    70  	}
    71  	conn := c._c
    72  	c._c = nil
    73  	c.released = true
    74  	return conn.release()
    75  }
    76  
    77  type conn struct {
    78  	rwc      io.ReadWriteCloser
    79  	err      error
    80  	tagmap   map[uint16]chan *plan9.Fcall
    81  	freetag  map[uint16]bool
    82  	freefid  map[uint32]bool
    83  	nexttag  uint16
    84  	nextfid  uint32
    85  	msize    uint32
    86  	version  string
    87  	w, x     sync.Mutex
    88  	muxer    bool
    89  	refCount int32 // atomic
    90  }
    91  
    92  func NewConn(rwc io.ReadWriteCloser) (*Conn, error) {
    93  	c := &conn{
    94  		rwc:      rwc,
    95  		tagmap:   make(map[uint16]chan *plan9.Fcall),
    96  		freetag:  make(map[uint16]bool),
    97  		freefid:  make(map[uint32]bool),
    98  		nexttag:  1,
    99  		nextfid:  1,
   100  		msize:    131072,
   101  		version:  "9P2000",
   102  		refCount: 1,
   103  	}
   104  
   105  	//	XXX raw messages, not c.rpc
   106  	tx := &plan9.Fcall{Type: plan9.Tversion, Tag: plan9.NOTAG, Msize: c.msize, Version: c.version}
   107  	err := c.write(tx)
   108  	if err != nil {
   109  		return nil, err
   110  	}
   111  	rx, err := c.read()
   112  	if err != nil {
   113  		return nil, err
   114  	}
   115  	if rx.Type != plan9.Rversion || rx.Tag != plan9.NOTAG {
   116  		return nil, plan9.ProtocolError(fmt.Sprintf("invalid type/tag in Tversion exchange: %v %v", rx.Type, rx.Tag))
   117  	}
   118  
   119  	if rx.Msize > c.msize {
   120  		return nil, plan9.ProtocolError(fmt.Sprintf("invalid msize %d in Rversion", rx.Msize))
   121  	}
   122  	c.msize = rx.Msize
   123  	if rx.Version != "9P2000" {
   124  		return nil, plan9.ProtocolError(fmt.Sprintf("invalid version %s in Rversion", rx.Version))
   125  	}
   126  	return &Conn{
   127  		_c: c,
   128  	}, nil
   129  }
   130  
   131  func (c *conn) newFid(fid uint32, qid plan9.Qid) *Fid {
   132  	c.acquire()
   133  	return &Fid{
   134  		_c:  c,
   135  		fid: fid,
   136  		qid: qid,
   137  	}
   138  }
   139  
   140  func (c *conn) newfidnum() (uint32, error) {
   141  	c.x.Lock()
   142  	defer c.x.Unlock()
   143  	for fidnum := range c.freefid {
   144  		delete(c.freefid, fidnum)
   145  		return fidnum, nil
   146  	}
   147  	fidnum := c.nextfid
   148  	if c.nextfid == plan9.NOFID {
   149  		return 0, plan9.ProtocolError("out of fids")
   150  	}
   151  	c.nextfid++
   152  	return fidnum, nil
   153  }
   154  
   155  func (c *conn) putfidnum(fid uint32) {
   156  	c.x.Lock()
   157  	defer c.x.Unlock()
   158  	c.freefid[fid] = true
   159  }
   160  
   161  func (c *conn) newtag(ch chan *plan9.Fcall) (uint16, error) {
   162  	c.x.Lock()
   163  	defer c.x.Unlock()
   164  	var tagnum uint16
   165  	for tagnum = range c.freetag {
   166  		delete(c.freetag, tagnum)
   167  		goto found
   168  	}
   169  	tagnum = c.nexttag
   170  	if c.nexttag == plan9.NOTAG {
   171  		return 0, plan9.ProtocolError("out of tags")
   172  	}
   173  	c.nexttag++
   174  found:
   175  	c.tagmap[tagnum] = ch
   176  	if !c.muxer {
   177  		c.muxer = true
   178  		ch <- &yourTurn
   179  	}
   180  	return tagnum, nil
   181  }
   182  
   183  func (c *conn) puttag(tag uint16) chan *plan9.Fcall {
   184  	c.x.Lock()
   185  	defer c.x.Unlock()
   186  	ch := c.tagmap[tag]
   187  	delete(c.tagmap, tag)
   188  	c.freetag[tag] = true
   189  	return ch
   190  }
   191  
   192  func (c *conn) mux(rx *plan9.Fcall) {
   193  	c.x.Lock()
   194  	defer c.x.Unlock()
   195  
   196  	ch := c.tagmap[rx.Tag]
   197  	delete(c.tagmap, rx.Tag)
   198  	c.freetag[rx.Tag] = true
   199  	c.muxer = false
   200  	for _, ch2 := range c.tagmap {
   201  		c.muxer = true
   202  		ch2 <- &yourTurn
   203  		break
   204  	}
   205  	ch <- rx
   206  }
   207  
   208  func (c *conn) read() (*plan9.Fcall, error) {
   209  	if err := c.getErr(); err != nil {
   210  		return nil, err
   211  	}
   212  	f, err := plan9.ReadFcall(c.rwc)
   213  	if err != nil {
   214  		c.setErr(err)
   215  		return nil, err
   216  	}
   217  	return f, nil
   218  }
   219  
   220  func (c *conn) write(f *plan9.Fcall) error {
   221  	if err := c.getErr(); err != nil {
   222  		return err
   223  	}
   224  	err := plan9.WriteFcall(c.rwc, f)
   225  	if err != nil {
   226  		c.setErr(err)
   227  	}
   228  	return err
   229  }
   230  
   231  var yourTurn plan9.Fcall
   232  
   233  func (c *conn) rpc(tx *plan9.Fcall, clunkFid *Fid) (rx *plan9.Fcall, err error) {
   234  	ch := make(chan *plan9.Fcall, 1)
   235  	tx.Tag, err = c.newtag(ch)
   236  	if err != nil {
   237  		return nil, err
   238  	}
   239  	c.w.Lock()
   240  	err = c.write(tx)
   241  	// Mark the fid as clunked inside the write lock so that we're
   242  	// sure that we don't reuse it after the sending the message
   243  	// that will clunk it, even in the presence of concurrent method
   244  	// calls on Fid.
   245  	if clunkFid != nil {
   246  		// Closing the Fid might release the conn, which
   247  		// would close the underlying rwc connection,
   248  		// which would prevent us from being able to receive the
   249  		// reply, so make sure that doesn't happen until the end
   250  		// by acquiring a reference for the duration of the call.
   251  		c.acquire()
   252  		defer c.release()
   253  		if err := clunkFid.clunked(); err != nil {
   254  			// This can happen if two clunking operations
   255  			// (e.g. Close and Remove) are invoked concurrently
   256  			c.w.Unlock()
   257  			return nil, err
   258  		}
   259  	}
   260  	c.w.Unlock()
   261  	if err != nil {
   262  		return nil, err
   263  	}
   264  
   265  	for rx = range ch {
   266  		if rx != &yourTurn {
   267  			break
   268  		}
   269  		rx, err = c.read()
   270  		if err != nil {
   271  			break
   272  		}
   273  		c.mux(rx)
   274  	}
   275  
   276  	if rx == nil {
   277  		return nil, c.getErr()
   278  	}
   279  	if rx.Type == plan9.Rerror {
   280  		return nil, Error(rx.Ename)
   281  	}
   282  	if rx.Type != tx.Type+1 {
   283  		return nil, plan9.ProtocolError("packet type mismatch")
   284  	}
   285  	return rx, nil
   286  }
   287  
   288  func (c *conn) acquire() {
   289  	atomic.AddInt32(&c.refCount, 1)
   290  }
   291  
   292  func (c *conn) release() error {
   293  	if atomic.AddInt32(&c.refCount, -1) != 0 {
   294  		return nil
   295  	}
   296  	err := c.rwc.Close()
   297  	c.setErr(errClosed)
   298  	return err
   299  }
   300  
   301  func (c *conn) getErr() error {
   302  	c.x.Lock()
   303  	defer c.x.Unlock()
   304  	return c.err
   305  }
   306  
   307  func (c *conn) setErr(err error) {
   308  	c.x.Lock()
   309  	defer c.x.Unlock()
   310  	c.err = err
   311  }