clear all; clc; % 载入数据 data = xlsread('Copy_of_数据集.xlsx'); input = data((1:120), 2:6)'; output = data((1:120), 7:9)'; % 划分训练集和测试集 input_train = input(:, 1:80); output_train = output(:, 1:80); input_test = input(:, 81:100); output_test = output(:, 81:100); % 归一化 [input_train_n, input_ps] = mapminmax(input_train, -1, 1); [output_train_n, output_ps] = mapminmax(output_train, -1, 1); % 建立模型 input_size = size(input_train_n, 1); hidden_size = 10; output_size = size(output_train_n, 1); net = newff(input_train_n, output_train_n, hidden_size, {'tansig','purelin'}, 'trainlm'); net.trainParam.epochs = 15000; net.trainParam.lr = 0.01; net.trainParam.goal = 0.0001; % 训练模型 [net, tr] = train(net, input_train_n, output_train_n); % 测试模型 input_test_n = mapminmax('apply', input_test, input_ps); output_test_n = mapminmax('apply', output_test, output_ps); output_pred_n = sim(net, input_test_n); %% 反归一化 output_test_pred = mapminmax('reverse', output_pred_n, output_ps); output_test_pred = round(output_test_pred); % 四舍五入取整 % 使用测试集评估网络性能 pos_pred = sim(net, input_test_n); % 预测位置 ori_pred = sim(net, input_test_n); % 预测姿态 pos_error = pos_pred - output_test(:,1:20)% 位置误差 ori_error = ori_pred - output_test(:,1:20);% 姿态误差 mse_pos = mean(pos_error.^2); % 位置均方误差 mse_ori = mean(ori_error.^2); % 姿态均方误差 % 使用附加测试集评估网络性能 % additional_test_data = [theta([6, 12, 18], :), actual_poses([6, 12, 18], :)]; additional_test_data = input(81:100,:); additional_test_data_n = mapminmax('apply', additional_test_data, input_ps); pos_pred = sim(net, additional_test_data_n); % 预测位置 ori_pred = sim(net, additional_test_data_n); % 预测姿态 pos_error = pos_pred - output(1,:); % 位置误差 ori_error = ori_pred - output(1,:); % 姿态误差 mse_pos_additional = mean(pos_error.^2); % 位置均方误差 mse_ori_additional = mean(ori_error.^2); % 姿态均方误差 % 调整维度为 2 x 10 % 绘制预测结果和真实结果的对比图 figure; plot(output_test(1,:), 'bo-'); hold on; plot(output_test_pred(1,:)', 'r*-'); % 注意转置 legend('真实结果', '预测结果'); xlabel('样本编号'); ylabel('输出值'); title('预测结果和真实结果');additional_test_data = input(81:100,:); 位置 1 处的索引超出数组边界(不能超出 5)。帮我修改
时间: 2024-01-04 15:04:07 浏览: 282
arima_test.zip_ARIMA代码_arima.test_arima预测_stationary test_平稳性检验
5星 · 资源好评率100%
在这段代码中,出现了 `input` 变量的命名冲突。建议将 `input` 变量改为其他名称,以避免和 MATLAB 内置函数冲突。另外,在计算 `pos_error` 和 `ori_error` 时,应该使用 `output_test_n`,而不是 `output_test(:,1:20)`。修改后的代码如下:
```
clear all;
clc;
% 载入数据
data = xlsread('Copy_of_数据集.xlsx');
input_data = data((1:120), 2:6)';
output_data = data((1:120), 7:9)';
% 划分训练集和测试集
input_train = input_data(:, 1:80);
output_train = output_data(:, 1:80);
input_test = input_data(:, 81:100);
output_test = output_data(:, 81:100);
% 归一化
[input_train_n, input_ps] = mapminmax(input_train, -1, 1);
[output_train_n, output_ps] = mapminmax(output_train, -1, 1);
% 建立模型
input_size = size(input_train_n, 1);
hidden_size = 10;
output_size = size(output_train_n, 1);
net = newff(input_train_n, output_train_n, hidden_size, {'tansig','purelin'}, 'trainlm');
net.trainParam.epochs = 15000;
net.trainParam.lr = 0.01;
net.trainParam.goal = 0.0001;
% 训练模型
[net, tr] = train(net, input_train_n, output_train_n);
% 测试模型
input_test_n = mapminmax('apply', input_test, input_ps);
output_test_n = mapminmax('apply', output_test, output_ps);
output_pred_n = sim(net, input_test_n);
%% 反归一化
output_test_pred = mapminmax('reverse', output_pred_n, output_ps);
output_test_pred = round(output_test_pred); % 四舍五入取整
% 使用测试集评估网络性能
pos_pred = sim(net, input_test_n); % 预测位置
ori_pred = sim(net, input_test_n); % 预测姿态
pos_error = pos_pred - output_test_n(1,:); % 位置误差
ori_error = ori_pred - output_test_n(2,:); % 姿态误差
mse_pos = mean(pos_error.^2); % 位置均方误差
mse_ori = mean(ori_error.^2); % 姿态均方误差
% 使用附加测试集评估网络性能
additional_test_data = input_data(81:100,:);
additional_test_data_n = mapminmax('apply', additional_test_data, input_ps);
pos_pred = sim(net, additional_test_data_n); % 预测位置
ori_pred = sim(net, additional_test_data_n); % 预测姿态
pos_error = pos_pred - output_data(81,:); % 位置误差
ori_error = ori_pred - output_data(81,:); % 姿态误差
mse_pos_additional = mean(pos_error.^2); % 位置均方误差
mse_ori_additional = mean(ori_error.^2); % 姿态均方误差
% 调整维度为 2 x 10
% 绘制预测结果和真实结果的对比图
figure;
plot(output_test_n(1,:), 'bo-');
hold on;
plot(output_test_pred(1,:)', 'r*-'); % 注意转置
legend('真实结果', '预测结果');
xlabel('样本编号');
ylabel('输出值');
title('预测结果和真实结果');
```
阅读全文