Skip to content

Commit a4f79d8

Browse files
khevsejohnweldon
authored andcommitted
fix: memory allocations (#252)
* fix: memory allocations * fix: replace strings.Builder to bytes.Buffer for old goland versions * fix: rename escapedStringToEncodedBytes to decodeEscapedSymbols * feat: remove one allocation * fix (v3): memory allocations
1 parent a75d3c9 commit a4f79d8

File tree

4 files changed

+276
-156
lines changed

4 files changed

+276
-156
lines changed

filter.go

Lines changed: 99 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,12 @@ import (
55
hexpac "encoding/hex"
66
"errors"
77
"fmt"
8+
"io"
89
"strings"
10+
"unicode"
911
"unicode/utf8"
1012

11-
"github.com/go-asn1-ber/asn1-ber"
13+
ber "github.com/go-asn1-ber/asn1-ber"
1214
)
1315

1416
// Filter choices
@@ -69,6 +71,8 @@ var MatchingRuleAssertionMap = map[uint64]string{
6971
MatchingRuleAssertionDNAttributes: "Matching Rule Assertion DN Attributes",
7072
}
7173

74+
var _SymbolAny = []byte{'*'}
75+
7276
// CompileFilter converts a string representation of a filter into a BER-encoded packet
7377
func CompileFilter(filter string) (*ber.Packet, error) {
7478
if len(filter) == 0 || filter[0] != '(' {
@@ -88,74 +92,75 @@ func CompileFilter(filter string) (*ber.Packet, error) {
8892
}
8993

9094
// DecompileFilter converts a packet representation of a filter into a string representation
91-
func DecompileFilter(packet *ber.Packet) (ret string, err error) {
95+
func DecompileFilter(packet *ber.Packet) (_ string, err error) {
9296
defer func() {
9397
if r := recover(); r != nil {
9498
err = NewError(ErrorFilterDecompile, errors.New("ldap: error decompiling filter"))
9599
}
96100
}()
97-
ret = "("
98-
err = nil
101+
102+
buf := bytes.NewBuffer(nil)
103+
buf.WriteByte('(')
99104
childStr := ""
100105

101106
switch packet.Tag {
102107
case FilterAnd:
103-
ret += "&"
108+
buf.WriteByte('&')
104109
for _, child := range packet.Children {
105110
childStr, err = DecompileFilter(child)
106111
if err != nil {
107112
return
108113
}
109-
ret += childStr
114+
buf.WriteString(childStr)
110115
}
111116
case FilterOr:
112-
ret += "|"
117+
buf.WriteByte('|')
113118
for _, child := range packet.Children {
114119
childStr, err = DecompileFilter(child)
115120
if err != nil {
116121
return
117122
}
118-
ret += childStr
123+
buf.WriteString(childStr)
119124
}
120125
case FilterNot:
121-
ret += "!"
126+
buf.WriteByte('!')
122127
childStr, err = DecompileFilter(packet.Children[0])
123128
if err != nil {
124129
return
125130
}
126-
ret += childStr
131+
buf.WriteString(childStr)
127132

128133
case FilterSubstrings:
129-
ret += ber.DecodeString(packet.Children[0].Data.Bytes())
130-
ret += "="
134+
buf.WriteString(ber.DecodeString(packet.Children[0].Data.Bytes()))
135+
buf.WriteByte('=')
131136
for i, child := range packet.Children[1].Children {
132137
if i == 0 && child.Tag != FilterSubstringsInitial {
133-
ret += "*"
138+
buf.Write(_SymbolAny)
134139
}
135-
ret += EscapeFilter(ber.DecodeString(child.Data.Bytes()))
140+
buf.WriteString(EscapeFilter(ber.DecodeString(child.Data.Bytes())))
136141
if child.Tag != FilterSubstringsFinal {
137-
ret += "*"
142+
buf.Write(_SymbolAny)
138143
}
139144
}
140145
case FilterEqualityMatch:
141-
ret += ber.DecodeString(packet.Children[0].Data.Bytes())
142-
ret += "="
143-
ret += EscapeFilter(ber.DecodeString(packet.Children[1].Data.Bytes()))
146+
buf.WriteString(ber.DecodeString(packet.Children[0].Data.Bytes()))
147+
buf.WriteByte('=')
148+
buf.WriteString(EscapeFilter(ber.DecodeString(packet.Children[1].Data.Bytes())))
144149
case FilterGreaterOrEqual:
145-
ret += ber.DecodeString(packet.Children[0].Data.Bytes())
146-
ret += ">="
147-
ret += EscapeFilter(ber.DecodeString(packet.Children[1].Data.Bytes()))
150+
buf.WriteString(ber.DecodeString(packet.Children[0].Data.Bytes()))
151+
buf.WriteString(">=")
152+
buf.WriteString(EscapeFilter(ber.DecodeString(packet.Children[1].Data.Bytes())))
148153
case FilterLessOrEqual:
149-
ret += ber.DecodeString(packet.Children[0].Data.Bytes())
150-
ret += "<="
151-
ret += EscapeFilter(ber.DecodeString(packet.Children[1].Data.Bytes()))
154+
buf.WriteString(ber.DecodeString(packet.Children[0].Data.Bytes()))
155+
buf.WriteString("<=")
156+
buf.WriteString(EscapeFilter(ber.DecodeString(packet.Children[1].Data.Bytes())))
152157
case FilterPresent:
153-
ret += ber.DecodeString(packet.Data.Bytes())
154-
ret += "=*"
158+
buf.WriteString(ber.DecodeString(packet.Data.Bytes()))
159+
buf.WriteString("=*")
155160
case FilterApproxMatch:
156-
ret += ber.DecodeString(packet.Children[0].Data.Bytes())
157-
ret += "~="
158-
ret += EscapeFilter(ber.DecodeString(packet.Children[1].Data.Bytes()))
161+
buf.WriteString(ber.DecodeString(packet.Children[0].Data.Bytes()))
162+
buf.WriteString("~=")
163+
buf.WriteString(EscapeFilter(ber.DecodeString(packet.Children[1].Data.Bytes())))
159164
case FilterExtensibleMatch:
160165
attr := ""
161166
dnAttributes := false
@@ -176,21 +181,22 @@ func DecompileFilter(packet *ber.Packet) (ret string, err error) {
176181
}
177182

178183
if len(attr) > 0 {
179-
ret += attr
184+
buf.WriteString(attr)
180185
}
181186
if dnAttributes {
182-
ret += ":dn"
187+
buf.WriteString(":dn")
183188
}
184189
if len(matchingRule) > 0 {
185-
ret += ":"
186-
ret += matchingRule
190+
buf.WriteString(":")
191+
buf.WriteString(matchingRule)
187192
}
188-
ret += ":="
189-
ret += EscapeFilter(value)
193+
buf.WriteString(":=")
194+
buf.WriteString(EscapeFilter(value))
190195
}
191196

192-
ret += ")"
193-
return
197+
buf.WriteByte(')')
198+
199+
return buf.String(), nil
194200
}
195201

196202
func compileFilterSet(filter string, pos int, parent *ber.Packet) (int, error) {
@@ -253,11 +259,10 @@ func compileFilter(filter string, pos int) (*ber.Packet, int, error) {
253259
)
254260

255261
state := stateReadingAttr
256-
257-
attribute := ""
262+
attribute := bytes.NewBuffer(nil)
258263
extensibleDNAttributes := false
259-
extensibleMatchingRule := ""
260-
condition := ""
264+
extensibleMatchingRule := bytes.NewBuffer(nil)
265+
condition := bytes.NewBuffer(nil)
261266

262267
for newPos < len(filter) {
263268
remainingFilter := filter[newPos:]
@@ -324,7 +329,7 @@ func compileFilter(filter string, pos int) (*ber.Packet, int, error) {
324329

325330
// Still reading the attribute name
326331
default:
327-
attribute += fmt.Sprintf("%c", currentRune)
332+
attribute.WriteRune(currentRune)
328333
newPos += currentWidth
329334
}
330335

@@ -338,13 +343,13 @@ func compileFilter(filter string, pos int) (*ber.Packet, int, error) {
338343

339344
// Still reading the matching rule oid
340345
default:
341-
extensibleMatchingRule += fmt.Sprintf("%c", currentRune)
346+
extensibleMatchingRule.WriteRune(currentRune)
342347
newPos += currentWidth
343348
}
344349

345350
case stateReadingCondition:
346351
// append to the condition
347-
condition += fmt.Sprintf("%c", currentRune)
352+
condition.WriteRune(currentRune)
348353
newPos += currentWidth
349354
}
350355
}
@@ -368,17 +373,17 @@ func compileFilter(filter string, pos int) (*ber.Packet, int, error) {
368373
// }
369374

370375
// Include the matching rule oid, if specified
371-
if len(extensibleMatchingRule) > 0 {
372-
packet.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, MatchingRuleAssertionMatchingRule, extensibleMatchingRule, MatchingRuleAssertionMap[MatchingRuleAssertionMatchingRule]))
376+
if extensibleMatchingRule.Len() > 0 {
377+
packet.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, MatchingRuleAssertionMatchingRule, extensibleMatchingRule.String(), MatchingRuleAssertionMap[MatchingRuleAssertionMatchingRule]))
373378
}
374379

375380
// Include the attribute, if specified
376-
if len(attribute) > 0 {
377-
packet.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, MatchingRuleAssertionType, attribute, MatchingRuleAssertionMap[MatchingRuleAssertionType]))
381+
if attribute.Len() > 0 {
382+
packet.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, MatchingRuleAssertionType, attribute.String(), MatchingRuleAssertionMap[MatchingRuleAssertionType]))
378383
}
379384

380385
// Add the value (only required child)
381-
encodedString, encodeErr := escapedStringToEncodedBytes(condition)
386+
encodedString, encodeErr := decodeEscapedSymbols(condition.Bytes())
382387
if encodeErr != nil {
383388
return packet, newPos, encodeErr
384389
}
@@ -389,16 +394,16 @@ func compileFilter(filter string, pos int) (*ber.Packet, int, error) {
389394
packet.AppendChild(ber.NewBoolean(ber.ClassContext, ber.TypePrimitive, MatchingRuleAssertionDNAttributes, extensibleDNAttributes, MatchingRuleAssertionMap[MatchingRuleAssertionDNAttributes]))
390395
}
391396

392-
case packet.Tag == FilterEqualityMatch && condition == "*":
393-
packet = ber.NewString(ber.ClassContext, ber.TypePrimitive, FilterPresent, attribute, FilterMap[FilterPresent])
394-
case packet.Tag == FilterEqualityMatch && strings.Contains(condition, "*"):
395-
packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, attribute, "Attribute"))
397+
case packet.Tag == FilterEqualityMatch && bytes.Equal(condition.Bytes(), _SymbolAny):
398+
packet = ber.NewString(ber.ClassContext, ber.TypePrimitive, FilterPresent, attribute.String(), FilterMap[FilterPresent])
399+
case packet.Tag == FilterEqualityMatch && bytes.Index(condition.Bytes(), _SymbolAny) > -1:
400+
packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, attribute.String(), "Attribute"))
396401
packet.Tag = FilterSubstrings
397402
packet.Description = FilterMap[uint64(packet.Tag)]
398403
seq := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "Substrings")
399-
parts := strings.Split(condition, "*")
404+
parts := bytes.Split(condition.Bytes(), _SymbolAny)
400405
for i, part := range parts {
401-
if part == "" {
406+
if len(part) == 0 {
402407
continue
403408
}
404409
var tag ber.Tag
@@ -410,19 +415,19 @@ func compileFilter(filter string, pos int) (*ber.Packet, int, error) {
410415
default:
411416
tag = FilterSubstringsAny
412417
}
413-
encodedString, encodeErr := escapedStringToEncodedBytes(part)
418+
encodedString, encodeErr := decodeEscapedSymbols(part)
414419
if encodeErr != nil {
415420
return packet, newPos, encodeErr
416421
}
417422
seq.AppendChild(ber.NewString(ber.ClassContext, ber.TypePrimitive, tag, encodedString, FilterSubstringsMap[uint64(tag)]))
418423
}
419424
packet.AppendChild(seq)
420425
default:
421-
encodedString, encodeErr := escapedStringToEncodedBytes(condition)
426+
encodedString, encodeErr := decodeEscapedSymbols(condition.Bytes())
422427
if encodeErr != nil {
423428
return packet, newPos, encodeErr
424429
}
425-
packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, attribute, "Attribute"))
430+
packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, attribute.String(), "Attribute"))
426431
packet.AppendChild(ber.NewString(ber.ClassUniversal, ber.TypePrimitive, ber.TagOctetString, encodedString, "Condition"))
427432
}
428433

@@ -432,34 +437,51 @@ func compileFilter(filter string, pos int) (*ber.Packet, int, error) {
432437
}
433438

434439
// Convert from "ABC\xx\xx\xx" form to literal bytes for transport
435-
func escapedStringToEncodedBytes(escapedString string) (string, error) {
436-
var buffer bytes.Buffer
437-
i := 0
438-
for i < len(escapedString) {
439-
currentRune, currentWidth := utf8.DecodeRuneInString(escapedString[i:])
440-
if currentRune == utf8.RuneError {
441-
return "", NewError(ErrorFilterCompile, fmt.Errorf("ldap: error reading rune at position %d", i))
440+
func decodeEscapedSymbols(src []byte) (string, error) {
441+
442+
var (
443+
buffer bytes.Buffer
444+
offset int
445+
reader = bytes.NewReader(src)
446+
byteHex []byte
447+
byteVal []byte
448+
)
449+
450+
for {
451+
runeVal, runeSize, err := reader.ReadRune()
452+
if err == io.EOF {
453+
return buffer.String(), nil
454+
} else if err != nil {
455+
return "", NewError(ErrorFilterCompile, fmt.Errorf("ldap: failed to read filter: %v", err))
456+
} else if runeVal == unicode.ReplacementChar {
457+
return "", NewError(ErrorFilterCompile, fmt.Errorf("ldap: error reading rune at position %d", offset))
442458
}
443459

444-
// Check for escaped hex characters and convert them to their literal value for transport.
445-
if currentRune == '\\' {
460+
if runeVal == '\\' {
446461
// http://tools.ietf.org/search/rfc4515
447462
// \ (%x5C) is not a valid character unless it is followed by two HEX characters due to not
448463
// being a member of UTF1SUBSET.
449-
if i+2 > len(escapedString) {
450-
return "", NewError(ErrorFilterCompile, errors.New("ldap: missing characters for escape in filter"))
464+
if byteHex == nil {
465+
byteHex = make([]byte, 2)
466+
byteVal = make([]byte, 1)
467+
}
468+
469+
if _, err := io.ReadFull(reader, byteHex); err != nil {
470+
if err == io.ErrUnexpectedEOF {
471+
return "", NewError(ErrorFilterCompile, errors.New("ldap: missing characters for escape in filter"))
472+
}
473+
return "", NewError(ErrorFilterCompile, fmt.Errorf("ldap: invalid characters for escape in filter: %v", err))
451474
}
452-
escByte, decodeErr := hexpac.DecodeString(escapedString[i+1 : i+3])
453-
if decodeErr != nil {
454-
return "", NewError(ErrorFilterCompile, errors.New("ldap: invalid characters for escape in filter"))
475+
476+
if _, err := hexpac.Decode(byteVal, byteHex); err != nil {
477+
return "", NewError(ErrorFilterCompile, fmt.Errorf("ldap: invalid characters for escape in filter: %v", err))
455478
}
456-
buffer.WriteByte(escByte[0])
457-
i += 2 // +1 from end of loop, so 3 total for \xx.
479+
480+
buffer.Write(byteVal)
458481
} else {
459-
buffer.WriteRune(currentRune)
482+
buffer.WriteRune(runeVal)
460483
}
461484

462-
i += currentWidth
485+
offset += runeSize
463486
}
464-
return buffer.String(), nil
465487
}

filter_test.go

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ import (
44
"strings"
55
"testing"
66

7-
"github.com/go-asn1-ber/asn1-ber"
7+
ber "github.com/go-asn1-ber/asn1-ber"
88
)
99

1010
type compileTest struct {
@@ -213,6 +213,44 @@ func TestFilter(t *testing.T) {
213213
}
214214
}
215215

216+
func TestDecodeEscapedSymbols(t *testing.T) {
217+
218+
for _, testInfo := range []struct {
219+
Src string
220+
Err string
221+
}{
222+
{
223+
Src: "a\u0100\x80",
224+
Err: `LDAP Result Code 201 "Filter Compile Error": ldap: error reading rune at position 3`,
225+
},
226+
{
227+
Src: `start\d`,
228+
Err: `LDAP Result Code 201 "Filter Compile Error": ldap: missing characters for escape in filter`,
229+
},
230+
{
231+
Src: `\`,
232+
Err: `LDAP Result Code 201 "Filter Compile Error": ldap: invalid characters for escape in filter: EOF`,
233+
},
234+
{
235+
Src: `start\--end`,
236+
Err: `LDAP Result Code 201 "Filter Compile Error": ldap: invalid characters for escape in filter: encoding/hex: invalid byte: U+002D '-'`,
237+
},
238+
{
239+
Src: `start\d0\hh`,
240+
Err: `LDAP Result Code 201 "Filter Compile Error": ldap: invalid characters for escape in filter: encoding/hex: invalid byte: U+0068 'h'`,
241+
},
242+
} {
243+
244+
res, err := decodeEscapedSymbols([]byte(testInfo.Src))
245+
if err == nil || err.Error() != testInfo.Err {
246+
t.Fatal(testInfo.Src, "=> ", err, "!=", testInfo.Err)
247+
}
248+
if res != "" {
249+
t.Fatal(testInfo.Src, "=> ", "invalid result", res)
250+
}
251+
}
252+
}
253+
216254
func TestInvalidFilter(t *testing.T) {
217255
for _, filterStr := range testInvalidFilters {
218256
if _, err := CompileFilter(filterStr); err == nil {

0 commit comments

Comments
 (0)