git.sr.ht/~pingoo/stdx@v0.0.0-20240218134121-094174641f6e/mmdb/reader.go (about)

     1  // package mmdb provides a reader for the MaxMind DB file format.
     2  package mmdb
     3  
     4  import (
     5  	"bytes"
     6  	"errors"
     7  	"fmt"
     8  	"net"
     9  	"reflect"
    10  )
    11  
    12  const (
    13  	// NotFound is returned by LookupOffset when a matched root record offset
    14  	// cannot be found.
    15  	NotFound = ^uintptr(0)
    16  
    17  	dataSectionSeparatorSize = 16
    18  )
    19  
    20  var metadataStartMarker = []byte("\xAB\xCD\xEFMaxMind.com")
    21  
    22  // Reader holds the data corresponding to the MaxMind DB file. Its only public
    23  // field is Metadata, which contains the metadata from the MaxMind DB file.
    24  //
    25  // All of the methods on Reader are thread-safe. The struct may be safely
    26  // shared across goroutines.
    27  type Reader struct {
    28  	nodeReader        nodeReader
    29  	buffer            []byte
    30  	decoder           decoder
    31  	Metadata          Metadata
    32  	ipv4Start         uint
    33  	ipv4StartBitDepth int
    34  	nodeOffsetMult    uint
    35  	hasMappedFile     bool
    36  }
    37  
    38  // Metadata holds the metadata decoded from the MaxMind DB file. In particular
    39  // it has the format version, the build time as Unix epoch time, the database
    40  // type and description, the IP version supported, and a slice of the natural
    41  // languages included.
    42  type Metadata struct {
    43  	Description              map[string]string `maxminddb:"description"`
    44  	DatabaseType             string            `maxminddb:"database_type"`
    45  	Languages                []string          `maxminddb:"languages"`
    46  	BinaryFormatMajorVersion uint              `maxminddb:"binary_format_major_version"`
    47  	BinaryFormatMinorVersion uint              `maxminddb:"binary_format_minor_version"`
    48  	BuildEpoch               uint              `maxminddb:"build_epoch"`
    49  	IPVersion                uint              `maxminddb:"ip_version"`
    50  	NodeCount                uint              `maxminddb:"node_count"`
    51  	RecordSize               uint              `maxminddb:"record_size"`
    52  }
    53  
    54  // FromBytes takes a byte slice corresponding to a MaxMind DB file and returns
    55  // a Reader structure or an error.
    56  func FromBytes(buffer []byte) (*Reader, error) {
    57  	metadataStart := bytes.LastIndex(buffer, metadataStartMarker)
    58  
    59  	if metadataStart == -1 {
    60  		return nil, newInvalidDatabaseError("error opening database: invalid MaxMind DB file")
    61  	}
    62  
    63  	metadataStart += len(metadataStartMarker)
    64  	metadataDecoder := decoder{buffer[metadataStart:]}
    65  
    66  	var metadata Metadata
    67  
    68  	rvMetdata := reflect.ValueOf(&metadata)
    69  	_, err := metadataDecoder.decode(0, rvMetdata, 0)
    70  	if err != nil {
    71  		return nil, err
    72  	}
    73  
    74  	searchTreeSize := metadata.NodeCount * metadata.RecordSize / 4
    75  	dataSectionStart := searchTreeSize + dataSectionSeparatorSize
    76  	dataSectionEnd := uint(metadataStart - len(metadataStartMarker))
    77  	if dataSectionStart > dataSectionEnd {
    78  		return nil, newInvalidDatabaseError("the MaxMind DB contains invalid metadata")
    79  	}
    80  	d := decoder{
    81  		buffer[searchTreeSize+dataSectionSeparatorSize : metadataStart-len(metadataStartMarker)],
    82  	}
    83  
    84  	nodeBuffer := buffer[:searchTreeSize]
    85  	var nodeReader nodeReader
    86  	switch metadata.RecordSize {
    87  	case 24:
    88  		nodeReader = nodeReader24{buffer: nodeBuffer}
    89  	case 28:
    90  		nodeReader = nodeReader28{buffer: nodeBuffer}
    91  	case 32:
    92  		nodeReader = nodeReader32{buffer: nodeBuffer}
    93  	default:
    94  		return nil, newInvalidDatabaseError("unknown record size: %d", metadata.RecordSize)
    95  	}
    96  
    97  	reader := &Reader{
    98  		buffer:         buffer,
    99  		nodeReader:     nodeReader,
   100  		decoder:        d,
   101  		Metadata:       metadata,
   102  		ipv4Start:      0,
   103  		nodeOffsetMult: metadata.RecordSize / 4,
   104  	}
   105  
   106  	reader.setIPv4Start()
   107  
   108  	return reader, err
   109  }
   110  
   111  func (r *Reader) setIPv4Start() {
   112  	if r.Metadata.IPVersion != 6 {
   113  		return
   114  	}
   115  
   116  	nodeCount := r.Metadata.NodeCount
   117  
   118  	node := uint(0)
   119  	i := 0
   120  	for ; i < 96 && node < nodeCount; i++ {
   121  		node = r.nodeReader.readLeft(node * r.nodeOffsetMult)
   122  	}
   123  	r.ipv4Start = node
   124  	r.ipv4StartBitDepth = i
   125  }
   126  
   127  // Lookup retrieves the database record for ip and stores it in the value
   128  // pointed to by result. If result is nil or not a pointer, an error is
   129  // returned. If the data in the database record cannot be stored in result
   130  // because of type differences, an UnmarshalTypeError is returned. If the
   131  // database is invalid or otherwise cannot be read, an InvalidDatabaseError
   132  // is returned.
   133  func (r *Reader) Lookup(ip net.IP, result any) error {
   134  	if r.buffer == nil {
   135  		return errors.New("cannot call Lookup on a closed database")
   136  	}
   137  	pointer, _, _, err := r.lookupPointer(ip)
   138  	if pointer == 0 || err != nil {
   139  		return err
   140  	}
   141  	return r.retrieveData(pointer, result)
   142  }
   143  
   144  // LookupNetwork retrieves the database record for ip and stores it in the
   145  // value pointed to by result. The network returned is the network associated
   146  // with the data record in the database. The ok return value indicates whether
   147  // the database contained a record for the ip.
   148  //
   149  // If result is nil or not a pointer, an error is returned. If the data in the
   150  // database record cannot be stored in result because of type differences, an
   151  // UnmarshalTypeError is returned. If the database is invalid or otherwise
   152  // cannot be read, an InvalidDatabaseError is returned.
   153  func (r *Reader) LookupNetwork(
   154  	ip net.IP,
   155  	result any,
   156  ) (network *net.IPNet, ok bool, err error) {
   157  	if r.buffer == nil {
   158  		return nil, false, errors.New("cannot call Lookup on a closed database")
   159  	}
   160  	pointer, prefixLength, ip, err := r.lookupPointer(ip)
   161  
   162  	network = r.cidr(ip, prefixLength)
   163  	if pointer == 0 || err != nil {
   164  		return network, false, err
   165  	}
   166  
   167  	return network, true, r.retrieveData(pointer, result)
   168  }
   169  
   170  // LookupOffset maps an argument net.IP to a corresponding record offset in the
   171  // database. NotFound is returned if no such record is found, and a record may
   172  // otherwise be extracted by passing the returned offset to Decode. LookupOffset
   173  // is an advanced API, which exists to provide clients with a means to cache
   174  // previously-decoded records.
   175  func (r *Reader) LookupOffset(ip net.IP) (uintptr, error) {
   176  	if r.buffer == nil {
   177  		return 0, errors.New("cannot call LookupOffset on a closed database")
   178  	}
   179  	pointer, _, _, err := r.lookupPointer(ip)
   180  	if pointer == 0 || err != nil {
   181  		return NotFound, err
   182  	}
   183  	return r.resolveDataPointer(pointer)
   184  }
   185  
   186  func (r *Reader) cidr(ip net.IP, prefixLength int) *net.IPNet {
   187  	// This is necessary as the node that the IPv4 start is at may
   188  	// be at a bit depth that is less that 96, i.e., ipv4Start points
   189  	// to a leaf node. For instance, if a record was inserted at ::/8,
   190  	// the ipv4Start would point directly at the leaf node for the
   191  	// record and would have a bit depth of 8. This would not happen
   192  	// with databases currently distributed by MaxMind as all of them
   193  	// have an IPv4 subtree that is greater than a single node.
   194  	if r.Metadata.IPVersion == 6 &&
   195  		len(ip) == net.IPv4len &&
   196  		r.ipv4StartBitDepth != 96 {
   197  		return &net.IPNet{IP: net.ParseIP("::"), Mask: net.CIDRMask(r.ipv4StartBitDepth, 128)}
   198  	}
   199  
   200  	mask := net.CIDRMask(prefixLength, len(ip)*8)
   201  	return &net.IPNet{IP: ip.Mask(mask), Mask: mask}
   202  }
   203  
   204  // Decode the record at |offset| into |result|. The result value pointed to
   205  // must be a data value that corresponds to a record in the database. This may
   206  // include a struct representation of the data, a map capable of holding the
   207  // data or an empty any value.
   208  //
   209  // If result is a pointer to a struct, the struct need not include a field
   210  // for every value that may be in the database. If a field is not present in
   211  // the structure, the decoder will not decode that field, reducing the time
   212  // required to decode the record.
   213  //
   214  // As a special case, a struct field of type uintptr will be used to capture
   215  // the offset of the value. Decode may later be used to extract the stored
   216  // value from the offset. MaxMind DBs are highly normalized: for example in
   217  // the City database, all records of the same country will reference a
   218  // single representative record for that country. This uintptr behavior allows
   219  // clients to leverage this normalization in their own sub-record caching.
   220  func (r *Reader) Decode(offset uintptr, result any) error {
   221  	if r.buffer == nil {
   222  		return errors.New("cannot call Decode on a closed database")
   223  	}
   224  	return r.decode(offset, result)
   225  }
   226  
   227  func (r *Reader) decode(offset uintptr, result any) error {
   228  	rv := reflect.ValueOf(result)
   229  	if rv.Kind() != reflect.Ptr || rv.IsNil() {
   230  		return errors.New("result param must be a pointer")
   231  	}
   232  
   233  	if dser, ok := result.(deserializer); ok {
   234  		_, err := r.decoder.decodeToDeserializer(uint(offset), dser, 0, false)
   235  		return err
   236  	}
   237  
   238  	_, err := r.decoder.decode(uint(offset), rv, 0)
   239  	return err
   240  }
   241  
   242  func (r *Reader) lookupPointer(ip net.IP) (uint, int, net.IP, error) {
   243  	if ip == nil {
   244  		return 0, 0, nil, errors.New("IP passed to Lookup cannot be nil")
   245  	}
   246  
   247  	ipV4Address := ip.To4()
   248  	if ipV4Address != nil {
   249  		ip = ipV4Address
   250  	}
   251  	if len(ip) == 16 && r.Metadata.IPVersion == 4 {
   252  		return 0, 0, ip, fmt.Errorf(
   253  			"error looking up '%s': you attempted to look up an IPv6 address in an IPv4-only database",
   254  			ip.String(),
   255  		)
   256  	}
   257  
   258  	bitCount := uint(len(ip) * 8)
   259  
   260  	var node uint
   261  	if bitCount == 32 {
   262  		node = r.ipv4Start
   263  	}
   264  	node, prefixLength := r.traverseTree(ip, node, bitCount)
   265  
   266  	nodeCount := r.Metadata.NodeCount
   267  	if node == nodeCount {
   268  		// Record is empty
   269  		return 0, prefixLength, ip, nil
   270  	} else if node > nodeCount {
   271  		return node, prefixLength, ip, nil
   272  	}
   273  
   274  	return 0, prefixLength, ip, newInvalidDatabaseError("invalid node in search tree")
   275  }
   276  
   277  func (r *Reader) traverseTree(ip net.IP, node, bitCount uint) (uint, int) {
   278  	nodeCount := r.Metadata.NodeCount
   279  
   280  	i := uint(0)
   281  	for ; i < bitCount && node < nodeCount; i++ {
   282  		bit := uint(1) & (uint(ip[i>>3]) >> (7 - (i % 8)))
   283  
   284  		offset := node * r.nodeOffsetMult
   285  		if bit == 0 {
   286  			node = r.nodeReader.readLeft(offset)
   287  		} else {
   288  			node = r.nodeReader.readRight(offset)
   289  		}
   290  	}
   291  
   292  	return node, int(i)
   293  }
   294  
   295  func (r *Reader) retrieveData(pointer uint, result any) error {
   296  	offset, err := r.resolveDataPointer(pointer)
   297  	if err != nil {
   298  		return err
   299  	}
   300  	return r.decode(offset, result)
   301  }
   302  
   303  func (r *Reader) resolveDataPointer(pointer uint) (uintptr, error) {
   304  	resolved := uintptr(pointer - r.Metadata.NodeCount - dataSectionSeparatorSize)
   305  
   306  	if resolved >= uintptr(len(r.buffer)) {
   307  		return 0, newInvalidDatabaseError("the MaxMind DB file's search tree is corrupt")
   308  	}
   309  	return resolved, nil
   310  }