import os import pickle import cv2 import matplotlib.pyplot as plt import numpy as np from keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout from keras.models import Sequential from keras.optimizers import adam_v2 from keras_preprocessing.image import ImageDataGenerator from sklearn.model_selection import train_test_split from sklearn.preprocessing import LabelEncoder, OneHotEncoder, LabelBinarizer def load_data(filename=r'/root/autodl-tmp/RML2016.10b.dat'): with open(r'/root/autodl-tmp/RML2016.10b.dat', 'rb') as p_f: Xd = pickle.load(p_f, encoding="latin-1") # 提取频谱图数据和标签 spectrograms = [] labels = [] train_idx = [] val_idx = [] test_idx = [] np.random.seed(2016) a = 0 for (mod, snr) in Xd: X_mod_snr = Xd[(mod, snr)] for i in range(X_mod_snr.shape[0]): data = X_mod_snr[i, 0] frequency_spectrum = np.fft.fft(data) power_spectrum = np.abs(frequency_spectrum) ** 2 spectrograms.append(power_spectrum) labels.append(mod) train_idx += list(np.random.choice(range(a * 6000, (a + 1) * 6000), size=3600, replace=False)) val_idx += list(np.random.choice(list(set(range(a * 6000, (a + 1) * 6000)) - set(train_idx)), size=1200, replace=False)) a += 1 # 数据预处理 # 1. 将频谱图的数值范围调整到0到1之间 spectrograms_normalized = spectrograms / np.max(spectrograms) # 2. 对标签进行独热编码 label_binarizer = LabelBinarizer() labels_encoded= label_binarizer.fit_transform(labels) # transfor the label form to one-hot # 3. 划分训练集、验证集和测试集 # X_train, X_temp, y_train, y_temp = train_test_split(spectrograms_normalized, labels_encoded, test_size=0.15, random_state=42) # X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, random_state=42) spectrogramss = np.array(spectrograms_normalized) print(spectrogramss.shape) labels = np.array(labels) X = np.vstack(spectrogramss) n_examples = X.shape[0] test_idx = list(set(range(0, n_examples)) - set(train_idx) - set(val_idx)) np.random.shuffle(train_idx) np.random.shuffle(val_idx) np.random.shuffle(test_idx) X_train = X[train_idx] X_val = X[val_idx] X_test = X[test_idx] print(X_train.shape) print(X_val.shape) print(X_test.shape) y_train = labels[train_idx] y_val = labels[val_idx] y_test = labels[test_idx] print(y_train.shape) print(y_val.shape) print(y_test.shape) # X_train = np.expand_dims(X_train,axis=-1) # X_test = np.expand_dims(X_test,axis=-1) # print(X_train.shape) return (mod, snr), (X_train, y_train), (X_val, y_val), (X_test, y_test) 这是我的数据预处理代码
时间: 2024-04-13 07:25:48 浏览: 245
你的数据预处理代码看起来很完整。它首先加载pickle文件,然后提取频谱图数据和标签。接下来,它对频谱图数据进行归一化处理,并对标签进行独热编码。最后,它将数据集划分为训练集、验证集和测试集,并返回这些数据集。
在代码中,你还有一些注释掉的代码段,例如对数据集进行划分的部分。如果你打算使用这些划分代码,你可以取消注释相应的行。另外,如果你的频谱图数据维度是二维的,你可能需要使用`np.expand_dims`函数将其扩展为三维,以便与Conv2D层兼容。
如果你还有其他问题或需要进一步帮助,请随时问我。
相关问题
import os import pickle import numpy as np import pandas as pd import matplotlib.pyplot as plt %matplotlib inline
这是一个 Python 脚本的开头部分,它引入了一些常用的数据科学和可视化库,包括 os、pickle、numpy、pandas 和 matplotlib。其中 %matplotlib inline 表示将绘图直接嵌入到 Jupyter Notebook 或 IPython 中,而不是弹出一个独立的窗口。这个语句通常放在开头,以确保绘图能够正确显示。
帮我修改错误import pandas import pandas as pd import np import numpy np_load_old = np.load np.load = lambda *a,**k: np_load_old(*a, allow_pickle=True, **k) import matplotlib.pyplot matplotlib.pyplot.rcParams['font.sans-serif'] = 'SimHei' matplotlib.pyplot.rcParams['axes.unicode_minus'] = False data = numpy.load('D:\\temp\\test\\国民经济核算季度数据.npz') name = data['columns'] values = data['values'] matplotlib.pyplot.figure(figsize = (8,7)) matplotlib.pyplot.scatter(values[:,0],values[:,3], marker = 'o',c = '#0099FF') matplotlib.pyplot.scatter(values[:,0],values[:,4], marker = '*',c = '#00FF00') matplotlib.pyplot.scatter(values[:,0],values[:,5], marker = '+',c = '#FF3300') matplotlib.pyplot.xlabel('年份') matplotlib.pyplot.ylabel('生产总值(亿元)') matplotlib.pyplot.xticks(range(0,70,4),values[range(0,70,4),1],rotation=45) matplotlib.pyplot.title('2000-2017年各产业季度生产总值散点图') matplotlib.pyplot.legend(['第一产业','第二产业','第三产业']) matplotlib.pyplot.savefig('D:\\temp\\test\\2022231168许开国c.png') matplotlib.pyplot.show name1 = pandas.DataFrame(list(data['columns'])) values1 = pandas.DataFrame(list(data['values'])) with pd.ExcelWriter('D:\\temp\\test\\2022231201许开国c.xlsx') as writer: name1.to_excel(writer, sheet_name= 'Sheet1', index=False, header=True) values1.to_excel(writer, sheet_name='Sheet2', index=False, header=True)
以下是修改过的代码,其中主要是将numpy和pandas的导入方式进行了修改,并且在matplotlib.pyplot的导入语句中添加了as关键字,以便在使用时不需要使用完整的模块名:
```python
import pandas as pd
import numpy as np
np_load_old = np.load
np.load = lambda *a,**k: np_load_old(*a, allow_pickle=True, **k)
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = 'SimHei'
plt.rcParams['axes.unicode_minus'] = False
data = np.load('D:\\temp\\test\\国民经济核算季度数据.npz')
name = data['columns']
values = data['values']
plt.figure(figsize = (8,7))
plt.scatter(values[:,0],values[:,3], marker = 'o',c = '#0099FF')
plt.scatter(values[:,0],values[:,4], marker = '*',c = '#00FF00')
plt.scatter(values[:,0],values[:,5], marker = '+',c = '#FF3300')
plt.xlabel('年份')
plt.ylabel('生产总值(亿元)')
plt.xticks(range(0,70,4),values[range(0,70,4),1],rotation=45)
plt.title('2000-2017年各产业季度生产总值散点图')
plt.legend(['第一产业','第二产业','第三产业'])
plt.savefig('D:\\temp\\test\\2022231168许开国c.png')
plt.show()
name1 = pd.DataFrame(list(data['columns']))
values1 = pd.DataFrame(list(data['values']))
with pd.ExcelWriter('D:\\temp\\test\\2022231201许开国c.xlsx') as writer:
name1.to_excel(writer, sheet_name= 'Sheet1', index=False, header=True)
values1.to_excel(writer, sheet_name='Sheet2', index=False, header=True)
```
这样的话,你便可以在运行程序时直接使用pd和np作为pandas和numpy的别名,使用plt来代替matplotlib.pyplot。
阅读全文