Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

变周期时预测效果退化为均值预测 #41

Open
mklpr opened this issue Dec 20, 2024 · 4 comments
Open

变周期时预测效果退化为均值预测 #41

mklpr opened this issue Dec 20, 2024 · 4 comments

Comments

@mklpr
Copy link

mklpr commented Dec 20, 2024

timer-base-84m模型对自己的数据做预测结果基本都相当于是context的均值预测,简化到用标准正弦信号做测试,发现周期为24时结果还行,周期为7时也变成均值预测了,7应该也算一种很常见的实际时序数据周期,是否是由于timer的预训练模型是没包含周期为7的数据?另一方面,用chronos测试很多不同周期包括非常规自然数据周期基本都预测的还行,chronos应该也不是训练样本直接包含了所有这些我随意指定周期的训练数据,看起来有一定周期泛化能力,timer可能是什么原因导致对周期变化这么敏感?

附timer测试代码和部分结果,

t = np.arange(0, 10000)
T = 7
y = np.sin(2 * np.pi / T * t)

context_length = 96
prediction_length = 28

model = AutoModelForCausalLM.from_pretrained('thuml/timer-base-84m', trust_remote_code=True)

context = torch.tensor(y[:context_length]).unsqueeze(0).float()
predict = model.generate(context, max_new_tokens=prediction_length).squeeze()

plt.plot(t[: context_length], y[: context_length], '-+', color="royalblue", label="historical data")
plt.plot(t[context_length: context_length + prediction_length], y[context_length: context_length + prediction_length], '-+', color="yellow", label="true data")
plt.plot(t[context_length: context_length + prediction_length], predict, '-+', color="tomato", label="predict data")
plt.legend()
plt.grid()
plt.title(f'[timer-base-84m] period T = {T}')

image
image
image
image

@WenWeiTHU
Copy link
Collaborator

WenWeiTHU commented Dec 20, 2024

感谢您提供的测试样例,这对我们很有启发。timer-base-84m的大量训练样本是以24为周期,而7等周期的样本较少,因此在相关周期上零样本预测效果不佳。另外我们发现在chronos的预训练数据包含一部分合成数据(下图),周期配比可覆盖1-512等,我们也测试了chronos在720等分布外大周期数据上的预测效果,发现也会出现预测性能劣化的问题。因此目前我们认为这主要是由于预训练数据集多样性不足导致。

image

@mklpr
Copy link
Author

mklpr commented Dec 20, 2024

我测了下chronos遍历1-365之间的不同周期数据下的预测性能,跟训练样本合成数据用到了的周期做比较看,在比较低周期时(T<50)时预测性能波动比较大,有时合成数据中出现过的周期也预测不好,在比较高周期时虽然合成训练数据中周期分布已经很稀疏了,但在不同周期下预测性能还行,可能有什么模型内在原因,也可能是高周期时prediction_len相对T比较小容易预测点有关,供参考。

image
image
image
image

import pandas as pd
import torch
from chronos import BaseChronosPipeline
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt

plt.rcParams['figure.figsize'] = (12, 8)
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
plt.rcParams['font.size'] = 14

min_T = 1
max_T = 365
mses = []
context_Ts = 4
min_context_len = 96
max_context_len = 365 * 2
prediction_len = 64
models = [f'amazon/chronos-bolt-{m}' for m in ['tiny', 'mini', 'small', 'base']] + ['baseline']
for model in models:
    if model != 'baseline':
        pipeline = BaseChronosPipeline.from_pretrained(
            model,
            device_map="cpu", 
            torch_dtype=torch.bfloat16,
        )
        for T in tqdm(range(min_T, max_T + 1)):
            context_len = min([max_context_len, max([min_context_len, T * context_Ts])])
            t = np.arange(0, context_len + prediction_len)
            y = np.sin(2 * np.pi / T * t)
            y = torch.tensor(y)
            _, predict_mean = pipeline.predict_quantiles(y[: context_len], prediction_len, [0.5])
            predict_mean = predict_mean.squeeze()
            mse = ((predict_mean - y[context_len: context_len + prediction_len]) ** 2).mean().item()
            mses.append([model, T, mse])
    else:
        for T in tqdm(range(min_T, max_T + 1)):
            context_len = min([max_context_len, max([min_context_len, T * context_Ts])])
            t = np.arange(0, context_len + prediction_len)
            y = np.sin(2 * np.pi / T * t)
            predict_mean = np.ones(prediction_len) * y[: context_len].mean()
            mse = ((predict_mean - y[context_len: context_len + prediction_len]) ** 2).mean()
            mses.append(['baseline', T, mse])

mses_df = pd.DataFrame(mses, columns=['model', 'T', 'mse'])
mses_df.to_csv('test_T_mses.csv', index=False)

for T_range in [(1, 365), (1, 30), (30, 96), (96, 365)]:
    plt.figure()
    for model in models:
        T = mses_df[(mses_df['model'] == model) & (mses_df['T'] >= T_range[0]) & (mses_df['T'] <= T_range[1])]['T']
        mse = mses_df[(mses_df['model'] == model) & (mses_df['T'] >= T_range[0]) & (mses_df['T'] <= T_range[1])]['mse']
        plt.plot(T, mse, label=model)

    traindata_T = np.array([4, 6, 7, 10, 12, 14, 24, 26, 30, 40, 48, 52, 60, 96, 168, 336, 365])
    traindata_T = traindata_T[(traindata_T >= T_range[0]) & (traindata_T <= T_range[1])]
    plt.plot(traindata_T, np.zeros_like(traindata_T), 'o', color='black', alpha=0.5, label='traindata_T')
    plt.legend()
    plt.xlabel('T')
    plt.ylabel('mse')
    plt.grid()
    plt.show()

对更大周期我现在暂无GPU跑起来慢没做遍历测试,随机做了几个单点测试看起来很行,比如下面是720周期数据用600点的context预测360点的prediction

image

@mklpr
Copy link
Author

mklpr commented Dec 21, 2024

补充下timer和moirai的测试方便比较

timer

image
image
image
image

import pandas as pd
import torch
from transformers import AutoModelForCausalLM
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt

plt.rcParams['figure.figsize'] = (12, 8)
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
plt.rcParams['font.size'] = 14

min_T = 1
max_T = 365
mses = []
context_Ts = 4
min_context_len = 96
max_context_len = 365 * 2
prediction_len = 64
models = ['thuml/timer-base-84m', 'baseline']
for model in models:
    if model != 'baseline':
        pipeline = AutoModelForCausalLM.from_pretrained(model, trust_remote_code=True)
        for T in tqdm(range(min_T, max_T + 1)):
            context_len = min([max_context_len, max([min_context_len, T * context_Ts])])
            t = np.arange(0, context_len + prediction_len)
            y = np.sin(2 * np.pi / T * t)
            y = torch.tensor(y)
            predict_mean = pipeline.generate(y[: context_len].unsqueeze(0).float(), max_new_tokens=prediction_len).squeeze()
            mse = ((predict_mean - y[context_len: context_len + prediction_len]) ** 2).mean().item()
            mses.append([model, T, mse])
    else:
        for T in tqdm(range(min_T, max_T + 1)):
            context_len = min([max_context_len, max([min_context_len, T * context_Ts])])
            t = np.arange(0, context_len + prediction_len)
            y = np.sin(2 * np.pi / T * t)
            predict_mean = np.ones(prediction_len) * y[: context_len].mean()
            mse = ((predict_mean - y[context_len: context_len + prediction_len]) ** 2).mean()
            mses.append(['baseline', T, mse])

mses_df = pd.DataFrame(mses, columns=['model', 'T', 'mse'])
mses_df.to_csv('test_T_mses.csv', index=False)

for T_range in [(1, 365), (1, 30), (30, 96), (96, 365)]:
    plt.figure()
    for model in models:
        T = mses_df[(mses_df['model'] == model) & (mses_df['T'] >= T_range[0]) & (mses_df['T'] <= T_range[1])]['T']
        mse = mses_df[(mses_df['model'] == model) & (mses_df['T'] >= T_range[0]) & (mses_df['T'] <= T_range[1])]['mse']
        plt.plot(T, mse, label=model)

    plt.legend()
    plt.xlabel('T')
    plt.ylabel('mse')
    plt.grid()
    plt.show()

moirai

moirai在T=1和2时预测发散出超高误差了,也就是对常量序列这种最简单序列效果异常,下面用T=3到365做的测试,有4组模型周期组合仍存在预测异常发散的现象,绘图时予以剔除便于观测其他数据
image
image
image
image
image

import pandas as pd
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from gluonts.dataset.pandas import PandasDataset
from gluonts.dataset.split import split
from gluonts.dataset.common import ListDataset
from uni2ts.model.moirai import MoiraiForecast, MoiraiModule
from uni2ts.model.moirai_moe import MoiraiMoEForecast, MoiraiMoEModule

plt.rcParams['figure.figsize'] = (12, 8)
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
plt.rcParams['font.size'] = 14

min_T = 3
max_T = 365
mses = []
context_Ts = 4
min_context_len = 96
max_context_len = 365 * 2
prediction_len = 64
patch_size = 16
batch_size = 64
models = [f'Salesforce/moirai-1.1-R-{m}' for m in ['small', 'base', 'large']] + [
    f'Salesforce/moirai-moe-1.0-R-{m}' for m in ['small', 'base']] + ['baseline']
for model in models:
    if model != 'baseline':
        if 'moe' not in model:
            module = MoiraiModule.from_pretrained(model)
        else:
            module = MoiraiMoEModule.from_pretrained(model)
            
        for T in tqdm(range(min_T, max_T + 1)):
            context_len = min([max_context_len, max([min_context_len, T * context_Ts])])
            t = np.arange(0, context_len + prediction_len)
            y = np.sin(2 * np.pi / T * t)
            start = pd.Period("01-01-2024", freq='D')
            ds = ListDataset([{'target': y, "start": start}], freq='D')

            if 'moe' not in model:
                pipeline = MoiraiForecast(
                    module=module,
                    prediction_length=prediction_len,
                    context_length=context_len,
                    patch_size=patch_size,
                    num_samples=batch_size,
                    target_dim=1,
                    feat_dynamic_real_dim=0,
                    past_feat_dynamic_real_dim=0,
                )
            else:
                pipeline = MoiraiMoEForecast(
                    module=module,
                    prediction_length=prediction_len,
                    context_length=context_len,
                    patch_size=16,
                    num_samples=batch_size,
                    target_dim=1,
                    feat_dynamic_real_dim=0,
                    past_feat_dynamic_real_dim=0,
                )
            
            train, test_template = split(ds, offset=-prediction_len)
            test_data = test_template.generate_instances(
                prediction_length=prediction_len,
                windows=1,
                distance=prediction_len,
            )
            predictor = pipeline.create_predictor(batch_size=batch_size)
            forecasts = predictor.predict(test_data.input)
            forecast = next(iter(forecasts))
            predict_mean = forecast.samples.mean(axis=0)
            mse = ((predict_mean - y[context_len: context_len + prediction_len]) ** 2).mean().item()
            mses.append([model, T, mse])
    else:
        for T in tqdm(range(min_T, max_T + 1)):
            context_len = min([max_context_len, max([min_context_len, T * context_Ts])])
            t = np.arange(0, context_len + prediction_len)
            y = np.sin(2 * np.pi / T * t)
            predict_mean = np.ones(prediction_len) * y[: context_len].mean()
            mse = ((predict_mean - y[context_len: context_len + prediction_len]) ** 2).mean()
            mses.append(['baseline', T, mse])

mses_df = pd.DataFrame(mses, columns=['model', 'T', 'mse'])
mses_df.to_csv('test_T_mses.csv', index=False)

for T_range in [(3, 365), (3, 30), (30, 96), (96, 365)]:
    plt.figure()
    for model in models:
        T = mses_df[(mses_df['model'] == model) & (mses_df['T'] >= T_range[0]) & (mses_df['T'] <= T_range[1]) & (mses_df['mse'] < 3)]['T']
        mse = mses_df[(mses_df['model'] == model) & (mses_df['T'] >= T_range[0]) & (mses_df['T'] <= T_range[1]) & (mses_df['mse'] < 3)]['mse']
        plt.plot(T, mse, label=model)

    plt.legend()
    plt.xlabel('T')
    plt.ylabel('mse')
    plt.grid()
    plt.show()

@WenWeiTHU
Copy link
Collaborator

@mklpr

感谢您提供的prediction case,我们会继续深入研究

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants