.flatten
时间: 2025-03-16 20:12:07 浏览: 9
Python 中 flatten
方法的具体实现
1. 使用 NumPy 的 ndarray.flatten()
方法
在 NumPy 库中,ndarray.flatten()
是一种强大的工具,用于将多维数组转换为一维数组。此方法不会修改原始数组,而是返回一个新的展平后的副本[^2]。
以下是具体的代码示例:
import numpy as np
# 创建一个多维数组
arr = np.array([[1, 2], [3, 4]])
# 调用 flatten() 方法将其展平为一维数组
flattened_arr = arr.flatten()
print(flattened_arr) # 输出: [1 2 3 4]
2. PyTorch 中 Tensor 的 flatten()
方法
在 PyTorch 中,Tensor.flatten(start_dim=0, end_dim=-1)
可以用来将指定范围内的维度展平成单一维度,默认是从第 0 维到最后一维[^3]。
以下是一个简单的例子:
import torch
# 创建一个三维张量
tensor = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
# 将第二维和第三维展平
flattened_tensor = tensor.flatten(start_dim=1)
print(flattened_tensor)
# 输出:
# tensor([[1, 2, 3, 4],
# [5, 6, 7, 8]])
3. 自定义实现 flatten
函数
对于普通的嵌套列表结构,可以通过多种方式手动实现 flatten
方法。下面展示两种常见的实现方式:
(a) 基于递归的实现
这种方法适用于任意深度的嵌套列表。
def flatten(lst):
result = []
for item in lst:
if isinstance(item, list): # 如果当前项是列表,则递归调用 flatten
result.extend(flatten(item))
else:
result.append(item)
return result
nested_list = [1, 2, [3, 4, [5, 6]], 7]
flat_list = flatten(nested_list)
print(flat_list) # 输出: [1, 2, 3, 4, 5, 6, 7]
(b) 利用迭代器的方式
这种方式更加高效,适合浅层嵌套的情况。
from collections import Iterable
def flatten(lst):
""" 展开任何层次的可迭代对象 """
for elem in lst:
if isinstance(elem, Iterable) and not isinstance(elem, (str, bytes)):
yield from flatten(elem)
else:
yield elem
nested_list = ['a', 'b', ['c', 'd', ['e']], 'f']
flat_generator = list(flatten(nested_list))
print(flat_generator) # 输出: ['a', 'b', 'c', 'd', 'e', 'f']
4. JavaScript 实现(作为对比)
虽然问题是关于 Python 的 flatten
方法,但为了扩展理解,这里也给出基于 JavaScript 的实现方式[^4]。
function flatten(arr) {
while (arr.some(Array.isArray)) { // 检查是否有子数组存在
arr = [].concat.apply([], arr);
}
return arr;
}
const nestedArray = [1, 2, [3, [4, 5]]];
console.log(flatten(nestedArray)); // 输出: [1, 2, 3, 4, 5]
总结
不同场景下可以选择不同的 flatten
方法来满足需求。如果是处理数值型数据并希望获得高性能支持,推荐使用 NumPy 或 PyTorch 提供的功能;而对于纯 Python 数据结构的操作,则可以根据实际情况选用自定义函数或第三方库辅助完成任务。
相关推荐

















