Skip to content

Commit 15f05fb

Browse files
authored
cleanup & bugfix dot (#524)
1 parent f47b5a4 commit 15f05fb

File tree

10 files changed

+56
-83
lines changed

10 files changed

+56
-83
lines changed

docs/src/operations.md

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,16 +34,15 @@ LP solver.
3434
| `x-y` or `x.-y` | subtraction | affine | increasing in $x$ decreasing in $y$ | none none |
3535
| `x*y` | multiplication | affine | increasing if constant term $\ge 0$ decreasing if constant term $\le 0$ not monotonic otherwise | PR: one argument is constant |
3636
| `x/y` | division | affine | increasing | PR: $y$ is scalar constant |
37-
| `dot(*)(x, y)` | elementwise multiplication | affine | increasing | PR: one argument is constant |
38-
| `dot(/)(x, y)` | elementwise division | affine | increasing | PR: one argument is constant |
37+
| `x .* y` | elementwise multiplication | affine | increasing | PR: one argument is constant |
38+
| `x ./ y` | elementwise division | affine | increasing | PR: one argument is constant |
3939
| `x[1:4, 2:3]` | indexing and slicing | affine | increasing | none |
4040
| `diag(x, k)` | $k$-th diagonal of a matrix | affine | increasing | none |
4141
| `diagm(x)` | construct diagonal matrix | affine | increasing | PR: $x$ is a vector |
4242
| `x'` | transpose | affine | increasing | none |
4343
| `vec(x)` | vector representation | affine | increasing | none |
4444
| `dot(x,y)` | $\sum_i x_i y_i$ | affine | increasing | PR: one argument is constant |
4545
| `kron(x,y)` | Kronecker product | affine | increasing | PR: one argument is constant |
46-
| `vecdot(x,y)` | `dot(vec(x),vec(y))` | affine | increasing | PR: one argument is constant |
4746
| `sum(x)` | $\sum_{ij} x_{ij}$ | affine | increasing | none |
4847
| `sum(x, k)` | sum elements across dimension $k$ | affine | increasing | none |
4948
| `sumlargest(x, k)` | sum of $k$ largest elements of $x$ | convex | increasing | none |
@@ -82,7 +81,7 @@ any solver that can solve both LPs and SOCPs can solve the problem.
8281
| `sumsquares(x)` | $\sum x_i^2$ | convex | increasing on $x \ge 0$ decreasing on $x \le 0$ | none |
8382
| `sqrt(x)` | $\sqrt{x}$ | concave | decreasing | IC: $x>0$ |
8483
| `square(x), x^2` | $x^2$ | convex | increasing on $x \ge 0$ decreasing on $x \le 0$ | PR : $x$ is scalar |
85-
| `dot(^)(x,2)` | $x.^2$ | convex | increasing on $x \ge 0$ decreasing on $x \le 0$ | elementwise |
84+
| `x .^ 2` | $x.^2$ | convex | increasing on $x \ge 0$ decreasing on $x \le 0$ | elementwise |
8685
| `geomean(x, y)` | $\sqrt{xy}$ | concave | increasing | IC: $x\ge0$, $y\ge0$ |
8786
| `huber(x, M=1)` | $\begin{cases} x^2 &\|x\| \leq M \\ 2M\|x\| - M^2 &\|x\| > M \end{cases}$ | convex | increasing on $x \ge 0$ decreasing on $x \le 0$ | PR: $M>=1$ |
8887

docs/src/release_notes.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ Breaking changes:
77
* `x + A` will error if `x` is a scalar variable and `A` is an array. Instead, use `x * ones(size(A)) + A`.
88
* The `RelativeEntropyAtom` now returns a scalar value instead of elementwise values. This does not affect the result of `relative_entropy`.
99
* The function `constant` should be used instead of the type `Constant` (which now refers to exclusively real constants).
10+
* The syntaxes `dot(*)`, `dot(/)` and `dot(^)` have been removed in favor of explicit broadcasting (`x .* y`, `x ./ y`, and `x .^ y`). These were (mild) type piracy.
11+
* `vecdot(x,y)` has been removed. Call `dot(vec(x), vec(y))` instead.
1012

1113

1214
Other changes:
@@ -15,6 +17,7 @@ Other changes:
1517
* `geomean` supports more than 2 arguments
1618
* [Type piracy](https://docs.julialang.org/en/v1/manual/style-guide/#Avoid-type-piracy) of `imag` and `real` has been removed. This should not affect use of Convex. Unfortunately, piracy of `hcat`, `vcat`, and `hvcat` still remains.
1719
* `sumlargesteigs` now enforces that it's argument is hermitian.
20+
* Bugfix: `dot` now correctly complex-conjugates its first argument
1821

1922
## v0.15.4 (October 24, 2023)
2023

src/Convex.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,6 @@ include("utilities/tree_print.jl")
263263
include("utilities/tree_interface.jl")
264264
include("utilities/show.jl")
265265
include("utilities/iteration.jl")
266-
include("utilities/broadcast.jl")
267266
include("problem_depot/problem_depot.jl")
268267

269268
end

src/atoms/affine/dot.jl

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,8 @@ ismatrix(::Any) = false
1010
# as extending singleton dimensions. We need to ensure that the inputs have the same
1111
# length, which broadcast will check for us if both inputs are vectors.
1212
asvec(x) = convert(AbstractExpr, ismatrix(x) ? vec(x) : x)
13-
_vecdot(x, y) = sum(broadcast(*, asvec(x), asvec(y)))
13+
_vecdot(x, y) = sum(broadcast(*, conj(asvec(x)), asvec(y)))
1414

1515
dot(x::AbstractExpr, y::AbstractExpr) = _vecdot(x, y)
1616
dot(x::Value, y::AbstractExpr) = _vecdot(x, y)
1717
dot(x::AbstractExpr, y::Value) = _vecdot(x, y)
18-
19-
if isdefined(LinearAlgebra, :vecdot) # defined but deprecated
20-
import LinearAlgebra: vecdot
21-
end
22-
Base.@deprecate vecdot(x::AbstractExpr, y::AbstractExpr) dot(x, y)
23-
Base.@deprecate vecdot(x::Value, y::AbstractExpr) dot(x, y)
24-
Base.@deprecate vecdot(x::AbstractExpr, y::Value) dot(x, y)

src/atoms/affine/multiply_divide.jl

Lines changed: 6 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,10 @@ end
195195
# end
196196

197197
function dotmultiply(x, y)
198+
if size(x) == (1, 1) || size(y) == (1, 1)
199+
return x * y
200+
end
201+
198202
if vexity(x) != ConstVexity()
199203
if vexity(y) != ConstVexity()
200204
error(
@@ -223,39 +227,12 @@ function dotmultiply(x, y)
223227
return reshape(const_multiplier * vec(var), size(var)...)
224228
end
225229

226-
function broadcasted(
227-
::typeof(*),
228-
x::Union{Constant,ComplexConstant},
229-
y::AbstractExpr,
230-
)
231-
if x.size == (1, 1) || y.size == (1, 1)
232-
return x * y
233-
elseif size(y, 1) < size(x, 1) && size(y, 1) == 1
234-
return dotmultiply(x, ones(size(x, 1)) * y)
235-
elseif size(y, 2) < size(x, 2) && size(y, 2) == 1
236-
return dotmultiply(x, y * ones(1, size(x, 1)))
237-
else
238-
return dotmultiply(x, y)
239-
end
240-
end
241-
function broadcasted(
242-
::typeof(*),
243-
y::AbstractExpr,
244-
x::Union{Constant,ComplexConstant},
245-
)
246-
return dotmultiply(x, y)
247-
end
248-
249230
# if neither is a constant it's not DCP, but might be nice to support anyway for eg MultiConvex
250231
function broadcasted(::typeof(*), x::AbstractExpr, y::AbstractExpr)
251-
if x.size == (1, 1) || y.size == (1, 1)
252-
return x * y
253-
elseif vexity(x) == ConstVexity()
254-
return dotmultiply(x, y)
255-
elseif isequal(x, y)
232+
if isequal(x, y)
256233
return square(x)
257234
else
258-
return dotmultiply(y, x)
235+
return dotmultiply(x, y)
259236
end
260237
end
261238
function broadcasted(::typeof(*), x::Value, y::AbstractExpr)

src/atoms/second_order_cone/qol_elementwise.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,16 @@ function broadcasted(::typeof(^), x::AbstractExpr, k::Int)
5656
error("raising variables to powers other than 2 is not implemented")
5757
end
5858

59+
# handle literal case
60+
function broadcasted(
61+
::typeof(Base.literal_pow),
62+
::typeof(^),
63+
x::AbstractExpr,
64+
::Val{k},
65+
) where {k}
66+
return broadcasted(^, x, k)
67+
end
68+
5969
invpos(x::AbstractExpr) = QolElemAtom(constant(ones(x.size[1], x.size[2])), x)
6070
function broadcasted(::typeof(/), x::Value, y::AbstractExpr)
6171
return dotmultiply(constant(x), invpos(y))

src/problem_depot/problems/affine.jl

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -405,31 +405,31 @@ end
405405
::Type{T},
406406
) where {T,test}
407407
x = Variable(3)
408-
p = maximize(sum(dot(*)(x, [1, 2, 3])), x <= 1; numeric_type = T)
408+
p = maximize(sum(x .* [1, 2, 3]), x <= 1; numeric_type = T)
409409

410410
if test
411411
@test problem_vexity(p) == AffineVexity()
412412
end
413413
handle_problem!(p)
414414
if test
415415
@test p.optval 6 atol = atol rtol = rtol
416-
@test evaluate(sum((dot(*))(x, [1, 2, 3]))) 6 atol = atol rtol = rtol
416+
@test evaluate(sum(x .* [1, 2, 3])) 6 atol = atol rtol = rtol
417417
end
418418

419419
x = Variable(3, 3)
420-
p = maximize(sum(dot(*)(x, eye(3))), x <= 1; numeric_type = T)
420+
p = maximize(sum(x .* eye(3)), x <= 1; numeric_type = T)
421421

422422
if test
423423
@test problem_vexity(p) == AffineVexity()
424424
end
425425
handle_problem!(p)
426426
if test
427427
@test p.optval 3 atol = atol rtol = rtol
428-
@test evaluate(sum((dot(*))(x, eye(3)))) 3 atol = atol rtol = rtol
428+
@test evaluate(sum(x .* eye(3))) 3 atol = atol rtol = rtol
429429
end
430430

431431
x = Variable(5, 5)
432-
p = minimize(x[1, 1], dot(*)(3, x) >= 3; numeric_type = T)
432+
p = minimize(x[1, 1], 3 .* x >= 3; numeric_type = T)
433433

434434
if test
435435
@test problem_vexity(p) == AffineVexity()
@@ -441,7 +441,7 @@ end
441441
end
442442

443443
x = Variable(3, 1)
444-
p = minimize(sum(dot(*)(ones(3, 3), x)), x >= 1; numeric_type = T)
444+
p = minimize(sum(ones(3, 3) .* x), x >= 1; numeric_type = T)
445445

446446
if test
447447
@test problem_vexity(p) == AffineVexity()
@@ -453,7 +453,7 @@ end
453453
end
454454

455455
x = Variable(1, 3)
456-
p = minimize(sum(dot(*)(ones(3, 3), x)), x >= 1; numeric_type = T)
456+
p = minimize(sum(ones(3, 3) .* x), x >= 1; numeric_type = T)
457457

458458
if test
459459
@test problem_vexity(p) == AffineVexity()
@@ -465,16 +465,15 @@ end
465465
end
466466

467467
x = Variable(1, 3, Positive())
468-
p = maximize(sum(dot(/)(x, [1 2 3])), x <= 1; numeric_type = T)
468+
p = maximize(sum(x ./ [1 2 3]), x <= 1; numeric_type = T)
469469

470470
if test
471471
@test problem_vexity(p) == AffineVexity()
472472
end
473473
handle_problem!(p)
474474
if test
475475
@test p.optval 11 / 6 atol = atol rtol = rtol
476-
@test evaluate(sum((dot(/))(x, [1 2 3]))) 11 / 6 atol = atol rtol =
477-
rtol
476+
@test evaluate(sum(x ./ [1 2 3])) 11 / 6 atol = atol rtol = rtol
478477
end
479478

480479
# Broadcast fusion works

src/problem_depot/problems/socp.jl

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,8 @@ end
177177
A = [1 2; 2 1; 3 4]
178178
b = [2; 3; 4]
179179
expr = A * x + b
180-
p = minimize(sum(dot(^)(expr, 2)); numeric_type = T) # elementwise ^
180+
# `literal_pow` case:
181+
p = minimize(sum(expr .^ 2); numeric_type = T) # elementwise ^
181182
if test
182183
@test problem_vexity(p) == ConvexVexity()
183184
end
@@ -188,16 +189,28 @@ end
188189
rtol
189190
end
190191

191-
p = minimize(sum(dot(*)(expr, expr)); numeric_type = T) # elementwise *
192+
# Test non-literal case:
193+
k = 2
194+
p = minimize(sum(expr .^ k); numeric_type = T) # elementwise ^
192195
if test
193196
@test problem_vexity(p) == ConvexVexity()
194197
end
195198
handle_problem!(p)
196199
if test
197200
@test p.optval 0.42105 atol = atol rtol = rtol
198-
@test evaluate(sum((dot(*))(expr, expr))) 0.42105 atol = atol rtol =
201+
@test evaluate(sum(broadcast(^, expr, 2))) 0.42105 atol = atol rtol =
199202
rtol
200203
end
204+
205+
p = minimize(sum(expr .* expr); numeric_type = T) # elementwise *
206+
if test
207+
@test problem_vexity(p) == ConvexVexity()
208+
end
209+
handle_problem!(p)
210+
if test
211+
@test p.optval 0.42105 atol = atol rtol = rtol
212+
@test evaluate(sum(expr .* expr)) 0.42105 atol = atol rtol = rtol
213+
end
201214
end
202215

203216
@add_problem socp function socp_inv_pos_atom(
@@ -227,13 +240,13 @@ end
227240
end
228241

229242
x = Variable(3)
230-
p = minimize(sum(dot(/)([3, 6, 9], x)), x <= 3; numeric_type = T)
243+
p = minimize(sum([3, 6, 9] ./ x), x <= 3; numeric_type = T)
231244

232245
handle_problem!(p)
233246
if test
234247
@test evaluate(x) fill(3.0, (3, 1)) atol = atol rtol = rtol
235248
@test p.optval 6 atol = atol rtol = rtol
236-
@test evaluate(sum((dot(/))([3, 6, 9], x))) 6 atol = atol rtol = rtol
249+
@test evaluate(sum([3, 6, 9] ./ x)) 6 atol = atol rtol = rtol
237250
end
238251

239252
x = Variable()

src/utilities/broadcast.jl

Lines changed: 0 additions & 26 deletions
This file was deleted.

test/definitions.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,9 @@ end
3838
set_value!(x, 1.5)
3939
@test evaluate(expr) log(1 + exp(1.5))
4040
end
41+
42+
@testset "`dot` (issue #508)" begin
43+
x = [1.0 + 1.0im]
44+
y = [-1.0im]
45+
@test dot(x, y) evaluate(dot(constant(x), y))
46+
end

0 commit comments

Comments
 (0)