github.com/nsqio/nsq@v1.3.0/nsqd/protocol_v2.go (about)

     1  package nsqd
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/binary"
     6  	"encoding/json"
     7  	"errors"
     8  	"fmt"
     9  	"io"
    10  	"math/rand"
    11  	"net"
    12  	"sync/atomic"
    13  	"time"
    14  	"unsafe"
    15  
    16  	"github.com/nsqio/nsq/internal/protocol"
    17  	"github.com/nsqio/nsq/internal/version"
    18  )
    19  
    20  const (
    21  	frameTypeResponse int32 = 0
    22  	frameTypeError    int32 = 1
    23  	frameTypeMessage  int32 = 2
    24  )
    25  
    26  var separatorBytes = []byte(" ")
    27  var heartbeatBytes = []byte("_heartbeat_")
    28  var okBytes = []byte("OK")
    29  
    30  type protocolV2 struct {
    31  	nsqd *NSQD
    32  }
    33  
    34  func (p *protocolV2) NewClient(conn net.Conn) protocol.Client {
    35  	clientID := atomic.AddInt64(&p.nsqd.clientIDSequence, 1)
    36  	return newClientV2(clientID, conn, p.nsqd)
    37  }
    38  
    39  func (p *protocolV2) IOLoop(c protocol.Client) error {
    40  	var err error
    41  	var line []byte
    42  	var zeroTime time.Time
    43  
    44  	client := c.(*clientV2)
    45  
    46  	// synchronize the startup of messagePump in order
    47  	// to guarantee that it gets a chance to initialize
    48  	// goroutine local state derived from client attributes
    49  	// and avoid a potential race with IDENTIFY (where a client
    50  	// could have changed or disabled said attributes)
    51  	messagePumpStartedChan := make(chan bool)
    52  	go p.messagePump(client, messagePumpStartedChan)
    53  	<-messagePumpStartedChan
    54  
    55  	for {
    56  		if client.HeartbeatInterval > 0 {
    57  			client.SetReadDeadline(time.Now().Add(client.HeartbeatInterval * 2))
    58  		} else {
    59  			client.SetReadDeadline(zeroTime)
    60  		}
    61  
    62  		// ReadSlice does not allocate new space for the data each request
    63  		// ie. the returned slice is only valid until the next call to it
    64  		line, err = client.Reader.ReadSlice('\n')
    65  		if err != nil {
    66  			if err == io.EOF {
    67  				err = nil
    68  			} else {
    69  				err = fmt.Errorf("failed to read command - %s", err)
    70  			}
    71  			break
    72  		}
    73  
    74  		// trim the '\n'
    75  		line = line[:len(line)-1]
    76  		// optionally trim the '\r'
    77  		if len(line) > 0 && line[len(line)-1] == '\r' {
    78  			line = line[:len(line)-1]
    79  		}
    80  		params := bytes.Split(line, separatorBytes)
    81  
    82  		p.nsqd.logf(LOG_DEBUG, "PROTOCOL(V2): [%s] %s", client, params)
    83  
    84  		var response []byte
    85  		response, err = p.Exec(client, params)
    86  		if err != nil {
    87  			ctx := ""
    88  			if parentErr := err.(protocol.ChildErr).Parent(); parentErr != nil {
    89  				ctx = " - " + parentErr.Error()
    90  			}
    91  			p.nsqd.logf(LOG_ERROR, "[%s] - %s%s", client, err, ctx)
    92  
    93  			sendErr := p.Send(client, frameTypeError, []byte(err.Error()))
    94  			if sendErr != nil {
    95  				p.nsqd.logf(LOG_ERROR, "[%s] - %s%s", client, sendErr, ctx)
    96  				break
    97  			}
    98  
    99  			// errors of type FatalClientErr should forceably close the connection
   100  			if _, ok := err.(*protocol.FatalClientErr); ok {
   101  				break
   102  			}
   103  			continue
   104  		}
   105  
   106  		if response != nil {
   107  			err = p.Send(client, frameTypeResponse, response)
   108  			if err != nil {
   109  				err = fmt.Errorf("failed to send response - %s", err)
   110  				break
   111  			}
   112  		}
   113  	}
   114  
   115  	p.nsqd.logf(LOG_INFO, "PROTOCOL(V2): [%s] exiting ioloop", client)
   116  	close(client.ExitChan)
   117  	if client.Channel != nil {
   118  		client.Channel.RemoveClient(client.ID)
   119  	}
   120  
   121  	return err
   122  }
   123  
   124  func (p *protocolV2) SendMessage(client *clientV2, msg *Message) error {
   125  	p.nsqd.logf(LOG_DEBUG, "PROTOCOL(V2): writing msg(%s) to client(%s) - %s", msg.ID, client, msg.Body)
   126  
   127  	buf := bufferPoolGet()
   128  	defer bufferPoolPut(buf)
   129  
   130  	_, err := msg.WriteTo(buf)
   131  	if err != nil {
   132  		return err
   133  	}
   134  
   135  	err = p.Send(client, frameTypeMessage, buf.Bytes())
   136  	if err != nil {
   137  		return err
   138  	}
   139  
   140  	return nil
   141  }
   142  
   143  func (p *protocolV2) Send(client *clientV2, frameType int32, data []byte) error {
   144  	client.writeLock.Lock()
   145  
   146  	var zeroTime time.Time
   147  	if client.HeartbeatInterval > 0 {
   148  		client.SetWriteDeadline(time.Now().Add(client.HeartbeatInterval))
   149  	} else {
   150  		client.SetWriteDeadline(zeroTime)
   151  	}
   152  
   153  	_, err := protocol.SendFramedResponse(client.Writer, frameType, data)
   154  	if err != nil {
   155  		client.writeLock.Unlock()
   156  		return err
   157  	}
   158  
   159  	if frameType != frameTypeMessage {
   160  		err = client.Flush()
   161  	}
   162  
   163  	client.writeLock.Unlock()
   164  
   165  	return err
   166  }
   167  
   168  func (p *protocolV2) Exec(client *clientV2, params [][]byte) ([]byte, error) {
   169  	if bytes.Equal(params[0], []byte("IDENTIFY")) {
   170  		return p.IDENTIFY(client, params)
   171  	}
   172  	err := enforceTLSPolicy(client, p, params[0])
   173  	if err != nil {
   174  		return nil, err
   175  	}
   176  	switch {
   177  	case bytes.Equal(params[0], []byte("FIN")):
   178  		return p.FIN(client, params)
   179  	case bytes.Equal(params[0], []byte("RDY")):
   180  		return p.RDY(client, params)
   181  	case bytes.Equal(params[0], []byte("REQ")):
   182  		return p.REQ(client, params)
   183  	case bytes.Equal(params[0], []byte("PUB")):
   184  		return p.PUB(client, params)
   185  	case bytes.Equal(params[0], []byte("MPUB")):
   186  		return p.MPUB(client, params)
   187  	case bytes.Equal(params[0], []byte("DPUB")):
   188  		return p.DPUB(client, params)
   189  	case bytes.Equal(params[0], []byte("NOP")):
   190  		return p.NOP(client, params)
   191  	case bytes.Equal(params[0], []byte("TOUCH")):
   192  		return p.TOUCH(client, params)
   193  	case bytes.Equal(params[0], []byte("SUB")):
   194  		return p.SUB(client, params)
   195  	case bytes.Equal(params[0], []byte("CLS")):
   196  		return p.CLS(client, params)
   197  	case bytes.Equal(params[0], []byte("AUTH")):
   198  		return p.AUTH(client, params)
   199  	}
   200  	return nil, protocol.NewFatalClientErr(nil, "E_INVALID", fmt.Sprintf("invalid command %s", params[0]))
   201  }
   202  
   203  func (p *protocolV2) messagePump(client *clientV2, startedChan chan bool) {
   204  	var err error
   205  	var memoryMsgChan chan *Message
   206  	var backendMsgChan <-chan []byte
   207  	var subChannel *Channel
   208  	// NOTE: `flusherChan` is used to bound message latency for
   209  	// the pathological case of a channel on a low volume topic
   210  	// with >1 clients having >1 RDY counts
   211  	var flusherChan <-chan time.Time
   212  	var sampleRate int32
   213  
   214  	subEventChan := client.SubEventChan
   215  	identifyEventChan := client.IdentifyEventChan
   216  	outputBufferTicker := time.NewTicker(client.OutputBufferTimeout)
   217  	heartbeatTicker := time.NewTicker(client.HeartbeatInterval)
   218  	heartbeatChan := heartbeatTicker.C
   219  	msgTimeout := client.MsgTimeout
   220  
   221  	// v2 opportunistically buffers data to clients to reduce write system calls
   222  	// we force flush in two cases:
   223  	//    1. when the client is not ready to receive messages
   224  	//    2. we're buffered and the channel has nothing left to send us
   225  	//       (ie. we would block in this loop anyway)
   226  	//
   227  	flushed := true
   228  
   229  	// signal to the goroutine that started the messagePump
   230  	// that we've started up
   231  	close(startedChan)
   232  
   233  	for {
   234  		if subChannel == nil || !client.IsReadyForMessages() {
   235  			// the client is not ready to receive messages...
   236  			memoryMsgChan = nil
   237  			backendMsgChan = nil
   238  			flusherChan = nil
   239  			// force flush
   240  			client.writeLock.Lock()
   241  			err = client.Flush()
   242  			client.writeLock.Unlock()
   243  			if err != nil {
   244  				goto exit
   245  			}
   246  			flushed = true
   247  		} else if flushed {
   248  			// last iteration we flushed...
   249  			// do not select on the flusher ticker channel
   250  			memoryMsgChan = subChannel.memoryMsgChan
   251  			backendMsgChan = subChannel.backend.ReadChan()
   252  			flusherChan = nil
   253  		} else {
   254  			// we're buffered (if there isn't any more data we should flush)...
   255  			// select on the flusher ticker channel, too
   256  			memoryMsgChan = subChannel.memoryMsgChan
   257  			backendMsgChan = subChannel.backend.ReadChan()
   258  			flusherChan = outputBufferTicker.C
   259  		}
   260  
   261  		select {
   262  		case <-flusherChan:
   263  			// if this case wins, we're either starved
   264  			// or we won the race between other channels...
   265  			// in either case, force flush
   266  			client.writeLock.Lock()
   267  			err = client.Flush()
   268  			client.writeLock.Unlock()
   269  			if err != nil {
   270  				goto exit
   271  			}
   272  			flushed = true
   273  		case <-client.ReadyStateChan:
   274  		case subChannel = <-subEventChan:
   275  			// you can't SUB anymore
   276  			subEventChan = nil
   277  		case identifyData := <-identifyEventChan:
   278  			// you can't IDENTIFY anymore
   279  			identifyEventChan = nil
   280  
   281  			outputBufferTicker.Stop()
   282  			if identifyData.OutputBufferTimeout > 0 {
   283  				outputBufferTicker = time.NewTicker(identifyData.OutputBufferTimeout)
   284  			}
   285  
   286  			heartbeatTicker.Stop()
   287  			heartbeatChan = nil
   288  			if identifyData.HeartbeatInterval > 0 {
   289  				heartbeatTicker = time.NewTicker(identifyData.HeartbeatInterval)
   290  				heartbeatChan = heartbeatTicker.C
   291  			}
   292  
   293  			if identifyData.SampleRate > 0 {
   294  				sampleRate = identifyData.SampleRate
   295  			}
   296  
   297  			msgTimeout = identifyData.MsgTimeout
   298  		case <-heartbeatChan:
   299  			err = p.Send(client, frameTypeResponse, heartbeatBytes)
   300  			if err != nil {
   301  				goto exit
   302  			}
   303  		case b := <-backendMsgChan:
   304  			if sampleRate > 0 && rand.Int31n(100) > sampleRate {
   305  				continue
   306  			}
   307  
   308  			msg, err := decodeMessage(b)
   309  			if err != nil {
   310  				p.nsqd.logf(LOG_ERROR, "failed to decode message - %s", err)
   311  				continue
   312  			}
   313  			msg.Attempts++
   314  
   315  			subChannel.StartInFlightTimeout(msg, client.ID, msgTimeout)
   316  			client.SendingMessage()
   317  			err = p.SendMessage(client, msg)
   318  			if err != nil {
   319  				goto exit
   320  			}
   321  			flushed = false
   322  		case msg := <-memoryMsgChan:
   323  			if sampleRate > 0 && rand.Int31n(100) > sampleRate {
   324  				continue
   325  			}
   326  			msg.Attempts++
   327  
   328  			subChannel.StartInFlightTimeout(msg, client.ID, msgTimeout)
   329  			client.SendingMessage()
   330  			err = p.SendMessage(client, msg)
   331  			if err != nil {
   332  				goto exit
   333  			}
   334  			flushed = false
   335  		case <-client.ExitChan:
   336  			goto exit
   337  		}
   338  	}
   339  
   340  exit:
   341  	p.nsqd.logf(LOG_INFO, "PROTOCOL(V2): [%s] exiting messagePump", client)
   342  	heartbeatTicker.Stop()
   343  	outputBufferTicker.Stop()
   344  	if err != nil {
   345  		p.nsqd.logf(LOG_ERROR, "PROTOCOL(V2): [%s] messagePump error - %s", client, err)
   346  	}
   347  }
   348  
   349  func (p *protocolV2) IDENTIFY(client *clientV2, params [][]byte) ([]byte, error) {
   350  	var err error
   351  
   352  	if atomic.LoadInt32(&client.State) != stateInit {
   353  		return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "cannot IDENTIFY in current state")
   354  	}
   355  
   356  	bodyLen, err := readLen(client.Reader, client.lenSlice)
   357  	if err != nil {
   358  		return nil, protocol.NewFatalClientErr(err, "E_BAD_BODY", "IDENTIFY failed to read body size")
   359  	}
   360  
   361  	if int64(bodyLen) > p.nsqd.getOpts().MaxBodySize {
   362  		return nil, protocol.NewFatalClientErr(nil, "E_BAD_BODY",
   363  			fmt.Sprintf("IDENTIFY body too big %d > %d", bodyLen, p.nsqd.getOpts().MaxBodySize))
   364  	}
   365  
   366  	if bodyLen <= 0 {
   367  		return nil, protocol.NewFatalClientErr(nil, "E_BAD_BODY",
   368  			fmt.Sprintf("IDENTIFY invalid body size %d", bodyLen))
   369  	}
   370  
   371  	body := make([]byte, bodyLen)
   372  	_, err = io.ReadFull(client.Reader, body)
   373  	if err != nil {
   374  		return nil, protocol.NewFatalClientErr(err, "E_BAD_BODY", "IDENTIFY failed to read body")
   375  	}
   376  
   377  	// body is a json structure with producer information
   378  	var identifyData identifyDataV2
   379  	err = json.Unmarshal(body, &identifyData)
   380  	if err != nil {
   381  		return nil, protocol.NewFatalClientErr(err, "E_BAD_BODY", "IDENTIFY failed to decode JSON body")
   382  	}
   383  
   384  	p.nsqd.logf(LOG_DEBUG, "PROTOCOL(V2): [%s] %+v", client, identifyData)
   385  
   386  	err = client.Identify(identifyData)
   387  	if err != nil {
   388  		return nil, protocol.NewFatalClientErr(err, "E_BAD_BODY", "IDENTIFY "+err.Error())
   389  	}
   390  
   391  	// bail out early if we're not negotiating features
   392  	if !identifyData.FeatureNegotiation {
   393  		return okBytes, nil
   394  	}
   395  
   396  	tlsv1 := p.nsqd.tlsConfig != nil && identifyData.TLSv1
   397  	deflate := p.nsqd.getOpts().DeflateEnabled && identifyData.Deflate
   398  	deflateLevel := 6
   399  	if deflate && identifyData.DeflateLevel > 0 {
   400  		deflateLevel = identifyData.DeflateLevel
   401  	}
   402  	if max := p.nsqd.getOpts().MaxDeflateLevel; max < deflateLevel {
   403  		deflateLevel = max
   404  	}
   405  	snappy := p.nsqd.getOpts().SnappyEnabled && identifyData.Snappy
   406  
   407  	if deflate && snappy {
   408  		return nil, protocol.NewFatalClientErr(nil, "E_IDENTIFY_FAILED", "cannot enable both deflate and snappy compression")
   409  	}
   410  
   411  	resp, err := json.Marshal(struct {
   412  		MaxRdyCount         int64  `json:"max_rdy_count"`
   413  		Version             string `json:"version"`
   414  		MaxMsgTimeout       int64  `json:"max_msg_timeout"`
   415  		MsgTimeout          int64  `json:"msg_timeout"`
   416  		TLSv1               bool   `json:"tls_v1"`
   417  		Deflate             bool   `json:"deflate"`
   418  		DeflateLevel        int    `json:"deflate_level"`
   419  		MaxDeflateLevel     int    `json:"max_deflate_level"`
   420  		Snappy              bool   `json:"snappy"`
   421  		SampleRate          int32  `json:"sample_rate"`
   422  		AuthRequired        bool   `json:"auth_required"`
   423  		OutputBufferSize    int    `json:"output_buffer_size"`
   424  		OutputBufferTimeout int64  `json:"output_buffer_timeout"`
   425  	}{
   426  		MaxRdyCount:         p.nsqd.getOpts().MaxRdyCount,
   427  		Version:             version.Binary,
   428  		MaxMsgTimeout:       int64(p.nsqd.getOpts().MaxMsgTimeout / time.Millisecond),
   429  		MsgTimeout:          int64(client.MsgTimeout / time.Millisecond),
   430  		TLSv1:               tlsv1,
   431  		Deflate:             deflate,
   432  		DeflateLevel:        deflateLevel,
   433  		MaxDeflateLevel:     p.nsqd.getOpts().MaxDeflateLevel,
   434  		Snappy:              snappy,
   435  		SampleRate:          client.SampleRate,
   436  		AuthRequired:        p.nsqd.IsAuthEnabled(),
   437  		OutputBufferSize:    client.OutputBufferSize,
   438  		OutputBufferTimeout: int64(client.OutputBufferTimeout / time.Millisecond),
   439  	})
   440  	if err != nil {
   441  		return nil, protocol.NewFatalClientErr(err, "E_IDENTIFY_FAILED", "IDENTIFY failed "+err.Error())
   442  	}
   443  
   444  	err = p.Send(client, frameTypeResponse, resp)
   445  	if err != nil {
   446  		return nil, protocol.NewFatalClientErr(err, "E_IDENTIFY_FAILED", "IDENTIFY failed "+err.Error())
   447  	}
   448  
   449  	if tlsv1 {
   450  		p.nsqd.logf(LOG_INFO, "PROTOCOL(V2): [%s] upgrading connection to TLS", client)
   451  		err = client.UpgradeTLS()
   452  		if err != nil {
   453  			return nil, protocol.NewFatalClientErr(err, "E_IDENTIFY_FAILED", "IDENTIFY failed "+err.Error())
   454  		}
   455  
   456  		err = p.Send(client, frameTypeResponse, okBytes)
   457  		if err != nil {
   458  			return nil, protocol.NewFatalClientErr(err, "E_IDENTIFY_FAILED", "IDENTIFY failed "+err.Error())
   459  		}
   460  	}
   461  
   462  	if snappy {
   463  		p.nsqd.logf(LOG_INFO, "PROTOCOL(V2): [%s] upgrading connection to snappy", client)
   464  		err = client.UpgradeSnappy()
   465  		if err != nil {
   466  			return nil, protocol.NewFatalClientErr(err, "E_IDENTIFY_FAILED", "IDENTIFY failed "+err.Error())
   467  		}
   468  
   469  		err = p.Send(client, frameTypeResponse, okBytes)
   470  		if err != nil {
   471  			return nil, protocol.NewFatalClientErr(err, "E_IDENTIFY_FAILED", "IDENTIFY failed "+err.Error())
   472  		}
   473  	}
   474  
   475  	if deflate {
   476  		p.nsqd.logf(LOG_INFO, "PROTOCOL(V2): [%s] upgrading connection to deflate (level %d)", client, deflateLevel)
   477  		err = client.UpgradeDeflate(deflateLevel)
   478  		if err != nil {
   479  			return nil, protocol.NewFatalClientErr(err, "E_IDENTIFY_FAILED", "IDENTIFY failed "+err.Error())
   480  		}
   481  
   482  		err = p.Send(client, frameTypeResponse, okBytes)
   483  		if err != nil {
   484  			return nil, protocol.NewFatalClientErr(err, "E_IDENTIFY_FAILED", "IDENTIFY failed "+err.Error())
   485  		}
   486  	}
   487  
   488  	return nil, nil
   489  }
   490  
   491  func (p *protocolV2) AUTH(client *clientV2, params [][]byte) ([]byte, error) {
   492  	if atomic.LoadInt32(&client.State) != stateInit {
   493  		return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "cannot AUTH in current state")
   494  	}
   495  
   496  	if len(params) != 1 {
   497  		return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "AUTH invalid number of parameters")
   498  	}
   499  
   500  	bodyLen, err := readLen(client.Reader, client.lenSlice)
   501  	if err != nil {
   502  		return nil, protocol.NewFatalClientErr(err, "E_BAD_BODY", "AUTH failed to read body size")
   503  	}
   504  
   505  	if int64(bodyLen) > p.nsqd.getOpts().MaxBodySize {
   506  		return nil, protocol.NewFatalClientErr(nil, "E_BAD_BODY",
   507  			fmt.Sprintf("AUTH body too big %d > %d", bodyLen, p.nsqd.getOpts().MaxBodySize))
   508  	}
   509  
   510  	if bodyLen <= 0 {
   511  		return nil, protocol.NewFatalClientErr(nil, "E_BAD_BODY",
   512  			fmt.Sprintf("AUTH invalid body size %d", bodyLen))
   513  	}
   514  
   515  	body := make([]byte, bodyLen)
   516  	_, err = io.ReadFull(client.Reader, body)
   517  	if err != nil {
   518  		return nil, protocol.NewFatalClientErr(err, "E_BAD_BODY", "AUTH failed to read body")
   519  	}
   520  
   521  	if client.HasAuthorizations() {
   522  		return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "AUTH already set")
   523  	}
   524  
   525  	if !client.nsqd.IsAuthEnabled() {
   526  		return nil, protocol.NewFatalClientErr(err, "E_AUTH_DISABLED", "AUTH disabled")
   527  	}
   528  
   529  	if err := client.Auth(string(body)); err != nil {
   530  		// we don't want to leak errors contacting the auth server to untrusted clients
   531  		p.nsqd.logf(LOG_WARN, "PROTOCOL(V2): [%s] AUTH failed %s", client, err)
   532  		return nil, protocol.NewFatalClientErr(err, "E_AUTH_FAILED", "AUTH failed")
   533  	}
   534  
   535  	if !client.HasAuthorizations() {
   536  		return nil, protocol.NewFatalClientErr(nil, "E_UNAUTHORIZED", "AUTH no authorizations found")
   537  	}
   538  
   539  	resp, err := json.Marshal(struct {
   540  		Identity        string `json:"identity"`
   541  		IdentityURL     string `json:"identity_url"`
   542  		PermissionCount int    `json:"permission_count"`
   543  	}{
   544  		Identity:        client.AuthState.Identity,
   545  		IdentityURL:     client.AuthState.IdentityURL,
   546  		PermissionCount: len(client.AuthState.Authorizations),
   547  	})
   548  	if err != nil {
   549  		return nil, protocol.NewFatalClientErr(err, "E_AUTH_ERROR", "AUTH error "+err.Error())
   550  	}
   551  
   552  	err = p.Send(client, frameTypeResponse, resp)
   553  	if err != nil {
   554  		return nil, protocol.NewFatalClientErr(err, "E_AUTH_ERROR", "AUTH error "+err.Error())
   555  	}
   556  
   557  	return nil, nil
   558  
   559  }
   560  
   561  func (p *protocolV2) CheckAuth(client *clientV2, cmd, topicName, channelName string) error {
   562  	// if auth is enabled, the client must have authorized already
   563  	// compare topic/channel against cached authorization data (refetching if expired)
   564  	if client.nsqd.IsAuthEnabled() {
   565  		if !client.HasAuthorizations() {
   566  			return protocol.NewFatalClientErr(nil, "E_AUTH_FIRST",
   567  				fmt.Sprintf("AUTH required before %s", cmd))
   568  		}
   569  		ok, err := client.IsAuthorized(topicName, channelName)
   570  		if err != nil {
   571  			// we don't want to leak errors contacting the auth server to untrusted clients
   572  			p.nsqd.logf(LOG_WARN, "PROTOCOL(V2): [%s] AUTH failed %s", client, err)
   573  			return protocol.NewFatalClientErr(nil, "E_AUTH_FAILED", "AUTH failed")
   574  		}
   575  		if !ok {
   576  			return protocol.NewFatalClientErr(nil, "E_UNAUTHORIZED",
   577  				fmt.Sprintf("AUTH failed for %s on %q %q", cmd, topicName, channelName))
   578  		}
   579  	}
   580  	return nil
   581  }
   582  
   583  func (p *protocolV2) SUB(client *clientV2, params [][]byte) ([]byte, error) {
   584  	if atomic.LoadInt32(&client.State) != stateInit {
   585  		return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "cannot SUB in current state")
   586  	}
   587  
   588  	if client.HeartbeatInterval <= 0 {
   589  		return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "cannot SUB with heartbeats disabled")
   590  	}
   591  
   592  	if len(params) < 3 {
   593  		return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "SUB insufficient number of parameters")
   594  	}
   595  
   596  	topicName := string(params[1])
   597  	if !protocol.IsValidTopicName(topicName) {
   598  		return nil, protocol.NewFatalClientErr(nil, "E_BAD_TOPIC",
   599  			fmt.Sprintf("SUB topic name %q is not valid", topicName))
   600  	}
   601  
   602  	channelName := string(params[2])
   603  	if !protocol.IsValidChannelName(channelName) {
   604  		return nil, protocol.NewFatalClientErr(nil, "E_BAD_CHANNEL",
   605  			fmt.Sprintf("SUB channel name %q is not valid", channelName))
   606  	}
   607  
   608  	if err := p.CheckAuth(client, "SUB", topicName, channelName); err != nil {
   609  		return nil, err
   610  	}
   611  
   612  	// This retry-loop is a work-around for a race condition, where the
   613  	// last client can leave the channel between GetChannel() and AddClient().
   614  	// Avoid adding a client to an ephemeral channel / topic which has started exiting.
   615  	var channel *Channel
   616  	for i := 1; ; i++ {
   617  		topic := p.nsqd.GetTopic(topicName)
   618  		channel = topic.GetChannel(channelName)
   619  		if err := channel.AddClient(client.ID, client); err != nil {
   620  			return nil, protocol.NewFatalClientErr(err, "E_SUB_FAILED", "SUB failed "+err.Error())
   621  		}
   622  
   623  		if (channel.ephemeral && channel.Exiting()) || (topic.ephemeral && topic.Exiting()) {
   624  			channel.RemoveClient(client.ID)
   625  			if i < 2 {
   626  				time.Sleep(100 * time.Millisecond)
   627  				continue
   628  			}
   629  			return nil, protocol.NewFatalClientErr(nil, "E_SUB_FAILED", "SUB failed to deleted topic/channel")
   630  		}
   631  		break
   632  	}
   633  	atomic.StoreInt32(&client.State, stateSubscribed)
   634  	client.Channel = channel
   635  	// update message pump
   636  	client.SubEventChan <- channel
   637  
   638  	return okBytes, nil
   639  }
   640  
   641  func (p *protocolV2) RDY(client *clientV2, params [][]byte) ([]byte, error) {
   642  	state := atomic.LoadInt32(&client.State)
   643  
   644  	if state == stateClosing {
   645  		// just ignore ready changes on a closing channel
   646  		p.nsqd.logf(LOG_INFO,
   647  			"PROTOCOL(V2): [%s] ignoring RDY after CLS in state ClientStateV2Closing",
   648  			client)
   649  		return nil, nil
   650  	}
   651  
   652  	if state != stateSubscribed {
   653  		return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "cannot RDY in current state")
   654  	}
   655  
   656  	count := int64(1)
   657  	if len(params) > 1 {
   658  		b10, err := protocol.ByteToBase10(params[1])
   659  		if err != nil {
   660  			return nil, protocol.NewFatalClientErr(err, "E_INVALID",
   661  				fmt.Sprintf("RDY could not parse count %s", params[1]))
   662  		}
   663  		count = int64(b10)
   664  	}
   665  
   666  	if count < 0 || count > p.nsqd.getOpts().MaxRdyCount {
   667  		// this needs to be a fatal error otherwise clients would have
   668  		// inconsistent state
   669  		return nil, protocol.NewFatalClientErr(nil, "E_INVALID",
   670  			fmt.Sprintf("RDY count %d out of range 0-%d", count, p.nsqd.getOpts().MaxRdyCount))
   671  	}
   672  
   673  	client.SetReadyCount(count)
   674  
   675  	return nil, nil
   676  }
   677  
   678  func (p *protocolV2) FIN(client *clientV2, params [][]byte) ([]byte, error) {
   679  	state := atomic.LoadInt32(&client.State)
   680  	if state != stateSubscribed && state != stateClosing {
   681  		return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "cannot FIN in current state")
   682  	}
   683  
   684  	if len(params) < 2 {
   685  		return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "FIN insufficient number of params")
   686  	}
   687  
   688  	id, err := getMessageID(params[1])
   689  	if err != nil {
   690  		return nil, protocol.NewFatalClientErr(nil, "E_INVALID", err.Error())
   691  	}
   692  
   693  	err = client.Channel.FinishMessage(client.ID, *id)
   694  	if err != nil {
   695  		return nil, protocol.NewClientErr(err, "E_FIN_FAILED",
   696  			fmt.Sprintf("FIN %s failed %s", *id, err.Error()))
   697  	}
   698  
   699  	client.FinishedMessage()
   700  
   701  	return nil, nil
   702  }
   703  
   704  func (p *protocolV2) REQ(client *clientV2, params [][]byte) ([]byte, error) {
   705  	state := atomic.LoadInt32(&client.State)
   706  	if state != stateSubscribed && state != stateClosing {
   707  		return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "cannot REQ in current state")
   708  	}
   709  
   710  	if len(params) < 3 {
   711  		return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "REQ insufficient number of params")
   712  	}
   713  
   714  	id, err := getMessageID(params[1])
   715  	if err != nil {
   716  		return nil, protocol.NewFatalClientErr(nil, "E_INVALID", err.Error())
   717  	}
   718  
   719  	timeoutMs, err := protocol.ByteToBase10(params[2])
   720  	if err != nil {
   721  		return nil, protocol.NewFatalClientErr(err, "E_INVALID",
   722  			fmt.Sprintf("REQ could not parse timeout %s", params[2]))
   723  	}
   724  	timeoutDuration := time.Duration(timeoutMs) * time.Millisecond
   725  
   726  	maxReqTimeout := p.nsqd.getOpts().MaxReqTimeout
   727  	clampedTimeout := timeoutDuration
   728  
   729  	if timeoutDuration < 0 {
   730  		clampedTimeout = 0
   731  	} else if timeoutDuration > maxReqTimeout {
   732  		clampedTimeout = maxReqTimeout
   733  	}
   734  	if clampedTimeout != timeoutDuration {
   735  		p.nsqd.logf(LOG_INFO, "PROTOCOL(V2): [%s] REQ timeout %d out of range 0-%d. Setting to %d",
   736  			client, timeoutDuration, maxReqTimeout, clampedTimeout)
   737  		timeoutDuration = clampedTimeout
   738  	}
   739  
   740  	err = client.Channel.RequeueMessage(client.ID, *id, timeoutDuration)
   741  	if err != nil {
   742  		return nil, protocol.NewClientErr(err, "E_REQ_FAILED",
   743  			fmt.Sprintf("REQ %s failed %s", *id, err.Error()))
   744  	}
   745  
   746  	client.RequeuedMessage()
   747  
   748  	return nil, nil
   749  }
   750  
   751  func (p *protocolV2) CLS(client *clientV2, params [][]byte) ([]byte, error) {
   752  	if atomic.LoadInt32(&client.State) != stateSubscribed {
   753  		return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "cannot CLS in current state")
   754  	}
   755  
   756  	client.StartClose()
   757  
   758  	return []byte("CLOSE_WAIT"), nil
   759  }
   760  
   761  func (p *protocolV2) NOP(client *clientV2, params [][]byte) ([]byte, error) {
   762  	return nil, nil
   763  }
   764  
   765  func (p *protocolV2) PUB(client *clientV2, params [][]byte) ([]byte, error) {
   766  	var err error
   767  
   768  	if len(params) < 2 {
   769  		return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "PUB insufficient number of parameters")
   770  	}
   771  
   772  	topicName := string(params[1])
   773  	if !protocol.IsValidTopicName(topicName) {
   774  		return nil, protocol.NewFatalClientErr(nil, "E_BAD_TOPIC",
   775  			fmt.Sprintf("PUB topic name %q is not valid", topicName))
   776  	}
   777  
   778  	bodyLen, err := readLen(client.Reader, client.lenSlice)
   779  	if err != nil {
   780  		return nil, protocol.NewFatalClientErr(err, "E_BAD_MESSAGE", "PUB failed to read message body size")
   781  	}
   782  
   783  	if bodyLen <= 0 {
   784  		return nil, protocol.NewFatalClientErr(nil, "E_BAD_MESSAGE",
   785  			fmt.Sprintf("PUB invalid message body size %d", bodyLen))
   786  	}
   787  
   788  	if int64(bodyLen) > p.nsqd.getOpts().MaxMsgSize {
   789  		return nil, protocol.NewFatalClientErr(nil, "E_BAD_MESSAGE",
   790  			fmt.Sprintf("PUB message too big %d > %d", bodyLen, p.nsqd.getOpts().MaxMsgSize))
   791  	}
   792  
   793  	messageBody := make([]byte, bodyLen)
   794  	_, err = io.ReadFull(client.Reader, messageBody)
   795  	if err != nil {
   796  		return nil, protocol.NewFatalClientErr(err, "E_BAD_MESSAGE", "PUB failed to read message body")
   797  	}
   798  
   799  	if err := p.CheckAuth(client, "PUB", topicName, ""); err != nil {
   800  		return nil, err
   801  	}
   802  
   803  	topic := p.nsqd.GetTopic(topicName)
   804  	msg := NewMessage(topic.GenerateID(), messageBody)
   805  	err = topic.PutMessage(msg)
   806  	if err != nil {
   807  		return nil, protocol.NewFatalClientErr(err, "E_PUB_FAILED", "PUB failed "+err.Error())
   808  	}
   809  
   810  	client.PublishedMessage(topicName, 1)
   811  
   812  	return okBytes, nil
   813  }
   814  
   815  func (p *protocolV2) MPUB(client *clientV2, params [][]byte) ([]byte, error) {
   816  	var err error
   817  
   818  	if len(params) < 2 {
   819  		return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "MPUB insufficient number of parameters")
   820  	}
   821  
   822  	topicName := string(params[1])
   823  	if !protocol.IsValidTopicName(topicName) {
   824  		return nil, protocol.NewFatalClientErr(nil, "E_BAD_TOPIC",
   825  			fmt.Sprintf("E_BAD_TOPIC MPUB topic name %q is not valid", topicName))
   826  	}
   827  
   828  	if err := p.CheckAuth(client, "MPUB", topicName, ""); err != nil {
   829  		return nil, err
   830  	}
   831  
   832  	topic := p.nsqd.GetTopic(topicName)
   833  
   834  	bodyLen, err := readLen(client.Reader, client.lenSlice)
   835  	if err != nil {
   836  		return nil, protocol.NewFatalClientErr(err, "E_BAD_BODY", "MPUB failed to read body size")
   837  	}
   838  
   839  	if bodyLen <= 0 {
   840  		return nil, protocol.NewFatalClientErr(nil, "E_BAD_BODY",
   841  			fmt.Sprintf("MPUB invalid body size %d", bodyLen))
   842  	}
   843  
   844  	if int64(bodyLen) > p.nsqd.getOpts().MaxBodySize {
   845  		return nil, protocol.NewFatalClientErr(nil, "E_BAD_BODY",
   846  			fmt.Sprintf("MPUB body too big %d > %d", bodyLen, p.nsqd.getOpts().MaxBodySize))
   847  	}
   848  
   849  	messages, err := readMPUB(client.Reader, client.lenSlice, topic,
   850  		p.nsqd.getOpts().MaxMsgSize, p.nsqd.getOpts().MaxBodySize)
   851  	if err != nil {
   852  		return nil, err
   853  	}
   854  
   855  	// if we've made it this far we've validated all the input,
   856  	// the only possible error is that the topic is exiting during
   857  	// this next call (and no messages will be queued in that case)
   858  	err = topic.PutMessages(messages)
   859  	if err != nil {
   860  		return nil, protocol.NewFatalClientErr(err, "E_MPUB_FAILED", "MPUB failed "+err.Error())
   861  	}
   862  
   863  	client.PublishedMessage(topicName, uint64(len(messages)))
   864  
   865  	return okBytes, nil
   866  }
   867  
   868  func (p *protocolV2) DPUB(client *clientV2, params [][]byte) ([]byte, error) {
   869  	var err error
   870  
   871  	if len(params) < 3 {
   872  		return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "DPUB insufficient number of parameters")
   873  	}
   874  
   875  	topicName := string(params[1])
   876  	if !protocol.IsValidTopicName(topicName) {
   877  		return nil, protocol.NewFatalClientErr(nil, "E_BAD_TOPIC",
   878  			fmt.Sprintf("DPUB topic name %q is not valid", topicName))
   879  	}
   880  
   881  	timeoutMs, err := protocol.ByteToBase10(params[2])
   882  	if err != nil {
   883  		return nil, protocol.NewFatalClientErr(err, "E_INVALID",
   884  			fmt.Sprintf("DPUB could not parse timeout %s", params[2]))
   885  	}
   886  	timeoutDuration := time.Duration(timeoutMs) * time.Millisecond
   887  
   888  	if timeoutDuration < 0 || timeoutDuration > p.nsqd.getOpts().MaxReqTimeout {
   889  		return nil, protocol.NewFatalClientErr(nil, "E_INVALID",
   890  			fmt.Sprintf("DPUB timeout %d out of range 0-%d",
   891  				timeoutMs, p.nsqd.getOpts().MaxReqTimeout/time.Millisecond))
   892  	}
   893  
   894  	bodyLen, err := readLen(client.Reader, client.lenSlice)
   895  	if err != nil {
   896  		return nil, protocol.NewFatalClientErr(err, "E_BAD_MESSAGE", "DPUB failed to read message body size")
   897  	}
   898  
   899  	if bodyLen <= 0 {
   900  		return nil, protocol.NewFatalClientErr(nil, "E_BAD_MESSAGE",
   901  			fmt.Sprintf("DPUB invalid message body size %d", bodyLen))
   902  	}
   903  
   904  	if int64(bodyLen) > p.nsqd.getOpts().MaxMsgSize {
   905  		return nil, protocol.NewFatalClientErr(nil, "E_BAD_MESSAGE",
   906  			fmt.Sprintf("DPUB message too big %d > %d", bodyLen, p.nsqd.getOpts().MaxMsgSize))
   907  	}
   908  
   909  	messageBody := make([]byte, bodyLen)
   910  	_, err = io.ReadFull(client.Reader, messageBody)
   911  	if err != nil {
   912  		return nil, protocol.NewFatalClientErr(err, "E_BAD_MESSAGE", "DPUB failed to read message body")
   913  	}
   914  
   915  	if err := p.CheckAuth(client, "DPUB", topicName, ""); err != nil {
   916  		return nil, err
   917  	}
   918  
   919  	topic := p.nsqd.GetTopic(topicName)
   920  	msg := NewMessage(topic.GenerateID(), messageBody)
   921  	msg.deferred = timeoutDuration
   922  	err = topic.PutMessage(msg)
   923  	if err != nil {
   924  		return nil, protocol.NewFatalClientErr(err, "E_DPUB_FAILED", "DPUB failed "+err.Error())
   925  	}
   926  
   927  	client.PublishedMessage(topicName, 1)
   928  
   929  	return okBytes, nil
   930  }
   931  
   932  func (p *protocolV2) TOUCH(client *clientV2, params [][]byte) ([]byte, error) {
   933  	state := atomic.LoadInt32(&client.State)
   934  	if state != stateSubscribed && state != stateClosing {
   935  		return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "cannot TOUCH in current state")
   936  	}
   937  
   938  	if len(params) < 2 {
   939  		return nil, protocol.NewFatalClientErr(nil, "E_INVALID", "TOUCH insufficient number of params")
   940  	}
   941  
   942  	id, err := getMessageID(params[1])
   943  	if err != nil {
   944  		return nil, protocol.NewFatalClientErr(nil, "E_INVALID", err.Error())
   945  	}
   946  
   947  	client.writeLock.RLock()
   948  	msgTimeout := client.MsgTimeout
   949  	client.writeLock.RUnlock()
   950  	err = client.Channel.TouchMessage(client.ID, *id, msgTimeout)
   951  	if err != nil {
   952  		return nil, protocol.NewClientErr(err, "E_TOUCH_FAILED",
   953  			fmt.Sprintf("TOUCH %s failed %s", *id, err.Error()))
   954  	}
   955  
   956  	return nil, nil
   957  }
   958  
   959  func readMPUB(r io.Reader, tmp []byte, topic *Topic, maxMessageSize int64, maxBodySize int64) ([]*Message, error) {
   960  	numMessages, err := readLen(r, tmp)
   961  	if err != nil {
   962  		return nil, protocol.NewFatalClientErr(err, "E_BAD_BODY", "MPUB failed to read message count")
   963  	}
   964  
   965  	// 4 == total num, 5 == length + min 1
   966  	maxMessages := (maxBodySize - 4) / 5
   967  	if numMessages <= 0 || int64(numMessages) > maxMessages {
   968  		return nil, protocol.NewFatalClientErr(err, "E_BAD_BODY",
   969  			fmt.Sprintf("MPUB invalid message count %d", numMessages))
   970  	}
   971  
   972  	messages := make([]*Message, 0, numMessages)
   973  	for i := int32(0); i < numMessages; i++ {
   974  		messageSize, err := readLen(r, tmp)
   975  		if err != nil {
   976  			return nil, protocol.NewFatalClientErr(err, "E_BAD_MESSAGE",
   977  				fmt.Sprintf("MPUB failed to read message(%d) body size", i))
   978  		}
   979  
   980  		if messageSize <= 0 {
   981  			return nil, protocol.NewFatalClientErr(nil, "E_BAD_MESSAGE",
   982  				fmt.Sprintf("MPUB invalid message(%d) body size %d", i, messageSize))
   983  		}
   984  
   985  		if int64(messageSize) > maxMessageSize {
   986  			return nil, protocol.NewFatalClientErr(nil, "E_BAD_MESSAGE",
   987  				fmt.Sprintf("MPUB message too big %d > %d", messageSize, maxMessageSize))
   988  		}
   989  
   990  		msgBody := make([]byte, messageSize)
   991  		_, err = io.ReadFull(r, msgBody)
   992  		if err != nil {
   993  			return nil, protocol.NewFatalClientErr(err, "E_BAD_MESSAGE", "MPUB failed to read message body")
   994  		}
   995  
   996  		messages = append(messages, NewMessage(topic.GenerateID(), msgBody))
   997  	}
   998  
   999  	return messages, nil
  1000  }
  1001  
  1002  // validate and cast the bytes on the wire to a message ID
  1003  func getMessageID(p []byte) (*MessageID, error) {
  1004  	if len(p) != MsgIDLength {
  1005  		return nil, errors.New("invalid message ID")
  1006  	}
  1007  	return (*MessageID)(unsafe.Pointer(&p[0])), nil
  1008  }
  1009  
  1010  func readLen(r io.Reader, tmp []byte) (int32, error) {
  1011  	_, err := io.ReadFull(r, tmp)
  1012  	if err != nil {
  1013  		return 0, err
  1014  	}
  1015  	return int32(binary.BigEndian.Uint32(tmp)), nil
  1016  }
  1017  
  1018  func enforceTLSPolicy(client *clientV2, p *protocolV2, command []byte) error {
  1019  	if p.nsqd.getOpts().TLSRequired != TLSNotRequired && atomic.LoadInt32(&client.TLS) != 1 {
  1020  		return protocol.NewFatalClientErr(nil, "E_INVALID",
  1021  			fmt.Sprintf("cannot %s in current state (TLS required)", command))
  1022  	}
  1023  	return nil
  1024  }