没有合适的资源?快使用搜索试试~ 我知道了~
Learning Augmentation Network via Influence FunctionsDonghoon LeeHyunsin ParkTrung PhamChang D. YooKorea Advanced Institute of Science and Technology (KAIST){iamdh, hs.park, trungpx, cd yoo}@kaist.ac.krAbstractData augmentation can impact the generalization per-formance of an image classification model in a significantway. However, it is currently conducted on the basis oftrial and error, and its impact on the generalization per-formance cannot be predicted during training. This paperconsiders an influence function that predicts how general-ization performance, in terms of validation loss, is affectedby a particular augmented training sample. The influencefunction provides an approximation of the change in val-idation loss without actually comparing the performancesthat include and exclude the sample in the training pro-cess. Based on this function, a differentiable augmentationnetwork is learned to augment an input training sample toreduce validation loss. The augmented sample is fed intothe classification network, and its influence is approximatedas a function of the parameters of the last fully-connectedlayer of the classification network. By backpropagating theinfluence to the augmentation network, the augmentationnetwork parameters are learned. Experimental results onCIFAR-10, CIFAR-100, and ImageNet show that the pro-posed method provides better generalization performancethan conventional data augmentation methods do.1. IntroductionIn supervised learning, deep neural networks generallyrequire large amounts of labeled data for training. An insuf-ficient number of labeled data will lead to poor generaliza-tion performance due to overfitting. One simple method toreduce overfitting and improve generalization performanceis to perform data augmentation, whereby each trainingsample is transformed with a label-preserving transforma-tion to create additional labeled data. Different augmenta-tions can result in significant differences in generalizationperformances depending on the task [7, 11, 14, 49]; how-ever, this has not yet been extensively explored. Even in awell-studied image classification task [15, 34, 46], trainingdata is often augmented in a manner similar to that per-formed in training AlexNet [21]. The composition of pre-defined transformations, such as rotation, translation, crop-ping, scaling, and color perturbation, is a popular choice;however, choosing the transformations and determining thestrength of each transformation, e.g. rotation angle, that re-sult in the best performance is often conducted empiricallythrough observations of validation loss [3, 33].Most recently, strategies for composing the transforma-tions through learning [6, 22, 28, 29, 35] rather than tuningby trial and error have been studied. To learn such a strategy,various learning criteria have been considered. Without con-sidering a classification model, Ratner et al. [29] and Sixt etal. [35] adopted a strategy for augmenting realistic samples.Given a classification model, [22] and [28] consider an-tithetical strategies for augmenting samples that minimizeand maximize training loss, respectively. It is not yet fullyunderstood why these antithetical studies, which relate dataaugmentation to training loss, are effective in improvingperformance on test samples. Cubuk et al. [6] consideredsmall child models to compute validation loss for augment-ing samples to improve generalization performance. Here,learning requires a reinforcement learning framework withthe validation loss as a reward. This necessitates the classi-fication model to be learned from scratch for every updateof the augmentation model parameters, requiring thousandsof GPU hours for learning.This paper proposes a data augmentation method thatlinks the impact of augmentation to the validation loss. Topredict impact, an influence function [5, 20] is incorporatedto compute the effect of a particular augmented trainingsample on the validation loss. Without a leave-one-out re-training process, the influence function approximates thechange in validation loss due to the inclusion or exclusion ofthe augmented training sample. A differentiable augmenta-tion network is proposed to augment the sample based on itspredicted impact on the validation loss. The transformationspace of the proposed network encompasses compositionsof predefined transformations. The influence function andthe differentiable augmentation model enable the gradientof validation loss to flow to the augmentation model.The remainder of this paper is organized as follows. Sec-tion 2 briefly reviews some of the most relevant literature110961related to the proposed method, while Sections 3 and 4describe the details of the proposed method. Experimentaland comparative results are reported in Section 5. Section 6summarizes and concludes the paper.2. Related work2.1. Data augmentation methodsThis sub-section provides a brief overview of data aug-mentation methods in the following three categories: (i) un-supervised methods that do not involve labels during learn-ing1, (ii) adversarial methods that maximize classificationloss, and (iii) supervised methods that minimize classifica-tion loss.Unsupervised methodsUnsupervised methods includeconventional data augmentation methods, which use a com-position of predefined transformations, such as rotating,translating, cropping, scaling and color perturbation [15, 21,34, 46]. The transformations are manually chosen throughtrial and error by empirically observing validation loss[3, 33]. Ratner et al. [29] considers a generator that gen-erates a sequence of predefined transformations. Given thesequence, the training sample is augmented by consecu-tively applying predefined transformations. The generator islearned in the generative adversarial network (GAN) frame-work [13]. The generated sequence produces a realistic aug-mented sample; however, the classifier is not involved dur-ing the learning process and the effect of data augmentationcan only be observed through trial and error.Adversarial methodsAdversarial methods include hardsample mining that collects or augments samples that aremisclassified by the current classification model. They havebeen used in training support vector machines (SVM) [8],boosted decision trees [11], shallow neural networks [30],and deep neural networks [31]. Wang et al. [42] and Penget al. [28] selected hard samples by adversarially updatingthe ranges of predefined transformations, such as occluding[28, 42], scaling, and rotating [28].Adversarial examples are slightly perturbed or trans-formed samples that result in a classification model pre-dicting an incorrect answer with high confidence [38]. Forthese methods, flexible and complex transformation mod-els [2, 45] cannot be used to generate adversarial examples[28]. Training the classification model with adversarial ex-amples will improve robustness to those adversarial exam-ples but may degrade performance on clean test samples[40, 41]. Recent studies have shown that adversarial updates1Unsupervised methods by chance can lead to validation loss; however,augmentation is conducted without any supervision to optimize an objec-tive.in convolution-based transformations [2] and spatial trans-formations [45] can potentially generate adversarial exam-ples.Supervised methodsLemley et al. [22] designed an aug-mentation model that learns to augment samples that reducetraining classification loss, but the performance on test sam-ples can only be empirically evaluated and not be predicted.Cubuk et al. [6] considered small child classification modelsfor computing validation loss to evaluate several augmen-tation policies over predefined transformations. However,learning requires a reinforcement learning framework aspredefined transformations are non-differentiable and can-not be backpropagated. The child model must be learnedfrom scratch for every update of the augmentation modelparameters and thus requires thousands of GPU hours. Fur-thermore, the validation loss of the small child model maynot be a good predictor of the validation loss of the finalclassification model.2.2. Generative adversarial networks for data aug-mentationGoodfellow et al. [13] proposed a framework for traininga deep generative network in an adversarial manner referredto as the generative adversarial network (GAN). The frame-work simultaneously trains two networks: a generator thatgenerates a sample from random noise and a discriminatorthat estimates the probability that a sample came from thetraining data rather than from the generator. An adversarialprocess is formed by a two-player minimax game in whicha discriminator tries to distinguish the source of a samplewhile a generator tries to generate an indistinguishable sam-ple.Several studies have demonstrated the potential of us-ing the GAN framework in data augmentation by either im-proving the realism of synthetic samples [32, 35, 43] or bygenerating class-conditional images [24, 26, 25]. However,generative models are generally known to require more datato train than a classification model, and for training, addi-tional synthetic [32, 35, 43] and/or unlabeled [24, 26, 25]samples are required.2.3. Influence functionsThe influence function is a function from robust statis-tics [5] to estimate how model parameters change due toup-weighting a particular training sample. Cook and Weis-berg [5] developed influence function of removing trainingdata in learning a linear model, and in [4, 39, 44], influencefunctions concerning a wider variety of perturbations werestudied. Koh and Liang [20] considered influence functionto non-convex and highly-complex models, including deepneural networks, by using efficient approximations based onHessian-vector products (HVPs) [27], conjugate gradients10962TransformationModel GLearnable network EClassificationModel Fxx�Figure 1: A generic data augmentation framework for a classification task. An input training sample x is transformed to ˜xby a transformation model, which is parameterized by τ. The conventional method usually defines G as a composition ofpredefined transformations based on randomly sampled range τ (solid line path). In this paper, G is a differentiable network.A learnable network E estimates τ given x to obtain the transformed sample ˜x, where ˜x maximizes the generalizationperformance of the classification model (dashed line path).[23], and stochastic estimation [1]. They also consideredthe influence of up-weighting a particular training sampleon validation loss. In [20], the influence functions are usedfor various purposes: debugging models, detecting dataseterrors, and creating training-set attacks.3. Augmented data evaluationFigure 1 depicts a general framework for tuning or learn-ing an augmentation model that may involve evaluating anaugmented sample. Given a classification model, the evalu-ation is often performed by computing the training loss ofthe sample; however, the effect on test samples cannot bepredicted and can only be empirically evaluated. To eval-uate the impact of an augmented sample using validationloss requires learning two classifiers: one that includes andthe other that excludes the sample during the learning pro-cess for comparing their performances. Learning both clas-sifiers is computationally expensive as the models need tobe fully trained from scratch and evaluated over all vali-dation samples. Rather than repeating this prohibitive pro-cess, a method is proposed to approximate the validationloss difference due to a particular augmented sample, thuseliminating the retraining process.3.1. Problem set upIn a classification task, given an input space X, an outputspace Y, and a parameter space Θ, a learner aims to learna classification model F that maps X �→ Y and is param-eterized by θ. Define l(z, θ) to be the loss evaluated on thesample z = (x, y) ∈ X × Y and model parameters θ ∈ Θ.Given training data ztr = {zi}Ni=1, an empirical risk mini-mizer is given as:ˆθ(ztr) = arg minθ∈ΘL(ztr, θ),(1)where the empirical risk is given asL(ztr, θ) =1N�i:zi∈ztrl(zi, ˆθ).(2)To measure the generalization performance of the classifi-cation model F, validation data zval = {zj}(N+M)j=N+1 is oftenconsidered. Generalization performance is approximated asthe average loss over validation data zval with parameterˆθ(ztr) asL(zval, ˆθ(ztr)) =1M�j:zj∈zvall(zj, ˆθ(ztr)).(3)Consider a label preserving transformation G that mapsX �→ X such as the one shown in Figure 1. Let τ be thecontrol parameters, ˜x = G(x, τ) be an augmented input,˜z = (˜x, y) be an augmented sample, and ˜ztr = {˜zi}Ni=1 bean augmented training dataset. In addition, consider a learn-able network E parameterized by φ. It estimates τ given x.Thus, ˜x = G(x, E(x, φ)).Given the transformation model G, the goal is to find theoptimal τ for each input x; otherwise, find optimal param-eters φ of E that minimizes the validation loss when theclassification model is learned using ˜ztr. This is mathemati-cally represented as followsφ = arg minφ∈ΦL(zval, ˆθ(˜ztr)),(4)whereˆθ(˜ztr) = arg minθ∈ΘL(˜ztr, θ).(5)Solving Equations 4–5 requires a bilevel optimizationwhere one problem is embedded (nested) within another.3.2. Influence by upweighting a training sampleConsider a change in model parameters θ due to the ex-clusion of a particular training sample zi. Formally, thischange is given as ˆθ(ztr\zi) − ˆθ(ztr). Influence functions[5, 20] provide an efficient approximation without a retrain-ing process to obtain ˆθ(ztr\zi). Let us consider the change10963in model parameters due to upweighting zi by an amount ofǫl(zi, θ) in the loss function:ˆθ(ztr ∪ ǫzi) = arg minθ∈ΘL(ztr, θ) + ǫl(zi, θ).(6)Then, from [5], the following approximation can be derived:− 1N Iup, params(zi) ≃ ˆθ(ztr\zi) − ˆθ(ztr),(7)whereIup, params(zi) ≜ dˆθ(ztr ∪ ǫzi)dǫ���ǫ=0(8)= −H(ˆθ(ztr))−1∇θl(zi, ˆθ(ztr)).(9)Here, H(θ) ≜1N�Ni=1 ∇2θl(zi, θ) is the Hessian evaluatedat θ.Using Equation 9 and applying the chain rule, the in-fluence of up-weighting zi ∈ ztr on the validation loss atzj ∈ zval can be approximated [20] as shown below:− 1N Iup, loss(zi, zj) ≃ l(zj, ˆθ(ztr\zi)) − l(zj, ˆθ(ztr)), (10)whereIup, loss(zi, zj) ≜ dl(zj, ˆθ(ztr ∪ ǫzi))dǫ���ǫ=0(11)= ∇θl(zj, ˆθ(ztr))⊤ dˆθ(ztr ∪ ǫzi)dǫ���ǫ=0(12)= −∇θl(zj, ˆθ(ztr))⊤H(ˆθ(ztr))−1∇θl(zi, ˆθ(ztr)).(13)For the validation dataset zval, Equation 12 can be ex-panded given by:Iup, loss(zi, zval)= −∇θL(zval, ˆθ(ztr))⊤H(ˆθ(ztr))−1∇θl(zi, ˆθ(ztr)). (14)Equation 11 describes a gradient of l(zj, ˆθ(ztr ∪ ǫzi)) withrespect to ǫ at nearby ǫ = 0. The influence of excluding zican be approximated by Equation 10.3.3. Influence by augmentationWith a training sample zi and the corresponding aug-mented training sample ˜zi, let ˆθ(ztr ∪ ǫ˜zi\ǫzi) be the es-timate of θ by downweighting zi and upweighting ˜zi by ǫ.Let ˆθ(ztr ∪ ˜zi\zi) be the estimate of θ by replacing zi with˜zi. An analogous approximation of Equations 10–13 yields:− 1N Iaug, loss(zi, ˜zi, zval)≃ L(zval, ˆθ(ztr)) − L(zval, ˆθ(ztr ∪ ˜zi\zi)),(15)where the influence function Iaug, loss(zi, ˜zi, zval) is:Iaug, loss(z
下载后可阅读完整内容,剩余1页未读,立即下载
cpongm
- 粉丝: 5
- 资源: 2万+
上传资源 快速赚钱
- 我的内容管理 展开
- 我的资源 快来上传第一个资源
- 我的收益 登录查看自己的收益
- 我的积分 登录查看自己的积分
- 我的C币 登录后查看C币余额
- 我的收藏
- 我的下载
- 下载帮助
最新资源
- 高清艺术文字图标资源,PNG和ICO格式免费下载
- mui框架HTML5应用界面组件使用示例教程
- Vue.js开发利器:chrome-vue-devtools插件解析
- 掌握ElectronBrowserJS:打造跨平台电子应用
- 前端导师教程:构建与部署社交证明页面
- Java多线程与线程安全在断点续传中的实现
- 免Root一键卸载安卓预装应用教程
- 易语言实现高级表格滚动条完美控制技巧
- 超声波测距尺的源码实现
- 数据可视化与交互:构建易用的数据界面
- 实现Discourse外聘回复自动标记的简易插件
- 链表的头插法与尾插法实现及长度计算
- Playwright与Typescript及Mocha集成:自动化UI测试实践指南
- 128x128像素线性工具图标下载集合
- 易语言安装包程序增强版:智能导入与重复库过滤
- 利用AJAX与Spotify API在Google地图中探索世界音乐排行榜
资源上传下载、课程学习等过程中有任何疑问或建议,欢迎提出宝贵意见哦~我们会及时处理!
点击此处反馈
安全验证
文档复制为VIP权益,开通VIP直接复制
信息提交成功