请你帮我写出修改后的完整代码
时间: 2024-10-24 22:12:50 浏览: 13
设计模式之美——教你写出高质量代码
当然,这里是经过修改后的完整代码,包含了`process_grid_for_resnet`函数用于处理单个网格的特征提取以及新的`extract_features`函数处理整个栅格列表:
```python
import numpy as np
import torch
from torchvision import models, transforms
from PIL import Image
import pandas as pd
# 其他代码不变
# 图像预处理
preprocess = transforms.Compose([
transforms.Resize((224, 224)), # 使用固定大小以保持ResNet的一致性
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# 定义处理单个网格的函数
def process_grid_for_resnet(grid_tensor, resnet, preprocess):
with torch.no_grad():
grid_tensor = preprocess(grid_tensor)
grid_tensor = grid_tensor.unsqueeze(0) # 添加 batch dimension
feature_map = resnet(grid_tensor) # 前向传播
# 获取最后一层的特征图,通常是卷积层的结果
features = feature_map[:, :, :, :].squeeze(0) # 按需选择特定层的特征
return features
# 提取特征函数
def extract_features(grids, resnet, preprocess):
features_list = []
for grid in grids:
features = process_grid_for_resnet(grid, resnet, preprocess)
features_list.append(features)
return features_list
# 加载并划分图像
poi_grids = load_and_split_image(poi_image_path, grid_width, grid_height)
nightlight_grids = load_and_split_image(nightlight_image_path, grid_width, grid_height)
# 提取特征
poi_features = extract_features(poi_grids, resnet, preprocess)
nightlight_features = extract_features(nightlight_grids, resnet, preprocess)
# 融合特征
def fuse_features(poi_features, nightlight_features, method='sum'):
fused_features = []
for poi_feat, nightlight_feat in zip(poi_features, nightlight_features):
if method == 'sum':
fused_feature = poi_feat + nightlight_feat
elif method == 'average':
fused_feature = (poi_feat + nightlight_feat) / 2
else:
raise ValueError("Invalid fusion method. Choose 'sum' or 'average'.")
fused_features.append(fused_feature)
return fused_features
fused_features = fuse_features(poi_features, nightlight_features, 'average') # 或者改为 'sum'
# 输出融合后的特征值列表
fused_feature_values = [feature.tolist() for feature in fused_features]
# 转换为 DataFrame
df = pd.DataFrame(fused_feature_values, columns=['Fused Feature'])
# 保存到 Excel 文件
df.to_excel('fused_features.xlsx', index=False)
```
在这个版本中,我添加了一个`fuse_features`函数,它接受一个融合方法(默认为求和,也可以改选为平均)。根据您选择的融合方法,您可以调用该函数对 `poi_features` 和 `nightlight_features` 进行融合。
阅读全文