本技術涉及人工智能,尤其涉及一種層次文本分類方法、模型訓練方法、系統(tǒng)、設備及介質。
背景技術:
1、層次文本分類(hierarchical?text?classification,htc)是一項特殊的文本分類任務,旨在將文本數(shù)據(jù)系統(tǒng)地劃分為具有層次結構的多個類別,適用于類別間構成一定體系的場景。與傳統(tǒng)的文本分類不同,層次文本分類提供了更精細和結構化的標簽,使得分類結果更具靈活性和可解釋性。層次文本分類技術在工業(yè)界中具有廣泛的應用,能自適應復雜的類別體系,學習到不同類別間的關系,支持更細粒度的分類,是自然語言處理走向實用化的關鍵技術。
2、層次文本分類的類別體系是以預定義的層次結構所構建和存儲的,標簽之間具有天然的層次約束關系,例如上下層類別之間的“父-子”關系,同一層級且屬于同一父節(jié)點的“兄弟”關系等,層次文本分類的類別體系層次結構和語義關系復雜。例如,父類別可能與多個子類別有所關聯(lián),同一層次的類別之間可能存在部分重疊的語義。尤其是在低階層次上,標簽粒度更細,同一父節(jié)點標簽上的子標簽之間特征較為相似,難以區(qū)分,導致層次文本分類任務準確性低。
技術實現(xiàn)思路
1、本技術實施例的主要目的在于提出一種層次文本分類方法、模型訓練方法、系統(tǒng)、設備及介質,旨在提高層次文本分類任務準確性。
2、為實現(xiàn)上述目的,本技術實施例的一方面提出了一種層次文本分類模型訓練方法,包括以下步驟:
3、獲取根據(jù)預定義的標簽體系結構標注的訓練數(shù)據(jù)集;
4、根據(jù)對偶提示模板對所述訓練數(shù)據(jù)集中的樣本數(shù)據(jù)進行形式轉換,得到提示文本,其中,所述對偶提示模板包括文本數(shù)據(jù)、正層級關系字符串和負層級關系字符串,所述正層級關系字符串包括多個表征層級關系的正標簽掩碼字符,所述負層級關系字符串包括多個表征層級關系的負標簽掩碼字符;
5、將所述提示文本中的文本數(shù)據(jù)輸入初始化后的層次文本分類模型進行前向傳遞,得到正標簽掩碼字符上的正標簽掩碼詞向量、負標簽掩碼字符上的負標簽掩碼詞向量和預測結果,其中,所述預測結果包括層次標簽類別和層次標簽類別的置信分數(shù);
6、根據(jù)所述文本數(shù)據(jù)的正標簽掩碼詞向量、真實標簽向量和負標簽向量,在所述正標簽掩碼字符上進行正標簽對比學習,得到正標簽對比損失;
7、根據(jù)所述文本數(shù)據(jù)的負標簽掩碼詞向量、真實標簽向量和負標簽向量,在所述負標簽掩碼字符上進行負標簽對比學習,得到負標簽對比損失;
8、根據(jù)所述正標簽對比損失、所述負標簽對比損失和所述預測結果確定分類任務損失;
9、根據(jù)所述分類任務損失更新所述層次文本分類模型的參數(shù),得到訓練完成的層次文本分類模型。
10、在一些實施例中,所述根據(jù)所述正標簽對比損失、所述負標簽對比損失和所述預測結果確定分類任務損失,包括以下步驟:
11、根據(jù)多個正標簽掩碼字符上的所述正標簽掩碼詞向量的層級關系確定目標正標簽掩碼詞向量的正子標簽詞向量和負子標簽詞向量;
12、根據(jù)所述目標正標簽掩碼詞向量與所述正子標簽詞向量的第一相似度,以及所述目標正標簽掩碼詞向量與所述負子標簽詞向量的第二相似度確定跨層排序損失;
13、對所述跨層排序損失、所述正標簽對比損失和所述負標簽對比損失進行加權組合,得到提示標簽對比學習損失;
14、根據(jù)所述提示標簽對比學習損失和所述預測結果確定分類任務損失。
15、在一些實施例中,所述層次文本分類模型訓練方法還包括以下步驟,包括以下步驟:
16、對樣本標簽文本進行標識詞映射,得到樣本標簽的標識表示標簽,其中,所述樣本標簽包括樣本數(shù)據(jù)的真實標簽文本和負標簽文本;
17、采用預設的層級結構表示形式對具有層級結構信息的標識表示標簽進行層級展平表示處理,得到標簽序列;
18、將所述標簽序列輸入所述層次文本分類模型的特征表示層,得到樣本標簽向量,其中,所述樣本標簽向量包括真實標簽向量和負標簽向量。
19、在一些實施例中,所述負標簽文本通過以下步驟得到,包括:
20、根據(jù)所述置信分數(shù)從大到小的順序對所述層次文本分類模型輸出的層次標簽類別進行排序;
21、將排名前第一預設數(shù)量的層次標簽類別作為負標簽文本。
22、在一些實施例中,所述根據(jù)所述提示標簽對比學習損失和所述預測結果確定分類任務損失,包括以下步驟:
23、基于交叉熵損失函數(shù),根據(jù)所述層次標簽類別的層級關系預測分數(shù)確定一致性損失,其中,所述層級關系預測分數(shù)表征所述層次文本分類模型對所述層次標簽類別是否符合真實層級關系的預測分數(shù);
24、基于交叉熵損失函數(shù),根據(jù)所述層次標簽類別的語義分類預測分數(shù)確定分類正確性損失,其中,所述語義分類預測分數(shù)表征所述層次文本分類模型對所述層次標簽類別是否符合真實標簽的預測分數(shù);
25、根據(jù)所述一致性損失和所述分類正確性損失確定輔助任務損失;
26、根據(jù)所述提示標簽對比學習損失、所述輔助任務損失、模型整體分類損失和掩碼語言模型損失確定分類任務損失。
27、為實現(xiàn)上述目的,本技術實施例的另一方面提出了一種層次文本分類方法,包括以下步驟:
28、獲取待推理文本;
29、根據(jù)對偶提示模板對所述待推理文本進行形式轉換,得到輸入文本;
30、將所述輸入文本輸入層次文本分類模型得到層次文本預測結果,并根據(jù)所述層次文本預測結果確定所述待推理文本的層次分類結果,其中,所述層次文本分類模型通過前面實施例的層次文本分類模型訓練方法得到。
31、在一些實施例中,所述層次文本預測結果包括多個層級標簽類別和每個層級標簽類別的置信分數(shù),所述根據(jù)所述層次文本預測結果確定所述待推理文本的層次分類結果,包括以下步驟:
32、判斷所述層次文本預測結果中的最高置信分數(shù)是否大于預設分數(shù);
33、當所述層次文本預測結果中的最高置信分數(shù)大于預設分數(shù),則根據(jù)所述層次文本預測結果確定所述待推理文本的層次分類結果;
34、當所述層次文本預測結果中的最高置信分數(shù)小于或等于預設分數(shù),則將所述層次文本預測結果中置信分數(shù)排名前第二預設數(shù)量的層級標簽類別作為候選層級標簽;
35、根據(jù)所述候選層級標簽填充大語言模型的上下文學習提示模板,得到文本分類提問文本;
36、將所述文本分類提問文本輸入大語言模型得到文本分類答案,并對所述文本分類答案進行分析提取,得到所述待推理文本的層次分類結果。
37、為實現(xiàn)上述目的,本技術實施例的另一方面提出了一種層次文本分類模型訓練系統(tǒng),包括:
38、第一模塊,用于獲取根據(jù)預定義的標簽體系結構標注的訓練數(shù)據(jù)集;
39、第二模塊,用于根據(jù)對偶提示模板對所述訓練數(shù)據(jù)集中的樣本數(shù)據(jù)進行形式轉換,得到提示文本,其中,所述對偶提示模板包括文本數(shù)據(jù)、正層級關系字符串和負層級關系字符串,所述正層級關系字符串包括多個表征層級關系的正標簽掩碼字符,所述負層級關系字符串包括多個表征層級關系的負標簽掩碼字符;
40、第三模塊,用于將所述提示文本中的文本數(shù)據(jù)輸入初始化后的層次文本分類模型進行前向傳遞,得到正標簽掩碼字符上的正標簽掩碼詞向量、負標簽掩碼字符上的負標簽掩碼詞向量和預測結果,其中,所述預測結果包括層次標簽類別和層次標簽類別的置信分數(shù);
41、第四模塊,用于根據(jù)所述文本數(shù)據(jù)的正標簽掩碼詞向量、真實標簽向量和負標簽向量,在所述正標簽掩碼字符上進行正標簽對比學習,得到正標簽對比損失;根據(jù)所述文本數(shù)據(jù)的負標簽掩碼詞向量、真實標簽向量和負標簽向量,在所述負標簽掩碼字符上進行負標簽對比學習,得到負標簽對比損失;根據(jù)所述正標簽對比損失、所述負標簽對比損失和所述預測結果確定分類任務損失;
42、第五模塊,用于根據(jù)所述分類任務損失更新所述層次文本分類模型的參數(shù),得到訓練完成的層次文本分類模型。
43、為實現(xiàn)上述目的,本技術實施例的另一方面提出了一種電子設備,所述電子設備包括存儲器、處理器、存儲在所述存儲器上并可在所述處理器上運行的程序以及用于實現(xiàn)所述處理器和所述存儲器之間的連接通信的數(shù)據(jù)總線,所述程序被所述處理器執(zhí)行時實現(xiàn)上述實施例所述的層次文本分類模型訓練方法或者層次文本分類方法。
44、為實現(xiàn)上述目的,本技術實施例的另一方面提出了一種存儲介質,所述存儲介質為計算機可讀存儲介質,用于計算機可讀存儲,所述存儲介質存儲有一個或者多個程序,所述一個或者多個程序可被一個或者多個處理器執(zhí)行,以實現(xiàn)上述實施例所述的層次文本分類模型訓練方法或者層次文本分類方法。
45、本技術提出的層次文本分類方法、模型訓練方法、系統(tǒng)、設備及介質,其通過根據(jù)對偶提示模板對訓練數(shù)據(jù)集中的樣本數(shù)據(jù)進行形式轉換,得到提示文本,并基于該對偶提示模板執(zhí)行層次感知的提示標簽對比學習算法,即提取標簽提示詞的特征嵌入,再將標簽掩碼位上的嵌入映射到類別名稱上,進一步結合真實標簽和負標簽計算正標簽對比損失和負標簽對比損失。利用結合正標簽對比損失和負標簽對比損失的分類任務損失訓練層次文本分類模型,使得模型不僅能夠識別出正確標簽,同時還能識別出傾向于與正標簽產生混淆的錯誤負標簽,有助于提取到具有強判別力且富含層次信息的語義特征,從而提高層次文本分類任務準確性。