Skip to content

Commit 326fbb4

Browse files
committed
diamond improvements
1 parent b2f477d commit 326fbb4

File tree

4 files changed

+331
-146
lines changed

4 files changed

+331
-146
lines changed

src/diamond/DiamondInspectFacet.sol

Lines changed: 89 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ pragma solidity >=0.8.30;
66
*/
77

88
interface IFacet {
9-
function functionSelectors() external view returns (bytes4[] memory);
9+
function packedSelectors() external view returns (bytes memory);
1010
}
1111

1212
contract DiamondInspectFacet {
@@ -56,6 +56,60 @@ contract DiamondInspectFacet {
5656
facet = s.facetNodes[_functionSelector].facet;
5757
}
5858

59+
/**
60+
* @notice Decodes a packed byte stream into a standard bytes4[] array.
61+
* @param packed The packed bytes (e.g., from `bytes.concat`).
62+
* @return unpacked The standard padded bytes4[] array.
63+
*/
64+
function unpackSelectors(bytes memory packed) internal pure returns (bytes4[] memory unpacked) {
65+
/*
66+
* Allocate the output array
67+
*/
68+
uint256 count = packed.length / 4;
69+
unpacked = new bytes4[](count);
70+
/*
71+
* Copy from packed to unpacked
72+
*/
73+
assembly ("memory-safe") {
74+
/*
75+
* 'src' points to the start of the data in the packed array (skip 32-byte length)
76+
*/
77+
let src := add(packed, 32)
78+
/*
79+
* 'dst' points to the start of the data in the new selectors array (skip 32-byte length)
80+
*/
81+
let dst := add(unpacked, 32)
82+
/*
83+
* 'end' is the stopping point for the destination pointer
84+
*/
85+
let end := add(dst, mul(count, 32))
86+
/*
87+
* While 'dst' is less than 'end', keep copying
88+
*/
89+
for {} lt(dst, end) {} {
90+
/*
91+
* A. Load 32 bytes from the packed source.
92+
* We read "dirty" data (neighboring bytes), but it doesn't matter
93+
* because we truncate it when writing.
94+
*/
95+
let value := mload(src)
96+
/*
97+
* B. Clearn up the value to extract only the 4 bytes we want.
98+
*/
99+
value := and(value, 0xFFFFFFFF00000000000000000000000000000000000000000000000000000000)
100+
/*
101+
* C. Store the value into the destination
102+
*/
103+
mstore(dst, value)
104+
/*
105+
* D. Advance pointers
106+
*/
107+
src := add(src, 4) // Move forward 4 bytes in packed source
108+
dst := add(dst, 32) // Move forward 32 bytes in destination target
109+
}
110+
}
111+
}
112+
59113
/**
60114
* @notice Gets the function selectors that are handled by the given facet.
61115
* @dev If facet is not found return empty array.
@@ -64,7 +118,7 @@ contract DiamondInspectFacet {
64118
*/
65119
function facetFunctionSelectors(address _facet) external view returns (bytes4[] memory facetSelectors) {
66120
DiamondStorage storage s = getStorage();
67-
facetSelectors = IFacet(_facet).functionSelectors();
121+
facetSelectors = unpackSelectors(IFacet(_facet).packedSelectors());
68122
if (facetSelectors.length == 0 || s.facetNodes[facetSelectors[0]].facet == address(0)) {
69123
facetSelectors = new bytes4[](0);
70124
}
@@ -105,13 +159,30 @@ contract DiamondInspectFacet {
105159
facetsAndSelectors = new Facet[](facetList.facetCount);
106160
for (uint256 i; i < facetList.facetCount; i++) {
107161
address facet = s.facetNodes[currentSelector].facet;
108-
bytes4[] memory facetSelectors = IFacet(facet).functionSelectors();
162+
bytes4[] memory facetSelectors = unpackSelectors(IFacet(facet).packedSelectors());
109163
facetsAndSelectors[i].facet = facet;
110164
facetsAndSelectors[i].functionSelectors = facetSelectors;
111165
currentSelector = s.facetNodes[currentSelector].nextFacetNodeId;
112166
}
113167
}
114168

169+
function at(bytes memory selectors, uint256 index) internal pure returns (bytes4 selector) {
170+
assembly ("memory-safe") {
171+
/**
172+
* 1. Calculate Pointer
173+
* add(selectors, 32) - skips the length field of the bytes array
174+
* shl(2, index) is the same as index * 4 but cheaper
175+
*/
176+
let ptr := add(add(selectors, 32), shl(2, index))
177+
/**
178+
* 2. Load & Return
179+
* We load 32 bytes, but Solidity truncates to 4 bytes automatically
180+
* upon return assignment, so masking is unnecessary.
181+
*/
182+
selector := mload(ptr)
183+
}
184+
}
185+
115186
struct FunctionFacetPair {
116187
bytes4 selector;
117188
address facet;
@@ -134,9 +205,13 @@ contract DiamondInspectFacet {
134205
uint256 selectorCount;
135206
for (uint256 i; i < facetList.facetCount; i++) {
136207
address facet = s.facetNodes[currentSelector].facet;
137-
bytes4[] memory facetSelectors = IFacet(facet).functionSelectors();
138-
for (uint256 selectorIndex; selectorIndex < facetSelectors.length; selectorIndex++) {
139-
bytes4 selector = facetSelectors[selectorIndex];
208+
bytes memory selectors = IFacet(facet).packedSelectors();
209+
uint256 selectorLength;
210+
unchecked {
211+
selectorLength = selectors.length / 4;
212+
}
213+
for (uint256 selectorIndex; selectorIndex < selectorLength; selectorIndex++) {
214+
bytes4 selector = at(selectors, selectorIndex);
140215
pairs[selectorCount] = FunctionFacetPair(selector, facet);
141216
unchecked {
142217
selectorCount++;
@@ -146,12 +221,13 @@ contract DiamondInspectFacet {
146221
}
147222
}
148223

149-
function functionSelectors() external pure returns (bytes4[] memory selectors) {
150-
selectors = new bytes4[](5);
151-
selectors[0] = DiamondInspectFacet.facetAddress.selector;
152-
selectors[1] = DiamondInspectFacet.facetFunctionSelectors.selector;
153-
selectors[2] = DiamondInspectFacet.facetAddresses.selector;
154-
selectors[3] = DiamondInspectFacet.facets.selector;
155-
selectors[4] = DiamondInspectFacet.functionFacetPairs.selector;
224+
function packedSelectors() external pure returns (bytes memory) {
225+
return bytes.concat(
226+
DiamondInspectFacet.facetAddress.selector,
227+
DiamondInspectFacet.facetFunctionSelectors.selector,
228+
DiamondInspectFacet.facetAddresses.selector,
229+
DiamondInspectFacet.facets.selector,
230+
DiamondInspectFacet.functionFacetPairs.selector
231+
);
156232
}
157233
}

src/diamond/DiamondMod.sol

Lines changed: 46 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -51,61 +51,78 @@ error FunctionSelectorsCallFailed(address _facet);
5151
error NoSelectorsForFacet(address _facet);
5252
error NoBytecodeAtAddress(address _contractAddress);
5353

54-
function functionSelectors(address _facet) view returns (bytes4[] memory) {
54+
/**
55+
* @notice Emitted when a function is added to a diamond.
56+
*
57+
* @param _selector The function selector being added.
58+
* @param _facet The facet address that will handle calls to `_selector`.
59+
*/
60+
event DiamondFunctionAdded(bytes4 indexed _selector, address indexed _facet);
61+
62+
error CannotAddFunctionToDiamondThatAlreadyExists(bytes4 _selector);
63+
error NoFacetsToAdd();
64+
65+
function packedSelectors(address _facet) view returns (bytes memory) {
5566
if (_facet.code.length == 0) {
5667
revert NoBytecodeAtAddress(_facet);
5768
}
58-
(bool success, bytes memory data) =
59-
_facet.staticcall(abi.encodeWithSelector(bytes4(keccak256("functionSelectors()"))));
69+
(bool success, bytes memory selectors) =
70+
_facet.staticcall(abi.encodeWithSelector(bytes4(keccak256("packedSelectors()"))));
6071

6172
if (success == false) {
6273
revert FunctionSelectorsCallFailed(_facet);
6374
}
64-
bytes4[] memory selectors = abi.decode(data, (bytes4[]));
65-
if (selectors.length == 0) {
75+
if (selectors.length < 4) {
6676
revert NoSelectorsForFacet(_facet);
6777
}
6878
return selectors;
6979
}
7080

71-
/**
72-
* @notice Emitted when a function is added to a diamond.
73-
*
74-
* @param _selector The function selector being added.
75-
* @param _facet The facet address that will handle calls to `_selector`.
76-
*/
77-
event DiamondFunctionAdded(bytes4 indexed _selector, address indexed _facet);
78-
79-
error CannotAddFunctionToDiamondThatAlreadyExists(bytes4 _selector);
80-
error NoFacetsToAdd();
81+
function at(bytes memory selectors, uint256 index) pure returns (bytes4 selector) {
82+
assembly ("memory-safe") {
83+
/**
84+
* 1. Calculate Pointer
85+
* add(selectors, 32) - skips the length field of the bytes array
86+
* shl(2, index) is the same as index * 4 but cheaper
87+
*/
88+
let ptr := add(add(selectors, 32), shl(2, index))
89+
/**
90+
* 2. Load & Return
91+
* We load 32 bytes, but Solidity truncates to 4 bytes automatically
92+
* upon return assignment, so masking is unnecessary.
93+
*/
94+
selector := mload(ptr)
95+
}
96+
}
8197

82-
/**
83-
* @notice Adds facets and their function selectors to the diamond.
84-
*/
8598
function addFacets(address[] memory _facets) {
8699
DiamondStorage storage s = getStorage();
87100
if (_facets.length == 0) {
88101
return;
89102
}
90103
FacetList memory facetList = s.facetList;
91104
bytes4 prevFacetNodeId = facetList.lastFacetNodeId;
92-
bytes4[] memory facetSelectors = functionSelectors(_facets[0]);
105+
bytes memory selectors = packedSelectors(_facets[0]);
106+
bytes4 currentFacetNodeId = at(selectors, 0);
93107
if (facetList.facetCount == 0) {
94-
facetList.firstFacetNodeId = facetSelectors[0];
108+
facetList.firstFacetNodeId = currentFacetNodeId;
109+
} else {
110+
s.facetNodes[prevFacetNodeId].nextFacetNodeId = currentFacetNodeId;
95111
}
96112
for (uint256 i; i < _facets.length; i++) {
113+
uint256 selectorsLength;
97114
uint256 nextI;
98115
unchecked {
99116
nextI = i + 1;
100-
facetList.selectorCount += uint32(facetSelectors.length);
117+
selectorsLength = selectors.length / 4;
118+
facetList.selectorCount += uint32(selectorsLength);
101119
}
102-
bytes4[] memory nextFacetSelectors;
120+
bytes memory nextSelectors;
103121
bytes4 nextFacetNodeId;
104122
if (nextI < _facets.length) {
105-
nextFacetSelectors = functionSelectors(_facets[nextI]);
106-
nextFacetNodeId = nextFacetSelectors[0];
123+
nextSelectors = packedSelectors(_facets[nextI]);
124+
nextFacetNodeId = at(nextSelectors, 0);
107125
}
108-
bytes4 currentFacetNodeId = facetSelectors[0];
109126
address oldFacet = s.facetNodes[currentFacetNodeId].facet;
110127
if (oldFacet != address(0)) {
111128
revert CannotAddFunctionToDiamondThatAlreadyExists(currentFacetNodeId);
@@ -114,8 +131,8 @@ function addFacets(address[] memory _facets) {
114131
s.facetNodes[currentFacetNodeId] = FacetNode(facet, prevFacetNodeId, nextFacetNodeId);
115132
emit DiamondFunctionAdded(currentFacetNodeId, facet);
116133

117-
for (uint256 selectorIndex = 1; selectorIndex < facetSelectors.length; selectorIndex++) {
118-
bytes4 selector = facetSelectors[selectorIndex];
134+
for (uint256 selectorIndex = 1; selectorIndex < selectorsLength; selectorIndex++) {
135+
bytes4 selector = at(selectors, selectorIndex);
119136
oldFacet = s.facetNodes[selector].facet;
120137
if (oldFacet != address(0)) {
121138
revert CannotAddFunctionToDiamondThatAlreadyExists(selector);
@@ -124,7 +141,8 @@ function addFacets(address[] memory _facets) {
124141
emit DiamondFunctionAdded(selector, facet);
125142
}
126143
prevFacetNodeId = currentFacetNodeId;
127-
facetSelectors = nextFacetSelectors;
144+
selectors = nextSelectors;
145+
currentFacetNodeId = nextFacetNodeId;
128146
}
129147
unchecked {
130148
facetList.facetCount += uint32(_facets.length);

0 commit comments

Comments
 (0)