怎么把onehot的二维数组转化成一维的
时间: 2024-03-03 07:52:40 浏览: 65
如果你想将 one-hot 编码的二维数组转换为一维数组,可以使用 `argmax` 方法。`argmax` 方法可以返回每个样本中最大值所在的索引,这个索引就是对应的类别标签。例如,假设某个样本的 one-hot 编码为 `[0, 1, 0]`,则 `argmax` 方法将返回 `1`,表示该样本属于第二个类别。
下面是一个例子,假设 `y` 是一个 one-hot 编码的二维数组,形状为 `(n_samples, n_classes)`,你可以使用 `argmax` 方法将其转换为一维数组:
```python
import numpy as np
y = np.array([[0, 1, 0], [1, 0, 0], [0, 0, 1]]) # 一个形状为 (3, 3) 的 one-hot 编码的二维数组
y_pred = np.argmax(y, axis=1) # 将 y 转换为一维数组
print(y_pred) # 输出 [1, 0, 2]
```
在上面的例子中,`argmax` 方法的 `axis` 参数被设置为 `1`,表示在每行中查找最大值所在的索引。输出的 `y_pred` 是一个长度为 3 的一维数组,每个元素表示对应样本的类别标签。
相关问题
请给我解释一下这几行代码onehot_encoder = OneHotEncoder(sparse=False) integer_encoded = integer_encoded.reshape(len(integer_encoded), 1) onehot_encoded = onehot_encoder.fit_transform(integer_encoded) print(onehot_encoded)
这段代码使用了sklearn库中的OneHotEncoder类,将整数编码转化为独热编码。
代码的具体解释如下:
1. 首先,创建了一个OneHotEncoder对象onehot_encoder,其中sparse=False是指不使用稀疏矩阵存储独热编码。
2. 接着,将整数编码integer_encoded进行了reshape操作,将其转换成了一个二维数组,其中每个元素为一个整数。
3. 然后,使用onehot_encoder对整数编码进行了转换,得到了独热编码,存储在onehot_encoded中。
4. 最后,使用print语句输出了onehot_encoded。
总的来说,这段代码实现了将整数编码转换成独热编码的功能。
X_train,T_train=idx2numpy.convert_from_file('emnist/emnist-letters-train-images-idx3-ubyte'),idx2numpy.convert_from_file('emnist/emnist-letters-train-labels-idx1-ubyte')转化为相同形式train_num = 60000 test_num = 10000 img_dim = (1, 28, 28) img_size = 784 def _download(file_name): file_path = dataset_dir + "/" + file_name if os.path.exists(file_path): return print("Downloading " + file_name + " ... ") urllib.request.urlretrieve(url_base + file_name, file_path) print("Done") def download_mnist(): for v in key_file.values(): _download(v) def _load_label(file_name): file_path = dataset_dir + "/" + file_name print("Converting " + file_name + " to NumPy Array ...") with gzip.open(file_path, 'rb') as f: labels = np.frombuffer(f.read(), np.uint8, offset=8) print("Done") return labels def _load_img(file_name): file_path = dataset_dir + "/" + file_name print("Converting " + file_name + " to NumPy Array ...") with gzip.open(file_path, 'rb') as f: data = np.frombuffer(f.read(), np.uint8, offset=16) data = data.reshape(-1, img_size) print("Done") return data def _convert_numpy(): dataset = {} dataset['train_img'] = _load_img(key_file['train_img']) dataset['train_label'] = _load_label(key_file['train_label']) dataset['test_img'] = _load_img(key_file['test_img']) dataset['test_label'] = _load_label(key_file['test_label']) return dataset def init_mnist(): download_mnist() dataset = _convert_numpy() print("Creating pickle file ...") with open(save_file, 'wb') as f: pickle.dump(dataset, f, -1) print("Done!") def _change_one_hot_label(X): T = np.zeros((X.size, 10)) for idx, row in enumerate(T): row[X[idx]] = 1 return T def load_mnist(normalize=True, flatten=True, one_hot_label=False): """读入MNIST数据集 Parameters ---------- normalize : 将图像的像素值正规化为0.0~1.0 one_hot_label : one_hot_label为True的情况下,标签作为one-hot数组返回 one-hot数组是指[0,0,1,0,0,0,0,0,0,0]这样的数组 flatten : 是否将图像展开为一维数组 Returns ------- (训练图像, 训练标签), (测试图像, 测试标签) """ if not os.path.exists(save_file): init_mnist() with open(save_file, 'rb') as f: dataset = pickle.load(f) if normalize: for key in ('train_img', 'test_img'): dataset[key] = dataset[key].astype(np.float32) dataset[key] /= 255.0 if one_hot_label: dataset['train_label'] = _change_one_hot_label(dataset['train_label']) dataset['test_label'] = _change_one_hot_label(dataset['test_label']) if not flatten: for key in ('train_img', 'test_img'): dataset[key] = dataset[key].reshape(-1, 1, 28, 28) return (dataset['train_img'], dataset['train_label']), (dataset['test_img'], dataset['test_label']) if name == 'main': init_mnist()模仿这段代码将获取同样形式
你提供的是MNIST数据集的加载代码,而不是EMNIST。如果要加载EMNIST数据集,可以使用以下代码:
```python
import idx2numpy
import numpy as np
def load_emnist():
X_train = idx2numpy.convert_from_file('emnist/emnist-letters-train-images-idx3-ubyte')
T_train = idx2numpy.convert_from_file('emnist/emnist-letters-train-labels-idx1-ubyte')
X_test = idx2numpy.convert_from_file('emnist/emnist-letters-test-images-idx3-ubyte')
T_test = idx2numpy.convert_from_file('emnist/emnist-letters-test-labels-idx1-ubyte')
# 将数据展开为一维数组
X_train = X_train.reshape(X_train.shape[0], -1)
X_test = X_test.reshape(X_test.shape[0], -1)
# 将标签转换为one-hot编码
T_train = np.eye(26)[T_train]
T_test = np.eye(26)[T_test]
return (X_train, T_train), (X_test, T_test)
```
这个代码将返回一个tuple,其中第一个元素是训练数据,第二个元素是测试数据,每个数据都是一个tuple,其中第一个元素是图像数据,第二个元素是标签数据。图像数据是展开为一维数组的,标签数据是one-hot编码的。
阅读全文