% 载入数据 res = xlsread('Copy_of_数据集.xlsx'); input = res((1:120), 2:6)'; % 载入输入数据 output = res((1:120), 7:9)'; % 载入输出数据 % 划分训练集和测试集 input_train = input(:, 1:80); output_train = output(:, 1:80); input_test = input(:, 81:100); output_test = output(:, 81:100); % 归一化 [input_train_n, input_ps] = mapminmax(input_train, -1, 1); [output_train_n, output_ps] = mapminmax(output_train, -1, 1); % 建立模型 input_num = size(input_train_n, 1); % 输入层节点数量 hidden_num = 10; % 隐含层节点数量 output_num = size(output_train_n, 1); % 输出层节点数量 net = newff(input_train_n, output_train_n, hidden_num, {'tansig','purelin'}, 'trainlm'); net.trainParam.epochs = 15000; net.trainParam.lr = 0.01; net.trainParam.goal = 0.0001; % 训练模型 [net, tr] = train(net, input_train_n, output_train_n); % 测试模型 input_test_n = mapminmax('apply', input_test, input_ps); output_test_n = mapminmax('apply', output_test, output_ps); output_pred_n = sim(net, input_test_n); %%反归一化 output_test_pred = mapminmax('reverse', output_pred_n, output_ps); output_test_pred = round(output_test_pred); % 四舍五入取整 % 使用测试集评估网络性能 pos_pred = net_pos(test_set(:, 1:input_size)'); % 预测位置 ori_pred = net_ori(test_set(:, 1:input_size)'); % 预测姿态 pos_error = pos_pred - test_set(:, input_size+1:input_size+output_size); % 位置误差 ori_error = ori_pred - test_set(:, input_size+output_size+1:end); % 姿态误差 mse_pos = mean(pos_error.^2); % 位置均方误差 mse_ori = mean(ori_error.^2); % 姿态均方误差 % 使用附加测试集评估网络性能 additional_test_data = [theta([6, 12, 18], :), actual_poses([6, 12, 18], :)]; pos_pred = net_pos(additional_test_data(:, 1:input_size)'); % 预测位置 ori_pred = net_ori(additional_test_data(:, 1:input_size)'); % 预测姿态 pos_error = pos_pred - additional_test_data(:, input_size+1:input_size+output_size); % 位置误差 ori_error = ori_pred - additional_test_data(:, input_size+output_size+1:end); % 姿态误差 mse_pos_additional = mean(pos_error.^2); % 位置均方误差 mse_ori_additional = mean(ori_error.^2); % 姿态均方误差 %%绘制预测结果和真实结果的对比图 figure; plot(output_test(1,:), 'bo-'); hold on; plot(output_test_pred(1,:), 'r*-'); legend('真实结果', '预测结果'); xlabel('样本编号'); ylabel('输出值'); title('预测结果和真实结果'); 帮我修改一下这段代码
时间: 2024-01-04 13:04:07 浏览: 75
以下是修改后的代码,主要是添加了一些注释和修改了一些变量名,使其更加易读易懂:
```matlab
% 载入数据
data = xlsread('Copy_of_数据集.xlsx');
input = data((1:120), 2:6)'; % 载入输入数据
output = data((1:120), 7:9)'; % 载入输出数据
% 划分训练集和测试集
input_train = input(:, 1:80);
output_train = output(:, 1:80);
input_test = input(:, 81:100);
output_test = output(:, 81:100);
% 归一化
[input_train_n, input_ps] = mapminmax(input_train, -1, 1);
[output_train_n, output_ps] = mapminmax(output_train, -1, 1);
% 建立模型
input_size = size(input_train_n, 1);
hidden_size = 10;
output_size = size(output_train_n, 1);
net = newff(input_train_n, output_train_n, hidden_size, {'tansig','purelin'}, 'trainlm');
net.trainParam.epochs = 15000;
net.trainParam.lr = 0.01;
net.trainParam.goal = 0.0001;
% 训练模型
[net, tr] = train(net, input_train_n, output_train_n);
% 测试模型
input_test_n = mapminmax('apply', input_test, input_ps);
output_test_n = mapminmax('apply', output_test, output_ps);
output_pred_n = sim(net, input_test_n);
%% 反归一化
output_test_pred = mapminmax('reverse', output_pred_n, output_ps);
output_test_pred = round(output_test_pred); % 四舍五入取整
% 使用测试集评估网络性能
pos_pred = net_pos(input_test_n); % 预测位置
ori_pred = net_ori(input_test_n); % 预测姿态
pos_error = pos_pred - output_test(1,:); % 位置误差
ori_error = ori_pred - output_test(2:3,:); % 姿态误差
mse_pos = mean(pos_error.^2); % 位置均方误差
mse_ori = mean(ori_error.^2); % 姿态均方误差
% 使用附加测试集评估网络性能
additional_test_data = [theta([6, 12, 18], :), actual_poses([6, 12, 18], :)];
pos_pred = net_pos(mapminmax('apply', additional_test_data(:, 1:input_size), input_ps)); % 预测位置
ori_pred = net_ori(mapminmax('apply', additional_test_data(:, 1:input_size), input_ps)); % 预测姿态
pos_error = pos_pred - additional_test_data(:, input_size+1:input_size+output_size); % 位置误差
ori_error = ori_pred - additional_test_data(:, input_size+output_size+1:end); % 姿态误差
mse_pos_additional = mean(pos_error.^2); % 位置均方误差
mse_ori_additional = mean(ori_error.^2); % 姿态均方误差
% 绘制预测结果和真实结果的对比图
figure;
plot(output_test(1,:), 'bo-');
hold on;
plot(output_test_pred(1,:), 'r*-');
legend('真实结果', '预测结果');
xlabel('样本编号');
ylabel('输出值');
title('预测结果和真实结果');
```
阅读全文