github.com/simpleiot/simpleiot@v0.18.3/modbus/tcp.go (about)

     1  package modbus
     2  
     3  import (
     4  	"encoding/binary"
     5  	"fmt"
     6  	"log"
     7  	"net"
     8  	"sync"
     9  	"time"
    10  )
    11  
    12  // TCPADU defines an ADU for TCP packets
    13  type TCPADU struct {
    14  	PDU
    15  	Address byte
    16  }
    17  
    18  // TCP defines an TCP connection
    19  type TCP struct {
    20  	sock         net.Conn
    21  	txID         uint16
    22  	timeout      time.Duration
    23  	clientServer TransportClientServer
    24  }
    25  
    26  // NewTCP creates a new TCP transport
    27  func NewTCP(sock net.Conn, timeout time.Duration, clientServer TransportClientServer) *TCP {
    28  	return &TCP{
    29  		sock:         sock,
    30  		timeout:      timeout,
    31  		clientServer: clientServer,
    32  	}
    33  }
    34  
    35  func (t *TCP) Read(p []byte) (int, error) {
    36  	err := t.sock.SetDeadline(time.Now().Add(t.timeout))
    37  	if err != nil {
    38  		return 0, err
    39  	}
    40  	return t.sock.Read(p)
    41  }
    42  
    43  func (t *TCP) Write(p []byte) (int, error) {
    44  	err := t.sock.SetDeadline(time.Now().Add(t.timeout))
    45  	if err != nil {
    46  		return 0, err
    47  	}
    48  	return t.sock.Write(p)
    49  }
    50  
    51  // Close connection
    52  func (t *TCP) Close() error {
    53  	return t.sock.Close()
    54  }
    55  
    56  // Encode encodes a TCP packet
    57  func (t *TCP) Encode(id byte, pdu PDU) ([]byte, error) {
    58  	// increment transaction ID
    59  	if t.clientServer == TransportClient {
    60  		t.txID++
    61  	}
    62  
    63  	// bytes 0,1 transaction ID
    64  	ret := make([]byte, len(pdu.Data)+8)
    65  	binary.BigEndian.PutUint16(ret[0:], t.txID)
    66  
    67  	// bytes 2,3 protocol identifier
    68  
    69  	// bytes 4,5 length
    70  	binary.BigEndian.PutUint16(ret[4:], uint16(len(pdu.Data)+2))
    71  
    72  	// byte 6 unit identifier
    73  	ret[6] = id
    74  
    75  	// byte 7 function code
    76  	ret[7] = byte(pdu.FunctionCode)
    77  
    78  	// byte 8: data
    79  	copy(ret[8:], pdu.Data)
    80  	return ret, nil
    81  }
    82  
    83  // Decode decodes a TCP packet
    84  func (t *TCP) Decode(packet []byte) (byte, PDU, error) {
    85  	if len(packet) < 9 {
    86  		return 0, PDU{}, fmt.Errorf("Not enough data for TCP packet: %v", len(packet))
    87  	}
    88  
    89  	txID := binary.BigEndian.Uint16(packet[:2])
    90  
    91  	switch t.clientServer {
    92  	case TransportClient:
    93  		// need to check that echo'd tx is correct
    94  		if txID != t.txID {
    95  			return 0, PDU{}, fmt.Errorf("Transaction id not correct, expected: 0x%x, got 0x%x", t.txID, txID)
    96  		}
    97  	case TransportServer:
    98  		// need to store tx to echo back to client on Encode
    99  		t.txID = txID
   100  	}
   101  
   102  	id := packet[6]
   103  
   104  	pdu := PDU{}
   105  	pdu.FunctionCode = FunctionCode(packet[7])
   106  	pdu.Data = packet[8:]
   107  
   108  	return id, pdu, nil
   109  }
   110  
   111  // Type returns TransportType
   112  func (t *TCP) Type() TransportType {
   113  	return TransportTypeTCP
   114  }
   115  
   116  // TCPServer listens for new connections and then starts a modbus listener
   117  // on the port.
   118  type TCPServer struct {
   119  	// config
   120  	id         int
   121  	maxClients int
   122  	port       string
   123  	regs       *Regs
   124  	debug      int
   125  
   126  	// state
   127  	listener net.Listener
   128  	servers  []*Server
   129  	lock     sync.Mutex
   130  	stopped  bool
   131  }
   132  
   133  // NewTCPServer starts a new TCP modbus server
   134  func NewTCPServer(id, maxClients int, port string, regs *Regs, debug int) (*TCPServer, error) {
   135  	listener, err := net.Listen("tcp", ":"+port)
   136  	if err != nil {
   137  		return nil, err
   138  	}
   139  
   140  	return &TCPServer{
   141  		id:         id,
   142  		maxClients: maxClients,
   143  		port:       port,
   144  		regs:       regs,
   145  		listener:   listener,
   146  		debug:      debug,
   147  	}, nil
   148  }
   149  
   150  // Listen starts the server and listens for modbus requests
   151  // this function does not return unless an error occurs
   152  // The listen function supports various debug levels:
   153  // 1 - dump packets
   154  // 9 - dump raw data
   155  func (ts *TCPServer) Listen(errorCallback func(error),
   156  	changesCallback func(), done func()) {
   157  	for {
   158  		sock, err := ts.listener.Accept()
   159  		if err != nil {
   160  			if ts.stopped {
   161  				if ts.debug > 0 {
   162  					log.Println("Modbus TCPServer, stopping listen")
   163  				}
   164  				done()
   165  				return
   166  			}
   167  			log.Println("Modbus TCP server: failed to accept connection:", err)
   168  		}
   169  
   170  		if ts.debug > 0 {
   171  			log.Println("New Modbus TCP connection")
   172  		}
   173  
   174  		ts.lock.Lock()
   175  		if len(ts.servers) < ts.maxClients {
   176  			transport := NewTCP(sock, 500*time.Millisecond, TransportServer)
   177  			server := NewServer(byte(ts.id), transport, ts.regs, ts.debug)
   178  			ts.servers = append(ts.servers, server)
   179  			go server.Listen(errorCallback,
   180  				changesCallback, func() {
   181  					// TCP server client has disconnected, remove from list
   182  					ts.lock.Lock()
   183  					for i := range ts.servers {
   184  						if ts.servers[i] == server {
   185  							ts.servers[i] = ts.servers[len(ts.servers)-1]
   186  							ts.servers = ts.servers[:len(ts.servers)-1]
   187  							break
   188  						}
   189  					}
   190  					ts.lock.Unlock()
   191  				})
   192  		} else {
   193  			log.Println("Modbus TCP server: warning reached max conn")
   194  		}
   195  		ts.lock.Unlock()
   196  	}
   197  }
   198  
   199  // Close stops the server and closes all connections
   200  func (ts *TCPServer) Close() error {
   201  	if ts.debug > 0 {
   202  		log.Println("Modbus TCPServer closing ...")
   203  	}
   204  
   205  	ts.lock.Lock()
   206  	defer ts.lock.Unlock()
   207  	ts.stopped = true
   208  
   209  	var retErr error
   210  
   211  	for _, server := range ts.servers {
   212  		err := server.Close()
   213  		if err != nil {
   214  			retErr = err
   215  		}
   216  	}
   217  
   218  	err := ts.listener.Close()
   219  
   220  	if err != nil {
   221  		retErr = err
   222  	}
   223  
   224  	return retErr
   225  }