-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpreprocessing.py
More file actions
66 lines (59 loc) · 2 KB
/
preprocessing.py
File metadata and controls
66 lines (59 loc) · 2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import cv2
import numpy as np
import torch
from torchvision import transforms
from typing import Tuple, List, Optional, Union
class ImagePreprocessor:
"""
Image preprocessor for ViT models.
"""
def __init__(
self,
target_size: Tuple[int, int] = (224, 224),
mean: List[float] = [0.485, 0.456, 0.406],
std: List[float] = [0.229, 0.224, 0.225],
augmentations: bool = True,
):
self.target_size = target_size
self.mean = mean
self.std = std
self.augmentations = augmentations
# self.train_transform = transforms.Compose(
# [
# transforms.ToPILImage(),
# transforms.Resize(self.target_size),
# transforms.ToTensor(),
# transforms.Normalize(
# mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
# ),
# ]
# )
def train_transform(self) -> transforms.Compose:
transform_list = [
transforms.Resize(self.target_size),
transforms.RandomHorizontalFlip(p=0.5),
]
if self.augmentations:
transform_list.extend(
[
transforms.RandomAffine(degrees=15, translate=(0.1, 0.1)),
transforms.ColorJitter(
brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1
),
transforms.RandomRotation(degrees=15),
]
)
transform_list.extend(
[
transforms.ToTensor(),
transforms.Normalize(mean=self.mean, std=self.std),
]
)
return transforms.Compose(transform_list)
def val_transform(self) -> transforms.Compose:
transform_list = [
transforms.Resize(self.target_size),
transforms.ToTensor(),
transforms.Normalize(mean=self.mean, std=self.std),
]
return transforms.Compose(transform_list)