本發(fā)明涉及圖像處理領(lǐng)域,具體涉及一種圖像描述模型的訓(xùn)練方法及系統(tǒng)、圖像描述的生成方法。
背景技術(shù):
1、圖像描述,或者稱為圖像標(biāo)注,是一種將圖像內(nèi)容轉(zhuǎn)化為自然語(yǔ)言描述的技術(shù)。圖像描述在現(xiàn)階段被廣泛應(yīng)用于各種場(chǎng)景,例如視覺(jué)搜索、自動(dòng)圖像標(biāo)注、內(nèi)容推薦以及為視覺(jué)障礙人士提供輔助等。圖像描述的準(zhǔn)確性直接關(guān)系到這些應(yīng)用的效果和用戶體驗(yàn),具有重要的實(shí)踐意義。
2、現(xiàn)有的圖像描述方法大多采用深度學(xué)習(xí)模型來(lái)實(shí)現(xiàn),其中最常見(jiàn)的方法是將卷積神經(jīng)網(wǎng)絡(luò)和transformer結(jié)合起來(lái),使用卷積神經(jīng)網(wǎng)絡(luò)來(lái)提取圖像的視覺(jué)特征,然后用transformer生成描述。然而,卷積神經(jīng)網(wǎng)絡(luò)提取的網(wǎng)格特征通常著重于圖像的全局特征和細(xì)粒度細(xì)節(jié),而忽略了粗粒度信息,如對(duì)象信息;而且由于網(wǎng)格特征在輸入transformer之前通常要進(jìn)行切分和展平,這可能導(dǎo)致以對(duì)象為中心的高級(jí)語(yǔ)義信息的丟失,造成現(xiàn)有的圖像描述方法生成的圖像描述不準(zhǔn)確或者不完整。
技術(shù)實(shí)現(xiàn)思路
1、為解決上述問(wèn)題,本發(fā)明提供一種圖像描述模型的訓(xùn)練方法及系統(tǒng)、圖像描述的生成方法。
2、本發(fā)明第一方面公開(kāi)了一種圖像描述模型的訓(xùn)練方法,包括:
3、將樣本輸入圖像特征提取器提取網(wǎng)格特征和區(qū)域特征,所述樣本包括目標(biāo)圖像及其對(duì)應(yīng)的圖像描述標(biāo)簽;
4、在編碼器的每一層,將網(wǎng)格特征和區(qū)域特征融合后進(jìn)行編碼,得到最終編碼;
5、將最終編碼輸入解碼器進(jìn)行解碼,得到目標(biāo)圖像的圖像描述;
6、根據(jù)預(yù)定義的損失函數(shù)對(duì)所述圖像描述模型進(jìn)行訓(xùn)練,直到達(dá)到預(yù)定的收斂標(biāo)準(zhǔn)。
7、進(jìn)一步的,所述在編碼器的每一層,將網(wǎng)格特征和區(qū)域特征融合后進(jìn)行編碼,得到最終編碼,包括:
8、對(duì)于編碼器第一層,將區(qū)域特征注入網(wǎng)格特征中得到第一層增強(qiáng)網(wǎng)格特征,將第一層增強(qiáng)網(wǎng)格特征輸入第一層編碼器得到第一層對(duì)應(yīng)的融合網(wǎng)格特征,以區(qū)域特征作為編碼器的第一層對(duì)應(yīng)的增強(qiáng)區(qū)域特征;
9、對(duì)于編碼器的第i+1層,將編碼器的第o層輸出的融合網(wǎng)格特征提取全局語(yǔ)義信息和細(xì)粒度語(yǔ)義信息融合到編碼器的第i層對(duì)應(yīng)的增強(qiáng)區(qū)域特征中,得到第i+1層的增強(qiáng)區(qū)域特征,并將第i+1層的增強(qiáng)區(qū)域特征注入編碼器的第i層輸出的融合網(wǎng)格特征中得到第i+1層增強(qiáng)網(wǎng)格特征,將第i+1層增強(qiáng)網(wǎng)格特征輸入第i+1層編碼器得到第i+1層對(duì)應(yīng)的融合網(wǎng)格特征;
10、以編碼器最后一層輸出的融合網(wǎng)格特征作為最終編碼。
11、進(jìn)一步的,將第i+1層增強(qiáng)網(wǎng)格特征輸入第i+1層編碼器得到第i+1層對(duì)應(yīng)的融合網(wǎng)格特征,包括:
12、將第i+1層增強(qiáng)網(wǎng)格特征輸入第i+1層編碼器得到第i+1層對(duì)應(yīng)的融合網(wǎng)格特征
13、
14、其中,ffn表示前饋網(wǎng)絡(luò),layernorm代表層歸一化操作,mhmla為多頭多級(jí)注意力,第i+1層增強(qiáng)網(wǎng)格特征為:
15、
16、其中,代表編碼器第i+1層對(duì)應(yīng)的增強(qiáng)區(qū)域特征,α為預(yù)設(shè)的權(quán)重參數(shù),g代表網(wǎng)格特征,r代表區(qū)域特征。
17、進(jìn)一步的,將編碼器的第i層輸出的融合網(wǎng)格特征提取全局語(yǔ)義信息和細(xì)粒度語(yǔ)義信息融合到編碼器的第i層對(duì)應(yīng)的增強(qiáng)區(qū)域特征中,得到第i+1層的增強(qiáng)區(qū)域特征,包括:
18、編碼器第i+1層對(duì)應(yīng)的增強(qiáng)區(qū)域特征為:
19、
20、
21、其中,為第i+1層編碼器對(duì)應(yīng)的自注意力區(qū)域特征,為第i+1層編碼器對(duì)應(yīng)的歸一化后的自注意力區(qū)域特征,為第i+1層編碼器對(duì)應(yīng)的歸一化后的融合網(wǎng)格特征,為第i+1層編碼器對(duì)應(yīng)的交叉注意力區(qū)域特征,ffn表示前饋網(wǎng)絡(luò),layernorm代表層歸一化操作,mhmla為多頭多級(jí)注意力。
22、進(jìn)一步的,所述多頭多級(jí)注意力mhmla為:
23、mhmla(q,k,v)=(concat(head1,head2,…,headh))wo;
24、headh=multilevelattention(q,k,v);
25、其中,concat表示向量拼接操作,h代表多頭注意力中的注意力頭的數(shù)量,headh代表第h個(gè)注意力頭的輸出,wo為可訓(xùn)練的投影矩陣,q,k,v分別為注意力機(jī)制的查詢矩陣、鍵矩陣與值矩陣,multilevelattention為多級(jí)注意力:
26、multilevelattention(q,k,v)=ffn(layernorm(q+(s+λc)));
27、其中,layernorm表示層歸一化操作,λ表示權(quán)重因子,s為空間注意力權(quán)重,c為通道注意力權(quán)重:
28、
29、其中,spatialatention代表空間注意力,dk為放縮因子,代表矩陣相乘符號(hào),channelatention代表通道注意力。
30、進(jìn)一步的,將最終編碼輸入解碼器進(jìn)行解碼,得到目標(biāo)圖像的圖像描述,包括:
31、對(duì)于待生成的圖像描述的第t個(gè)單詞的詞向量wt,以已生成的單詞的詞向量按單詞的順序組成輸入矩陣w<t:
32、w<t=[w0,…,wt-1];
33、其中,w0表示預(yù)定義的句子開(kāi)始標(biāo)識(shí);
34、將輸入矩陣輸入解碼器,第l+1層編碼器對(duì)第t個(gè)單詞的輸出為:
35、
36、其中,crossattention代表多頭交叉注意力模塊,代表交叉注意力詞向量,代表歸一化交叉注意力詞向量,代表歸一化詞向量,mmhsa代表帶掩碼的多頭自注意力模塊,代表帶掩碼的多頭自注意力詞向量,代表第l層解碼器生成的前t-1個(gè)單詞的詞向量矩陣,layernorm代表層歸一化操作;
37、將最后一層解碼器對(duì)第t個(gè)單詞的輸出作為圖像描述中的第t個(gè)單詞wt;
38、經(jīng)過(guò)t個(gè)時(shí)間步的單詞自回歸生成,獲得目標(biāo)圖像的描述句子序列。
39、進(jìn)一步的,所述損失函數(shù)為:
40、
41、其中,θ表示圖像描述模型的參數(shù),pθ代表模型在t時(shí)刻下在條件基礎(chǔ)上生成單詞的條件概率,t表示生成圖像描述需要的時(shí)間步的數(shù)量,是目標(biāo)圖像對(duì)應(yīng)的圖像描述標(biāo)簽。
42、進(jìn)一步的,所述生成方法還包括:
43、利用強(qiáng)化學(xué)習(xí)模型來(lái)微調(diào)圖像描述模型,采用cider分?jǐn)?shù)作為獎(jiǎng)勵(lì)函數(shù)r,并進(jìn)行自我批評(píng)序列訓(xùn)練,強(qiáng)化學(xué)習(xí)模型的損失函數(shù)為:
44、
45、其中,e代表期望值,u1:tpθ代表第1到第t個(gè)時(shí)間步生成對(duì)應(yīng)的圖像描述的幾率;
46、基于梯度下降法進(jìn)行模型參數(shù)θ的調(diào)整,使用下述的梯度表達(dá)式計(jì)算樣本梯度:
47、
48、其中,b是采樣序列獲得的獎(jiǎng)勵(lì)的平均值,k是采樣的序列的數(shù)量,uk代表第k個(gè)采樣序列,代表當(dāng)前樣本關(guān)于圖像描述模型參數(shù)為θ時(shí)的梯度。
49、本發(fā)明第二方面公開(kāi)了一種圖像描述的生成方法,包括:
50、根據(jù)如本發(fā)明第一方面公開(kāi)的任一種圖像描述模型的訓(xùn)練方法獲得圖像描述模型;
51、將待處理圖像輸入至圖像描述模型,得到待處理圖像對(duì)應(yīng)的圖像描述。
52、本發(fā)明第三方面公開(kāi)了一種圖像描述模型的訓(xùn)練系統(tǒng),包括:
53、提取模塊,用于將樣本輸入圖像特征提取器提取網(wǎng)格特征和區(qū)域特征,所述樣本包括目標(biāo)圖像及其對(duì)應(yīng)的圖像描述標(biāo)簽;
54、編碼模塊,用于在編碼器的每一層,將網(wǎng)格特征和區(qū)域特征融合后進(jìn)行編碼,得到最終編碼;
55、解碼模塊,用于將最終編碼輸入解碼器進(jìn)行解碼,得到目標(biāo)圖像的圖像描述;
56、訓(xùn)練模塊,用于根據(jù)預(yù)定義的損失函數(shù)對(duì)所述圖像描述模型進(jìn)行訓(xùn)練,直到達(dá)到預(yù)定的收斂標(biāo)準(zhǔn)。
57、本發(fā)明通過(guò)將區(qū)域特征中的對(duì)象級(jí)信息注入到網(wǎng)格特征中,并從網(wǎng)格特征中提取全局信息和細(xì)粒度語(yǔ)義到區(qū)域特征中,可以有效地整合細(xì)粒度細(xì)節(jié)和粗粒度信息,彌補(bǔ)了網(wǎng)格特征和區(qū)域特征之間的差異,從而提高了生成的圖像描述的準(zhǔn)確度。