diff --git a/libs/jit/src/jit.erl b/libs/jit/src/jit.erl index ce4519a8f..10bc23d7a 100644 --- a/libs/jit/src/jit.erl +++ b/libs/jit/src/jit.erl @@ -2998,17 +2998,26 @@ first_pass_bs_match_ensure_at_least( ]), {J0, Rest0, MatchState, BSOffsetReg, MSt1}; true -> - % TODO: check use of unit here (TODO is the same in opcodeswitch.h) - {_Unit, Rest2} = decode_literal(Rest1), - ?TRACE("{ensure_at_least,~p,~p},", [Stride, _Unit]), + {Unit, Rest2} = decode_literal(Rest1), + ?TRACE("{ensure_at_least,~p,~p},", [Stride, Unit]), {MSt1, Reg} = MMod:get_array_element(MSt0, BSBinaryReg, 1), MSt2 = MMod:shift_left(MSt1, Reg, 3), - % Reg is bs_bin_size * 8 (use unit instead ??) + % Reg is bs_bin_size * 8 MSt3 = MMod:sub(MSt2, Reg, BSOffsetReg), - % Reg is (bs_bin_size * 8) - bs_offset + % Reg is (bs_bin_size * 8) - bs_offset = remaining bits MSt4 = cond_jump_to_label({Reg, '<', Stride}, Fail, MMod, MSt3), - MSt5 = MMod:free_native_registers(MSt4, [Reg]), - {J0 - 2, Rest2, MatchState, BSOffsetReg, MSt5} + % Also check unit alignment: (remaining - stride) % unit == 0 + MSt7 = + if + Unit > 1 -> + MSt4b = MMod:sub(MSt4, Reg, Stride), + {MSt5, UnitReg} = MMod:and_(MSt4b, {free, Reg}, Unit - 1), + MSt6 = cond_jump_to_label({{free, UnitReg}, '!=', 0}, Fail, MMod, MSt5), + MSt6; + true -> + MMod:free_native_registers(MSt4, [Reg]) + end, + {J0 - 2, Rest2, MatchState, BSOffsetReg, MSt7} end. first_pass_bs_match_ensure_exactly( diff --git a/src/libAtomVM/opcodesswitch.h b/src/libAtomVM/opcodesswitch.h index 576211bba..663ae7951 100644 --- a/src/libAtomVM/opcodesswitch.h +++ b/src/libAtomVM/opcodesswitch.h @@ -7547,7 +7547,7 @@ HOT_FUNC int scheduler_entry_point(GlobalContext *glb) int stride; DECODE_LITERAL(stride, pc); j++; - int unit; // TODO: check use of unit here + int unit; DECODE_LITERAL(unit, pc); j++; #ifdef IMPL_EXECUTE_LOOP @@ -7556,7 +7556,8 @@ HOT_FUNC int scheduler_entry_point(GlobalContext *glb) RAISE_ERROR(BADARG_ATOM); } size_t unsigned_stride = (size_t) stride; - if ((bs_bin_size * 8) - bs_offset < unsigned_stride) { + size_t remaining = (bs_bin_size * 8) - bs_offset; + if (remaining < unsigned_stride || (remaining - unsigned_stride) % unit != 0) { TRACE("bs_match/3: ensure_at_least failed -- bs_bin_size = %d, bs_offset = %d, stride = %d, unit = %d\n", (int) bs_bin_size, (int) bs_offset, (int) stride, (int) unit); goto bs_match_jump_to_fail; } diff --git a/tests/erlang_tests/CMakeLists.txt b/tests/erlang_tests/CMakeLists.txt index 8b133068e..613149dc0 100644 --- a/tests/erlang_tests/CMakeLists.txt +++ b/tests/erlang_tests/CMakeLists.txt @@ -624,7 +624,7 @@ compile_erlang(test_op_bs_create_bin) compile_assembler(test_op_bs_create_bin_asm) compile_erlang(test_op_bs_test_unit) -compile_assembler(test_op_bs_test_unit_asm) +compile_erlang(test_bs_match_ensure_at_least) compile_assembler(bs_get_binary2_all_asm) @@ -649,7 +649,6 @@ compile_erlang(test_inline_arith) if(Erlang_VERSION VERSION_GREATER_EQUAL "23") set(OTP23_OR_GREATER_TESTS test_op_bs_start_match_asm.beam - test_op_bs_test_unit_asm.beam bs_get_binary2_all_asm.beam ) else() @@ -1195,6 +1194,7 @@ set(erlang_test_beams test_op_bs_start_match.beam test_op_bs_test_unit.beam + test_bs_match_ensure_at_least.beam test_op_bs_create_bin.beam bigint.beam diff --git a/tests/erlang_tests/test_bs_match_ensure_at_least.erl b/tests/erlang_tests/test_bs_match_ensure_at_least.erl new file mode 100644 index 000000000..0553639f2 --- /dev/null +++ b/tests/erlang_tests/test_bs_match_ensure_at_least.erl @@ -0,0 +1,87 @@ +% +% This file is part of AtomVM. +% +% Copyright 2026 Paul Guyot +% +% Licensed under the Apache License, Version 2.0 (the "License"); +% you may not use this file except in compliance with the License. +% You may obtain a copy of the License at +% +% http://www.apache.org/licenses/LICENSE-2.0 +% +% Unless required by applicable law or agreed to in writing, software +% distributed under the License is distributed on an "AS IS" BASIS, +% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +% See the License for the specific language governing permissions and +% limitations under the License. +% +% SPDX-License-Identifier: Apache-2.0 OR LGPL-2.1-or-later +% + +%% Test bs_match ensure_at_least with stride > 0 and unit > 1 +%% after non-byte-aligned bits have been consumed. +%% +%% When the compiler splits binary matching across multiple bs_match +%% instructions (e.g. because an extracted value is used in a guard), +%% the second bs_match may see ensure_at_least with unit > 1 when +%% the remaining bits are not byte-aligned. The unit check must apply +%% to (remaining - stride), not to remaining. +%% +%% This reproduces the bug where parse_dns_name matching +%% <<3:2, Ptr:14, Tail/binary>> failed because after consuming 2 bits, +%% remaining % 8 != 0, but (remaining - 14) % 8 == 0. + +-module(test_bs_match_ensure_at_least). + +-export([start/0]). + +start() -> + ok = test_parse_name(), + ok = test_parse_name_fail(), + 0. + +%% Multi-clause function that forces the compiler to split bs_match: +%% - Clause 1 tries <<0, Tail/binary>> +%% - Clause 2 tries <<3:2, Ptr:14, Tail/binary>> with guard using Ptr +%% - Clause 3 tries <> with guard +%% The guard on Ptr forces the compiler to extract the 2-bit value and +%% 14-bit Ptr in a first bs_match, check the guard, then issue a second +%% bs_match with ensure_at_least 14,8 for the remaining tail. +parse_name(Msg, <<0, Tail/binary>>) -> + {ok, {done, Tail, Msg}}; +parse_name(Msg, <<3:2, Ptr:14, Tail/binary>>) when byte_size(Msg) > Ptr -> + {ok, {ptr, Ptr, Tail}}; +parse_name(_Msg, <>) when N < 64 -> + {ok, {label, Data, Rest}}; +parse_name(_Msg, Other) -> + {error, Other}. + +test_parse_name() -> + Msg = id( + <<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 7, 95, 114, 100, 108, 105, 110, 107, 192, 28, 0, 12, + 0, 1>> + ), + %% Parse the 7-byte label "_rdlink" + {ok, {label, <<"_rdlink">>, Rest}} = parse_name( + Msg, id(<<7, 95, 114, 100, 108, 105, 110, 107, 192, 28, 0, 12, 0, 1>>) + ), + %% Rest is <<192, 28, 0, 12, 0, 1>> -- a DNS compression pointer + %% 192 = 0xC0 = 11_000000, so 3:2 matches. + %% Ptr:14 = 28. byte_size(Msg) = 24 > 28 is false, so guard fails. + %% But with 24-byte Msg, let's use a bigger message. + + % 31 bytes + BigMsg = id(<<0:240, 0>>), + {ok, {ptr, 28, <<0, 12, 0, 1>>}} = parse_name(BigMsg, Rest), + ok. + +test_parse_name_fail() -> + % 10 bytes, smaller than Ptr=28 + SmallMsg = id(<<0:80>>), + %% <<192, 28>> is 3:2, Ptr=28, but guard byte_size(SmallMsg)=10 > 28 fails + %% Falls through to clause 3: N=192, N >= 64, guard fails + %% Falls through to clause 4: error + {error, <<192, 28, 0, 12, 0, 1>>} = parse_name(SmallMsg, id(<<192, 28, 0, 12, 0, 1>>)), + ok. + +id(X) -> X. diff --git a/tests/erlang_tests/test_op_bs_test_unit.erl b/tests/erlang_tests/test_op_bs_test_unit.erl index 275cc5929..cba61301e 100644 --- a/tests/erlang_tests/test_op_bs_test_unit.erl +++ b/tests/erlang_tests/test_op_bs_test_unit.erl @@ -18,46 +18,64 @@ % SPDX-License-Identifier: Apache-2.0 OR LGPL-2.1-or-later % +%% Force the compiler to use bs_test_unit opcode instead of the +%% newer bs_match opcode (OTP 25+). On older OTP this is ignored. +%% On OTP 29+, no_bs_match was removed but no_ssa_opt_bs_ensure +%% prevents the SSA pass from converting bs_test_unit to bs_match +%% with ensure_at_least. +-ifdef(OTP_RELEASE). +-if(?OTP_RELEASE =< 28). +-compile([no_bs_match]). +-endif. +-if(?OTP_RELEASE >= 29). +-compile([no_ssa_opt_bs_ensure]). +-endif. +-endif. + -module(test_op_bs_test_unit). -export([start/0]). start() -> - HasBSStartMatch3 = - case erlang:system_info(machine) of - "BEAM" -> - erlang:system_info(otp_release) >= "23"; - "ATOM" -> - % If code was compiled with OTP < 23, we won't have bs_test_unit asm file - ?OTP_RELEASE >= 23 - end, - ok = - if - HasBSStartMatch3 -> - ok = test_byte_aligned(), - ok = test_unit_16(); - true -> - ok - end, + ok = test_byte_aligned(), + ok = test_unit_16(), 0. test_byte_aligned() -> %% A byte-aligned binary should pass bs_test_unit with unit=8 - <<1, 2, 3, 4>> = test_op_bs_test_unit_asm:get_tail_if_byte_aligned(id(<<1, 2, 3, 4>>)), - <<>> = test_op_bs_test_unit_asm:get_tail_if_byte_aligned(id(<<>>)), - <<2, 3>> = test_op_bs_test_unit_asm:get_tail_if_byte_aligned(id(<<2, 3>>)), + <<1, 2, 3, 4>> = get_tail_if_byte_aligned(id(<<1, 2, 3, 4>>)), + <<>> = get_tail_if_byte_aligned(id(<<>>)), + <<2, 3>> = get_tail_if_byte_aligned(id(<<2, 3>>)), ok. test_unit_16() -> %% After skipping 8 bits, remaining must be 16-bit aligned %% 5 bytes = 40 bits, skip 8 => 32 remaining, 32 rem 16 = 0 => ok - <<2, 3, 4, 5>> = test_op_bs_test_unit_asm:get_tail_unit_16(id(<<1, 2, 3, 4, 5>>)), + <<2, 3, 4, 5>> = get_tail_unit_16(id(<<1, 2, 3, 4, 5>>)), %% 3 bytes = 24 bits, skip 8 => 16 remaining, 16 rem 16 = 0 => ok - <<2, 3>> = test_op_bs_test_unit_asm:get_tail_unit_16(id(<<1, 2, 3>>)), + <<2, 3>> = get_tail_unit_16(id(<<1, 2, 3>>)), %% 2 bytes = 16 bits, skip 8 => 8 remaining, 8 rem 16 = 8 => fail - error = test_op_bs_test_unit_asm:get_tail_unit_16(id(<<1, 2>>)), + error = get_tail_unit_16(id(<<1, 2>>)), %% 4 bytes = 32 bits, skip 8 => 24 remaining, 24 rem 16 = 8 => fail - error = test_op_bs_test_unit_asm:get_tail_unit_16(id(<<1, 2, 3, 4>>)), + error = get_tail_unit_16(id(<<1, 2, 3, 4>>)), ok. +%% Start a match, verify the remaining bits are byte-aligned (unit=8), +%% then return the tail. Returns 'error' on failure. +get_tail_if_byte_aligned(Bin) -> + case Bin of + <<_/binary-unit:8>> -> Bin; + _ -> error + end. + +%% Start a match, skip 8 bits, verify remaining bits are 16-bit aligned, +%% then return the tail. Returns 'error' on failure. +get_tail_unit_16(<<_:8, Rest/binary>>) -> + case Rest of + <<_/binary-unit:16>> -> Rest; + _ -> error + end; +get_tail_unit_16(_) -> + error. + id(X) -> X. diff --git a/tests/erlang_tests/test_op_bs_test_unit_asm.S b/tests/erlang_tests/test_op_bs_test_unit_asm.S deleted file mode 100644 index 159e54854..000000000 --- a/tests/erlang_tests/test_op_bs_test_unit_asm.S +++ /dev/null @@ -1,65 +0,0 @@ -% -% This file is part of AtomVM. -% -% Copyright 2026 Paul Guyot -% -% Licensed under the Apache License, Version 2.0 (the "License"); -% you may not use this file except in compliance with the License. -% You may obtain a copy of the License at -% -% http://www.apache.org/licenses/LICENSE-2.0 -% -% Unless required by applicable law or agreed to in writing, software -% distributed under the License is distributed on an "AS IS" BASIS, -% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -% See the License for the specific language governing permissions and -% limitations under the License. -% -% SPDX-License-Identifier: Apache-2.0 OR LGPL-2.1-or-later -% - -%% BEAM assembly file to test the bs_test_unit opcode. -%% Modern OTP compilers (26+) fold bs_test_unit into the bs_match opcode, -%% so we need hand-written assembly to exercise this code path. - -{module, test_op_bs_test_unit_asm}. - -{exports, [ - {get_tail_if_byte_aligned, 1}, - {get_tail_unit_16, 1} -]}. - -{attributes, []}. - -{labels, 9}. - -%% get_tail_if_byte_aligned(Bin) -> -%% Start a match, verify the remaining bits are byte-aligned (unit=8), -%% then return the tail. Returns 'error' on failure. -{function, get_tail_if_byte_aligned, 1, 2}. -{label, 1}. -{func_info, {atom, test_op_bs_test_unit_asm}, {atom, get_tail_if_byte_aligned}, 1}. -{label, 2}. -{test, bs_start_match3, {f, 3}, 1, [{x, 0}], {x, 1}}. -{test, bs_test_unit, {f, 3}, [{x, 1}, 8]}. -{bs_get_tail, {x, 1}, {x, 0}, 2}. -return. -{label, 3}. -{move, {atom, error}, {x, 0}}. -return. - -%% get_tail_unit_16(Bin) -> -%% Start a match, skip 8 bits, verify remaining bits are 16-bit aligned, -%% then return the tail. Returns 'error' on failure. -{function, get_tail_unit_16, 1, 5}. -{label, 4}. -{func_info, {atom, test_op_bs_test_unit_asm}, {atom, get_tail_unit_16}, 1}. -{label, 5}. -{test, bs_start_match3, {f, 6}, 1, [{x, 0}], {x, 1}}. -{test, bs_skip_bits2, {f, 6}, [{x, 1}, {integer, 8}, 1, {field_flags, [unsigned, big]}]}. -{test, bs_test_unit, {f, 6}, [{x, 1}, 16]}. -{bs_get_tail, {x, 1}, {x, 0}, 2}. -return. -{label, 6}. -{move, {atom, error}, {x, 0}}. -return. diff --git a/tests/test.c b/tests/test.c index aacc70689..8e2a782cc 100644 --- a/tests/test.c +++ b/tests/test.c @@ -565,6 +565,7 @@ struct Test tests[] = { TEST_CASE(test_op_bs_start_match), TEST_CASE(test_op_bs_test_unit), + TEST_CASE(test_bs_match_ensure_at_least), TEST_CASE(test_op_bs_create_bin), TEST_CASE(test_multi_value_comprehension),