PyTorch笔记之scatter()函数的使用
在PyTorch中,`scatter()`函数是一种非常实用的操作,用于根据提供的索引(index)将源元素(src)分散到目标张量中。该函数有两个变体:`scatter()`和`scatter_()`,它们的主要区别在于是否直接修改原始张量。`scatter_()`会在原地修改张量,而`scatter()`则返回一个新的张量,不会影响原始数据。 `scatter(dim, index, src)`函数接收三个参数: 1. `dim`: 这是一个整数,表示沿哪个维度进行索引和放置元素。它决定了操作发生在张量的哪个轴上。 2. `index`: 一个张量,定义了要修改的元素的位置。它的形状应该与`src`的形状匹配,除了`dim`对应的尺寸之外。例如,如果`dim=0`,那么`index`的形状应与`src`的形状相同,但如果`dim=1`,那么`index`的形状应是`src.shape[:-1]`。 3. `src`: 也是一个张量,提供了要放置的值。它可以是标量,也可以是与`index`形状相同的张量。如果`src`是标量,那么所有指定的索引位置都将被这个值覆盖。 理解`scatter()`的一个关键点是,它根据`index`中的值在目标张量的相应位置放置`src`中的值。如果`dim`是0,那么`index`的第一个维度对应于目标张量的第一维度;如果是1,则对应第二维度,依此类推。 让我们看一个具体的例子: ```python x = torch.rand(2, 5) # tensor([[0.1940, 0.3340, 0.8184, 0.4269, 0.5945], # [0.2078, 0.5978, 0.0074, 0.0943, 0.0266]]) ``` 这里,我们创建了一个2x5的张量`x`,并使用`scatter_()`创建一个3x5的零张量,然后根据`index`分布`x`的元素: ```python index = torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]) torch.zeros(3, 5).scatter_(0, index, x) ``` 这将会把`x`的元素根据`index`的指示放置到新张量的相应位置。例如,第一个元素`x[0][0]`(值为0.1940)会被放置到新张量的`[0][0]`位置,因为`index[0][0]`是0。同样,`x[1][0]`(值为0.2078)会被放置到新张量的`[2][0]`位置,因为`index[1][0]`是2。 此外,`src`也可以是一个标量。例如,我们可以将7分散到张量的指定位置: ```python torch.zeros(3, 5).scatter_(0, index, 7) ``` 这将创建一个3x5的张量,其中`index`指示的位置都被7填充。 `scatter()`函数的一个常见用途是对标签进行one-hot编码。例如,假设我们有一个表示类别标签的张量`label`: ```python class_num = 10 batch_size = 4 label = torch.LongTensor(batch_size, 1).random_() % class_num ``` 我们可以使用`scatter_()`将其转换为one-hot编码: ```python torch.zeros(batch_size, class_num).scatter_(1, label, 1) ``` 这将创建一个`batch_size`行`class_num`列的张量,其中`label`的每个值对应的列被设置为1,其余位置为0。 `scatter()`函数在处理张量索引和分配时提供了很大的灵活性,尤其在需要根据特定索引值改变张量数据的场景中,如one-hot编码、分布式存储等。熟悉并掌握这个函数对于深度学习和数据处理任务是十分重要的。