Skip to content

Commit 8f2ca51

Browse files
committed
ice_rus
1 parent 3460111 commit 8f2ca51

File tree

5 files changed

+169
-1
lines changed

5 files changed

+169
-1
lines changed
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
## Sea Ice Concentration Forecasting Example
2+
3+
The ice condition data was prepared based on the OSI SAF product.
4+
You can download a prepared sample dataset to run the example [here](https://disk.yandex.ru/d/C8KknPCr65nqSw).
5+
6+
After downloading the data, you need to specify the data directory in the
7+
[```data_loader```](data_loader.py) file - ```/path_to_data/```.
8+
9+
The [```data_loader```](data_loader.py) file contains a function for loading specific
10+
time intervals for ice concentration matrices.
11+
12+
After downloading the data and configuring the directory, you can proceed to train
13+
a convolutional neural network model by running the [```train_cnn```](train_cnn.py) script.
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
2+
## Пример прогнозирования концентрации льда
3+
4+
Данные о ледовой обстановке были подготовлены на основе продукта OSI SAF.
5+
Скачать подготовленный образец для запуска примера можно [здесь](https://disk.yandex.ru/d/C8KrnPCr65nqSw).
6+
7+
После скачивания данных, необходимо указать директорию с данными в файле
8+
[``` data_loader```](data_loader.py) - ``` /path_to_data/``` .
9+
10+
Файл [``` data_loader```](data_loader.py) содержит функцию загрузки определенного
11+
временного интервала для матриц с концентрацией льда.
12+
13+
После скачивания данных и настройки директории можно приступить к обучению
14+
сверточной нейросетевой модели посредством запуска скрипта [```train_cnn```](train_cnn.py).
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import os
2+
from datetime import datetime
3+
import numpy as np
4+
5+
6+
def get_timespatial_series(sea_name, start_date, stop_date):
7+
"""
8+
Function for loading spatiotemporal data for sea
9+
"""
10+
datamodule_path = '/path_to_data/'
11+
files_path = f'{datamodule_path}/{sea_name}'
12+
timespatial_series = []
13+
dates_series = []
14+
for file in os.listdir(files_path):
15+
date = datetime.strptime(file, f'osi_%Y%m%d.npy')
16+
if start_date <= date.strftime('%Y%m%d') < stop_date:
17+
array = np.load(f'{files_path}/{file}')
18+
timespatial_series.append(array)
19+
dates_series.append(date)
20+
else:
21+
break
22+
timespatial_series = np.array(timespatial_series)
23+
return timespatial_series, dates_series
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
2+
#################################################################
3+
# Following block install additional packages used in this example #
4+
# If your environment is already set up, install them manually to avoid version conflicts #
5+
#################################################################
6+
7+
try:
8+
import numpy as np
9+
except ImportError:
10+
print(f'numpy not found, installing')
11+
import pip
12+
pip.main(["install", "numpy"])
13+
import numpy as np
14+
15+
try:
16+
import matplotlib.pyplot as plt
17+
except ImportError:
18+
print(f'matplotlib not found, installing')
19+
import pip
20+
pip.main(["install", "matplotlib"])
21+
import matplotlib.pyplot as plt
22+
23+
try:
24+
import pandas as pd
25+
except ImportError:
26+
print(f'pandas not found, installing')
27+
import pip
28+
pip.main(["install", "pandas"])
29+
import pandas as pd
30+
31+
try:
32+
from pytorch_msssim import ssim
33+
except ImportError:
34+
print(f'pytorch_msssim not found, installing')
35+
import pip
36+
pip.main(["install", "pytorch_msssim"])
37+
from pytorch_msssim import ssim
38+
39+
#################################################################
40+
41+
42+
import time
43+
import torch
44+
import torch.nn as nn
45+
import torch.optim as optim
46+
from torch.utils.data import DataLoader
47+
from torchcnnbuilder.preprocess import multi_output_tensor
48+
from torchcnnbuilder.models import ForecasterBase
49+
from data_loader import get_timespatial_series
50+
51+
# This script generate 2D CNN with 5 layers and train it with saving weights of model
52+
53+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
54+
print(f'Calculating on device: {device}')
55+
56+
sea_name = 'kara'
57+
start_date = '19790101'
58+
end_date = '20200101'
59+
sea_data, dates = get_timespatial_series(sea_name, start_date, end_date)
60+
sea_data = sea_data[::7]
61+
dates = dates[::7]
62+
63+
pre_history_size = 104
64+
forecast_size = 52
65+
66+
dataset = multi_output_tensor(data=sea_data,
67+
forecast_len=forecast_size,
68+
pre_history_len=pre_history_size)
69+
dataloader = DataLoader(dataset, batch_size=200, shuffle=False)
70+
print('Loader created')
71+
72+
encoder = ForecasterBase(input_size=(sea_data.shape[1], sea_data.shape[2]),
73+
n_layers=5,
74+
in_time_points=pre_history_size,
75+
out_time_points=forecast_size)
76+
encoder.to(device)
77+
print(encoder)
78+
79+
optimizer = optim.Adam(encoder.parameters(), lr=0.001)
80+
criterion = nn.L1Loss()
81+
82+
losses = []
83+
start = time.time()
84+
epochs = 1000
85+
best_loss = 999
86+
best_model = None
87+
for epoch in range(epochs):
88+
loss = 0
89+
for train_features, test_features in dataloader:
90+
train_features = train_features.to(device)
91+
test_features = test_features.to(device)
92+
optimizer.zero_grad()
93+
outputs = encoder(train_features)
94+
train_loss = criterion(outputs, test_features)
95+
train_loss.backward()
96+
optimizer.step()
97+
loss += train_loss.item()
98+
99+
loss = loss / len(dataloader)
100+
if loss is None:
101+
break
102+
if loss < best_loss and loss is not None:
103+
print('Upd best model')
104+
best_model = encoder
105+
best_loss = loss
106+
losses.append(loss)
107+
108+
print("epoch : {}/{}, recon loss = {:.8f}".format(epoch + 1, epochs, loss))
109+
110+
end = time.time() - start
111+
print(f'Runtime seconds: {end}')
112+
torch.save(encoder.state_dict(), f"models/{sea_name}_{pre_history_size}_{forecast_size}_l1({start_date}-{end_date}){epochs}.pt")
113+
plt.plot(np.arange(len(losses)), losses)
114+
plt.xlabel('Epoch')
115+
plt.ylabel('Loss')
116+
plt.title(f'Runtime={end}')
117+
plt.savefig(f"models/{sea_name}_{pre_history_size}_{forecast_size}_l1({start_date}-{end_date}){epochs}.png")
118+
plt.show()

examples/synthetic_noise_examples/grid_train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
try:
2424
import pandas as pd
2525
except ImportError:
26-
print(f'matplotlib not found, installing')
26+
print(f'pandas not found, installing')
2727
import pip
2828
pip.main(["install", "pandas"])
2929
import pandas as pd

0 commit comments

Comments
 (0)