-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathmain.py
More file actions
351 lines (302 loc) · 14.9 KB
/
main.py
File metadata and controls
351 lines (302 loc) · 14.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
import json
import os
import random
import timeit
import numpy as np
import torch
from torch import nn
from torchvision.models import resnet18
from torchvision.transforms import RandomApply, GaussianBlur, ElasticTransform
from transformers import SegformerForSemanticSegmentation
from transformers.utils import logging
import datasets.np_transforms as nptr
import datasets.ss_transforms as sstr
import datasets.weather as weather
from client import Client
from datasets.gta5 import GTA5Dataset
from datasets.idda import IDDADataset
from datasets.loveda import LoveDADataset
from fda_server import FdaServer
from models.bisenetv2 import BiSeNetV2
from models.deeplabv3 import deeplabv3_mobilenetv2
from server import Server
from utils.args import get_parser
from utils.stream_metrics import StreamClsMetrics, StreamSegMetrics
from utils.utils import split_list_balanced, split_list_random
def set_seed(random_seed):
""" Set the defined seed to reproducibility of the results """
random.seed(random_seed)
np.random.seed(random_seed)
torch.manual_seed(random_seed)
torch.cuda.manual_seed(random_seed)
torch.cuda.manual_seed_all(random_seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
def get_dataset_num_classes(dataset):
""" Get dataset classes to initialize the model. """
if dataset in ['idda','gta5']:
return 16
if dataset == 'loveda':
return 8
raise NotImplementedError
def model_init(args):
""" Get the model based on the value of args. """
if args.model == 'deeplabv3_mobilenetv2':
return deeplabv3_mobilenetv2(num_classes=get_dataset_num_classes(args.dataset))
if args.model == 'resnet18':
model = resnet18()
model.conv1 = torch.nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
model.fc = nn.Linear(in_features=512, out_features=get_dataset_num_classes(args.dataset))
return model
if args.model == 'segformer':
logging.set_verbosity(logging.FATAL)
weights = args.transformer_model
return SegformerForSemanticSegmentation.from_pretrained(
f"nvidia/mit-{weights}",
num_labels=get_dataset_num_classes(args.dataset),
ignore_mismatched_sizes=True,
)
if args.model == "bisenetv2":
return BiSeNetV2(get_dataset_num_classes(args.dataset), pretrained=True)
raise NotImplementedError
def get_transforms(args):
""" Get the transformations based both on the dataset and the model. """
if args.model in ["segformer",'deeplabv3_mobilenetv2', 'bisenetv2']:
if args.dataset == "loveda":
train_transforms = sstr.Compose([
sstr.RandomCrop((512, 512)),
sstr.ToTensor(),
sstr.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
else:
train_transforms = [
sstr.Compose([
RandomApply([sstr.Lambda(lambda x: weather.add_rain(x))], p=0.15),
]),
sstr.Compose([
sstr.RandomCrop((512, 928 if args.model != "segformer" else 512)),
sstr.ToTensor(),
sstr.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
]),
]
test_transforms = sstr.Compose([
sstr.ToTensor(),
sstr.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
elif args.model == 'resnet18':
train_transforms = nptr.Compose([
nptr.ToTensor(),
nptr.Normalize((0.5,), (0.5,)),
])
test_transforms = nptr.Compose([
nptr.ToTensor(),
nptr.Normalize((0.5,), (0.5,)),
])
else:
raise NotImplementedError
return train_transforms, test_transforms
def get_datasets(args):
""" Function to get the datasets based on the args. It return three lists: train datasets, test_datasets and if necessary
validation datasets. """
train_datasets = []
train_transforms, test_transforms = get_transforms(args)
if args.dataset == 'idda':
root = 'data/idda'
if args.centr:
# If centralized we get all training data on one single client
print("Centralized mode set")
with open(os.path.join(root, 'train.txt'), 'r') as f:
all_data = f.read().splitlines()
train_datasets.append(IDDADataset(root=root, list_samples=all_data, transform=train_transforms,
client_name='centralized'))
else:
# Otherwise we divide data in multiple datasets.
print("Distributed Mode Set")
with open(os.path.join(root, 'train.json'), 'r') as f:
all_data = json.load(f)
for client_id in all_data.keys():
train_datasets.append(IDDADataset(root=root, list_samples=all_data[client_id], transform=train_transforms,
client_name=client_id))
with open(os.path.join(root, 'test_same_dom.txt'), 'r') as f:
test_same_dom_data = f.read().splitlines()
test_same_dom_dataset = IDDADataset(root=root, list_samples=test_same_dom_data, transform=test_transforms,
client_name='test_same_dom')
with open(os.path.join(root, 'test_diff_dom.txt'), 'r') as f:
test_diff_dom_data = f.read().splitlines()
test_diff_dom_dataset = IDDADataset(root=root, list_samples=test_diff_dom_data, transform=test_transforms,
client_name='test_diff_dom')
test_datasets = [test_same_dom_dataset, test_diff_dom_dataset]
elif args.dataset == 'gta5':
root = 'data/gta5'
# Extract all data from train.txt
all_data_train = []
with open(os.path.join(root, 'train.txt'), 'r') as f:
all_data_train = f.read().splitlines()
f.close()
print(f"Total number of images to be loaded: {len(all_data_train)}")
if args.centr:
# If centralized we get all training data on one single client
print("Centralized mode set.")
train_datasets.append(GTA5Dataset(root=root, list_samples=all_data_train, transform=train_transforms,
client_name='centralized'))
else:
# Otherwise we divide data in multiple datasets.
print("Distributed Mode Set.")
total_client_splits = split_list_balanced(all_data_train, args.n_clients)
for i, samples in enumerate(total_client_splits):
train_datasets.append(GTA5Dataset(root=root, list_samples=samples, transform=train_transforms,
client_name="client_"+str(i)))
root_idda = "data/idda"
# Test on IDDA
with open(os.path.join(root_idda, 'test_same_dom.txt'), 'r') as f:
test_same_dom_data = f.read().splitlines()
test_same_dom_dataset = IDDADataset(root=root_idda, list_samples=test_same_dom_data, transform=test_transforms,
client_name='test_same_dom')
with open(os.path.join(root_idda, 'test_diff_dom.txt'), 'r') as f:
test_diff_dom_data = f.read().splitlines()
test_diff_dom_dataset = IDDADataset(root=root_idda, list_samples=test_diff_dom_data, transform=test_transforms,
client_name='test_diff_dom')
test_datasets = [test_same_dom_dataset, test_diff_dom_dataset]
# Setting up IDDA as validation set
validation_data = []
with open(os.path.join(root_idda, 'train.txt'), 'r') as f:
all_data = f.read().splitlines()
validation_data.append(IDDADataset(root=root_idda, list_samples=all_data, transform=train_transforms,
client_name='centralized'))
return train_datasets, test_datasets, validation_data
elif args.dataset == "loveda":
root = 'data/loveda'
if not args.fda:
# Extract all data from the Urban set (train)
folder_loveda = "Urban"
else:
# Extract all data from the target set (train)
folder_loveda = "target"
all_data_train = os.listdir(os.path.join(root, folder_loveda, "images_png"))
print(f"Total number of images to be loaded: {len(all_data_train)}")
if args.centr:
# If centralized we get all training data on one single client
print("Centralized mode set.")
train_datasets.append(LoveDADataset(root=root, list_samples=all_data_train, folder=folder_loveda, transform=train_transforms,
client_name='centralized'))
else:
# Otherwise we divide data in multiple datasets.
print("Distributed Mode Set.")
total_client_splits = split_list_balanced(all_data_train, args.n_clients)
for i, samples in enumerate(total_client_splits):
train_datasets.append(LoveDADataset(root=root, list_samples=samples, folder=folder_loveda, transform=train_transforms,
client_name="client_"+str(i)))
# Extract test data from the Urban2 (test same domain)
test_same_dom_data = os.listdir(os.path.join(root, "Urban2", "images_png"))
test_same_dom_dataset = LoveDADataset(root=root, list_samples=test_same_dom_data, folder="Urban2", transform=test_transforms,
client_name='test_same_dom')
# Extract test data from the Rural (test diff domain)
test_diff_dom_data = os.listdir(os.path.join(root, "Rural", "images_png"))
test_diff_dom_dataset = LoveDADataset(root=root, list_samples=test_diff_dom_data, folder="Rural", transform=test_transforms,
client_name='test_diff_dom')
test_datasets = [test_same_dom_dataset, test_diff_dom_dataset]
else:
raise NotImplementedError
return train_datasets, test_datasets, None
def get_source_client(args, model):
""" Function to get the source client based on the dataset. This function is only used in the fda setting. Returns None otehrwise. \n
Args:
`args`, `model` (pytorch)
Returns:
list of one clients containing the source training set. This function is needed since the 'gen_clients' functions focuses
on the dataset split.
"""
train_transforms, _ = get_transforms(args)
if args.fda:
if args.dataset == "idda": # target == idda
root = 'data/gta5'
# Extract all data from train.txt
all_data_train = []
with open(os.path.join(root, 'train.txt'), 'r') as f:
all_data_train = f.read().splitlines()
f.close()
sc = Client(args, GTA5Dataset(root=root, list_samples=all_data_train, transform=train_transforms, client_name='gta5_all'), model)
elif args.dataset == "loveda":
root = 'data/loveda'
# Extract all data from the Urban (trainset)
all_data_train = os.listdir(os.path.join(root, "Urban", "images_png"))
dataset = LoveDADataset(root=root, list_samples=all_data_train, folder="Urban", transform=train_transforms,
client_name='loveda_all')
sc = Client(args, dataset, model)
else:
return None
return [sc]
else:
return None
def set_metrics(args):
""" Get the metrics used to evaluate performance based on the task (determined by the model). """
num_classes = get_dataset_num_classes(args.dataset)
if args.model in ['deeplabv3_mobilenetv2', "segformer", "bisenetv2"]:
metrics = {
'eval_train': StreamSegMetrics(num_classes, 'eval_train'),
'test_same_dom': StreamSegMetrics(num_classes, 'test_same_dom'),
'test_diff_dom': StreamSegMetrics(num_classes, 'test_diff_dom')
}
elif args.model == 'resnet18' or args.model == 'cnn':
metrics = {
'eval_train': StreamClsMetrics(num_classes, 'eval_train'),
'test': StreamClsMetrics(num_classes, 'test')
}
else:
raise NotImplementedError
return metrics
def gen_clients(args, train_datasets, test_datasets, validation_datasets, model):
""" Divide the datasets in clients. """
clients = [[], [], []]
for i, datasets in enumerate([train_datasets, test_datasets]):
# For each dataset datasets (one for each client), create and append a client
for ds in datasets:
clients[i].append(Client(args, ds, model, test_client=i == 1))
if validation_datasets:
clients[2].append(Client(args, validation_datasets[0], model, test_client=True))
return clients[0], clients[1], clients[2]
def main():
# Initilizalize the parser to get all the parameters
parser = get_parser()
args = parser.parse_args()
# Setting up the seed for reproducibility
set_seed(args.seed)
# Get the model and move it to GPU
print('Initializing model...', end=" ")
# This code requires cuda enabled.
try:
model = model_init(args)
model.cuda()
print("Model Loaded: "+args.model)
except:
print("\FATAL: seems like you have not CUDA enabled or you did not specify a model to use. Try again.")
exit(1)
print('Done.')
# Get the datasets needed.
train_datasets, test_datasets, validation_dataset = get_datasets(args)
print('Done.')
source_dataset = get_source_client(args, model)
# Get the metrics needed.
metrics = set_metrics(args)
# Generate the clients.
print('Generate clients...', end=" ")
train_clients, test_clients, valid_clients = gen_clients(args, train_datasets, test_datasets, validation_dataset, model)
print('Done.')
# Setting up the server based on the mode chosen. Two server classes are available Server/FdaServer
print('Setup server...', end=" ")
if args.fda == False:
if args.dataset == "gta5":
server = Server(args, train_clients, test_clients, model, metrics, True, valid_clients)
else:
server = Server(args, train_clients, test_clients, model, metrics)
else:
print("\nActivating FDA mode...\t", end="")
server = FdaServer(args, source_dataset, train_clients, test_clients, model, metrics)
print('Done.')
execution_time = timeit.timeit(server.train, number=1)
print(f"Execution time: {execution_time} seconds")
# Predict an new image if needed image saved in the root directory as image_fin (for centralized setting), fda_imagine_fin (for fda)
if args.pred:
print("Predicting "+args.pred)
server.predict(args.pred)
if __name__ == '__main__':
main()