-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcub_dataset.py
More file actions
96 lines (77 loc) · 2.92 KB
/
cub_dataset.py
File metadata and controls
96 lines (77 loc) · 2.92 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import os
import pandas as pd
from torch.utils.data import Dataset
from torchvision.datasets.folder import default_loader
from torchvision.datasets.utils import download_url
class Cub2011(Dataset):
base_folder = "CUB_200_2011/images"
url = (
"https://data.caltech.edu/tindfiles/serve/1239ea37-e132-42ee-8c09-c383bb54e7ff/"
)
filename = "CUB_200_2011.tgz"
tgz_md5 = "97eceeb196236b17998738112f37df78"
def __init__(
self, root, train=True, transform=None, loader=default_loader, download=True
):
self.root = os.path.expanduser(root)
self.transform = transform
self.loader = default_loader
self.train = train
if download:
self._download()
if not self._check_integrity():
raise RuntimeError(
"Dataset not found or corrupted."
+ " You can use download=True to download it"
)
def _load_metadata(self):
images = pd.read_csv(
os.path.join(self.root, "CUB_200_2011", "images.txt"),
sep=" ",
names=["img_id", "filepath"],
)
image_class_labels = pd.read_csv(
os.path.join(self.root, "CUB_200_2011", "image_class_labels.txt"),
sep=" ",
names=["img_id", "target"],
)
train_test_split = pd.read_csv(
os.path.join(self.root, "CUB_200_2011", "train_test_split.txt"),
sep=" ",
names=["img_id", "is_training_img"],
)
data = images.merge(image_class_labels, on="img_id")
self.data = data.merge(train_test_split, on="img_id")
if self.train:
self.data = self.data[self.data.is_training_img == 1]
else:
self.data = self.data[self.data.is_training_img == 0]
def _check_integrity(self):
try:
self._load_metadata()
except Exception:
return False
for index, row in self.data.iterrows():
filepath = os.path.join(self.root, self.base_folder, row.filepath)
if not os.path.isfile(filepath):
print(filepath)
return False
return True
def _download(self):
import tarfile
if self._check_integrity():
print("Files already downloaded and verified")
return
download_url(self.url, self.root, self.filename, self.tgz_md5)
with tarfile.open(os.path.join(self.root, self.filename), "r:gz") as tar:
tar.extractall(path=self.root)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
sample = self.data.iloc[idx]
path = os.path.join(self.root, self.base_folder, sample.filepath)
target = sample.target - 1 # Targets start at 1 by default, so shift to 0
img = self.loader(path)
if self.transform is not None:
img = self.transform(img)
return img, target