matlab实现LSTM时序预测-ChatGPT4o+mathworks文档

一、初始化与加载数据clearclose allclcclear: 清除工作区变量;close all: 关闭所有图窗;clc: 清空命令窗口。load WaveformData加载一个内


一、初始化与加载数据

clear
close all
clc
  • clear: 清除工作区变量;
  • close all: 关闭所有图窗;
  • clc: 清空命令窗口。
load WaveformData
  • 加载一个内置的数据集,格式为 data{n} 是第 n 个观测序列,大小为 [时间步 × 通道数]。
  • 比如 data{1} 是一个形如 [100 × 3] 的矩阵,表示有 100 个时间步、3 个通道。
numChannels = size(data{1},2)
  • 获取每个序列的通道数(列数),用于后续定义网络。
tiledlayout(2,2)
for i = 1:4
    nexttile
    stackedplot(data{i})
    xlabel("Time Step")
end
  • 可视化前 4 个时间序列样本;
  • stackedplot 会将多通道数据画在一张图里,便于观察每个通道的变化趋势。

二、划分训练集和测试集

numObservations = numel(data);
idxTrain = 1:floor(0.9*numObservations);
idxTest = floor(0.9*numObservations)+1:numObservations;
dataTrain = data(idxTrain);
dataTest = data(idxTest);
  • 把数据分为 90% 的训练集,10% 的测试集;
  • dataTraindataTest 都是 cell 数组,每个元素是一个 [时间步 × 通道数] 的矩阵。

三、构造训练数据:输入X和目标T

X = dataTrain{n};
XTrain{n} = X(1:end-1,:);
TTrain{n} = X(2:end,:);
  • 每个样本的训练对是:
    • 输入 XTrain{n}:去掉最后一个时间点;
    • 目标 TTrain{n}:去掉第一个时间点;
    • 表示:当前时刻预测下一时刻

归一化处理:

muX = mean(cell2mat(XTrain));
sigmaX = std(cell2mat(XTrain),0);
  • 把所有 XTrain 拼接起来,计算均值和标准差;
  • 之后标准化每个样本:
XTrain{n} = (XTrain{n} - muX) ./ sigmaX;

这样做能提升 LSTM 收敛速度与精度。


四、定义 LSTM 网络

layers = [
    sequenceInputLayer(numChannels)  % 输入层
    lstmLayer(128)                  % LSTM 隐藏层
    fullyConnectedLayer(numChannels)]; % 输出层(与输入通道数相同)
  • 网络结构很简单:
    • 输入:多通道序列;
    • LSTM:128 个隐藏单元,捕捉时间特征;
    • 输出:一个时刻的多通道预测。

五、设置训练选项并训练模型

options = trainingOptions("adam", ...
    MaxEpochs=200, ...
    SequencePaddingDirection="left", ...
    Shuffle="every-epoch", ...
    Plots="training-progress", ...
    Verbose=false);
  • 优化器:Adam,自适应梯度更新;
  • SequencePaddingDirection="left":填充短序列时靠左对齐;
  • Plots="training-progress":开启训练可视化图。
net = trainnet(XTrain,TTrain,layers,"mse",options);
  • trainnet() 是 R2023b 新函数,替代 trainNetwork()
  • "mse":均方误差损失。

六、测试模型性能

XTest{n} = (X(1:end-1,:) - muX) ./ sigmaX;
TTest{n} = (X(2:end,:) - muT) ./ sigmaT;
  • 用相同的均值/标准差对测试数据归一化;
  • 注意 TTest 是用于评估模型预测性能。
YTest = minibatchpredict(net,XTest, ...);
  • 这个函数会一次性预测所有测试样本,结果是 cell 数组,每个元素是 [时间步 × 通道数] 的预测结果。
err(n) = rmse(Y,T,"all");
  • 计算每个测试样本的 RMSE(预测 vs 真实);
  • 用于衡量测试集上的平均预测误差。

七、Open Loop 预测

offset = 75;
[Z,state] = predict(net,X(1:offset,:));
net.State = state;
  • 使用前 75 个真实输入来“热启动”网络;
  • net.State 是网络内部的 LSTM 状态,会持续更新。
for t = 1:numPredictionTimeSteps-1
    Xt = X(offset+t,:);
    [Y(t+1,:),state] = predict(net,Xt);
    net.State = state;
end
  • 每一步预测都用当前真实输入;
  • 预测结果与目标比较,绘图展示。

八、Closed Loop 预测(自回归)

[Z,state] = predict(net,X(1:offset,:));
Y(1,:) = Z(end,:);
  • 用历史数据预测最后一个点,作为预测序列的起点。
for t = 2:numPredictionTimeSteps
    [Y(t,:),state] = predict(net,Y(t-1,:));
  • 之后每一步都是用上一时刻的预测值来继续预测;
  • 完全自回归预测,更接近真实使用场景(比如天气预测、金融预测等)。

图示解读

  • 预测图会显示:
    • 蓝线:原始真实数据;
    • 虚线:预测的时间段。

Open Loop:真实输入 + 预测输出(预测一段)
Closed Loop:预测输入 + 预测输出(连续多步)


总结

这个项目的流程:

  1. 数据预处理:标准化 + 输入-目标配对;
  2. 网络搭建:LSTM 用于建模时序;
  3. 训练模型:多序列训练;
  4. 评估性能:RMSE,误差分布;
  5. 未来预测
    • Open Loop:部分已知输入;
    • Closed Loop:全程自预测。

发布者:admin,转转请注明出处:http://www.yc00.com/web/1754769572a5199978.html

相关推荐

发表回复

评论列表(0条)

  • 暂无评论

联系我们

400-800-8888

在线咨询: QQ交谈

邮件:admin@example.com

工作时间:周一至周五,9:30-18:30,节假日休息

关注微信