Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 20 additions & 60 deletions comms/torchcomms/tests/integration/py/MultiCommTest.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,17 +234,12 @@ def test_two_comms_no_store(self):
wrappers = []
comms = []

# Create first communicator with a store
store = create_store()
wrappers.append(TorchCommTestWrapper(store=store))
# Create first communicator with no store
wrappers.append(TorchCommTestWrapper())
comms.append(wrappers[-1].get_torchcomm())

store = None
store_deletion_barrier(comms[-1])

# Create second communicator with a store
store = create_store()
wrappers.append(TorchCommTestWrapper(store=store))
# Create second communicator with no store
wrappers.append(TorchCommTestWrapper())
comms.append(wrappers[-1].get_torchcomm())

# Test communication on each communicator individually
Expand All @@ -254,34 +249,22 @@ def test_two_comms_no_store(self):
# Test simultaneous communication across all communicators
self._verify_simultaneous_communication(wrappers)

store = None
store_deletion_barrier(comms[-1])

def test_three_comms_no_store(self):
"""Test with three communicators with no store."""
# Create three communicators
wrappers = []
comms = []

# Create first communicator with a store
store = create_store()
wrappers.append(TorchCommTestWrapper(store=store))
# Create first communicator with no store
wrappers.append(TorchCommTestWrapper())
comms.append(wrappers[-1].get_torchcomm())

store = None
store_deletion_barrier(comms[-1])

# Create second communicator with a store
store = create_store()
wrappers.append(TorchCommTestWrapper(store=store))
# Create second communicator with no store
wrappers.append(TorchCommTestWrapper())
comms.append(wrappers[-1].get_torchcomm())

store = None
store_deletion_barrier(comms[-1])

# Create third communicator with a store
store = create_store()
wrappers.append(TorchCommTestWrapper(store=store))
# Create third communicator with no store
wrappers.append(TorchCommTestWrapper())
comms.append(wrappers[-1].get_torchcomm())

# Test communication on each communicator individually
Expand All @@ -291,26 +274,18 @@ def test_three_comms_no_store(self):
# Test simultaneous communication across all communicators
self._verify_simultaneous_communication(wrappers)

store = None
store_deletion_barrier(comms[-1])

def test_mixed_ops_no_store(self):
"""Test mixed operations across multiple communicators with no store."""
# Create two communicators
wrappers = []
comms = []

# Create first communicator with a store
store = create_store()
wrappers.append(TorchCommTestWrapper(store=store))
# Create first communicator with no store
wrappers.append(TorchCommTestWrapper())
comms.append(wrappers[-1].get_torchcomm())

store = None
store_deletion_barrier(comms[-1])

# Create second communicator with a store
store = create_store()
wrappers.append(TorchCommTestWrapper(store=store))
# Create second communicator with no store
wrappers.append(TorchCommTestWrapper())
comms.append(wrappers[-1].get_torchcomm())

# Prepare tensors for different operations
Expand Down Expand Up @@ -341,9 +316,6 @@ def test_mixed_ops_no_store(self):
verify_tensor_equality(input1.cpu(), expected1, "comm_0 all_reduce result")
verify_tensor_equality(input2.cpu(), broadcast_value, "comm_1 broadcast result")

store = None
store_deletion_barrier(comms[-1])

def test_two_comms_mixed_store(self):
"""Test with two communicators with mixed store (one explicit, one None)."""
# Create two communicators
Expand All @@ -361,9 +333,8 @@ def test_two_comms_mixed_store(self):
store = None
store_deletion_barrier(comms[0])

# Create second communicator with a new store
store = create_store()
wrappers.append(TorchCommTestWrapper(store=store))
# Create second communicator with no store
wrappers.append(TorchCommTestWrapper())
comms.append(wrappers[-1].get_torchcomm())

# Test communication on each communicator individually
Expand All @@ -373,9 +344,6 @@ def test_two_comms_mixed_store(self):
# Test simultaneous communication across all communicators
self._verify_simultaneous_communication(wrappers)

store = None
store_deletion_barrier(comms[-1])

def test_three_comms_mixed_store(self):
"""Test with three communicators with mixed store (two explicit, one None)."""
# Create three communicators
Expand All @@ -402,9 +370,8 @@ def test_three_comms_mixed_store(self):
store = None
store_deletion_barrier(comms[0])

# Create third communicator with a new store
store = create_store()
wrappers.append(TorchCommTestWrapper(store=store))
# Create third communicator with no store
wrappers.append(TorchCommTestWrapper())
comms.append(wrappers[-1].get_torchcomm())

# Test communication on each communicator individually
Expand All @@ -414,9 +381,6 @@ def test_three_comms_mixed_store(self):
# Test simultaneous communication across all communicators
self._verify_simultaneous_communication(wrappers)

store = None
store_deletion_barrier(comms[-1])

def test_mixed_ops_mixed_store(self):
"""Test mixed operations across multiple communicators with mixed store."""
# Create two communicators
Expand All @@ -434,9 +398,8 @@ def test_mixed_ops_mixed_store(self):
store = None
store_deletion_barrier(comms[0])

# Create second communicator with a new store
store = create_store()
wrappers.append(TorchCommTestWrapper(store=store))
# Create second communicator with no store
wrappers.append(TorchCommTestWrapper())
comms.append(wrappers[-1].get_torchcomm())

device0 = comms[0].get_device()
Expand Down Expand Up @@ -465,9 +428,6 @@ def test_mixed_ops_mixed_store(self):
verify_tensor_equality(input1.cpu(), expected1, "comm_0 all_reduce result")
verify_tensor_equality(input2.cpu(), broadcast_value, "comm_1 broadcast result")

store = None
store_deletion_barrier(comms[-1])


if __name__ == "__main__":
unittest.main(failfast=True)
Loading