Skip to content

Commit 264b414

Browse files
committed
some progress on sparse sparse but no dice
1 parent 1703471 commit 264b414

File tree

1 file changed

+79
-47
lines changed

1 file changed

+79
-47
lines changed

ceviche/primitives.py

Lines changed: 79 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,20 @@ def make_sparse(entries, indices, N):
3232
coo = sp.coo_matrix((entries, indices), shape=shape, dtype=np.complex128)
3333
return coo.tocsc()
3434

35+
def make_sparse_MxN(entries, indices, shape):
36+
"""Construct a sparse csc matrix
37+
Args:
38+
entries: numpy array with shape (M,) giving values for non-zero
39+
matrix entries.
40+
indices: numpy array with shape (2, M) giving x and y indices for
41+
non-zero matrix entries.
42+
shape: shape of matrix
43+
Returns:
44+
sparse, complex, matrix with specified values
45+
"""
46+
coo = sp.coo_matrix((entries, indices), shape=shape, dtype=np.complex128)
47+
return coo.tocsc()
48+
3549
def transpose_indices(indices):
3650
# returns the transposed indices for transpose sparse matrix creation
3751
return np.flip(indices, axis=0)
@@ -146,13 +160,27 @@ def sp_mult(entries, indices, x):
146160
A = make_sparse(entries, indices, N=x.size)
147161
return A.dot(x)
148162

149-
def grad_sp_mult_entries_reverse(ans, entries, indices, x):
150-
i, j = indices
163+
# def grad_sp_mult_entries_reverse(ans, entries, indices, x):
164+
# i, j = indices
165+
# def vjp(v):
166+
# return v[i] * x[j]
167+
# return vjp
168+
169+
def grad_sp_mult_entries_reverse(b, entries, indices, x):
170+
entries_1 = np.ones(entries.shape)
171+
num_k = entries.size
172+
ik, jk = indices
173+
indices_BT = np.vstack((np.arange(num_k), jk))
174+
BT = make_sparse_MxN(entries_1, indices_BT, shape=(num_k, x.size))
175+
BTx = BT.dot(x)
151176
def vjp(v):
152-
return v[i] * x[j]
177+
indices_AT = np.vstack((np.arange(num_k), ik))
178+
AT = make_sparse_MxN(entries_1, indices_AT, shape=(num_k, v.size))
179+
ATv = AT.dot(v)
180+
return BTx * ATv
153181
return vjp
154182

155-
def grad_sp_mult_x_reverse(ans, entries, indices, x):
183+
def grad_sp_mult_x_reverse(b, entries, indices, x):
156184
indices_T = transpose_indices(indices)
157185
def vjp(v):
158186
return sp_mult(entries, indices_T, v)
@@ -216,56 +244,60 @@ def grad_sp_solve_x_forward(g, x, entries, indices, b):
216244
""" ==========================Sparse Matrix-Sparse Matrix Multiplication ========================== """
217245

218246
@ag.primitive
219-
def spsp_mult(entries_a, indices_a, entries_b, indices_b, N):
220-
""" Multiply a sparse matrix (A) by a sparse matrix (B)
247+
def spsp_mult(entries_a, indices_a, entries_x, indices_x, N):
248+
""" Multiply a sparse matrix (A) by a sparse matrix (X) A @ X = B
221249
Args:
222250
entries_a: numpy array with shape (num_non_zeros,) giving values for non-zero
223251
matrix entries into A.
224252
indices_a: numpy array with shape (2, num_non_zeros) giving x and y indices for
225253
non-zero matrix entries into A.
226-
entries_b: numpy array with shape (num_non_zeros,) giving values for non-zero
227-
matrix entries into B.
228-
indices_b: numpy array with shape (2, num_non_zeros) giving x and y indices for
229-
non-zero matrix entries into B.
254+
entries_x: numpy array with shape (num_non_zeros,) giving values for non-zero
255+
matrix entries into X.
256+
indices_x: numpy array with shape (2, num_non_zeros) giving x and y indices for
257+
non-zero matrix entries into X.
230258
N: all matrices are assumed of shape (N, N) (need to specify because no dense vector supplied)
231259
Returns:
232-
entries_c: numpy array with shape (num_non_zeros,) giving values for non-zero
233-
matrix entries into the result C.
234-
indices_c: numpy array with shape (2, num_non_zeros) giving x and y indices for
235-
non-zero matrix entries into the result C.
260+
entries_b: numpy array with shape (num_non_zeros,) giving values for non-zero
261+
matrix entries into the result B.
262+
indices_b: numpy array with shape (2, num_non_zeros) giving i, j indices for
263+
non-zero matrix entries into the result B.
236264
"""
237265
A = make_sparse(entries_a, indices_a, N=N)
238-
B = make_sparse(entries_b, indices_b, N=N)
239-
C = A.dot(B)
240-
entries_c, indices_c = get_entries_indices(C)
241-
return entries_c, indices_c
242-
243-
def grad_spsp_mult_entries_a_reverse(ans, entries_a, indices_a, entries_b, indices_b, N):
266+
X = make_sparse(entries_x, indices_x, N=N)
267+
B = A.dot(X)
268+
entries_b, indices_b = get_entries_indices(B)
269+
return entries_b, indices_b
270+
271+
def grad_spsp_mult_entries_a_reverse(b_out, entries_a, indices_a, entries_x, indices_x, N):
272+
entries_1 = np.ones(entries_a.shape)
273+
num_k = entries_a.size
244274
ik, jk = indices_a
245-
def vjp(v):
246-
entries_v, indices_v = v
247-
V = make_sparse(entries_v, indices_v, N).todense()
248-
B = make_sparse(entries_b, indices_b, N).todense()
249-
V_z = V[ik, indices_v[1]]
250-
B_z = B[jk, indices_b[1]]
251-
V_B = np.multiply(B_z, V_z)
252-
return V_B.flatten()
275+
indices_BT = np.vstack((np.arange(num_k), jk))
276+
BT = make_sparse_MxN(entries_1, indices_BT, shape=(num_k, N))
277+
X = make_sparse(entries_x, indices_x, N)
278+
BTX = BT.dot(X)
279+
def vjp(V):
280+
indices_AT = np.vstack((np.arange(num_k), ik))
281+
AT = make_sparse_MxN(entries_1, indices_AT, shape=(num_k, N))
282+
entries_v, indices_v = V
283+
indices_v = np.vstack((np.arange(N), np.arange(N)))
284+
print(entries_v, indices_v)
285+
V = sp.diags(entries_v, shape=(N,N))
286+
print(V.todense())
287+
ATV = AT.dot(V)
288+
ATV_BTX = BTX.T.multiply(ATV.T)
289+
print(ATV_BTX)
290+
return ATV_BTX.sum(axis=0)
253291
return vjp
254292

255-
def grad_spsp_mult_entries_a_reverse(ans, entries_a, indices_a, entries_b, indices_b, N):
256-
# why you no work?
257-
ik, jk = indices_a
258-
def vjp(v):
259-
entries_v, indices_v = v
260-
return entries_v[ik] * entries_b[jk]
261-
return vjp
262293

263294
ag.extend.defvjp(spsp_mult, grad_spsp_mult_entries_a_reverse, None, None)
264295

265-
def grad_spsp_mult_entries_a_forward(g, ans, entries_a, indices_a, entries_b, indices_b, N):
266-
# out = spsp_mult(g, iandices_a, entries_b, indices_b, N)
267-
# entries_out, indices_out = out
268-
return spsp_mult(g, indices_a, entries_b, indices_b, N)
296+
def grad_spsp_mult_entries_a_forward(g, out_b, entries_a, indices_a, entries_x, indices_x, N):
297+
return spsp_mult(g, indices_a, entries_x, indices_x, N)
298+
299+
# def grad_sp_mult_x_forward(g, b, entries, indices, x):
300+
# return sp_mult(entries, indices, g)
269301

270302
ag.extend.defjvp(spsp_mult, grad_spsp_mult_entries_a_forward, None, None)
271303

@@ -392,8 +424,8 @@ def vjp(v):
392424

393425
## Setup
394426

395-
N = 5 # size of matrix dimensions. matrix shape = (N, N)
396-
M = N**2 # number of non-zeros (make it dense for numerical stability)
427+
N = 4 # size of matrix dimensions. matrix shape = (N, N)
428+
M = N**2- 1 # number of non-zeros (make it dense for numerical stability)
397429

398430
# these are the default values used within the test functions
399431
indices_const = make_rand_indeces(N, M)
@@ -408,25 +440,25 @@ def out_fn(output_vector):
408440
def fn_spsp_entries(entries):
409441
# sparse matrix multiplication (Ax = b) as a function of matrix entries 'A(entries)'
410442
entries_c, indices_c = spsp_mult(entries, indices_const, entries_const, indices_const, N=N)
411-
return out_fn(entries_c)
443+
# return out_fn(entries_c)
412444
x = sp_solve(entries_c, indices_c, b_const)
413445
return out_fn(x)
414446

415447
entries = make_rand_complex(M)
416448

449+
# doesnt pass yet
417450
grad_rev = ceviche.jacobian(fn_spsp_entries, mode='reverse')(entries)[0]
418451
grad_true = grad_num(fn_spsp_entries, entries)
419-
420-
# doesnt pass yet
421452
np.testing.assert_almost_equal(grad_rev, grad_true, decimal=DECIMAL)
422453

423454
# Testing Gradients of 'Sparse-Sparse Multiply entries Forward-mode'
424455

425-
grad_for = ceviche.jacobian(fn_spsp_entries, mode='forward')(entries)[0]
426-
grad_true = grad_num(fn_spsp_entries, entries)
456+
# grad_for = ceviche.jacobian(fn_spsp_entries, mode='forward')(entries)[0]
457+
# grad_true = grad_num(fn_spsp_entries, entries)
427458

459+
# print(grad_for, grad_true)
428460
# doesnt pass for more complicated functions
429-
np.testing.assert_almost_equal(grad_for, grad_true, decimal=DECIMAL)
461+
# np.testing.assert_almost_equal(grad_for, grad_true, decimal=DECIMAL)
430462

431463
## TESTS SPARSE MATRX CREATION
432464

0 commit comments

Comments
 (0)