-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathgs_dataset.py
More file actions
86 lines (68 loc) · 4.15 KB
/
gs_dataset.py
File metadata and controls
86 lines (68 loc) · 4.15 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import os
import torch
from PIL import Image
from torch.utils import data
import numpy as np
from torch.utils.data import DataLoader
from voxelize import voxelize
from plyfile import PlyData, PlyElement
import spconv.pytorch as spconv
from spconv.pytorch.utils import PointToVoxel
class gs_dataset(data.Dataset):
def __init__(self, root, resol, random_permute = False, train=True):
self.data_path = root
self.resol = resol
self.random_permute = random_permute
self.folder_path_each = os.listdir(self.data_path)[:1000]
def __getitem__(self, index):
# gs_params_path_each = self.data_path + self.folder_path_each[index] + f"/point_cloud/iteration_30000/point_cloud_{self.resol}_norm.ply"
gs_params_path_each = self.data_path + self.folder_path_each[index] + f"/point_cloud/iteration_30000/gs_filtered.ply"
plydata = PlyData.read(gs_params_path_each)
xyz = np.stack((np.asarray(plydata.elements[0]["x"]),
np.asarray(plydata.elements[0]["y"]),
np.asarray(plydata.elements[0]["z"])), axis=1)
color_rgb = np.stack((np.asarray(plydata.elements[0]["f_dc_0"]),
np.asarray(plydata.elements[0]["f_dc_1"]),
np.asarray(plydata.elements[0]["f_dc_2"])), axis=1)
opacity = np.asarray(plydata.elements[0]["opacity"])
scale = np.stack((np.asarray(plydata.elements[0]["scale_0"]),
np.asarray(plydata.elements[0]["scale_1"]),
np.asarray(plydata.elements[0]["scale_2"])), axis=1)
rot = np.stack((np.asarray(plydata.elements[0]["rot_0"]),
np.asarray(plydata.elements[0]["rot_1"]),
np.asarray(plydata.elements[0]["rot_2"]),
np.asarray(plydata.elements[0]["rot_3"])), axis=1)
#### PE based on xyz
coord_min = np.min(xyz, 0)
coord = xyz - coord_min
uniq_idx, count = voxelize(coord, 0.8, 'fnv') # [-8, 8] with voxel_size=0.4 resolution=16/0.4=40; voxel_size=0.8 resolution=20
gs_full_params = np.concatenate((xyz, color_rgb, opacity[:,None], scale, rot), axis=1)
# gen_vxs_from_pts = PointToVoxel(vsize_xyz=[0.2, 0.2, 0.2], # [0.1, 0.1, 0.1] for 80 resolution; [0.2, 0.2, 0.2] for 40 resolution
# coors_range_xyz=[-8, -8, -8, 8, 8, 8],
# num_point_features=14,
# max_num_voxels=8000,
# max_num_points_per_voxel=40) # 40
# _,_,_, pc_voxel_id = gen_vxs_from_pts.generate_voxel_with_id(torch.tensor(gs_full_params),empty_mean=True)
# gs_full_params = np.concatenate((np.array(uniq_idx)[:,None], gs_full_params), axis=1)
####### centers as PE
volume_dims = 20 # 40
resolution = 16.0/volume_dims
origin_offset = np.array([(volume_dims - 1) / 2, (volume_dims - 1) / 2, (volume_dims - 1) / 2]) * resolution
shifted_points = xyz + origin_offset
voxel_indices = np.floor(shifted_points / resolution).astype(int)
voxel_indices = np.clip(voxel_indices, 0, np.array(volume_dims) - 1)
voxel_centers = (voxel_indices - (np.array(volume_dims) - 1) / 2) * resolution
gs_full_params = np.concatenate((voxel_centers, np.array(uniq_idx)[:,None], gs_full_params), axis=1)
##########################
##### padding in case...
if gs_full_params.shape[0] != 40000:
dummpy_gs_full_params = np.zeros([40000,18],dtype=np.float32)
dummpy_gs_full_params[:gs_full_params.shape[0],:] = gs_full_params
dummpy_gs_full_params[gs_full_params.shape[0],:] = gs_full_params[-1,:]
gs_full_params = dummpy_gs_full_params
# if self.random_permute == True:
# gs_full_params = gs_full_params[torch.randperm(gs_full_params.size()[1])]
# gs_full_params = gs_full_params[uniq_idx]
return gs_full_params, index
def __len__(self):
return len(self.folder_path_each)