np.where怎么用
时间: 2023-07-08 12:42:20 浏览: 45
`np.where` 是 NumPy 中用于根据条件返回数组中的元素下标或值的函数。具体使用方式如下:
```python
import numpy as np
# 生成一个随机数组
a = np.random.randint(10, size=(3, 4))
print(a)
# 输出:
# [[7 7 7 0]
# [3 5 2 3]
# [9 6 8 9]]
# 返回数组中所有等于 7 的元素下标
indices = np.where(a == 7)
print(indices)
# 输出:
# (array([0, 0, 0]), array([0, 1, 2]))
# 返回数组中所有小于 5 的元素下标和值
indices, values = np.where(a < 5)
print(indices)
# 输出:
# (array([0, 1, 1, 1]), array([3, 0, 2, 3]))
print(values)
# 输出:
# [0 3 2 3]
```
可以看到,`np.where` 的返回值是一个元组,元组中的每个元素都是一个数组,分别表示满足条件的元素在原数组中的下标或值。第一个元素是行下标,第二个元素是列下标(如果原数组是二维的)。需要注意的是,如果条件中使用了逻辑运算符 `and` 或 `or`,需要使用括号将条件括起来,例如 `np.where((a > 3) & (a < 7))`。
相关问题
torch.where np.where
torch.where()是PyTorch和NumPy中的函数,用于根据给定的条件选择元素。torch.where()用于PyTorch张量,而np.where()用于NumPy数组。
在PyTorch中,torch.where()函数接受一个条件和两个张量作为输入。它会根据条件选择元素,返回与条件为True的元素对应的索引。例如,如果我们有一个张量x,我们可以使用torch.where(x==0)来找到x中为0的元素的索引。
在NumPy中,np.where()函数的功能与torch.where()类似,但输入和输出的类型不同。np.where()函数接受一个条件和一个数组作为输入,并返回与条件为True的元素对应的索引。例如,如果我们有一个数组arr,我们可以使用np.where(arr==0)来找到arr中为0的元素的索引。
np.where的使用
np.where()函数可以根据条件返回数组中符合条件的元素的索引或值。它的语法如下:
np.where(condition, [x, y])
其中,condition是一个布尔数组,x和y是可选参数。如果只有condition参数,则返回一个元素为True的索引数组;如果同时有x和y参数,则返回一个数组,其元素的值根据condition数组的值来决定,即当condition为True时,返回x的对应元素,否则返回y的对应元素。
例如,下面的代码演示了如何使用np.where()函数:
```python
import numpy as np
# 创建一个数组
arr = np.array([1, 2, 3, 4, 5])
# 获取所有大于3的元素的索引
indices = np.where(arr > 3)
print(indices)
# 根据条件返回不同的值
new_arr = np.where(arr > 3, arr, 0)
print(new_arr)
```
输出结果为:
```
(array([3, 4]),)
[0 0 0 4 5]
```
其中,`np.where(arr > 3)`返回了一个元素为True的索引数组,即[3, 4];而`np.where(arr > 3, arr, 0)`返回了一个新的数组,其中大于3的元素的值保持不变,小于等于3的元素的值被替换为0。