def load_data(): """ Standardize the data and package it into a list together with labels """ tr_d, va_d, te_d = load() scaler = StandardScaler() scaler.fit_transform(tr_d[0]) scaler.fit_transform(va_d[0]) scaler.fit_transform(te_d[0]) training_data = [(np.reshape(d, (784, 1)), one_hot_enc(l)) for d, l in zip(tr_d[0], tr_d[1])] validation_data = [(np.reshape(d, (784, 1)), one_hot_enc(l)) for d, l in zip(va_d[0], va_d[1])] test_data = [(np.reshape(d, (784, 1)), one_hot_enc(l)) for d, l in zip(te_d[0], te_d[1])] return training_data, validation_data, test_data
时间: 2024-04-18 10:31:16 浏览: 61
计算机专业英语_考试题目.pdf
这段代码是一个数据加载的函数,它的功能是将数据标准化并打包成一个列表,同时包含标签。
函数中首先调用了一个名为`load()`的函数,返回训练集、验证集和测试集数据(tr_d, va_d, te_d)。
接下来使用`StandardScaler()`对训练集、验证集和测试集的数据进行标准化处理。标准化是一种常见的数据预处理技术,可以将数据缩放到均值为0,方差为1的范围内。
然后,使用`np.reshape()`将每个样本的数据从原始的一维数组形式转换为二维矩阵,大小为(784, 1)。这里假设每个样本的原始数据是长度为784的一维数组。
最后,将每个样本的特征向量和对应的标签进行组合,并使用`one_hot_enc()`函数对标签进行独热编码。独热编码是将离散型的标签转换为二进制形式的向量表示,方便计算机处理。
最终,将处理后的训练集、验证集和测试集数据打包成一个列表,并作为函数的返回值。
阅读全文