github.com/sagernet/netlink@v0.0.0-20240612041022-b9a21c07ac6a/nl/nl_linux_test.go (about)

     1  package nl
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/rand"
     6  	"encoding/binary"
     7  	"reflect"
     8  	"testing"
     9  	"time"
    10  
    11  	"golang.org/x/sys/unix"
    12  )
    13  
    14  type testSerializer interface {
    15  	serializeSafe() []byte
    16  	Serialize() []byte
    17  }
    18  
    19  func testDeserializeSerialize(t *testing.T, orig []byte, safemsg testSerializer, msg testSerializer) {
    20  	if !reflect.DeepEqual(safemsg, msg) {
    21  		t.Fatal("Deserialization failed.\n", safemsg, "\n", msg)
    22  	}
    23  	safe := msg.serializeSafe()
    24  	if !bytes.Equal(safe, orig) {
    25  		t.Fatal("Safe serialization failed.\n", safe, "\n", orig)
    26  	}
    27  	b := msg.Serialize()
    28  	if !bytes.Equal(b, safe) {
    29  		t.Fatal("Serialization failed.\n", b, "\n", safe)
    30  	}
    31  }
    32  
    33  func (msg *IfInfomsg) write(b []byte) {
    34  	native := NativeEndian()
    35  	b[0] = msg.Family
    36  	// pad byte is skipped because it is not exported on linux/s390x
    37  	native.PutUint16(b[2:4], msg.Type)
    38  	native.PutUint32(b[4:8], uint32(msg.Index))
    39  	native.PutUint32(b[8:12], msg.Flags)
    40  	native.PutUint32(b[12:16], msg.Change)
    41  }
    42  
    43  func (msg *IfInfomsg) serializeSafe() []byte {
    44  	length := unix.SizeofIfInfomsg
    45  	b := make([]byte, length)
    46  	msg.write(b)
    47  	return b
    48  }
    49  
    50  func deserializeIfInfomsgSafe(b []byte) *IfInfomsg {
    51  	var msg = IfInfomsg{}
    52  	binary.Read(bytes.NewReader(b[0:unix.SizeofIfInfomsg]), NativeEndian(), &msg)
    53  	return &msg
    54  }
    55  
    56  func TestIfInfomsgDeserializeSerialize(t *testing.T) {
    57  	var orig = make([]byte, unix.SizeofIfInfomsg)
    58  	rand.Read(orig)
    59  	// zero out the pad byte
    60  	orig[1] = 0
    61  	safemsg := deserializeIfInfomsgSafe(orig)
    62  	msg := DeserializeIfInfomsg(orig)
    63  	testDeserializeSerialize(t, orig, safemsg, msg)
    64  }
    65  
    66  func TestIfSocketCloses(t *testing.T) {
    67  	nlSock, err := Subscribe(unix.NETLINK_ROUTE, unix.RTNLGRP_NEIGH)
    68  	if err != nil {
    69  		t.Fatalf("Error on creating the socket: %v", err)
    70  	}
    71  	endCh := make(chan error)
    72  	go func(sk *NetlinkSocket, endCh chan error) {
    73  		endCh <- nil
    74  		for {
    75  			_, _, err := sk.Receive()
    76  			if err == unix.EAGAIN {
    77  				endCh <- err
    78  				return
    79  			}
    80  		}
    81  	}(nlSock, endCh)
    82  
    83  	// first receive nil
    84  	if msg := <-endCh; msg != nil {
    85  		t.Fatalf("Expected nil instead got: %v", msg)
    86  	}
    87  	// this to guarantee that the receive is invoked before the close
    88  	time.Sleep(4 * time.Second)
    89  
    90  	// Close the socket
    91  	nlSock.Close()
    92  
    93  	// Expect to have an error
    94  	msg := <-endCh
    95  	if msg == nil {
    96  		t.Fatalf("Expected error instead received nil")
    97  	}
    98  }
    99  
   100  func (msg *CnMsgOp) write(b []byte) {
   101  	native := NativeEndian()
   102  	native.PutUint32(b[0:4], msg.ID.Idx)
   103  	native.PutUint32(b[4:8], msg.ID.Val)
   104  	native.PutUint32(b[8:12], msg.Seq)
   105  	native.PutUint32(b[12:16], msg.Ack)
   106  	native.PutUint16(b[16:18], msg.Length)
   107  	native.PutUint16(b[18:20], msg.Flags)
   108  	native.PutUint32(b[20:24], msg.Op)
   109  }
   110  
   111  func (msg *CnMsgOp) serializeSafe() []byte {
   112  	length := msg.Len()
   113  	b := make([]byte, length)
   114  	msg.write(b)
   115  	return b
   116  }
   117  
   118  func deserializeCnMsgOpSafe(b []byte) *CnMsgOp {
   119  	var msg = CnMsgOp{}
   120  	binary.Read(bytes.NewReader(b[0:SizeofCnMsgOp]), NativeEndian(), &msg)
   121  	return &msg
   122  }
   123  
   124  func TestCnMsgOpDeserializeSerialize(t *testing.T) {
   125  	var orig = make([]byte, SizeofCnMsgOp)
   126  	rand.Read(orig)
   127  	safemsg := deserializeCnMsgOpSafe(orig)
   128  	msg := DeserializeCnMsgOp(orig)
   129  	testDeserializeSerialize(t, orig, safemsg, msg)
   130  }