本說明書涉及計算機(jī),尤其涉及一種基于特征分布的模型訓(xùn)練方法和任務(wù)執(zhí)行方法。
背景技術(shù):
1、隨著現(xiàn)代芯片制造業(yè)的快速發(fā)展,計算機(jī)的算力得到了極大的提升,從而為大型神經(jīng)網(wǎng)絡(luò)模型的訓(xùn)練提供了基礎(chǔ)。但是一個復(fù)雜神經(jīng)網(wǎng)絡(luò)模型的訓(xùn)練通常需要大量的數(shù)據(jù)。而由于數(shù)據(jù)隱私保護(hù)等原因,很多數(shù)據(jù)只允許在不同企業(yè)或機(jī)構(gòu)側(cè)的客戶端本地進(jìn)行存儲和使用。為了能夠利用這些數(shù)據(jù)共同對模型進(jìn)行訓(xùn)練,同時又不將數(shù)據(jù)傳送到一個服務(wù)器端,聯(lián)邦學(xué)習(xí)應(yīng)運(yùn)而生。
2、然而,在現(xiàn)有的聯(lián)邦學(xué)習(xí)方法中,客戶端會將本地訓(xùn)練的訓(xùn)練數(shù)據(jù)上傳給服務(wù)器,再由服務(wù)器對全部客戶端發(fā)送過來的訓(xùn)練數(shù)據(jù)進(jìn)行統(tǒng)一的處理,得到統(tǒng)一的模型數(shù)據(jù)后發(fā)送給每個客戶端。但是在實際應(yīng)用場景中,不同客戶端下的數(shù)據(jù)所對應(yīng)的數(shù)據(jù)類型的分布情況往往是不同的,如果所有客戶端中的模型均采用相同的模型數(shù)據(jù),則會導(dǎo)致不同客戶端中的模型無法適用于本地數(shù)據(jù)的分布情況,難以對數(shù)據(jù)進(jìn)行精準(zhǔn)的識別。
3、因此,如何在聯(lián)邦學(xué)習(xí)的過程中,提高不同客戶端下的模型的準(zhǔn)確性,是一個亟待解決的問題。
技術(shù)實現(xiàn)思路
1、本說明書提供一種基于特征分布的模型訓(xùn)練方法和任務(wù)執(zhí)行方法,以部分的解決現(xiàn)有技術(shù)存在的上述問題。
2、本說明書采用下述技術(shù)方案:
3、本說明書提供了基于特征分布的模型訓(xùn)練方法,所述方法應(yīng)用于客戶端,包括:
4、獲取本地的各樣本圖像,并確定每個樣本圖像對應(yīng)的標(biāo)簽信息;
5、針對每個樣本圖像,將該樣本圖像輸入待訓(xùn)練的本地分類模型,以通過所述本地分類模型,確定該樣本圖像對應(yīng)的數(shù)據(jù)特征,并根據(jù)所述數(shù)據(jù)特征確定該樣本圖像的分類結(jié)果;
6、根據(jù)每個樣本圖像的標(biāo)簽信息,確定每個樣本圖像對應(yīng)數(shù)據(jù)特征的數(shù)據(jù)分布,并根據(jù)所述數(shù)據(jù)分布,確定個體數(shù)據(jù)特征,并將所述個體數(shù)據(jù)特征發(fā)送給服務(wù)器,以使服務(wù)器根據(jù)接收到的各客戶端發(fā)送的個體數(shù)據(jù)特征,確定全局?jǐn)?shù)據(jù)特征,并將全局?jǐn)?shù)據(jù)特征返回給各客戶端,所述個體數(shù)據(jù)特征用于表征所述本地分類模型所使用的不同標(biāo)簽信息下的樣本圖像在整體上的特征表示;
7、根據(jù)所述分類結(jié)果和所述標(biāo)簽信息的之間偏差,以及所述個體數(shù)據(jù)特征和全局?jǐn)?shù)據(jù)特征之間的偏差,確定所述本地分類模型的損失值;
8、根據(jù)所述損失值,對所述本地分類模型的模型參數(shù)進(jìn)行更新,得到訓(xùn)練后本地分類模型。
9、可選地,根據(jù)所述分類結(jié)果和所述標(biāo)簽信息的之間偏差,以及所述個體數(shù)據(jù)特征和全局?jǐn)?shù)據(jù)特征之間的偏差,確定所述分類模型的損失值,具體包括:
10、根據(jù)所述分類結(jié)果和所述標(biāo)簽信息之間的偏差,確定所述分類模型的分類損失值,以及,根據(jù)所述個體數(shù)據(jù)特征和全局?jǐn)?shù)據(jù)特征之間的偏差,確定所述分類模型的修正值;
11、根據(jù)所述修正值,對所述分類損失值進(jìn)行修正,以確定出所述本地分類模型的損失值。
12、可選地,所述本地分類模型的模型參數(shù)包括:特征提取層參數(shù);
13、根據(jù)所述損失值,對所述分類模型的公共模型參數(shù)進(jìn)行更新,具體包括:
14、根據(jù)所述損失值,對所述特征提取層參數(shù)進(jìn)行更新。
15、可選地,所述本地分類模型的模型參數(shù)包括:分類層參數(shù);
16、根據(jù)所述損失值,對所述本地分類模型的模型參數(shù)進(jìn)行更新,具體包括:
17、確定根據(jù)所述分類損失值對所述分類層參數(shù)進(jìn)行更新時,每個分類層參數(shù)所對應(yīng)的變化值;
18、根據(jù)所述分類損失值以及所述變化值,對所述分類層參數(shù)進(jìn)行更新,得到更新后分類層參數(shù)并發(fā)送給所述服務(wù)器,以使所述服務(wù)器對每個客戶端發(fā)送的更新后分類層參數(shù)進(jìn)行融合,得到融合后分類層參數(shù)并返回給各客戶端。
19、可選地,根據(jù)所述損失值以及所述變化值,對所述分類層參數(shù)進(jìn)行更新,具體包括:
20、按照所述變化值,將所述分類層參數(shù)劃分為個性化分類層參數(shù)和聯(lián)合分類層參數(shù);
21、根據(jù)所述分類損失值,對所述個性化分類層參數(shù)進(jìn)行更新,并根據(jù)所述聯(lián)合分類層參數(shù)以及更新后的個性化分類層參數(shù),確定所述更新后分類層參數(shù)。
22、可選地,按照所述變化值,將所述分類層參數(shù)劃分為個性化分類層參數(shù)和聯(lián)合分類層參數(shù),具體包括:
23、根據(jù)每個更新后分類層參數(shù)對應(yīng)的變化值,將各分類層參數(shù)按照從大到小的順序進(jìn)行排序;
24、將排序位于指定位次之前的分類層參數(shù)作為所述個性化分類層參數(shù),將排序位于所述指定位次之后的分類層參數(shù)作為所述聯(lián)合分類層參數(shù)。
25、本說明書提供了一種任務(wù)執(zhí)行方法,包括:
26、獲取待分類的目標(biāo)圖像;
27、將所述目標(biāo)圖像輸入預(yù)先訓(xùn)練的分類模型,以通過所述分類模型,確定所述目標(biāo)圖像對應(yīng)的分類結(jié)果,其中,所述分類模型是通過上述模型訓(xùn)練的方法訓(xùn)練得到的;
28、根據(jù)所述分類結(jié)果,執(zhí)行任務(wù)。
29、本說明書提供了一種基于特征分布的模型訓(xùn)練裝置,包括:
30、獲取模塊,用于獲取本地的各樣本圖像,并確定每個樣本圖像對應(yīng)的標(biāo)簽信息;
31、輸入模塊,用于針對每個樣本圖像,將該樣本圖像輸入待訓(xùn)練的本地分類模型,以通過所述本地分類模型,確定該樣本圖像對應(yīng)的數(shù)據(jù)特征,并根據(jù)所述數(shù)據(jù)特征確定該樣本圖像的分類結(jié)果;
32、第一確定模塊,用于根據(jù)每個樣本圖像的標(biāo)簽信息,確定每個樣本圖像對應(yīng)數(shù)據(jù)特征的數(shù)據(jù)分布,并根據(jù)所述數(shù)據(jù)分布,確定個體數(shù)據(jù)特征,并將所述個體數(shù)據(jù)特征發(fā)送給服務(wù)器,以使服務(wù)器根據(jù)接收到的各客戶端發(fā)送的個體數(shù)據(jù)特征,確定全局?jǐn)?shù)據(jù)特征,并將全局?jǐn)?shù)據(jù)特征返回給各客戶端,所述個體數(shù)據(jù)特征用于表征所述本地分類模型所使用的不同標(biāo)簽信息下的樣本圖像在整體上的特征表示;
33、第二確定模塊,用于根據(jù)所述分類結(jié)果和所述標(biāo)簽信息的之間偏差,以及所述個體數(shù)據(jù)特征和全局?jǐn)?shù)據(jù)特征之間的偏差,確定所述本地分類模型的損失值;
34、更新模塊,用于根據(jù)所述損失值,對所述本地分類模型的模型參數(shù)進(jìn)行更新,得到訓(xùn)練后本地分類模型。
35、本說明書提供了一種計算機(jī)可讀存儲介質(zhì),所述存儲介質(zhì)存儲有計算機(jī)程序,所述計算機(jī)程序被處理器執(zhí)行時實現(xiàn)上述基于特征分布的模型訓(xùn)練方法和任務(wù)執(zhí)行方法。
36、本說明書提供了一種電子設(shè)備,包括存儲器、處理器及存儲在存儲器上并可在處理器上運(yùn)行的計算機(jī)程序,所述處理器執(zhí)行所述程序時實現(xiàn)上述基于特征分布的模型訓(xùn)練方法和任務(wù)執(zhí)行方法。
37、本說明書采用的上述至少一個技術(shù)方案能夠達(dá)到以下有益效果:
38、在本說明書提供的基于特征分布的模型訓(xùn)練方法中,客戶端獲取本地的各樣本圖像,并確定每個樣本圖像對應(yīng)的標(biāo)簽信息;針對每個樣本圖像,將該樣本圖像輸入待訓(xùn)練的本地分類模型,確定該樣本圖像對應(yīng)的數(shù)據(jù)特征并確定分類結(jié)果;根據(jù)每個樣本圖像的標(biāo)簽信息,確定每個樣本圖像對應(yīng)數(shù)據(jù)特征的數(shù)據(jù)分布,并根據(jù)數(shù)據(jù)分布確定個體數(shù)據(jù)特征,將個體數(shù)據(jù)特征發(fā)送給服務(wù)器,服務(wù)器根據(jù)接收到的各客戶端發(fā)送的個體數(shù)據(jù)特征,確定全局?jǐn)?shù)據(jù)特征,并將全局?jǐn)?shù)據(jù)特征返回給各客戶端;根據(jù)分類結(jié)果和標(biāo)簽信息的之間偏差,以及個體數(shù)據(jù)特征和全局?jǐn)?shù)據(jù)特征之間的偏差,確定損失值;根據(jù)損失值對本地分類模型的模型參數(shù)進(jìn)行更新。
39、從上述方法可以看出,本方案在對模型進(jìn)行訓(xùn)練的過程中,會基于分類結(jié)果和標(biāo)簽信息的之間偏差,以及個體數(shù)據(jù)特征和全局?jǐn)?shù)據(jù)特征之間的偏差,確定損失值,相比于聯(lián)邦學(xué)習(xí)的過程中不同客戶端中的模型單純地共享模型參數(shù),本方案可以通過對不同客戶端數(shù)據(jù)特征分布在整體上對齊的方式,使得所有客戶的樣本表示空間更加緊致,有利于提高特征的表達(dá)能力,即使各個客戶的特征偏移較大,依然可以通過學(xué)習(xí)特定的特征提取能力,讓它們的表示中心盡可能一致,從而提高模型的泛化能力,進(jìn)一步提高不同客戶端個體中模型的性能。