Skip to content

Commit 5350fc9

Browse files
authored
Xrdmatch (#1234)
* 提交文件 * Create vgg.py * 修改 * Delete xrdmatch.yaml * 提交 * 修改 * 提交 * 修改 * Update mkdocs.yml * 完善md文件 * 修改说明文件 * 提交代码 帮我把之前的模型都删除吧 * Update xrdmatch.md * Update xrdmatch.yaml * 修改配置 * 修改训练轮次
1 parent a961fab commit 5350fc9

File tree

6 files changed

+1287
-0
lines changed

6 files changed

+1287
-0
lines changed

docs/zh/examples/xrdmatch.md

Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
# XRDMatch
2+
3+
## 1.模型训练与评估
4+
=== "模型训练命令"
5+
``` sh
6+
python main.py
7+
```
8+
=== "模型评估命令"
9+
``` sh
10+
python main.py --mode eval --exp_id x --epoch x
11+
```
12+
13+
## 2.背景简介
14+
15+
XRDMatch 是一个基于 PaddleScience 的 XRD 数据半监督学习示例,使用 FlexMatch 算法进行材料分类。该示例展示了如何使用少量有标签数据和大量无标签数据来训练高性能的分类模型,特别适用于材料科学中的 XRD 谱线分析。
16+
17+
X射线衍射(XRD)是材料科学中重要的表征技术,能够提供材料的晶体结构信息。在实际应用中,获取大量有标签的 XRD 数据成本高昂且耗时,而半监督学习可以充分利用大量无标签数据来提升模型性能,降低标注成本。
18+
19+
本工作目的是利用锂离子固态电解质材料的XRD数据训练,得到相应的结构和性能关系。通过FlexMatch算法,结合数据增强、伪标签生成、动态阈值和一致性正则化等技术,实现高效的半监督学习。
20+
21+
22+
## 3.模型原理
23+
24+
该方法的主要思想是通过卷积神经网络建立XRD谱线数据与材料性能之间的非线性映射关系。模型采用VGG网络作为特征提取器,结合FlexMatch半监督学习算法,能够有效利用大量无标签数据提升模型性能。
25+
26+
本案例采用VGG网络作为基础模型架构,主要包括以下几个部分:
27+
28+
1. 输入层:接收 1×4501 的 XRD 谱线数据
29+
2. 卷积层:多层卷积块,提取局部特征模式
30+
3. 池化层:降维和特征聚合
31+
4. 全连接层:特征映射到分类结果
32+
5. 输出层:2类分类(正类/负类)
33+
34+
通过FlexMatch算法,模型能够:
35+
- 基于弱增强数据生成伪标签
36+
- 使用强增强数据进行一致性训练
37+
- 动态调整选择阈值,平衡各类别样本
38+
39+
### 3.1数据格式说明
40+
41+
数据集包含材料XRD谱线数据和对应的性能标签:
42+
- **数据连接**:
43+
```
44+
https://paddle-org.bj.bcebos.com/paddlescience/datasets/xrdmatch/lbs.csv
45+
https://paddle-org.bj.bcebos.com/paddlescience/datasets/xrdmatch/ulbs.csv
46+
```
47+
- **`xrd_data/lbs.csv`**: 有标签数据
48+
- 包含样本名称、ID、标签和 XRD 谱线数据(4501维特征)
49+
- 标签:0(正类)、1(负类)
50+
51+
- **`xrd_data/ulbs.csv`**: 无标签数据
52+
- 包含样本名称、ID 和 XRD 谱线数据(4501维特征)
53+
- 无标签信息,用于半监督学习
54+
55+
### 3.2数据预处理与增强策略
56+
57+
1. **归一化**:将 XRD 强度值归一化到 [0,1] 范围
58+
2. **噪声处理**:去除低强度噪声(阈值 < 0.1)
59+
3. **数据增强**
60+
- **弱增强**:添加少量噪声(10%)和位移(100像素)
61+
- **强增强**:缩放(15%)、消除(15%)、大幅位移(200像素)和噪声(20%)
62+
``` py linenums="42" title="examples/xrdmatch/main.py"
63+
--8<--
64+
examples/xrdmatch/main.py:42:89
65+
--8<--
66+
```
67+
### 3.3自定义数据集类
68+
``` py linenums="201" title="examples/xrdmatch/main.py"
69+
--8<--
70+
examples/xrdmatch/main.py:201:239
71+
--8<--
72+
```
73+
### 3.4 FlexMatch 半监督损失函数
74+
75+
1. **有标签数据训练**:使用交叉熵损失进行监督学习
76+
2. **无标签数据处理**
77+
- 生成弱增强和强增强版本
78+
- 基于弱增强版本生成伪标签
79+
- 使用强增强版本进行一致性训练
80+
3. **动态阈值**:根据类别置信度动态调整选择阈值
81+
``` py linenums="242" title="examples/xrdmatch/main.py"
82+
--8<--
83+
examples/xrdmatch/main.py:242:327
84+
--8<--
85+
```
86+
87+
### 3.5 损失函数
88+
```python
89+
total_loss = loss_lb + lambda_u * loss_ulb
90+
```
91+
92+
其中:
93+
- `loss_lb`: 有标签数据的交叉熵损失
94+
- `loss_ulb`: 无标签数据的一致性损失
95+
- `lambda_u`: 无标签损失权重(默认1.0)
96+
97+
### 3.6 训练配置
98+
99+
- **优化器**:AdamW (lr=3e-4, weight_decay=0.01)
100+
- **批次大小**:有标签32,无标签96
101+
- **实验次数**:100次独立实验
102+
- **训练轮数**:每个实验100轮(每轮10个迭代)
103+
- **数据划分**:正类前20个,负类前75个用于训练
104+
- **模型保存**:仅当F1分数≥0.7时保存模型
105+
106+
## 3.7 评估指标
107+
108+
- **准确率 (Accuracy)**:正确分类的样本比例
109+
- **精确率 (Precision)**:预测为正类中实际为正类的比例
110+
- **召回率 (Recall)**:实际正类中被正确预测的比例
111+
- **F1 分数 (F1-Score)**:精确率和召回率的调和平均
112+
- **混淆矩阵 (Confusion Matrix)**:各类别预测结果的详细分布
113+
- **评估方法**:支持训练时评估和独立评估两种模式
114+
- **训练时评估**:在训练过程中自动调用,将日志保存到每个实验的 saved_models_ppsci/exp_*/log.txt 文件中
115+
- **独立评估**:使用 `--mode eval` 参数对已保存的模型进行评估,结果保存到 eval_log.txt 文件中
116+
- **模型保存策略**:仅当F1分数≥0.7时保存模型
117+
- **内置评估实现**:该函数会在训练过程中自动调用,并将日志保存到每个实验的 saved_models_ppsci/exp_*/log.txt 文件中。代码实现:
118+
``` py linenums="412" title="examples/xrdmatch/main.py"
119+
--8<--
120+
examples/xrdmatch/main.py:412:457
121+
--8<--
122+
```
123+
124+
## 4.结果示例
125+
126+
### 训练日志示例
127+
128+
```
129+
Epoch: 0
130+
[2025-8-27 02:40:12,747 INFO] confusion matrix
131+
[2025-8-27 02:40:12,748 INFO] [[0.22222222 0.77777778]
132+
[0.2 0.8 ]]
133+
[2025-8-27 02:40:12,748 INFO] evaluation metric
134+
[2025-8-27 02:40:12,748 INFO] acc: 0.7188
135+
[2025-8-27 02:40:12,748 INFO] precision: 0.5083
136+
[2025-8-27 02:40:12,750 INFO] recall: 0.5111
137+
[2025-8-27 02:40:12,750 INFO] f1: 0.5060
138+
F1 score 0.5060 < 0.7, model not saved at epoch 0
139+
```
140+
141+
### 性能指标
142+
143+
在标准测试集上的典型性能:
144+
145+
| 指标 ||
146+
|------|-----|
147+
| 准确率 | 0.797 |
148+
| 精确率 | 0.673 |
149+
| 召回率 | 0.789 |
150+
| F1分数 | 0.695 |
151+
152+
### 评估日志示例
153+
154+
```
155+
Evaluating experiment 0 epoch 11 model...
156+
Starting prediction...
157+
confusion matrix
158+
[[0.77777778 0.22222222]
159+
[0.16363636 0.83636364]]
160+
evaluation metric
161+
acc: 0.6480
162+
precision: 0.6480
163+
recall: 0.6480
164+
f1: 0.6480
165+
```
166+
167+
## 5.完整代码
168+
``` py linenums="1" title="examples/xrdmatch/main.py"
169+
--8<--
170+
examples/xrdmatch/main.py
171+
--8<--
172+
```
173+
174+
## 参考文献
175+
176+
Zheng Wan., et al.** "XRDMatch: a semi-supervised learning framework to efficiently discover room temperature lithium superionic conductors." *Energy Environ. Sci.*, 2024, 17, 9487. (https://pubs.rsc.org/en/content/articlelanding/2024/ee/d4ee02970d)
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
Global:
2+
seed: 0
3+
device: gpu
4+
5+
MODEL:
6+
name: VGG
7+
in_channel: 1
8+
num_classes: 2
9+
10+
# Semi-supervised learning configuration
11+
SEMI_SUPERVISED:
12+
num_labels: 20
13+
num_train_iter: 1000 # Total training iterations
14+
num_eval_iter: 10 # Evaluate every 10 epochs
15+
un_ratio: 0.8
16+
ulb_num_labels: 10000
17+
include_lb_to_ulb: true
18+
uratio: 3 # Unlabeled data ratio for batch size
19+
20+
# Data path configuration
21+
DATA:
22+
ulbs_path: ./xrd_data/ulbs.csv
23+
lbs_path: ./xrd_data/lbs.csv
24+
25+
# Training configuration
26+
TRAIN:
27+
epochs: 100
28+
log_interval: 10
29+
save_dir: ./saved_models_ppsci/
30+
num_experiments: 10
31+
32+
# Data loader configuration
33+
DATALOADER:
34+
batch_size: 32
35+
eval_batch_size: 32
36+
shuffle: true
37+
num_workers: 0
38+
39+
# Optimizer configuration
40+
OPTIMIZER:
41+
name: AdamW
42+
learning_rate: 0.0003
43+
weight_decay: 0.01
44+
45+
# Loss function configuration
46+
LOSS:
47+
name: CrossEntropyLoss
48+
49+
# Evaluation configuration
50+
EVAL:
51+
interval: 1
52+
metric: f1
53+
54+
# Data augmentation configuration
55+
AUGMENTATION:
56+
weak_aug:
57+
noise_ratio: 0.1
58+
noise_peak: 0.05
59+
move_gap: 100
60+
strong_aug:
61+
noise_ratio: 0.2
62+
noise_peak: 0.1
63+
move_gap: 200

0 commit comments

Comments
 (0)