github.com/hellobchain/third_party@v0.0.0-20230331131523-deb0478a2e52/ldap.v2/filter.go (about)

     1  // Copyright 2011 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package ldap
     6  
     7  import (
     8  	"bytes"
     9  	hexpac "encoding/hex"
    10  	"errors"
    11  	"fmt"
    12  	"strings"
    13  	"unicode/utf8"
    14  
    15  	"gopkg.in/asn1-ber.v1"
    16  )
    17  
    18  // Filter choices
    19  const (
    20  	FilterAnd             = 0
    21  	FilterOr              = 1
    22  	FilterNot             = 2
    23  	FilterEqualityMatch   = 3
    24  	FilterSubstrings      = 4
    25  	FilterGreaterOrEqual  = 5
    26  	FilterLessOrEqual     = 6
    27  	FilterPresent         = 7
    28  	FilterApproxMatch     = 8
    29  	FilterExtensibleMatch = 9
    30  )
    31  
    32  // FilterMap contains human readable descriptions of Filter choices
    33  var FilterMap = map[uint64]string{
    34  	FilterAnd:             "And",
    35  	FilterOr:              "Or",
    36  	FilterNot:             "Not",
    37  	FilterEqualityMatch:   "Equality Match",
    38  	FilterSubstrings:      "Substrings",
    39  	FilterGreaterOrEqual:  "Greater Or Equal",
    40  	FilterLessOrEqual:     "Less Or Equal",
    41  	FilterPresent:         "Present",
    42  	FilterApproxMatch:     "Approx Match",
    43  	FilterExtensibleMatch: "Extensible Match",
    44  }
    45  
    46  // SubstringFilter options
    47  const (
    48  	FilterSubstringsInitial = 0
    49  	FilterSubstringsAny     = 1
    50  	FilterSubstringsFinal   = 2
    51  )
    52  
    53  // FilterSubstringsMap contains human readable descriptions of SubstringFilter choices
    54  var FilterSubstringsMap = map[uint64]string{
    55  	FilterSubstringsInitial: "Substrings Initial",
    56  	FilterSubstringsAny:     "Substrings Any",
    57  	FilterSubstringsFinal:   "Substrings Final",
    58  }
    59  
    60  // MatchingRuleAssertion choices
    61  const (
    62  	MatchingRuleAssertionMatchingRule = 1
    63  	MatchingRuleAssertionType         = 2
    64  	MatchingRuleAssertionMatchValue   = 3
    65  	MatchingRuleAssertionDNAttributes = 4
    66  )
    67  
    68  // MatchingRuleAssertionMap contains human readable descriptions of MatchingRuleAssertion choices
    69  var MatchingRuleAssertionMap = map[uint64]string{
    70  	MatchingRuleAssertionMatchingRule: "Matching Rule Assertion Matching Rule",
    71  	MatchingRuleAssertionType:         "Matching Rule Assertion Type",
    72  	MatchingRuleAssertionMatchValue:   "Matching Rule Assertion Match Value",
    73  	MatchingRuleAssertionDNAttributes: "Matching Rule Assertion DN Attributes",
    74  }
    75  
    76  // CompileFilter converts a string representation of a filter into a BER-encoded packet
    77  func CompileFilter(filter string) (*ber.Packet, error) {
    78  	if len(filter) == 0 || filter[0] != '(' {
    79  		return nil, NewError(ErrorFilterCompile, errors.New("ldap: filter does not start with an '('"))
    80  	}
    81  	packet, pos, err := compileFilter(filter, 1)
    82  	if err != nil {
    83  		return nil, err
    84  	}
    85  	if pos != len(filter) {
    86  		return nil, NewError(ErrorFilterCompile, errors.New("ldap: finished compiling filter with extra at end: "+fmt.Sprint(filter[pos:])))
    87  	}
    88  	return packet, nil
    89  }
    90  
    91  // DecompileFilter converts a packet representation of a filter into a string representation
    92  func DecompileFilter(packet *ber.Packet) (ret string, err error) {
    93  	defer func() {
    94  		if r := recover(); r != nil {
    95  			err = NewError(ErrorFilterDecompile, errors.New("ldap: error decompiling filter"))
    96  		}
    97  	}()
    98  	ret = "("
    99  	err = nil
   100  	childStr := ""
   101  
   102  	switch packet.Tag {
   103  	case FilterAnd:
   104  		ret += "&"
   105  		for _, child := range packet.Children {
   106  			childStr, err = DecompileFilter(child)
   107  			if err != nil {
   108  				return
   109  			}
   110  			ret += childStr
   111  		}
   112  	case FilterOr:
   113  		ret += "|"
   114  		for _, child := range packet.Children {
   115  			childStr, err = DecompileFilter(child)
   116  			if err != nil {
   117  				return
   118  			}
   119  			ret += childStr
   120  		}
   121  	case FilterNot:
   122  		ret += "!"
   123  		childStr, err = DecompileFilter(packet.Children[0])
   124  		if err != nil {
   125  			return
   126  		}
   127  		ret += childStr
   128  
   129  	case FilterSubstrings:
   130  		ret += ber.DecodeString(packet.Children[0].Data.Bytes())
   131  		ret += "="
   132  		for i, child := range packet.Children[1].Children {
   133  			if i == 0 && child.Tag != FilterSubstringsInitial {
   134  				ret += "*"
   135  			}
   136  			ret += EscapeFilter(ber.DecodeString(child.Data.Bytes()))
   137  			if child.Tag != FilterSubstringsFinal {
   138  				ret += "*"
   139  			}
   140  		}
   141  	case FilterEqualityMatch:
   142  		ret += ber.DecodeString(packet.Children[0].Data.Bytes())
   143  		ret += "="
   144  		ret += EscapeFilter(ber.DecodeString(packet.Children[1].Data.Bytes()))
   145  	case FilterGreaterOrEqual:
   146  		ret += ber.DecodeString(packet.Children[0].Data.Bytes())
   147  		ret += ">="
   148  		ret += EscapeFilter(ber.DecodeString(packet.Children[1].Data.Bytes()))
   149  	case FilterLessOrEqual:
   150  		ret += ber.DecodeString(packet.Children[0].Data.Bytes())
   151  		ret += "<="
   152  		ret += EscapeFilter(ber.DecodeString(packet.Children[1].Data.Bytes()))
   153  	case FilterPresent:
   154  		ret += ber.DecodeString(packet.Data.Bytes())
   155  		ret += "=*"
   156  	case FilterApproxMatch:
   157  		ret += ber.DecodeString(packet.Children[0].Data.Bytes())
   158  		ret += "~="
   159  		ret += EscapeFilter(ber.DecodeString(packet.Children[1].Data.Bytes()))
   160  	case FilterExtensibleMatch:
   161  		attr := ""
   162  		dnAttributes := false
   163  		matchingRule := ""
   164  		value := ""
   165  
   166  		for _, child := range packet.Children {
   167  			switch child.Tag {
   168  			case MatchingRuleAssertionMatchingRule:
   169  				matchingRule = ber.DecodeString(child.Data.Bytes())
   170  			case MatchingRuleAssertionType:
   171  				attr = ber.DecodeString(child.Data.Bytes())
   172  			case MatchingRuleAssertionMatchValue:
   173  				value = ber.DecodeString(child.Data.Bytes())
   174  			case MatchingRuleAssertionDNAttributes:
   175  				dnAttributes = child.Value.(bool)
   176  			}
   177  		}
   178  
   179  		if len(attr) > 0 {
   180  			ret += attr
   181  		}
   182  		if dnAttributes {
   183  			ret += ":dn"
   184  		}
   185  		if len(matchingRule) > 0 {
   186  			ret += ":"
   187  			ret += matchingRule
   188  		}
   189  		ret += ":="
   190  		ret += EscapeFilter(value)
   191  	}
   192  
   193  	ret += ")"
   194  	return
   195  }
   196  
   197  func compileFilterSet(filter string, pos int, parent *ber.Packet) (int, error) {
   198  	for pos < len(filter) && filter[pos] == '(' {
   199  		child, newPos, err := compileFilter(filter, pos+1)
   200  		if err != nil {
   201  			return pos, err
   202  		}
   203  		pos = newPos
   204  		parent.AppendChild(child)
   205  	}
   206  	if pos == len(filter) {
   207  		return pos, NewError(ErrorFilterCompile, errors.New("ldap: unexpected end of filter"))
   208  	}
   209  
   210  	return pos + 1, nil
   211  }
   212  
   213  func compileFilter(filter string, pos int) (*ber.Packet, int, error) {
   214  	var (
   215  		packet *ber.Packet
   216  		err    error
   217  	)
   218  
   219  	defer func() {
   220  		if r := recover(); r != nil {
   221  			err = NewError(ErrorFilterCompile, errors.New("ldap: error compiling filter"))
   222  		}
   223  	}()
   224  	newPos := pos
   225  
   226  	currentRune, currentWidth := utf8.DecodeRuneInString(filter[newPos:])
   227  
   228  	switch currentRune {
   229  	case utf8.RuneError:
   230  		return nil, 0, NewError(ErrorFilterCompile, fmt.Errorf("ldap: error reading rune at position %d", newPos))
   231  	case '(':
   232  		packet, newPos, err = compileFilter(filter, pos+currentWidth)
   233  		newPos++
   234  		return packet, newPos, err
   235  	case '&':
   236  		packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterAnd, nil, FilterMap[FilterAnd])
   237  		newPos, err = compileFilterSet(filter, pos+currentWidth, packet)
   238  		return packet, newPos, err
   239  	case '|':
   240  		packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterOr, nil, FilterMap[FilterOr])
   241  		newPos, err = compileFilterSet(filter, pos+currentWidth, packet)
   242  		return packet, newPos, err
   243  	case '!':
   244  		packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterNot, nil, FilterMap[FilterNot])
   245  		var child *ber.Packet
   246  		child, newPos, err = compileFilter(filter, pos+currentWidth)
   247  		packet.AppendChild(child)
   248  		return packet, newPos, err
   249  	default:
   250  		const (
   251  			stateReadingAttr                   = 0
   252  			stateReadingExtensibleMatchingRule = 1
   253  			stateReadingCondition              = 2
   254  		)
   255  
   256  		state := stateReadingAttr
   257  
   258  		attribute := ""
   259  		extensibleDNAttributes := false
   260  		extensibleMatchingRule := ""
   261  		condition := ""
   262  
   263  		for newPos < len(filter) {
   264  			remainingFilter := filter[newPos:]
   265  			currentRune, currentWidth = utf8.DecodeRuneInString(remainingFilter)
   266  			if currentRune == ')' {
   267  				break
   268  			}
   269  			if currentRune == utf8.RuneError {
   270  				return packet, newPos, NewError(ErrorFilterCompile, fmt.Errorf("ldap: error reading rune at position %d", newPos))
   271  			}
   272  
   273  			switch state {
   274  			case stateReadingAttr:
   275  				switch {
   276  				// Extensible rule, with only DN-matching
   277  				case currentRune == ':' && strings.HasPrefix(remainingFilter, ":dn:="):
   278  					packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterExtensibleMatch, nil, FilterMap[FilterExtensibleMatch])
   279  					extensibleDNAttributes = true
   280  					state = stateReadingCondition
   281  					newPos += 5
   282  
   283  				// Extensible rule, with DN-matching and a matching OID
   284  				case currentRune == ':' && strings.HasPrefix(remainingFilter, ":dn:"):
   285  					packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterExtensibleMatch, nil, FilterMap[FilterExtensibleMatch])
   286  					extensibleDNAttributes = true
   287  					state = stateReadingExtensibleMatchingRule
   288  					newPos += 4
   289  
   290  				// Extensible rule, with attr only
   291  				case currentRune == ':' && strings.HasPrefix(remainingFilter, ":="):
   292  					packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterExtensibleMatch, nil, FilterMap[FilterExtensibleMatch])
   293  					state = stateReadingCondition
   294  					newPos += 2
   295  
   296  				// Extensible rule, with no DN attribute matching
   297  				case currentRune == ':':
   298  					packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterExtensibleMatch, nil, FilterMap[FilterExtensibleMatch])
   299  					state = stateReadingExtensibleMatchingRule
   300  					newPos++
   301  
   302  				// Equality condition
   303  				case currentRune == '=':
   304  					packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterEqualityMatch, nil, FilterMap[FilterEqualityMatch])
   305  					state = stateReadingCondition
   306  					newPos++
   307  
   308  				// Greater-than or equal
   309  				case currentRune == '>' && strings.HasPrefix(remainingFilter, ">="):
   310  					packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterGreaterOrEqual, nil, FilterMap[FilterGreaterOrEqual])
   311  					state = stateReadingCondition
   312  					newPos += 2
   313  
   314  				// Less-than or equal
   315  				case currentRune == '<' && strings.HasPrefix(remainingFilter, "<="):
   316  					packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterLessOrEqual, nil, FilterMap[FilterLessOrEqual])
   317  					state = stateReadingCondition
   318  					newPos += 2
   319  
   320  				// Approx
   321  				case currentRune == '~' && strings.HasPrefix(remainingFilter, "~="):
   322  					packet = ber.Encode(ber.ClassContext, ber.TypeConstructed, FilterApproxMatch, nil, FilterMap[FilterApproxMatch])
   323  					state = stateReadingCondition
   324  					newPos += 2
   325  
   326  				// Still reading the attribute name
   327  				default:
   328  					attribute += fmt.Sprintf("%c", currentRune)
   329  					newPos += currentWidth
   330  				}
   331  
   332  			case stateReadingExtensibleMatchingRule:
   333  				switch {
   334  
   335  				// Matching rule OID is done
   336  				case currentRune == ':' && strings.HasPrefix(remainingFilter, ":="):
   337  					state = stateReadingCondition
   338  					newPos += 2
   339  
   340  				// Still reading the matching rule oid
   341  				default:
   342  					extensibleMatchingRule += fmt.Sprintf("%c", currentRune)
   343  					newPos += currentWidth
   344  				}
   345  
   346  			case stateReadingCondition:
   347  				// append to the condition
   348  				condition += fmt.Sprintf("%c", currentRune)
   349  				newPos += currentWidth
   350  			}
   351  		}
   352  
   353  		if newPos == len(filter) {
   354  			err = NewError(ErrorFilterCompile, errors.New("ldap: unexpected end of filter"))
   355  			return packet, newPos, err
   356  		}
   357  		if packet == nil {
   358  			err = NewError(ErrorFilterCompile, errors.New("ldap: error parsing filter"))
   359  			return packet, newPos, err
   360  		}
   361  
   362  		switch {
   363  		case packet.Tag == FilterExtensibleMatch:
   364  			// MatchingRuleAssertion ::= SEQUENCE {
   365  			//         matchingRule    [1] MatchingRuleID OPTIONAL,
   366  			//         type            [2] AttributeDescription OPTIONAL,
   367  			//         matchValue      [3] AssertionValue,
   368  			//         dnAttributes    [4] BOOLEAN DEFAULT FALSE
   369  			// }
   370  
   371  			// Include the matching rule oid, if specified
   372  			if len(extensibleMatchingRule) > 0 {
   373  				packet.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, MatchingRuleAssertionMatchingRule, extensibleMatchingRule, MatchingRuleAssertionMap[MatchingRuleAssertionMatchingRule]))
   374  			}
   375  
   376  			// Include the attribute, if specified
   377  			if len(attribute) > 0 {
   378  				packet.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, MatchingRuleAssertionType, attribute, MatchingRuleAssertionMap[MatchingRuleAssertionType]))
   379  			}
   380  
   381  			// Add the value (only required child)
   382  			encodedString, encodeErr := escapedStringToEncodedBytes(condition)
   383  			if encodeErr != nil {
   384  				return packet, newPos, encodeErr
   385  			}
   386  			packet.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, MatchingRuleAssertionMatchValue, encodedString, MatchingRuleAssertionMap[MatchingRuleAssertionMatchValue]))
   387  
   388  			// Defaults to false, so only include in the sequence if true
   389  			if extensibleDNAttributes {
   390  				packet.AppendChild(ber.NewBoolean(ber.ClassContext, ber.TypePrimitive, MatchingRuleAssertionDNAttributes, extensibleDNAttributes, MatchingRuleAssertionMap[MatchingRuleAssertionDNAttributes]))
   391  			}
   392  
   393  		case packet.Tag == FilterEqualityMatch && condition == "*":
   394  			packet = ber.NewString(ber.ClassContext, ber.TypePrimitive, FilterPresent, attribute, FilterMap[FilterPresent])
   395  		case packet.Tag == FilterEqualityMatch && strings.Contains(condition, "*"):
   396  			packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, attribute, "Attribute"))
   397  			packet.Tag = FilterSubstrings
   398  			packet.Description = FilterMap[uint64(packet.Tag)]
   399  			seq := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Substrings")
   400  			parts := strings.Split(condition, "*")
   401  			for i, part := range parts {
   402  				if part == "" {
   403  					continue
   404  				}
   405  				var tag ber.Tag
   406  				switch i {
   407  				case 0:
   408  					tag = FilterSubstringsInitial
   409  				case len(parts) - 1:
   410  					tag = FilterSubstringsFinal
   411  				default:
   412  					tag = FilterSubstringsAny
   413  				}
   414  				encodedString, encodeErr := escapedStringToEncodedBytes(part)
   415  				if encodeErr != nil {
   416  					return packet, newPos, encodeErr
   417  				}
   418  				seq.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, tag, encodedString, FilterSubstringsMap[uint64(tag)]))
   419  			}
   420  			packet.AppendChild(seq)
   421  		default:
   422  			encodedString, encodeErr := escapedStringToEncodedBytes(condition)
   423  			if encodeErr != nil {
   424  				return packet, newPos, encodeErr
   425  			}
   426  			packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, attribute, "Attribute"))
   427  			packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, encodedString, "Condition"))
   428  		}
   429  
   430  		newPos += currentWidth
   431  		return packet, newPos, err
   432  	}
   433  }
   434  
   435  // Convert from "ABC\xx\xx\xx" form to literal bytes for transport
   436  func escapedStringToEncodedBytes(escapedString string) (string, error) {
   437  	var buffer bytes.Buffer
   438  	i := 0
   439  	for i < len(escapedString) {
   440  		currentRune, currentWidth := utf8.DecodeRuneInString(escapedString[i:])
   441  		if currentRune == utf8.RuneError {
   442  			return "", NewError(ErrorFilterCompile, fmt.Errorf("ldap: error reading rune at position %d", i))
   443  		}
   444  
   445  		// Check for escaped hex characters and convert them to their literal value for transport.
   446  		if currentRune == '\\' {
   447  			// http://tools.ietf.org/search/rfc4515
   448  			// \ (%x5C) is not a valid character unless it is followed by two HEX characters due to not
   449  			// being a member of UTF1SUBSET.
   450  			if i+2 > len(escapedString) {
   451  				return "", NewError(ErrorFilterCompile, errors.New("ldap: missing characters for escape in filter"))
   452  			}
   453  			escByte, decodeErr := hexpac.DecodeString(escapedString[i+1 : i+3])
   454  			if decodeErr != nil {
   455  				return "", NewError(ErrorFilterCompile, errors.New("ldap: invalid characters for escape in filter"))
   456  			}
   457  			buffer.WriteByte(escByte[0])
   458  			i += 2 // +1 from end of loop, so 3 total for \xx.
   459  		} else {
   460  			buffer.WriteRune(currentRune)
   461  		}
   462  
   463  		i += currentWidth
   464  	}
   465  	return buffer.String(), nil
   466  }