本發(fā)明涉及深度學習與計算機視覺,特別是關(guān)于一種基于提示學習的鳥類細粒度識別增量學習方法及裝置。
背景技術(shù):
1、鳥類圖像識別屬于細粒度圖像識別任務(wù)。細粒度圖像識別是對屬于同一大類物體下的不同子類進行區(qū)分,例如區(qū)分不同型號的車,不同科、目的鳥等。一般的圖像識別任務(wù)是對物體所屬的大類進行劃分,比如區(qū)分狗、車、鳥等。對于一般的圖像識別任務(wù),不同種類間的物體特征差異較大,識別相對容易。對于細粒度圖像,不同子類別間的拍攝物體有相似的外形和顏色,而相同子類別內(nèi)受到拍攝物體背景、視角和姿態(tài)等因素影響,存在較大的類內(nèi)差異。因此,細粒度圖像識別任務(wù)更具有挑戰(zhàn)性。
技術(shù)實現(xiàn)思路
1、本發(fā)明的目的在于提供一種基于提示學習的鳥類細粒度識別增量學習方法及裝置,其能夠獲得鳥類細粒度圖像更高的識別精度。
2、為實現(xiàn)上述目的,本發(fā)明提供一種基于提示學習的鳥類細粒度識別增量學習方法,其包括:
3、步驟1,搭建增量學習模型,增量學習模型具體包括圖像特征提取模塊、查詢特征提取模塊、文本提示編碼模塊、視覺提示生成模塊和分類頭;
4、步驟2,通過圖像特征提取模塊將預處理圖像x∈rh×w×c重構(gòu)變?yōu)橐粋€序列其中,p表示每個圖像塊的分辨率,c為通道數(shù),l為得到序列的長度,圖像特征提取模塊描述為fe為編碼模塊,fr為嵌入層,預處理圖像x輸入嵌入層fr,得到嵌入特征xe=fr(x)∈rd×l;
5、步驟3,通過查詢特征提取模塊接收預處理圖像x,輸出嵌入特征h(x),再將嵌入特征h(x)與可學習參數(shù)wp做點積,獲得查詢特征;
6、步驟4,通過文本提示編碼模塊接收種級類別標簽,輸出與種級類別標簽對應(yīng)的多粒度文本信息,再將多粒度文本信息轉(zhuǎn)化為獨熱編碼向量,作為文本提示向量,存儲到文本提示池pw;
7、步驟5,通過視覺提示生成模塊構(gòu)建一個可學習的視覺提示池,基于鍵值對的查詢機制,從視覺提示池中選擇最終視覺提示子集,存儲到視覺提示池pi;
8、步驟6,將文本提示池和視覺提示池共同與步驟2獲得的嵌入特征xe進行拼接,得到拼接結(jié)果
9、步驟7,將拼接結(jié)果xinput先后輸入編碼模塊fe、分類頭,輸出預測分類結(jié)果,再根據(jù)預測分類結(jié)果對增量學習模型的參數(shù)進行優(yōu)化,并隨著增量學習模型學習不同分類任務(wù)來更新增量學習模型的參數(shù),引導增量學習模型進行預測。
10、進一步地,步驟3具體包括:
11、將預處理圖像x輸入查詢特征提取模塊h(·):rh×w×c→rd,提取查詢特征h,h∈rd,r為實數(shù)集合,查詢特征h包括:
12、(1)查詢特征提取模塊的編碼模塊最后輸出的類別標記hcls=h(x)[cls];
13、(2)查詢特征提取模塊的編碼模塊最后輸出的嵌入特征在第一個維度上取最大值得到查詢特征hmax=h(x)max;
14、(3)查詢特征提取模塊的編碼模塊最后輸出的嵌入特征在第一個維度上取最均值得到查詢特征hmean=h(x)mean;
15、(4)查詢特征提取模塊的編碼模塊最后輸出的嵌入特征與可學習參數(shù)wp做點積得到查詢特征hl=h(x)·wp。
16、進一步地,步驟4具體包括:
17、步驟41,從現(xiàn)有的數(shù)據(jù)集中得到圖像x中每一個種級類別標簽,依此獲得圖像x的種級類別標簽集合;
18、步驟42,根據(jù)步驟41的每一個種級類別標簽,擴充每一個種級類別標簽對應(yīng)的多粒度文本信息,多粒度文本信息包括目級類別標簽和科級類別標簽;
19、步驟43,將步驟42獲得的目級類別標簽和科級類別標簽分別轉(zhuǎn)變?yōu)楠殶峋幋a向量,作為種級類別標簽對應(yīng)的文本提示向量,存儲到文本提示池pw。
20、進一步地,步驟5中的“基于鍵值對的查詢機制,從視覺提示池中選擇最終視覺提示子集”的方法具體包括:
21、步驟51,為每一個視覺提示向量分別引入一個索引鍵值,再依據(jù)每一個視覺提示向量及對應(yīng)的鍵值,獲得視覺提示池;
22、步驟52,將步驟3輸出的查詢特征hl與視覺提示池中所有視覺提示向量進行余弦相似度計算,根據(jù)余弦相似度值大小進行降序排列,得到余弦相似度集合s={s1,...,sm},s1為第1個余弦相似度值,即數(shù)值最大的余弦相似度值,依此類推,sm第m個余弦相似度值,從余弦相似度集合s中選出前t個余弦相似度值,即余弦相似度子集s'={s1,...,st};
23、步驟53,根據(jù)余弦相似度子集s',將計算s1,…,st時對應(yīng)的視覺提示向量描述為初步視覺提示子集表示索引子集s'對應(yīng)的第i個視覺提示,然后將初步視覺提示子集ps中的單個視覺提示進行重塑,得到最終視覺提示子集pl。
24、進一步地,步驟7具體包括:
25、步驟71,將拼接結(jié)果xinput先后輸入到編碼模塊fe、分類頭g,通過分類頭將輸入的特征圖轉(zhuǎn)換為最終的分類結(jié)果,也就是得到預測分類結(jié)果g(fe(xinput)),再計算與預處理圖像x對應(yīng)的標簽y的交叉熵損失如下式(2):
26、
27、式中,l表示交叉熵損失函數(shù);
28、步驟72,將文本提示池描述為pw和視覺提示池pi組成可學習提示池,對可學習提示池進行訓練:從可學習提示池中選擇一組選擇最終視覺提示子集與圖像嵌入特征xe結(jié)合,得到特定任務(wù)的嵌入特征xinput;在訓練過程中,按照上述的視覺提示選擇過程來實現(xiàn)增量學習模型的更新,引入索引鍵值和查詢特征的損失函數(shù)如下式(3):
29、
30、式中,表示選定的最終視覺提示子集中視覺提示向量對應(yīng)的索引鍵值的集合,γ為余弦相似度函數(shù),h表示查詢特征與中索引鍵值的損失,λ>0是平衡兩項損失重要性的權(quán)重參數(shù)。
31、本發(fā)明還提供一種基于提示學習的鳥類細粒度識別增量學習裝置,其包括:
32、圖像特征提取模塊,其用于將預處理圖像x∈rh×w×c重構(gòu)變?yōu)橐粋€序列其中,p表示每個圖像塊的分辨率,c為通道數(shù),l為得到序列的長度,圖像特征提取模塊描述為fe為編碼模塊,fr為嵌入層,預處理圖像x輸入嵌入層fr,得到嵌入特征xe=fr(x)∈rd×l;
33、查詢特征提取模塊,其用于接收預處理圖像x,輸出嵌入特征h(x),再將嵌入特征h(x)與可學習參數(shù)wp做點積,獲得查詢特征;
34、文本提示編碼模塊,其用于接收種級類別標簽,輸出與種級類別標簽對應(yīng)的多粒度文本信息,再將多粒度文本信息轉(zhuǎn)化為獨熱編碼向量,作為文本提示向量,存儲到文本提示池pw;
35、視覺提示生成模塊,其用于構(gòu)建一個可學習的視覺提示池,基于鍵值對的查詢機制,從視覺提示池中選擇最終視覺提示子集,存儲到視覺提示池pi;
36、訓練模塊,其用于將文本提示池和視覺提示池共同與嵌入特征xe進行拼接,得到拼接結(jié)果將拼接結(jié)果xinput先后輸入編碼模塊fe、分類頭,輸出預測分類結(jié)果,再根據(jù)預測分類結(jié)果對增量學習模型的參數(shù)進行優(yōu)化,并隨著增量學習模型學習不同分類任務(wù)來更新增量學習模型的參數(shù),引導增量學習模型進行預測。
37、進一步地,查詢特征提取模塊具體包括:
38、將預處理圖像x輸入查詢特征提取模塊h(·):rh×w×c→rd,提取查詢特征h,h∈rd,r為實數(shù)集合,查詢特征h包括:
39、(1)查詢特征提取模塊的編碼模塊最后輸出的類別標記hcls=h(x)[cls];
40、(2)查詢特征提取模塊的編碼模塊最后輸出的嵌入特征在第一個維度上取最大值得到查詢特征hmax=h(x)max;
41、(3)查詢特征提取模塊的編碼模塊最后輸出的嵌入特征在第一個維度上取最均值得到查詢特征hmean=h(x)mean;
42、(4)查詢特征提取模塊的編碼模塊最后輸出的嵌入特征與可學習參數(shù)wp做點積得到查詢特征hl=h(x)wp。
43、進一步地,文本提示編碼模塊具體包括:
44、從現(xiàn)有的數(shù)據(jù)集中得到圖像x中每一個種級類別標簽,依此獲得圖像x的種級類別標簽集合;根據(jù)每一個種級類別標簽,擴充每一個種級類別標簽對應(yīng)的多粒度文本信息,多粒度文本信息包括目級類別標簽和科級類別標簽;將目級類別標簽和科級類別標簽分別轉(zhuǎn)變?yōu)楠殶峋幋a向量,作為種級類別標簽對應(yīng)的文本提示向量,存儲到文本提示池pw。
45、進一步地,視覺提示生成模塊具體包括:
46、索引鍵引入單元,其用于為每一個視覺提示向量分別引入一個索引鍵值,再依據(jù)每一個視覺提示向量及對應(yīng)的鍵值,獲得視覺提示池;
47、余弦相似度計算單元,其用于將查詢特征hl與視覺提示池中所有視覺提示向量進行余弦相似度計算,根據(jù)余弦相似度值大小進行降序排列,得到余弦相似度集合s={s1,…,sm},s1為第1個余弦相似度值,即數(shù)值最大的余弦相似度值,依此類推,sm第m個余弦相似度值,從余弦相似度集合s中選出前t個余弦相似度值,即余弦相似度子集s'={s1,...,st};
48、視覺提示重塑單元,其用于根據(jù)余弦相似度子集s',將計算s1,…,st時對應(yīng)的視覺提示向量描述為初步視覺提示子集表示索引子集s'對應(yīng)的第i個視覺提示,然后將初步視覺提示子集ps中的單個視覺提示進行重塑,得到最終視覺提示子集pl。
49、進一步地,訓練模塊具體包括:
50、交叉熵損失計算單元,其用于將拼接結(jié)果xinput先后輸入到編碼模塊fe、分類頭g,通過分類頭將輸入的特征圖轉(zhuǎn)換為最終的分類結(jié)果,也就是得到預測分類結(jié)果g(fe(xinput)),再計算與預處理圖像x對應(yīng)的標簽y的交叉熵損失如下式(2):
51、
52、式中,l表示交叉熵損失函數(shù);
53、索引鍵值和查詢特征的損失計算單元,其用于將文本提示池描述為pw和視覺提示池pi組成可學習提示池,對可學習提示池進行訓練:從可學習提示池中選擇一組選擇最終視覺提示子集與圖像嵌入特征xe結(jié)合,得到特定任務(wù)的嵌入特征xinput;在訓練過程中,按照上述的視覺提示選擇過程來實現(xiàn)增量學習模型的更新,引入索引鍵值和查詢特征的損失函數(shù)如下式(3):
54、
55、式中,表示選定的最終視覺提示子集中視覺提示向量對應(yīng)的索引鍵值的集合,γ為余弦相似度函數(shù),h表示查詢特征與中索引鍵值的損失,λ>0是平衡兩項損失重要性的權(quán)重參數(shù)。
56、本發(fā)明由于采取以上技術(shù)方案,其具有以下優(yōu)點:
57、本發(fā)明通過基于提示學習的增量學習模型實現(xiàn)鳥類細粒度圖像識別。在增量學習模型中引入可學習的視覺提示,緩解增量學習模型中災難性遺忘的現(xiàn)象;對于鳥類細粒度識別任務(wù),引入不同粒度的文本信息作為增量學習模型中的文本提示向量,與視覺提示融合,實現(xiàn)由粗到細地學習不同鳥類的特征,提升鳥類細粒度識別精度。