Skip to content

Commit 1ddf445

Browse files
committed
Fix tests.
1 parent 2656e29 commit 1ddf445

File tree

1 file changed

+4
-7
lines changed

1 file changed

+4
-7
lines changed

tests/test_dataset.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -100,15 +100,12 @@ def test_IterableLLMDataset_padding(mock_hf_dataset: MockHFDataset):
100100
assert torch.equal(input_ids, expected_input)
101101

102102

103-
def test_IterableLLMDataset_with_dataloader(mock_hf_dataset: MockHFDataset):
104-
"""Test if it works with actual PyTorch DataLoader."""
105-
tokenizer = BPETokenizer()
106-
dataset = IterableLLMDataset(
107-
dataset=mock_hf_dataset, tokenizer=tokenizer, max_length=5, stride=2
103+
def test_create_llm_dataloader_from_dataset(mock_hf_dataset: MockHFDataset):
104+
"""Test if it works with actual DataLoader."""
105+
data_loader = create_llm_dataloader_from_dataset(
106+
mock_hf_dataset, batch_size=2, max_length=5, stride=2, shuffle=False
108107
)
109108

110-
data_loader = create_llm_dataloader_from_dataset(dataset, batch_size=2, shuffle=False)
111-
112109
batch = next(iter(data_loader))
113110
input_batch, target_batch = batch
114111

0 commit comments

Comments
 (0)