Skip to content

Commit

Permalink
Merge pull request #3173 from ruby/code-units-cache
Browse files Browse the repository at this point in the history
Prism::CodeUnitsCache
  • Loading branch information
kddnewton authored Oct 10, 2024
2 parents e6794e6 + 2e3e1a4 commit ba89182
Show file tree
Hide file tree
Showing 5 changed files with 219 additions and 0 deletions.
112 changes: 112 additions & 0 deletions lib/prism/parse_result.rb
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,12 @@ def code_units_offset(byte_offset, encoding)
end
end

# Generate a cache that targets a specific encoding for calculating code
# unit offsets.
def code_units_cache(encoding)
CodeUnitsCache.new(source, encoding)
end

# Returns the column number in code units for the given encoding for the
# given byte offset.
def code_units_column(byte_offset, encoding)
Expand Down Expand Up @@ -149,6 +155,76 @@ def find_line(byte_offset)
end
end

# A cache that can be used to quickly compute code unit offsets from byte
# offsets. It purposefully provides only a single #[] method to access the
# cache in order to minimize surface area.
#
# Note that there are some known issues here that may or may not be addressed
# in the future:
#
# * The first is that there are issues when the cache computes values that are
# not on character boundaries. This can result in subsequent computations
# being off by one or more code units.
# * The second is that this cache is currently unbounded. In theory we could
# introduce some kind of LRU cache to limit the number of entries, but this
# has not yet been implemented.
#
class CodeUnitsCache
class UTF16Counter # :nodoc:
def initialize(source, encoding)
@source = source
@encoding = encoding
end

def count(byte_offset, byte_length)
@source.byteslice(byte_offset, byte_length).encode(@encoding, invalid: :replace, undef: :replace).bytesize / 2
end
end

class LengthCounter # :nodoc:
def initialize(source, encoding)
@source = source
@encoding = encoding
end

def count(byte_offset, byte_length)
@source.byteslice(byte_offset, byte_length).encode(@encoding, invalid: :replace, undef: :replace).length
end
end

private_constant :UTF16Counter, :LengthCounter

# Initialize a new cache with the given source and encoding.
def initialize(source, encoding)
@source = source
@counter =
if encoding == Encoding::UTF_16LE || encoding == Encoding::UTF_16BE
UTF16Counter.new(source, encoding)
else
LengthCounter.new(source, encoding)
end

@cache = {}
@offsets = []
end

# Retrieve the code units offset from the given byte offset.
def [](byte_offset)
@cache[byte_offset] ||=
if (index = @offsets.bsearch_index { |offset| offset > byte_offset }).nil?
@offsets << byte_offset
@counter.count(0, byte_offset)
elsif index == 0
@offsets.unshift(byte_offset)
@counter.count(0, byte_offset)
else
@offsets.insert(index, byte_offset)
offset = @offsets[index - 1]
@cache[offset] + @counter.count(offset, byte_offset - offset)
end
end
end

# Specialized version of Prism::Source for source code that includes ASCII
# characters only. This class is used to apply performance optimizations that
# cannot be applied to sources that include multibyte characters.
Expand Down Expand Up @@ -178,6 +254,13 @@ def code_units_offset(byte_offset, encoding)
byte_offset
end

# Returns a cache that is the identity function in order to maintain the
# same interface. We can do this because code units are always equivalent to
# byte offsets for ASCII-only sources.
def code_units_cache(encoding)
->(byte_offset) { byte_offset }
end

# Specialized version of `code_units_column` that does not depend on
# `code_units_offset`, which is a more expensive operation. This is
# essentially the same as `Prism::Source#column`.
Expand Down Expand Up @@ -287,6 +370,12 @@ def start_code_units_offset(encoding = Encoding::UTF_16LE)
source.code_units_offset(start_offset, encoding)
end

# The start offset from the start of the file in code units using the given
# cache to fetch or calculate the value.
def cached_start_code_units_offset(cache)
cache[start_offset]
end

# The byte offset from the beginning of the source where this location ends.
def end_offset
start_offset + length
Expand All @@ -303,6 +392,12 @@ def end_code_units_offset(encoding = Encoding::UTF_16LE)
source.code_units_offset(end_offset, encoding)
end

# The end offset from the start of the file in code units using the given
# cache to fetch or calculate the value.
def cached_end_code_units_offset(cache)
cache[end_offset]
end

# The line number where this location starts.
def start_line
source.line(start_offset)
Expand Down Expand Up @@ -337,6 +432,12 @@ def start_code_units_column(encoding = Encoding::UTF_16LE)
source.code_units_column(start_offset, encoding)
end

# The start column in code units using the given cache to fetch or calculate
# the value.
def cached_start_code_units_column(cache)
cache[start_offset] - cache[source.line_start(start_offset)]
end

# The column number in bytes where this location ends from the start of the
# line.
def end_column
Expand All @@ -355,6 +456,12 @@ def end_code_units_column(encoding = Encoding::UTF_16LE)
source.code_units_column(end_offset, encoding)
end

# The end column in code units using the given cache to fetch or calculate
# the value.
def cached_end_code_units_column(cache)
cache[end_offset] - cache[source.line_start(end_offset)]
end

# Implement the hash pattern matching interface for Location.
def deconstruct_keys(keys)
{ start_offset: start_offset, end_offset: end_offset }
Expand Down Expand Up @@ -604,6 +711,11 @@ def success?
def failure?
!success?
end

# Create a code units cache for the given encoding.
def code_units_cache(encoding)
source.code_units_cache(encoding)
end
end

# This is a result specific to the `parse` and `parse_file` methods.
Expand Down
29 changes: 29 additions & 0 deletions rbi/prism/parse_result.rbi
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,21 @@ class Prism::Source
sig { params(byte_offset: Integer, encoding: Encoding).returns(Integer) }
def code_units_offset(byte_offset, encoding); end

sig { params(encoding: Encoding).returns(T.any(Prism::CodeUnitsCache, T.proc.params(byte_offset: Integer).returns(Integer))) }
def code_units_cache(encoding); end

sig { params(byte_offset: Integer, encoding: Encoding).returns(Integer) }
def code_units_column(byte_offset, encoding); end
end

class Prism::CodeUnitsCache
sig { params(source: String, encoding: Encoding).void }
def initialize(source, encoding); end

sig { params(byte_offset: Integer).returns(Integer) }
def [](byte_offset); end
end

class Prism::ASCIISource < Prism::Source
sig { params(byte_offset: Integer).returns(Integer) }
def character_offset(byte_offset); end
Expand All @@ -54,6 +65,9 @@ class Prism::ASCIISource < Prism::Source
sig { params(byte_offset: Integer, encoding: Encoding).returns(Integer) }
def code_units_offset(byte_offset, encoding); end

sig { params(encoding: Encoding).returns(T.any(Prism::CodeUnitsCache, T.proc.params(byte_offset: Integer).returns(Integer))) }
def code_units_cache(encoding); end

sig { params(byte_offset: Integer, encoding: Encoding).returns(Integer) }
def code_units_column(byte_offset, encoding); end
end
Expand Down Expand Up @@ -107,6 +121,9 @@ class Prism::Location
sig { params(encoding: Encoding).returns(Integer) }
def start_code_units_offset(encoding = Encoding::UTF_16LE); end

sig { params(cache: T.any(Prism::CodeUnitsCache, T.proc.params(byte_offset: Integer).returns(Integer))).returns(Integer) }
def cached_start_code_units_offset(cache); end

sig { returns(Integer) }
def end_offset; end

Expand All @@ -116,6 +133,9 @@ class Prism::Location
sig { params(encoding: Encoding).returns(Integer) }
def end_code_units_offset(encoding = Encoding::UTF_16LE); end

sig { params(cache: T.any(Prism::CodeUnitsCache, T.proc.params(byte_offset: Integer).returns(Integer))).returns(Integer) }
def cached_end_code_units_offset(cache); end

sig { returns(Integer) }
def start_line; end

Expand All @@ -134,6 +154,9 @@ class Prism::Location
sig { params(encoding: Encoding).returns(Integer) }
def start_code_units_column(encoding = Encoding::UTF_16LE); end

sig { params(cache: T.any(Prism::CodeUnitsCache, T.proc.params(byte_offset: Integer).returns(Integer))).returns(Integer) }
def cached_start_code_units_column(cache); end

sig { returns(Integer) }
def end_column; end

Expand All @@ -143,6 +166,9 @@ class Prism::Location
sig { params(encoding: Encoding).returns(Integer) }
def end_code_units_column(encoding = Encoding::UTF_16LE); end

sig { params(cache: T.any(Prism::CodeUnitsCache, T.proc.params(byte_offset: Integer).returns(Integer))).returns(Integer) }
def cached_end_code_units_column(cache); end

sig { params(keys: T.nilable(T::Array[Symbol])).returns(T::Hash[Symbol, T.untyped]) }
def deconstruct_keys(keys); end

Expand Down Expand Up @@ -296,6 +322,9 @@ class Prism::Result

sig { returns(T::Boolean) }
def failure?; end

sig { params(encoding: Encoding).returns(T.any(Prism::CodeUnitsCache, T.proc.params(byte_offset: Integer).returns(Integer))) }
def code_units_cache(encoding); end
end

class Prism::ParseResult < Prism::Result
Expand Down
12 changes: 12 additions & 0 deletions sig/prism/_private/parse_result.rbs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,18 @@ module Prism
def find_line: (Integer) -> Integer
end

class CodeUnitsCache
class UTF16Counter
def initialize: (String source, Encoding encoding) -> void
def count: (Integer byte_offset, Integer byte_length) -> Integer
end

class LengthCounter
def initialize: (String source, Encoding encoding) -> void
def count: (Integer byte_offset, Integer byte_length) -> Integer
end
end

class Location
private

Expand Down
20 changes: 20 additions & 0 deletions sig/prism/parse_result.rbs
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
module Prism
interface _CodeUnitsCache
def []: (Integer byte_offset) -> Integer
end

class Source
attr_reader source: String
attr_reader start_line: Integer
Expand All @@ -16,15 +20,22 @@ module Prism
def character_offset: (Integer byte_offset) -> Integer
def character_column: (Integer byte_offset) -> Integer
def code_units_offset: (Integer byte_offset, Encoding encoding) -> Integer
def code_units_cache: (Encoding encoding) -> _CodeUnitsCache
def code_units_column: (Integer byte_offset, Encoding encoding) -> Integer

def self.for: (String source) -> Source
end

class CodeUnitsCache
def initialize: (String source, Encoding encoding) -> void
def []: (Integer byte_offset) -> Integer
end

class ASCIISource < Source
def character_offset: (Integer byte_offset) -> Integer
def character_column: (Integer byte_offset) -> Integer
def code_units_offset: (Integer byte_offset, Encoding encoding) -> Integer
def code_units_cache: (Encoding encoding) -> _CodeUnitsCache
def code_units_column: (Integer byte_offset, Encoding encoding) -> Integer
end

Expand All @@ -45,15 +56,23 @@ module Prism
def slice: () -> String
def slice_lines: () -> String
def start_character_offset: () -> Integer
def start_code_units_offset: (Encoding encoding) -> Integer
def cached_start_code_units_offset: (_CodeUnitsCache cache) -> Integer
def end_offset: () -> Integer
def end_character_offset: () -> Integer
def end_code_units_offset: (Encoding encoding) -> Integer
def cached_end_code_units_offset: (_CodeUnitsCache cache) -> Integer
def start_line: () -> Integer
def start_line_slice: () -> String
def end_line: () -> Integer
def start_column: () -> Integer
def start_character_column: () -> Integer
def start_code_units_column: (Encoding encoding) -> Integer
def cached_start_code_units_column: (_CodeUnitsCache cache) -> Integer
def end_column: () -> Integer
def end_character_column: () -> Integer
def end_code_units_column: (Encoding encoding) -> Integer
def cached_end_code_units_column: (_CodeUnitsCache cache) -> Integer
def deconstruct_keys: (Array[Symbol]? keys) -> Hash[Symbol, untyped]
def pretty_print: (untyped q) -> untyped
def join: (Location other) -> Location
Expand Down Expand Up @@ -125,6 +144,7 @@ module Prism
def deconstruct_keys: (Array[Symbol]? keys) -> Hash[Symbol, untyped]
def success?: () -> bool
def failure?: () -> bool
def code_units_cache: (Encoding encoding) -> _CodeUnitsCache
end

class ParseResult < Result
Expand Down
46 changes: 46 additions & 0 deletions test/prism/ruby/location_test.rb
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,52 @@ def test_code_units
assert_equal 7, location.end_code_units_column(Encoding::UTF_32LE)
end

def test_cached_code_units
result = Prism.parse("πŸ˜€ + πŸ˜€\n😍 ||= 😍")

utf8_cache = result.code_units_cache(Encoding::UTF_8)
utf16_cache = result.code_units_cache(Encoding::UTF_16LE)
utf32_cache = result.code_units_cache(Encoding::UTF_32LE)

# first πŸ˜€
location = result.value.statements.body.first.receiver.location

assert_equal 0, location.cached_start_code_units_offset(utf8_cache)
assert_equal 0, location.cached_start_code_units_offset(utf16_cache)
assert_equal 0, location.cached_start_code_units_offset(utf32_cache)

assert_equal 1, location.cached_end_code_units_offset(utf8_cache)
assert_equal 2, location.cached_end_code_units_offset(utf16_cache)
assert_equal 1, location.cached_end_code_units_offset(utf32_cache)

assert_equal 0, location.cached_start_code_units_column(utf8_cache)
assert_equal 0, location.cached_start_code_units_column(utf16_cache)
assert_equal 0, location.cached_start_code_units_column(utf32_cache)

assert_equal 1, location.cached_end_code_units_column(utf8_cache)
assert_equal 2, location.cached_end_code_units_column(utf16_cache)
assert_equal 1, location.cached_end_code_units_column(utf32_cache)

# second πŸ˜€
location = result.value.statements.body.first.arguments.arguments.first.location

assert_equal 4, location.cached_start_code_units_offset(utf8_cache)
assert_equal 5, location.cached_start_code_units_offset(utf16_cache)
assert_equal 4, location.cached_start_code_units_offset(utf32_cache)

assert_equal 5, location.cached_end_code_units_offset(utf8_cache)
assert_equal 7, location.cached_end_code_units_offset(utf16_cache)
assert_equal 5, location.cached_end_code_units_offset(utf32_cache)

assert_equal 4, location.cached_start_code_units_column(utf8_cache)
assert_equal 5, location.cached_start_code_units_column(utf16_cache)
assert_equal 4, location.cached_start_code_units_column(utf32_cache)

assert_equal 5, location.cached_end_code_units_column(utf8_cache)
assert_equal 7, location.cached_end_code_units_column(utf16_cache)
assert_equal 5, location.cached_end_code_units_column(utf32_cache)
end

def test_code_units_binary_valid_utf8
program = Prism.parse(<<~RUBY).value
# -*- encoding: binary -*-
Expand Down

0 comments on commit ba89182

Please sign in to comment.