Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions lib/axon/loop.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1632,7 +1632,7 @@ defmodule Axon.Loop do
final_metrics_map = loop_state.metrics
loop_state = %{loop_state | metrics: zero_metrics}

{status, final_metrics_map, state} =
{status, final_metrics_map, %State{} = state} =
case fire_event(:started, handler_fns, loop_state, debug?) do
{:halt_epoch, state} ->
{:halted, final_metrics_map, state}
Expand Down Expand Up @@ -1691,7 +1691,7 @@ defmodule Axon.Loop do
{:halt_loop, state} ->
{:halt, {final_metrics_map, state}}

{:continue, state} ->
{:continue, %State{} = state} ->
{:cont,
{batch_fn, Map.put(final_metrics_map, epoch, state.metrics),
%State{
Expand Down Expand Up @@ -1922,9 +1922,9 @@ defmodule Axon.Loop do
end

# Halts an epoch during looping
defp halt_epoch(handler_fns, batch_fn, final_metrics_map, loop_state, debug?) do
defp halt_epoch(handler_fns, batch_fn, final_metrics_map, %State{} = loop_state, debug?) do
case fire_event(:epoch_halted, handler_fns, loop_state, debug?) do
{:halt_epoch, %{epoch: epoch, metrics: metrics} = state} ->
{:halt_epoch, %State{epoch: epoch, metrics: metrics} = state} ->
final_metrics_map = Map.put(final_metrics_map, epoch, metrics)
{:cont, {batch_fn, final_metrics_map, %State{state | epoch: epoch + 1, iteration: 0}}}

Expand Down
5 changes: 2 additions & 3 deletions lib/axon/quantization/layers.ex
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,11 @@ defmodule Axon.Quantization.Layers do

deftransformp reshape_scales(scales, y) do
ones = List.to_tuple(List.duplicate(1, Nx.rank(y) - 1))
Nx.reshape(scales, Tuple.append(ones, :auto))
Nx.reshape(scales, :erlang.append_element(ones, :auto))
end

deftransformp reshape_output(output, x_shape) do
all_but_last = Tuple.delete_at(x_shape, tuple_size(x_shape) - 1)
new_shape = Tuple.append(all_but_last, :auto)
Nx.reshape(output, new_shape)
Nx.reshape(output, :erlang.append_element(all_but_last, :auto))
end
end
13 changes: 8 additions & 5 deletions mix.exs
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,14 @@ defmodule Axon.MixProject do
deps: deps(),
docs: docs(),
description: "Create and train neural networks in Elixir",
package: package(),
preferred_cli_env: [
docs: :docs,
"hex.publish": :docs
]
package: package()
]
end

def cli do
[
docs: :docs,
"hex.publish": :docs
]
end

Expand Down
Loading