forked from catalyst-team/catalyst
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathexperiment.py
More file actions
54 lines (39 loc) · 1.48 KB
/
experiment.py
File metadata and controls
54 lines (39 loc) · 1.48 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
from typing import Tuple
from collections import OrderedDict
from torch.utils.data import Dataset
from catalyst.contrib.datasets import MNIST as _MNIST
from catalyst.dl.experiment import ConfigExperiment
class MNIST(_MNIST):
"""`MNIST <http://yann.lecun.com/exdb/mnist/>`_ Dataset."""
def __getitem__(self, index: int) -> Tuple:
"""Fetches a sample for a given index from MNIST dataset.
Args:
index: index of the element in the dataset
Returns:
tuple: (image, target) where target is index of the target class
"""
image, target = self.data[index], self.targets[index]
if self.transform is not None:
image = self.transform({"image": image})["image"]
return image, target
class Experiment(ConfigExperiment):
"""``ConfigExperiment`` with MNIST dataset."""
def get_datasets(
self, stage: str, **kwargs
) -> "OrderedDict[str, Dataset]":
"""Provides train/validation subsets from MNIST dataset.
Args:
stage: stage name e.g. ``'stage1'`` or ``'infer'``
**kwargs: extra params
Returns:
ordered dict with datasets
"""
datasets = OrderedDict()
for mode in ("train", "valid"):
datasets[mode] = MNIST(
"./data",
train=False,
download=True,
transform=self.get_transforms(stage=stage, dataset=mode),
)
return datasets