Skip to content

Commit 19175fd

Browse files
committed
Add PersistentBlock class for dislib integration
1 parent dbdbd97 commit 19175fd

File tree

2 files changed

+79
-0
lines changed

2 files changed

+79
-0
lines changed
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
from __future__ import annotations
2+
3+
import numpy as np
4+
from sklearn.metrics import pairwise_distances
5+
6+
from dataclay import DataClayObject, activemethod
7+
from dataclay.event_loop import run_dc_coroutine
8+
9+
try:
10+
from pycompss.api.task import task
11+
from pycompss.api.parameter import IN
12+
except ImportError:
13+
from dataclay.contrib.dummy_pycompss import task, IN
14+
15+
16+
class PersistentBlock(DataClayObject):
17+
block_data: np.ndarray
18+
shape: tuple[int, ...]
19+
ndim: int
20+
nbytes: int
21+
itemsize: int
22+
size: int
23+
24+
@activemethod
25+
def __init__(self, data: np.ndarray):
26+
self.block_data = data
27+
self.shape = data.shape
28+
self.ndim = data.ndim
29+
self.size = data.size
30+
self.itemsize = data.itemsize
31+
self.nbytes = data.nbytes
32+
33+
@activemethod
34+
def __getitem__(self, key) -> np.ndarray:
35+
return self.block_data[key]
36+
37+
@activemethod
38+
def __setitem__(self, key, value):
39+
self.block_data[key] = value
40+
41+
@activemethod
42+
def __delitem__(self, key):
43+
del self.block_data[key]
44+
45+
@activemethod
46+
def __array__(self) -> np.ndarray:
47+
return self.block_data
48+
49+
@activemethod
50+
def transpose(self) -> np.ndarray:
51+
return self.block_data.transpose()
52+
53+
@activemethod
54+
def __len__(self) -> int:
55+
return len(self.block_data)
56+
57+
@task(target_direction=IN)
58+
@activemethod
59+
def rotate_in_place(self, rotation_matrix: np.ndarray):
60+
self.block_data = self.block_data @ rotation_matrix
61+
62+
@task(target_direction=IN, returns=object)
63+
@activemethod
64+
def partial_sum(self, centers: np.ndarray) -> np.ndarray:
65+
partials = np.zeros((centers.shape[0], 2), dtype=object)
66+
arr = self.block_data
67+
close_centers = pairwise_distances(arr, centers).argmin(axis=1)
68+
for center_idx in range(len(centers)):
69+
indices = np.argwhere(close_centers == center_idx).flatten()
70+
partials[center_idx][0] = np.sum(arr[indices], axis=0)
71+
partials[center_idx][1] = indices.shape[0]
72+
return partials
73+
74+
@task(target_direction=IN, returns=np.ndarray)
75+
@activemethod
76+
def partial_histogram(self, n_bins: int, n_dimensions: int) -> np.ndarray:
77+
values, _ = np.histogramdd(self.block_data, n_bins, [(0, 1)] * n_dimensions)
78+
return values

src/storage/api.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
# "Publish" the StorageObject (which is a plain DataClayObject internally)
99
from dataclay import DataClayObject as StorageObject
10+
from dataclay.contrib.persistent_block import PersistentBlock
1011
from dataclay.client.api import Client
1112

1213
# Also "publish" the split method

0 commit comments

Comments
 (0)