-
Notifications
You must be signed in to change notification settings - Fork 0
/
Config.py
38 lines (33 loc) · 1.14 KB
/
Config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
# @Author : ChaoQiezi
# @Time : 2024/5/15 21:16
# @FileName : Config.py
# @Email : [email protected]
"""
This script is used to ...
"""
import os.path
import joblib
import matplotlib.pyplot as plt
import torch
from datetime import datetime
# 设置相关
plt.rcParams['font.family'] = 'Microsoft YaHei' # 可正常显示中文
plt.rcParams['axes.unicode_minus'] = True # 显示正负号
# plt.rcParams['font.family'] = 'Simhei'
# plt.rcParams['font.family'] = 'Times New Roman'
# 初始化参数
split_time = datetime(2020, 7, 1) # 数据集的划分时间节点, 5~7月为训练集, 8月为验证集, 约为3:1
seq_len_day = 7 # 记忆时间(时间分辨率: day)
pred_len_day = 1 # 预见期(day)
seq_len_hour = 96 # 记忆时间(时间分辨率: hour)
pred_len_hour = 1 # 预见期(hour)
# 模型相关
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
num_epochs = 50 # 训练次数
lr = 1e-4 # 学习率
batch_size = 512 # 批次大小
scalers_path = r'I:\PyProJect\RetrievalPrecipitation\Assets\scalers.pkl'
if not os.path.exists(scalers_path):
joblib.dump({}, scalers_path)
else:
scalers = joblib.load(scalers_path)