Skip to content

Performance squeeze: Don't pre-compute anything for the derivative when initially solving #8

@bamos

Description

@bamos

It should be pretty easy to update solve_and_differentiate to not pre-compute anything for the derivative if the user doesn't need this. Otherwise if somebody wants to use this for the derivative sometimes (but not always) then they'll incur some additional performance overhead for the derivative pre-computation even though they don't use it. This also makes the timing results of the forward/backward passes slightly off as time that should be measured in the backward pass is actually present in the forward pass. I think this overhead might even larger for the explicit mode in #2 that calls into cone_lib.dpi_explicit.

I just tried running the following quick example and this part seems to add ~15% overhead

#!/usr/bin/env python3

import numpy as np
from scipy import sparse

import diffcp

nzero = 100
npos = 100
nsoc = 100
m = nzero + npos + nsoc
n = 100

cone_dict = {
    diffcp.ZERO: nzero,
    diffcp.POS: npos,
    diffcp.SOC: [nsoc]
}

A, b, c = diffcp.utils.random_cone_prog(m, n, cone_dict)
x, y, s, D, DT = diffcp.solve_and_derivative(A, b, c, cone_dict)

# evaluate the derivative
nonzeros = A.nonzero()
data = 1e-4 * np.random.randn(A.size)
dA = sparse.csc_matrix((data, nonzeros), shape=A.shape)
db = 1e-4 * np.random.randn(m)
dc = 1e-4 * np.random.randn(n)
dx, dy, ds = D(dA, db, dc)

# evaluate the adjoint of the derivative
dx = c
dy = np.zeros(m)
ds = np.zeros(m)
dA, db, dc = DT(dx, dy, ds)
Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
    55                                           @profile
    56                                           def solve_and_derivative(A, b, c, cone_dict, warm_start=None, **kwargs):

...

128         1      79706.0  79706.0     84.1      result = scs.solve(data, cone_dict, **kwargs)
129
130                                               # check status
131         1          6.0      6.0      0.0      status = result["info"]["status"]
132         1          4.0      4.0      0.0      if status == "Solved/Innacurate":
133                                                   warnings.warn("Solved/Innacurate.")
134         1          4.0      4.0      0.0      elif status != "Solved":
135                                                   raise SolverError("Solver scs returned status %s" % status)
136
137         1          3.0      3.0      0.0      x = result["x"]
138         1          4.0      4.0      0.0      y = result["y"]
139         1          3.0      3.0      0.0      s = result["s"]
140
141                                               # pre-compute quantities for the derivative
142         1          7.0      7.0      0.0      m, n = A.shape
143         1          4.0      4.0      0.0      N = m + n + 1
144         1         14.0     14.0      0.0      cones = cone_lib.parse_cone_dict(cone_dict)
145         1         21.0     21.0      0.0      z = (x, y - s, np.array([1]))
146         1          4.0      4.0      0.0      u, v, w = z
147         1       1850.0   1850.0      2.0      D_proj_dual_cone = cone_lib.dpi(v, cones, dual=True)
148         1          5.0      5.0      0.0      Q = sparse.bmat([
149         1        271.0    271.0      0.3          [None, A.T, np.expand_dims(c, - 1)],
150         1        299.0    299.0      0.3          [-A, None, np.expand_dims(b, -1)],
151         1       4230.0   4230.0      4.5          [-np.expand_dims(c, -1).T, -np.expand_dims(b, -1).T, None]
152                                               ])
153         1       2878.0   2878.0      3.0      M = splinalg.aslinearoperator(Q - sparse.eye(N)) @ dpi(
154         1       3301.0   3301.0      3.5          z, cones) + splinalg.aslinearoperator(sparse.eye(N))
155         1        445.0    445.0      0.5      pi_z = pi(z, cones)
156         1       1742.0   1742.0      1.8      rows, cols = A.nonzero()

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions