wdata_origin = wdata; index_list = randperm(size(wdata, 1)); ind = round(0.8*length(index_list)); train_index = index_list(1:ind); test_index = index_list(ind+1:end); train_index = sort(train_index); test_index = sort(test_index); dataTrain = wdata(train_index, :); dataTest = wdata(test_index, :); XTrain = dataTrain(:, 1:end-1)'; YTrain = dataTrain(:, end)'; XTest = dataTest(:, 1:end-1)'; YTest = dataTest(:, end)'; layers = get_lstm_net(wd); options = trainingOptions('adam', ... 'MaxEpochs',200, ... 'GradientThreshold',1, ... 'InitialLearnRate',0.005, ... 'LearnRateSchedule','piecewise', ... 'LearnRateDropPeriod',125, ... 'LearnRateDropFactor',0.2, ... 'Verbose',0, ... 'Plots','training-progress'); net = trainNetwork(XTrain,YTrain,layers,options); Xall = wdata_origin(:, 1:end-1)'; Yall = wdata_origin(:, end)'; YPred = predict(net,Xall,'MiniBatchSize',1);
时间: 2023-12-19 18:06:50 浏览: 352
wdata_neq_0_wvalid_up.vcd
这段代码是在进行数据预处理、模型训练和测试的过程。首先,将原始数据集wdata复制一份存储到wdata_origin中。然后,将数据集的索引随机打乱,划分成80%的训练集和20%的测试集,存储到train_index和test_index中。接下来,通过train_index和test_index将数据集分别划分成训练集和测试集,存储到dataTrain和dataTest中。将训练集和测试集的输入和输出分别存储到XTrain、YTrain、XTest和YTest中。然后,通过get_lstm_net函数获取一个LSTM神经网络模型的层次结构,存储到layers中。接着,通过trainingOptions函数设置训练选项,包括优化器、最大训练轮数、梯度阈值、初始学习率、学习率调整方式、学习率下降周期、学习率下降因子、是否显示详细信息和绘制训练进度图等。最后,通过trainNetwork函数使用训练集和训练选项训练LSTM模型,存储到net中。将所有数据集的输入存储到Xall中,使用训练好的LSTM模型对Xall进行预测,并将预测结果存储到YPred中。
阅读全文