Skip to content

Commit 278fc95

Browse files
Add depth support (#6)
* Depth support * Added depth support * Fixed depth w/ draccus config * Added tests for depth image and all inputs available * Cameras folder moved to alignit * Formatting fix * Before measuring precision * Added measuring/precision to compare ideal aligned pose to real pose after infere. XY values look fine but Z value is offset, probably due to depth image handling * Z offset still an issue, out of ideas * Z offset still an issue, out of ideas * Fixed z offset issue * Depth values clipping now done in get_observation, before train and infere * cleanup --------- Co-authored-by: Darko Lukic <lukicdarkoo@gmail.com>
1 parent 8b8d05c commit 278fc95

File tree

13 files changed

+973
-258
lines changed

13 files changed

+973
-258
lines changed

alignit/cameras/realsense.py

Lines changed: 583 additions & 0 deletions
Large diffs are not rendered by default.

alignit/config.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from dataclasses import dataclass, field
44
from typing import Optional, List
5+
56
import numpy as np
67

78

@@ -10,7 +11,7 @@ class DatasetConfig:
1011
"""Configuration for dataset paths and loading."""
1112

1213
path: str = field(
13-
default="./data/duck", metadata={"help": "Path to the dataset directory"}
14+
default="./data/default", metadata={"help": "Path to the dataset directory"}
1415
)
1516

1617

@@ -46,6 +47,12 @@ class ModelConfig:
4647
default="alignnet_model.pth",
4748
metadata={"help": "Path to save/load trained model"},
4849
)
50+
use_depth_input: bool = field(
51+
default=True, metadata={"help": "Whether to use depth input for the model"}
52+
)
53+
depth_hidden_dim: int = field(
54+
default=128, metadata={"help": "Output dimension of depth CNN"}
55+
)
4956

5057

5158
@dataclass
@@ -98,6 +105,13 @@ class RecordConfig:
98105
ang_tol_trajectory: float = field(
99106
default=0.05, metadata={"help": "Angular tolerance for trajectory servo"}
100107
)
108+
manual_height: float = field(
109+
default=-0.05, metadata={"help": "Height above surface for manual movement"}
110+
)
111+
world_z_offset: float = field(
112+
default=-0.02,
113+
metadata={"help": "World frame Z offset after manual positioning"},
114+
)
101115

102116

103117
@dataclass
@@ -106,7 +120,7 @@ class TrainConfig:
106120

107121
dataset: DatasetConfig = field(default_factory=DatasetConfig)
108122
model: ModelConfig = field(default_factory=ModelConfig)
109-
batch_size: int = field(default=8, metadata={"help": "Training batch size"})
123+
batch_size: int = field(default=4, metadata={"help": "Training batch size"})
110124
learning_rate: float = field(
111125
default=1e-4, metadata={"help": "Learning rate for optimizer"}
112126
)
@@ -133,22 +147,31 @@ class InferConfig:
133147
metadata={"help": "Starting pose RPY angles"},
134148
)
135149
lin_tolerance: float = field(
136-
default=2e-3, metadata={"help": "Linear tolerance for convergence (meters)"}
150+
default=5e-3, metadata={"help": "Linear tolerance for convergence (meters)"}
137151
)
138152
ang_tolerance: float = field(
139-
default=2, metadata={"help": "Angular tolerance for convergence (degrees)"}
153+
default=5, metadata={"help": "Angular tolerance for convergence (degrees)"}
140154
)
141155
max_iterations: Optional[int] = field(
142-
default=None,
156+
default=20,
143157
metadata={"help": "Maximum iterations before stopping (None = infinite)"},
144158
)
145159
debug_output: bool = field(
146160
default=True, metadata={"help": "Print debug information during inference"}
147161
)
148162
debouncing_count: int = field(
149-
default=5,
163+
default=20,
150164
metadata={"help": "Number of iterations within tolerance before stopping"},
151165
)
166+
rotation_matrix_multiplier: int = field(
167+
default=3,
168+
metadata={
169+
"help": "Number of times to multiply the rotation matrix of relative action in order to speed up convergence"
170+
},
171+
)
172+
manual_height: float = field(
173+
default=0.08, metadata={"help": "Height above surface for manual movement"}
174+
)
152175

153176

154177
@dataclass

alignit/infere.py

Lines changed: 51 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
import transforms3d as t3d
2-
import numpy as np
31
import time
4-
import draccus
5-
from alignit.config import InferConfig
62

73
import torch
4+
import transforms3d as t3d
5+
import numpy as np
6+
import draccus
87

8+
from alignit.config import InferConfig
99
from alignit.models.alignnet import AlignNet
1010
from alignit.utils.zhou import sixd_se3
1111
from alignit.utils.tfs import print_pose, are_tfs_close
@@ -19,7 +19,6 @@ def main(cfg: InferConfig):
1919
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
2020
print(f"Using device: {device}")
2121

22-
# load model from file
2322
net = AlignNet(
2423
backbone_name=cfg.model.backbone,
2524
backbone_weights=cfg.model.backbone_weights,
@@ -28,75 +27,100 @@ def main(cfg: InferConfig):
2827
vector_hidden_dim=cfg.model.vector_hidden_dim,
2928
output_dim=cfg.model.output_dim,
3029
feature_agg=cfg.model.feature_agg,
30+
use_depth_input=cfg.model.use_depth_input,
3131
)
3232
net.load_state_dict(torch.load(cfg.model.path, map_location=device))
3333
net.to(device)
3434
net.eval()
3535

3636
robot = XarmSim()
3737

38-
# Set initial pose from config
3938
start_pose = t3d.affines.compose(
4039
[0.23, 0, 0.25], t3d.euler.euler2mat(np.pi, 0, 0), [1, 1, 1]
4140
)
4241
robot.servo_to_pose(start_pose, lin_tol=1e-2)
43-
4442
iteration = 0
4543
iterations_within_tolerance = 0
4644
ang_tol_rad = np.deg2rad(cfg.ang_tolerance)
47-
4845
try:
4946
while True:
5047
observation = robot.get_observation()
51-
images = [observation["camera.rgb"].astype(np.float32) / 255.0]
52-
53-
# Convert images to tensor and reshape from HWC to CHW format
54-
images_tensor = (
55-
torch.from_numpy(np.array(images))
56-
.permute(0, 3, 1, 2)
48+
rgb_image = observation["rgb"].astype(np.float32) / 255.0
49+
depth_image = observation["depth"].astype(np.float32)
50+
print(
51+
"Min/Max depth,mean (raw):",
52+
observation["depth"].min(),
53+
observation["depth"].max(),
54+
observation["depth"].mean(),
55+
)
56+
print(
57+
"Min/Max depth,mean (scaled):",
58+
depth_image.min(),
59+
depth_image.max(),
60+
depth_image.mean(),
61+
)
62+
rgb_image_tensor = (
63+
torch.from_numpy(np.array(rgb_image))
64+
.permute(2, 0, 1) # (H, W, C) -> (C, H, W)
5765
.unsqueeze(0)
5866
.to(device)
5967
)
6068

61-
if cfg.debug_output:
62-
print(f"Max pixel value: {torch.max(images_tensor)}")
69+
depth_image_tensor = (
70+
torch.from_numpy(np.array(depth_image))
71+
.unsqueeze(0) # Add channel dimension: (1, H, W)
72+
.unsqueeze(0) # Add batch dimension: (1, 1, H, W)
73+
.to(device)
74+
)
75+
rgb_images_batch = rgb_image_tensor.unsqueeze(1)
76+
depth_images_batch = depth_image_tensor.unsqueeze(1)
6377

64-
start = time.time()
6578
with torch.no_grad():
66-
relative_action = net(images_tensor)
79+
relative_action = net(rgb_images_batch, depth_images=depth_images_batch)
6780
relative_action = relative_action.squeeze(0).cpu().numpy()
6881
relative_action = sixd_se3(relative_action)
6982

7083
if cfg.debug_output:
7184
print_pose(relative_action)
7285

73-
# Check convergence
86+
relative_action[:3, :3] = np.linalg.matrix_power(
87+
relative_action[:3, :3], cfg.rotation_matrix_multiplier
88+
)
7489
if are_tfs_close(
7590
relative_action, lin_tol=cfg.lin_tolerance, ang_tol=ang_tol_rad
7691
):
7792
iterations_within_tolerance += 1
7893
else:
7994
iterations_within_tolerance = 0
8095

81-
if iterations_within_tolerance >= cfg.debouncing_count:
82-
print("Alignment achieved - stopping.")
83-
break
84-
96+
print(relative_action)
8597
target_pose = robot.pose() @ relative_action
8698
iteration += 1
8799
action = {
88100
"pose": target_pose,
89101
"gripper.pos": 1.0,
90102
}
91103
robot.send_action(action)
92-
93-
# Check max iterations
94-
if cfg.max_iterations and iteration >= cfg.max_iterations:
104+
if iterations_within_tolerance >= cfg.max_iterations:
95105
print(f"Reached maximum iterations ({cfg.max_iterations}) - stopping.")
106+
print("Moving robot to final pose.")
107+
current_pose = robot.pose()
108+
gripper_z_offset = np.array(
109+
[
110+
[1, 0, 0, 0],
111+
[0, 1, 0, 0],
112+
[0, 0, 1, cfg.manual_height],
113+
[0, 0, 0, 1],
114+
]
115+
)
116+
offset_pose = current_pose @ gripper_z_offset
117+
robot.servo_to_pose(pose=offset_pose)
118+
robot.close_gripper()
119+
robot.gripper_off()
120+
96121
break
97122

98123
time.sleep(10.0)
99-
100124
except KeyboardInterrupt:
101125
print("\nExiting...")
102126

alignit/models/alignnet.py

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,10 @@ def __init__(
1010
backbone_name="efficientnet_b0",
1111
backbone_weights="DEFAULT",
1212
use_vector_input=True,
13+
use_depth_input=True,
1314
fc_layers=[256, 128],
1415
vector_hidden_dim=64,
16+
depth_hidden_dim=128,
1517
output_dim=7,
1618
feature_agg="mean",
1719
):
@@ -23,27 +25,39 @@ def __init__(
2325
:param vector_hidden_dim: output dim of the vector MLP
2426
:param output_dim: final output vector size
2527
:param feature_agg: 'mean' or 'max' across image views
28+
:param use_depth_input: whether to accept depth input
29+
:param depth_hidden_dim: output dim of the depth MLP
2630
"""
2731
super().__init__()
2832
self.use_vector_input = use_vector_input
33+
self.use_depth_input = use_depth_input
2934
self.feature_agg = feature_agg
3035

31-
# CNN backbone
3236
self.backbone, self.image_feature_dim = self._build_backbone(
3337
backbone_name, backbone_weights
3438
)
3539

36-
# Linear projection of image features
3740
self.image_fc = nn.Sequential(
3841
nn.Linear(self.image_feature_dim, fc_layers[0]), nn.ReLU()
3942
)
4043

44+
if use_depth_input:
45+
self.depth_cnn = nn.Sequential(
46+
nn.Conv2d(1, 8, 3, padding=1),
47+
nn.ReLU(),
48+
nn.Conv2d(8, 16, 3, padding=1),
49+
nn.ReLU(),
50+
nn.AdaptiveAvgPool2d(1),
51+
)
52+
self.depth_fc = nn.Sequential(nn.Linear(16, depth_hidden_dim), nn.ReLU())
53+
input_dim = fc_layers[0] + depth_hidden_dim
54+
else:
55+
input_dim = fc_layers[0]
56+
4157
# Optional vector input processing
4258
if use_vector_input:
4359
self.vector_fc = nn.Sequential(nn.Linear(1, vector_hidden_dim), nn.ReLU())
44-
input_dim = fc_layers[0] + vector_hidden_dim
45-
else:
46-
input_dim = fc_layers[0]
60+
input_dim += vector_hidden_dim
4761

4862
# Fully connected layers
4963
layers = []
@@ -81,10 +95,11 @@ def aggregate_image_features(self, feats):
8195
else:
8296
raise ValueError("Invalid aggregation type")
8397

84-
def forward(self, rgb_images, vector_inputs=None):
98+
def forward(self, rgb_images, vector_inputs=None, depth_images=None):
8599
"""
86100
:param rgb_images: Tensor of shape (B, N, 3, H, W)
87101
:param vector_inputs: List of tensors of shape (L_i,) or None
102+
:param depth_images: Tensor of shape (B, N, 1, H, W) or None
88103
:return: Tensor of shape (B, output_dim)
89104
"""
90105
B, N, C, H, W = rgb_images.shape
@@ -93,15 +108,25 @@ def forward(self, rgb_images, vector_inputs=None):
93108
image_feats = self.aggregate_image_features(feats)
94109
image_feats = self.image_fc(image_feats)
95110

111+
features = [image_feats]
112+
113+
if self.use_depth_input and depth_images is not None:
114+
depth = depth_images.view(B * N, 1, H, W)
115+
depth_feats = self.depth_cnn(depth).view(B, N, -1)
116+
depth_feats = self.aggregate_image_features(depth_feats)
117+
depth_feats = self.depth_fc(depth_feats)
118+
features.append(depth_feats)
119+
96120
if self.use_vector_input and vector_inputs is not None:
97121
vec_feats = []
98122
for vec in vector_inputs:
99123
vec = vec.unsqueeze(1) # (L, 1)
100124
pooled = self.vector_fc(vec).mean(dim=0) # (D,)
101125
vec_feats.append(pooled)
102126
vec_feats = torch.stack(vec_feats, dim=0)
103-
fused = torch.cat([image_feats, vec_feats], dim=1)
104-
else:
105-
fused = image_feats
127+
features.append(vec_feats)
128+
129+
fused = torch.cat(features, dim=1)
130+
print("Fused shape:", fused.shape)
106131

107132
return self.head(fused) # (B, output_dim)

0 commit comments

Comments
 (0)