-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathgs_pttransformer_dataset.py
More file actions
69 lines (57 loc) · 3.34 KB
/
gs_pttransformer_dataset.py
File metadata and controls
69 lines (57 loc) · 3.34 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import os
import torch
from PIL import Image
from torch.utils import data
import numpy as np
from torch.utils.data import DataLoader
from voxelize import voxelize
from plyfile import PlyData, PlyElement
class gs_dataset(data.Dataset):
def __init__(self, root, resol, random_permute = False, train=True):
self.data_path = root
self.resol = resol
self.random_permute = random_permute
self.folder_path_each = os.listdir(self.data_path)[:1000]
def __getitem__(self, index):
gs_params_path_each = self.data_path + self.folder_path_each[index] + f"/point_cloud/iteration_30000/point_cloud_{self.resol}_norm.ply"
# gs_params_path_each = self.data_path + self.folder_path_each[index] + f"/point_cloud/iteration_30000/gs_filtered.ply"
plydata = PlyData.read(gs_params_path_each)
xyz = np.stack((np.asarray(plydata.elements[0]["x"]),
np.asarray(plydata.elements[0]["y"]),
np.asarray(plydata.elements[0]["z"])), axis=1)
color_rgb = np.stack((np.asarray(plydata.elements[0]["f_dc_0"]),
np.asarray(plydata.elements[0]["f_dc_1"]),
np.asarray(plydata.elements[0]["f_dc_2"])), axis=1)
opacity = np.asarray(plydata.elements[0]["opacity"])
scale = np.stack((np.asarray(plydata.elements[0]["scale_0"]),
np.asarray(plydata.elements[0]["scale_1"]),
np.asarray(plydata.elements[0]["scale_2"])), axis=1)
rot = np.stack((np.asarray(plydata.elements[0]["rot_0"]),
np.asarray(plydata.elements[0]["rot_1"]),
np.asarray(plydata.elements[0]["rot_2"]),
np.asarray(plydata.elements[0]["rot_3"])), axis=1)
coord_min = np.min(xyz, 0)
coord = xyz - coord_min
uniq_idx, count = voxelize(coord, 0.4, 'fnv') # [-8, 8] with voxel_size=0.4 # ravel, fnv
gs_full_params = np.concatenate((xyz, color_rgb, opacity[:,None], scale, rot), axis=1)
####### centers as PE
volume_dims = 40
resolution = 16.0/volume_dims
origin_offset = np.array([(volume_dims - 1) / 2, (volume_dims - 1) / 2, (volume_dims - 1) / 2]) * resolution
shifted_points = xyz + origin_offset
voxel_indices = np.floor(shifted_points / resolution).astype(int)
voxel_indices = np.clip(voxel_indices, 0, np.array(volume_dims) - 1)
voxel_centers = (voxel_indices - (np.array(volume_dims) - 1) / 2) * resolution
gs_full_params = np.concatenate((voxel_centers, np.array(uniq_idx)[:,None], gs_full_params), axis=1)
##### padding in case...
# if gs_full_params.shape[0] != 40000:
# dummpy_gs_full_params = np.zeros([40000,14],dtype=np.float32)
# dummpy_gs_full_params[:gs_full_params.shape[0],:] = gs_full_params
# dummpy_gs_full_params[gs_full_params.shape[0],:] = gs_full_params[-1,:]
# gs_full_params = dummpy_gs_full_params
# if self.random_permute == True:
# gs_full_params = gs_full_params[torch.randperm(gs_full_params.size()[1])]
# gs_full_params = gs_full_params[uniq_idx]
return gs_full_params, index
def __len__(self):
return len(self.folder_path_each)