Skip to content

Commit 0da673a

Browse files
committed
Fix Core.box in GR
1 parent cba7298 commit 0da673a

File tree

1 file changed

+15
-9
lines changed

1 file changed

+15
-9
lines changed

src/regularizer.jl

+15-9
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ of a n-dimensional array.
220220
different dimensions. If `weights=nothing` all dimensions are weighted equally.
221221
- `step=1`: A integer indicating the step width for the array indexing
222222
- `mode="forward"`: Either `"central"` or `"forward"` accounting for different
223-
modes of the spatial gradient. Default is "central".
223+
modes of the spatial gradient. Default is "forward".
224224
- `ϵ=1f-8` is a smoothness variable, to make it differentiable
225225
226226
# Examples
@@ -234,22 +234,28 @@ julia> reg([1 2 3; 4 5 6; 7 8 9])
234234
```
235235
"""
236236
function GR(; num_dims=2, sum_dims=1:num_dims, weights=[1, 1], step=1,
237-
mode="central", ϵ=1f-8)
237+
mode="forward", ϵ=1f-8)
238238
if weights == nothing
239239
weights = ones(Int, num_dims)
240240
end
241241
if mode == "central"
242-
GRf = @eval arr -> ($(generate_GR(num_dims, sum_dims, weights,
243-
step, (-1) * step)...))
242+
GRf = @eval arr2 -> begin
243+
arr = sqrt.(arr2 .+ $ϵ)
244+
$(generate_GR(num_dims, sum_dims, weights,
245+
step, (-1) * step)...)
246+
end
244247
elseif mode == "forward"
245-
GRf = @eval arr -> ($(generate_GR(num_dims, sum_dims, weights,
248+
GRf = @eval arr2 -> begin
249+
arr = sqrt.(arr2 .+ $ϵ)
250+
($(generate_GR(num_dims, sum_dims, weights,
246251
step, 0)...))
252+
end
247253
else
248254
throw(ArgumentError("The provided mode is not valid."))
249255
end
250256

251257
# we need to add a ϵ to prevent NaN in the derivative of it
252-
return arr -> GRf(sqrt.(arr .+ ϵ))
258+
return GRf#arr -> begin
253259
end
254260

255261

@@ -292,8 +298,8 @@ of a n-dimensional array.
292298
- `weights=nothing`: A array containing weights to weight the contribution of
293299
different dimensions. If `weights=nothing` all dimensions are weighted equally.
294300
- `step=1`: A integer indicating the step width for the array indexing
295-
- `mode="central"`: Either `"central"` or `"forward"` accounting for different
296-
modes of the spatial gradient. Default is "central".
301+
- `mode="forward"`: Either `"central"` or `"forward"` accounting for different
302+
modes of the spatial gradient. Default is "forward".
297303
- `ϵ=1f-8` is a smoothness variable, to make it differentiable
298304
299305
# Examples
@@ -306,7 +312,7 @@ julia> reg([1 2 3; 4 5 6; 7 8 9])
306312
12.649111f0
307313
```
308314
"""
309-
function TV(; num_dims=2, sum_dims=1:num_dims, weights=nothing, step=1, mode="central", ϵ=1f-8)
315+
function TV(; num_dims=2, sum_dims=1:num_dims, weights=nothing, step=1, mode="forward", ϵ=1f-8)
310316

311317
if weights == nothing
312318
weights = ones(Int, num_dims)

0 commit comments

Comments
 (0)