Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
215 changes: 203 additions & 12 deletions cirkit/backend/torch/compiler.py

Large diffs are not rendered by default.

53 changes: 50 additions & 3 deletions cirkit/backend/torch/graph/folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,34 @@ def build_folded_graph(
list[TorchModuleT],
FoldIndexInfo[TorchModuleT],
]:
"""Find and apply all possible folding on a graph.

Args:
ordering (Iterable[list[TorchModuleT]]):
Module in the graph in the layerwise topological order.
outputs (Iterable[TorchModuleT]): Outputs of the graph.
incomings_fn (Callable[[TorchModuleT],Sequence[TorchModuleT]]):
Function returning the input modules of a given module.
fold_group_fn (Callable[[list[TorchModuleT]], TorchModuleT]):
Function returning a folded module givena group of modules.

Returns:
tuple[
list[TorchModuleT],
dict[TorchModuleT, list[TorchModuleT]],
list[TorchModuleT],
FoldIndexInfo[TorchModuleT],
]:
- The final, potentially folded, modules.
- The adjacency list updated with the folded modules.
- The list of modules that acts as output of the graph.
- A `FoldIndexInfo` objects which stores the information necessary
to retrieve "locate" a unfolded module into the folded circuit.
It is basically a map between a module from the unfolded circuit
and a pair (id_folded_module, fold_id).


"""
# A useful data structure mapping each unfolded module to
# (i) a 'fold_id' (a natural number) pointing to the module layer it is associated to; and
# (ii) a 'slice_idx' (a natural number) within the output of the folded module,
Expand Down Expand Up @@ -111,7 +139,9 @@ def build_folded_graph(
# Check if we are folding input modules
in_modules_idx: list[list[tuple[int, int]]]
if in_group_modules[0]:
in_modules_idx = [[fold_idx[mi] for mi in msi] for msi in in_group_modules]
in_modules_idx = [
[fold_idx[mi] for mi in msi] for msi in in_group_modules
]
else:
in_modules_idx = []

Expand All @@ -130,12 +160,27 @@ def build_folded_graph(
# Construct the sequence of folded output modules
outputs = list(dict.fromkeys(modules[fi[0]] for fi in out_fold_idx))

return modules, in_modules, outputs, FoldIndexInfo(modules, in_fold_idx, out_fold_idx)
return (
modules,
in_modules,
outputs,
FoldIndexInfo(modules, in_fold_idx, out_fold_idx),
)


def group_foldable_modules(
modules: list[TorchModuleT],
) -> list[list[TorchModuleT]]:
"""Groups module that can be folded together.

Args:
modules (list[TorchModuleT]): Modules from the same level in the graph's
layerwise topological ordering.

Returns:
list[list[TorchModuleT]]: List of grouped torch module that can be folded together.
"""

def _gather_fold_settings(module: AbstractTorchModule) -> tuple[Any, ...]:
ss = [type(m), *m.fold_settings]
for _, sub_module in module.sub_modules.items():
Expand Down Expand Up @@ -176,7 +221,9 @@ def build_address_book_stacked_entry(
)

# Build the bookkeeping entry
cum_fold_idx = [[cum_module_ids[idx[0]] + idx[1] for idx in fi] for fi in in_fold_idx]
cum_fold_idx = [
[cum_module_ids[idx[0]] + idx[1] for idx in fi] for fi in in_fold_idx
]

# Check if we are computing the output stacked address book entry
# If so, then squeeze the fold dimension that is equal to one
Expand Down
Loading
Loading