diff --git a/go/compiler_test.go b/go/compiler_test.go index d334c7572..28a28b983 100644 --- a/go/compiler_test.go +++ b/go/compiler_test.go @@ -1,6 +1,7 @@ package yara_x import ( + "bytes" "github.com/stretchr/testify/assert" "testing" ) @@ -50,12 +51,20 @@ func TestSerialization(t *testing.T) { r, err := Compile("rule test { condition: true }") assert.NoError(t, err) - b, _ := r.Serialize() - r, _ = Deserialize(b) + var buf bytes.Buffer + // Write rules into buffer + n, err := r.WriteTo(&buf) + + assert.NoError(t, err) + assert.Len(t, buf.Bytes(), int(n)) + + // Read rules from buffer + r, _ = ReadFrom(&buf) + + // Make sure the rules work properly. s := NewScanner(r) scanResults, _ := s.Scan([]byte{}) - assert.Len(t, scanResults.MatchingRules(), 1) } @@ -118,7 +127,7 @@ func TestError(t *testing.T) { func TestCompilerFeatures(t *testing.T) { rules := `import "test_proto2" rule test { condition: test_proto2.requires_foo_and_bar }` - _, err := Compile(rules) + _, err := Compile(rules) assert.EqualError(t, err, `error[E034]: foo is required --> line:1:57 | @@ -126,16 +135,16 @@ func TestCompilerFeatures(t *testing.T) { | ^^^^^^^^^^^^^^^^^^^^ this field was used without foo |`) - _, err = Compile(rules, WithFeature("foo")) - assert.EqualError(t, err, `error[E034]: bar is required + _, err = Compile(rules, WithFeature("foo")) + assert.EqualError(t, err, `error[E034]: bar is required --> line:1:57 | 1 | import "test_proto2" rule test { condition: test_proto2.requires_foo_and_bar } | ^^^^^^^^^^^^^^^^^^^^ this field was used without bar |`) - _, err = Compile(rules, WithFeature("foo"), WithFeature("bar")) - assert.NoError(t, err) + _, err = Compile(rules, WithFeature("foo"), WithFeature("bar")) + assert.NoError(t, err) } func TestErrors(t *testing.T) { @@ -186,8 +195,8 @@ func TestRulesIter(t *testing.T) { }`) assert.NoError(t, err) - rules := c.Build() - assert.Equal(t, 2, rules.Count()) + rules := c.Build() + assert.Equal(t, 2, rules.Count()) slice := rules.Slice() assert.Len(t, slice, 2) @@ -200,7 +209,7 @@ func TestRulesIter(t *testing.T) { assert.Len(t, slice[0].Metadata(), 0) assert.Len(t, slice[1].Metadata(), 1) - assert.Equal(t, "foo", slice[1].Metadata()[0].Identifier()) + assert.Equal(t, "foo", slice[1].Metadata()[0].Identifier()) } func TestImportsIter(t *testing.T) { @@ -216,12 +225,12 @@ func TestImportsIter(t *testing.T) { }`) assert.NoError(t, err) - rules := c.Build() - imports := rules.Imports() + rules := c.Build() + imports := rules.Imports() - assert.Len(t, imports, 2) - assert.Equal(t, "pe", imports[0]) - assert.Equal(t, "elf", imports[1]) + assert.Len(t, imports, 2) + assert.Equal(t, "pe", imports[0]) + assert.Equal(t, "elf", imports[1]) } func TestWarnings(t *testing.T) { diff --git a/go/main.go b/go/main.go index 1368c451c..dd3267f94 100644 --- a/go/main.go +++ b/go/main.go @@ -31,6 +31,8 @@ import "C" import ( "errors" + "io" + "reflect" "runtime" "runtime/cgo" "unsafe" @@ -49,25 +51,30 @@ func Compile(src string, opts ...CompileOption) (*Rules, error) { return c.Build(), nil } -// Deserialize deserializes rules from a byte slice. +// ReadFrom reads compiled rules from a reader. // -// The counterpart is [Rules.Serialize] -func Deserialize(data []byte) (*Rules, error) { +// The counterpart is [Rules.WriteTo]. +func ReadFrom(r io.Reader) (*Rules, error) { + data, err := io.ReadAll(r) + if err != nil { + return nil, err + } + var ptr *C.uint8_t if len(data) > 0 { ptr = (*C.uint8_t)(unsafe.Pointer(&(data[0]))) } - r := &Rules{cRules: nil} + rules := &Rules{cRules: nil} runtime.LockOSThread() defer runtime.UnlockOSThread() - if C.yrx_rules_deserialize(ptr, C.size_t(len(data)), &r.cRules) != C.SUCCESS { + if C.yrx_rules_deserialize(ptr, C.size_t(len(data)), &rules.cRules) != C.SUCCESS { return nil, errors.New(C.GoString(C.yrx_last_error())) } - return r, nil + return rules, nil } // Rules represents a set of compiled YARA rules. @@ -79,17 +86,60 @@ func (r *Rules) Scan(data []byte) (*ScanResults, error) { return scanner.Scan(data) } -// Serialize converts the compiled rules into a byte slice. -func (r *Rules) Serialize() ([]byte, error) { +// WriteTo writes the compiled rules into a writer. +// +// The counterpart is [ReadFrom]. +func (r *Rules) WriteTo(w io.Writer) (int64, error) { var buf *C.YRX_BUFFER runtime.LockOSThread() defer runtime.UnlockOSThread() if C.yrx_rules_serialize(r.cRules, &buf) != C.SUCCESS { - return nil, errors.New(C.GoString(C.yrx_last_error())) + return 0, errors.New(C.GoString(C.yrx_last_error())) } defer C.yrx_buffer_destroy(buf) runtime.KeepAlive(r) - return C.GoBytes(unsafe.Pointer(buf.data), C.int(buf.length)), nil + + // We are going to write into `w` in chunks of 64MB. + const chunkSize = 1 << 26 + + // This is the slice that contains the next chunk that will be written. + var chunk []byte + + // Modify the `chunk` slice, making it point to the buffer returned + // by yrx_rules_serialize. This allows us to access the buffer from + // Go without copying the data. This is safe because the slice won't + // be used after the buffer is destroyed. + chunkHdr := (*reflect.SliceHeader)(unsafe.Pointer(&chunk)) + chunkHdr.Data = uintptr(unsafe.Pointer(buf.data)) + chunkHdr.Len = chunkSize + chunkHdr.Cap = chunkSize + + bufLen := C.ulong(buf.length) + bytesWritten := int64(0) + + for { + // If the data to be written is shorted than `chunkSize`, set the length + // of the `chunk` slice to this length. + if bufLen < chunkSize { + chunkHdr.Len = int(bufLen) + chunkHdr.Cap = int(bufLen) + } + if n, err := w.Write(chunk); err == nil { + bytesWritten += int64(n) + } else { + return 0, err + } + // If `bufLen` is still greater than `chunkSize`, there's more data to + // write, if not, we are done. + if bufLen > chunkSize { + chunkHdr.Data += chunkSize + bufLen -= chunkSize + } else { + break + } + } + + return bytesWritten, nil } // Destroy destroys the compiled YARA rules represented by [Rules]. @@ -106,6 +156,7 @@ func (r *Rules) Destroy() { // This is the callback called by yrx_rules_iterate, when Rules.GetRules is // called. +// //export onRule func onRule(rule *C.YRX_RULE, handle unsafe.Pointer) { h := (cgo.Handle)(handle) @@ -143,6 +194,7 @@ func (r *Rules) Count() int { // This is the callback called by yrx_rules_iterate_imports, when Rules.Imports // is called. +// //export onImport func onImport(module_name *C.char, handle unsafe.Pointer) { h := (cgo.Handle)(handle) diff --git a/go/scanner_test.go b/go/scanner_test.go index 60a76a06e..a888e64c8 100644 --- a/go/scanner_test.go +++ b/go/scanner_test.go @@ -19,6 +19,14 @@ func TestScanner1(t *testing.T) { assert.Equal(t, "t", matchingRules[0].Identifier()) assert.Equal(t, "default", matchingRules[0].Namespace()) assert.Len(t, matchingRules[0].Patterns(), 0) + + scanResults, _ = s.Scan(nil) + matchingRules = scanResults.MatchingRules() + + assert.Len(t, matchingRules, 1) + assert.Equal(t, "t", matchingRules[0].Identifier()) + assert.Equal(t, "default", matchingRules[0].Namespace()) + assert.Len(t, matchingRules[0].Patterns(), 0) } func TestScanner2(t *testing.T) {