本技術(shù)涉及深度學(xué)習(xí),尤其涉及一種網(wǎng)絡(luò)模型的訓(xùn)練方法及裝置。
背景技術(shù):
1、現(xiàn)有技術(shù)中,訓(xùn)練大模型往往需要海量的數(shù)據(jù)。這些海量的數(shù)據(jù)在存儲(chǔ)和訓(xùn)練的過(guò)程中都需要占用大量的硬件資源,而對(duì)于企業(yè)而言這些硬件資源的花費(fèi)往往過(guò)于高昂。如何在低損耗的情況下壓縮數(shù)據(jù)是未來(lái)應(yīng)對(duì)海量數(shù)據(jù)的重要方向。數(shù)據(jù)集蒸餾旨在以生成的方式將一個(gè)大的數(shù)據(jù)集壓縮成一個(gè)小的數(shù)據(jù)集。得益于生成數(shù)據(jù)的高信息密度,在壓縮率較高時(shí),數(shù)據(jù)集蒸餾的效果遠(yuǎn)超數(shù)據(jù)集剪枝等其他數(shù)據(jù)壓縮方法。主流的數(shù)據(jù)集蒸餾方法大多通過(guò)匹配模型在目標(biāo)數(shù)據(jù)集和生成數(shù)據(jù)集上的某種指標(biāo)來(lái)完成特征的提取與壓縮。常用的方法有知識(shí)蒸餾,自蒸餾和適用網(wǎng)絡(luò)等。然而,隨著壓縮率的逐漸降低,現(xiàn)有數(shù)據(jù)集蒸餾方法的效果逐漸變差,甚至弱于從原數(shù)據(jù)集中隨機(jī)選取等量數(shù)據(jù)。這導(dǎo)致現(xiàn)有數(shù)據(jù)集蒸餾方法局限于高壓縮率場(chǎng)景,無(wú)法實(shí)現(xiàn)對(duì)原數(shù)據(jù)集的無(wú)損壓縮。
技術(shù)實(shí)現(xiàn)思路
1、本技術(shù)提供了一種網(wǎng)絡(luò)模型的訓(xùn)練方法及裝置,用以提高通過(guò)網(wǎng)絡(luò)模型對(duì)數(shù)據(jù)集蒸餾的有效性。
2、第一方面,本技術(shù)實(shí)施例提供了一種網(wǎng)絡(luò)模型的訓(xùn)練方法,包括:
3、基于預(yù)訓(xùn)練過(guò)程中網(wǎng)絡(luò)模型的損失值變化,確定待訓(xùn)練的初始網(wǎng)絡(luò)模型和訓(xùn)練次數(shù)n;
4、通過(guò)所述初始網(wǎng)絡(luò)模型對(duì)樣本數(shù)據(jù)集進(jìn)行識(shí)別,確定第一樣本集和第二樣本集,所述第一樣本集包括所述樣本數(shù)據(jù)集中所述初始網(wǎng)絡(luò)模型正確識(shí)別的樣本,所述第二樣本集包括所述樣本數(shù)據(jù)集中除所述第一樣本集之外的樣本;
5、通過(guò)所述第二樣本集對(duì)所述初始網(wǎng)絡(luò)模型進(jìn)行第一階段的訓(xùn)練,直至第一階段的訓(xùn)練次數(shù)滿足所述訓(xùn)練次數(shù)n,得到中間網(wǎng)絡(luò)模型;
6、通過(guò)所述樣本數(shù)據(jù)集對(duì)所述中間網(wǎng)絡(luò)模型進(jìn)行第二階段的訓(xùn)練,直至第二階段的訓(xùn)練次數(shù)滿足所述訓(xùn)練次數(shù)n,得到用于數(shù)據(jù)集蒸餾的目標(biāo)網(wǎng)絡(luò)模型。
7、基于上述方案,本技術(shù)在訓(xùn)練網(wǎng)絡(luò)模型時(shí),通過(guò)預(yù)訓(xùn)練的損失變化控制訓(xùn)練時(shí)的初始網(wǎng)絡(luò)模型和訓(xùn)練次數(shù),進(jìn)而控制模型對(duì)數(shù)據(jù)集的壓縮比率和網(wǎng)絡(luò)模型生成數(shù)據(jù)的難易程度。在第一輪訓(xùn)練時(shí),使用初始網(wǎng)絡(luò)模型不能識(shí)別的樣本集進(jìn)行訓(xùn)練,以使中間網(wǎng)絡(luò)模型能夠提取到復(fù)雜特征。進(jìn)而,在第二輪訓(xùn)練時(shí),使用第一樣本集和第二樣本集對(duì)中間網(wǎng)絡(luò)模型進(jìn)行訓(xùn)練,使得模型在訓(xùn)練過(guò)程中不遺忘簡(jiǎn)單特征,以使訓(xùn)練好的網(wǎng)絡(luò)模型能夠在低壓縮率的情況下保證數(shù)據(jù)集蒸餾的有效性。
8、在一種可能的實(shí)現(xiàn)方式中,基于預(yù)訓(xùn)練過(guò)程中網(wǎng)絡(luò)模型的損失值變化,確定待訓(xùn)練的初始網(wǎng)絡(luò)模型和訓(xùn)練次數(shù)n,包括:
9、基于訓(xùn)練樣本集對(duì)所述網(wǎng)絡(luò)模型進(jìn)行m輪的預(yù)訓(xùn)練,確定每輪預(yù)訓(xùn)練分別對(duì)應(yīng)的損失值,m大于n;
10、根據(jù)所述m個(gè)損失值確定總損失梯度,將滿足第一條件時(shí)網(wǎng)絡(luò)模型對(duì)應(yīng)的輪次作為第一訓(xùn)練輪次,并將所述第一訓(xùn)練輪次對(duì)應(yīng)的網(wǎng)絡(luò)模型作為初始網(wǎng)絡(luò)模型,其中,所述第一條件為損失梯度為總損失梯度的第一設(shè)定占比;
11、將滿足第二條件時(shí)網(wǎng)絡(luò)模型對(duì)應(yīng)的輪次作為第二訓(xùn)練輪次,并基于所述第一訓(xùn)練輪次和所述第二訓(xùn)練輪次確定所述訓(xùn)練次數(shù)n;其中,所述第二條件為損失梯度為總損失梯度的第二設(shè)定占比,所述第二設(shè)定占比大于所述第一設(shè)定占比。
12、基于上述方案,通過(guò)預(yù)訓(xùn)練過(guò)程中損失值的變化情況,確定從初始網(wǎng)絡(luò)模型開始進(jìn)行訓(xùn)練,并根據(jù)損失值的變化情況確定訓(xùn)練輪次,通過(guò)控制網(wǎng)絡(luò)模型的初始網(wǎng)絡(luò)參數(shù)和訓(xùn)練輪次,可以減少網(wǎng)絡(luò)模型的訓(xùn)練時(shí)間。
13、在一種可能的實(shí)現(xiàn)方式中,所述通過(guò)所述樣本數(shù)據(jù)集對(duì)所述中間網(wǎng)絡(luò)模型進(jìn)行第二階段的訓(xùn)練,直至第二階段的訓(xùn)練次數(shù)滿足所述訓(xùn)練次數(shù)n,包括:
14、每個(gè)訓(xùn)練輪次執(zhí)行以下操作:
15、將所述第二樣本集分為k個(gè)批次,并將所述第一樣本集插值到所述k個(gè)批次中,得到k個(gè)樣本集;
16、將所述k個(gè)樣本集輸入到所述中間網(wǎng)絡(luò)模型中,以根據(jù)所述中間網(wǎng)絡(luò)模型輸出的結(jié)果對(duì)網(wǎng)絡(luò)參數(shù)進(jìn)行調(diào)整。
17、基于上述方案,在訓(xùn)練過(guò)程中增加反向蒸餾,基于后期復(fù)雜特征與簡(jiǎn)單特征共同對(duì)網(wǎng)絡(luò)進(jìn)行訓(xùn)練,以保證網(wǎng)絡(luò)模型對(duì)特征提取的穩(wěn)定性。
18、在一種可能的實(shí)現(xiàn)方式中,所述通過(guò)所述第二樣本集對(duì)所述初始網(wǎng)絡(luò)模型進(jìn)行第一階段的訓(xùn)練,直至第一階段的訓(xùn)練次數(shù)滿足所述訓(xùn)練次數(shù)n,得到中間網(wǎng)絡(luò)模型,包括:
19、每個(gè)訓(xùn)練輪次執(zhí)行以下操作:
20、通過(guò)所述第二樣本集對(duì)所述初始網(wǎng)絡(luò)模型進(jìn)行訓(xùn)練,得到目標(biāo)損失值;
21、根據(jù)所述目標(biāo)損失值對(duì)所述初始網(wǎng)絡(luò)模型的網(wǎng)絡(luò)參數(shù)進(jìn)行調(diào)整,得到中間網(wǎng)絡(luò)模型。
22、在一種可能的實(shí)現(xiàn)方式中,所述通過(guò)所述第二樣本集對(duì)所述初始網(wǎng)絡(luò)模型進(jìn)行訓(xùn)練,得到目標(biāo)損失值,包括:
23、將所述第二樣本集輸入所述初始網(wǎng)絡(luò)模型進(jìn)行特征提取,得到所述第二樣本集中各樣本的預(yù)測(cè)軟標(biāo)簽;
24、針對(duì)每個(gè)樣本的所述預(yù)測(cè)軟標(biāo)簽和樣本標(biāo)簽確定第一損失值,并根據(jù)所述第二樣本集中樣本標(biāo)簽的種類和所述每個(gè)樣本的預(yù)測(cè)軟標(biāo)簽的標(biāo)簽種類確定第二損失值;
25、將所述第一損失值和所述第二損失值進(jìn)行加權(quán)得到目標(biāo)損失值。
26、基于上述方案,通過(guò)標(biāo)簽平滑的方法將硬標(biāo)簽替換為軟標(biāo)簽,進(jìn)而通過(guò)樣本標(biāo)簽的第一損失值和第二損失值確定損失值,進(jìn)而提高網(wǎng)絡(luò)模型提取的特征的豐富性。
27、第二方面,本技術(shù)實(shí)施例提供了一種網(wǎng)絡(luò)模型的訓(xùn)練裝置,包括:
28、確定模塊,用于基于預(yù)訓(xùn)練過(guò)程中網(wǎng)絡(luò)模型的損失值變化,確定待訓(xùn)練的初始網(wǎng)絡(luò)模型和訓(xùn)練次數(shù)n;
29、通過(guò)所述初始網(wǎng)絡(luò)模型對(duì)樣本數(shù)據(jù)集進(jìn)行識(shí)別,確定第一樣本集和第二樣本集,所述第一樣本集包括所述樣本數(shù)據(jù)集中所述初始網(wǎng)絡(luò)模型正確識(shí)別的樣本,所述第二樣本集包括所述樣本數(shù)據(jù)集中除所述第一樣本集之外的樣本;
30、訓(xùn)練模塊,用于通過(guò)所述第二樣本集對(duì)所述初始網(wǎng)絡(luò)模型進(jìn)行第一階段的訓(xùn)練,直至第一階段的訓(xùn)練次數(shù)滿足所述訓(xùn)練次數(shù)n,得到中間網(wǎng)絡(luò)模型;
31、通過(guò)所述樣本數(shù)據(jù)集對(duì)所述中間網(wǎng)絡(luò)模型進(jìn)行第二階段的訓(xùn)練,直至第二階段的訓(xùn)練次數(shù)滿足所述訓(xùn)練次數(shù)n,得到用于數(shù)據(jù)集蒸餾的目標(biāo)網(wǎng)絡(luò)模型。
32、在一種可能的實(shí)現(xiàn)方式中,所述確定模塊,在基于預(yù)訓(xùn)練過(guò)程中網(wǎng)絡(luò)模型的損失值變化,確定待訓(xùn)練的初始網(wǎng)絡(luò)模型和訓(xùn)練次數(shù)n時(shí),具體用于:
33、基于訓(xùn)練樣本集對(duì)所述網(wǎng)絡(luò)模型進(jìn)行m輪的預(yù)訓(xùn)練,確定每輪預(yù)訓(xùn)練分別對(duì)應(yīng)的損失值,m大于n;
34、根據(jù)所述m個(gè)損失值確定總損失梯度,將滿足第一條件時(shí)網(wǎng)絡(luò)模型對(duì)應(yīng)的輪次作為第一訓(xùn)練輪次,并將所述第一訓(xùn)練輪次對(duì)應(yīng)的網(wǎng)絡(luò)模型作為初始網(wǎng)絡(luò)模型,其中,所述第一條件為損失梯度為總損失梯度的第一設(shè)定占比;
35、將滿足第二條件時(shí)網(wǎng)絡(luò)模型對(duì)應(yīng)的輪次作為第二訓(xùn)練輪次,并基于所述第一訓(xùn)練輪次和所述第二訓(xùn)練輪次確定所述訓(xùn)練次數(shù)n;其中,所述第二條件為損失梯度為總損失梯度的第二設(shè)定占比,所述第二設(shè)定占比大于所述第一設(shè)定占比。
36、在一種可能的實(shí)現(xiàn)方式中,所述訓(xùn)練模塊,在通過(guò)所述樣本數(shù)據(jù)集對(duì)所述中間網(wǎng)絡(luò)模型進(jìn)行第二階段的訓(xùn)練,直至第二階段的訓(xùn)練次數(shù)滿足所述訓(xùn)練次數(shù)n時(shí),具體用于:
37、每個(gè)訓(xùn)練輪次執(zhí)行以下操作:
38、將所述第二樣本集分為k個(gè)批次,并將所述第一樣本集插值到所述k個(gè)批次中,得到k個(gè)樣本集;
39、將所述k個(gè)樣本集輸入到所述中間網(wǎng)絡(luò)模型中,以根據(jù)所述中間網(wǎng)絡(luò)模型輸出的結(jié)果對(duì)網(wǎng)絡(luò)參數(shù)進(jìn)行調(diào)整。
40、在一種可能的實(shí)現(xiàn)方式中,所述訓(xùn)練模塊,在通過(guò)所述第二樣本集對(duì)所述初始網(wǎng)絡(luò)模型進(jìn)行第一階段的訓(xùn)練,直至第一階段的訓(xùn)練次數(shù)滿足所述訓(xùn)練次數(shù)n,得到中間網(wǎng)絡(luò)模型時(shí),具體用于:
41、每個(gè)訓(xùn)練輪次執(zhí)行以下操作:
42、通過(guò)所述第二樣本集對(duì)所述初始網(wǎng)絡(luò)模型進(jìn)行訓(xùn)練,得到目標(biāo)損失值;
43、根據(jù)所述目標(biāo)損失值對(duì)所述初始網(wǎng)絡(luò)模型的網(wǎng)絡(luò)參數(shù)進(jìn)行調(diào)整,得到中間網(wǎng)絡(luò)模型。
44、在一種可能的實(shí)現(xiàn)方式中,所述訓(xùn)練模塊,在通過(guò)所述第二樣本集對(duì)所述初始網(wǎng)絡(luò)模型進(jìn)行訓(xùn)練,得到目標(biāo)損失值時(shí),具體用于:
45、將所述第二樣本集輸入所述初始網(wǎng)絡(luò)模型進(jìn)行特征提取,得到所述第二樣本集中各樣本的預(yù)測(cè)軟標(biāo)簽;
46、針對(duì)每個(gè)樣本的所述預(yù)測(cè)軟標(biāo)簽和樣本標(biāo)簽確定第一損失值,并根據(jù)所述第二樣本集中樣本標(biāo)簽的種類和所述每個(gè)樣本的預(yù)測(cè)軟標(biāo)簽的標(biāo)簽種類確定第二損失值;
47、將所述第一損失值和所述第二損失值進(jìn)行加權(quán)得到目標(biāo)損失值。
48、第三方面,本技術(shù)實(shí)施例提供了一種執(zhí)行設(shè)備,包括:
49、存儲(chǔ)器,用于存儲(chǔ)程序指令;
50、處理器,用于獲取所述存儲(chǔ)器中的程序指令,并按照獲得的程序指令實(shí)現(xiàn)第一方面以及第一方面不同實(shí)現(xiàn)方式所述的方法。
51、第四方面,本技術(shù)實(shí)施例提供了一種計(jì)算機(jī)可讀存儲(chǔ)介質(zhì),所述計(jì)算機(jī)可讀存儲(chǔ)介質(zhì)包括計(jì)算機(jī)指令,當(dāng)所述計(jì)算機(jī)指令被計(jì)算機(jī)執(zhí)行時(shí),實(shí)現(xiàn)如第一方面以及第一方面不同實(shí)現(xiàn)方式所述的方法。
52、另外,第二方面至第四方面中任一種實(shí)現(xiàn)方式所帶來(lái)的技術(shù)效果可參見第一方面以及第一方面不同實(shí)現(xiàn)方式所帶來(lái)的技術(shù)效果,此處不再贅述。