本發(fā)明屬于信息安全,具體是一種通信高效的聯(lián)邦學習多粒度分組微調方法。
背景技術:
1、在信息安全領域,聯(lián)邦學習是一種新興的機器學習范式,可以在數(shù)據(jù)不離開本地的前提下,完成一些需要聯(lián)合用戶或者機構執(zhí)行的模型訓練任務,從而有效地保護用戶和機構隱私安全。
2、隨著深度學習的發(fā)展,模型的性能越來越強,特別是基于預訓練模型進行開發(fā)的大語言模型和多模態(tài)模型。然而,這些模型的尺寸都非常龐大,不同用戶或者機構通過聯(lián)邦學習直接訓練這些大模型會帶來龐大的通信開銷。
3、one-shot聯(lián)邦學習是一種通信高效的聯(lián)邦學習范式,不同于多輪通信的聯(lián)邦學習方式,one-shot聯(lián)邦學習期望通過一輪的通信完成聯(lián)邦學習的訓練過程,從而大大降低了多輪通信帶來的流量開銷。現(xiàn)有的one-shot聯(lián)邦學習方法主要包括:基于集成學習的方法、知識蒸餾方法和模型融合方法。
4、基于集成學習的方法,通常需要客戶端將訓練好的完整模型傳輸給服務器,在服務器上集成為一個大的模型用于推理。這種方式的缺陷在于集成模型的大小和客戶端數(shù)量有關,客戶端數(shù)量多時集成的模型大小會非常大,總的通信開銷也會增大。基于知識蒸餾的方法,通常需要服務器上有合成數(shù)據(jù)集或者公共數(shù)據(jù)集用于服務器上的模型訓練;對于不需要額外數(shù)據(jù)集的知識蒸餾方法,通常需要額外的計算開銷?;谀P腿诤系姆椒ǎǔP枰珔⑽⒄{然后傳輸整個模型,甚至需要傳輸額外的輔助參數(shù)矩陣用于服務器上的模型融合,在使用大模型進行聯(lián)邦學習訓練時通信開銷仍然非常龐大。
5、本發(fā)明提供了一種通信高效的聯(lián)邦學習多粒度分組微調方法,以解決上述技術問題。
技術實現(xiàn)思路
1、本發(fā)明旨在至少解決現(xiàn)有技術中存在的技術問題之一;為此,本發(fā)明提出了一種通信高效的聯(lián)邦學習多粒度分組微調方法,用于解決當前聯(lián)邦學習的通信開銷大以及在數(shù)據(jù)非獨立同分布即non-iid場景下的模型性能弱的技術問題。
2、為實現(xiàn)上述目的,本發(fā)明的第一方面提供了一種通信高效的聯(lián)邦學習多粒度分組微調方法,包括:
3、s1:通過lora低秩參數(shù)矩陣構建客戶端的預訓練模型;其中,lora低秩參數(shù)矩陣包括參數(shù)矩陣a和參數(shù)矩陣b;
4、s2:客戶端對參數(shù)矩陣b進行非對稱性微調并將參數(shù)矩陣b上傳至服務器;
5、s3:計算各客戶端之間的相似度向量;
6、s4:服務器進行多粒度分組并計算組內(nèi)加權平均低秩矩陣;
7、s5:客戶端進行低秩混合專家設置并加載若干組低秩矩陣得到專家lora;
8、s6:客戶端通過低秩混合專家微調專家門控參數(shù)并上傳至服務器;
9、s7:服務器對各客戶端對應的專家門控參數(shù)進行加權平均得到平均專家門控參數(shù),并用于全局模型的推理。
10、優(yōu)選的,所述通過lora低秩參數(shù)矩陣構建客戶端的預訓練模型,包括:
11、所有客戶端使用相同的預訓練模型作為后續(xù)微調的基底模型,凍結預訓練模型參數(shù)矩陣并添加lora低秩參數(shù)矩陣;其中,lora低秩參數(shù)矩陣包括參數(shù)矩陣和參數(shù)矩陣r為實數(shù),r為低秩維度,din為預訓練模型參數(shù)矩陣的輸入維度,dout為預訓練模型參數(shù)矩陣的輸出維度。
12、優(yōu)選的,所述客戶端對參數(shù)矩陣b進行非對稱性微調并將參數(shù)矩陣b上傳至服務器,包括:
13、x1:所有客戶端使用相同的lora初始化方式,初始化完成后所有客戶端的lora參數(shù)都是相同的;
14、x2:各客戶端使用本地的數(shù)據(jù)集進行訓練預訓練模型,訓練時需要凍結參數(shù)矩陣a,只訓練參數(shù)矩陣b;
15、x3:客戶端將訓練完成的參數(shù)矩陣b上傳至服務器,參數(shù)矩陣a無須上傳。
16、本發(fā)明利用lora的非對稱性微調參數(shù)矩陣b,由于在使用lora進行模型微調的過程中參數(shù)矩陣a從輸入中提取特征,參數(shù)矩陣b根據(jù)特征輸出數(shù)據(jù),因此只微調參數(shù)矩陣b與同時微調參數(shù)矩陣a和參數(shù)矩陣b相比是同樣有效的,并且需要學習和傳輸?shù)膮?shù)量只有后者的一半;本發(fā)明中客戶端在訓練預訓練模型時需要凍結參數(shù)矩陣a,只訓練參數(shù)矩陣b;客戶端將訓練完成的參數(shù)矩陣b上傳至服務器,參數(shù)矩陣a無須上傳,有利于降低聯(lián)邦學習的通信開銷。
17、優(yōu)選的,所述計算各客戶端之間的相似度向量,包括:
18、設置客戶端數(shù)量為n,每個客戶端的預訓練模型均有h層,bn,h表示第n個客戶端第h層的參數(shù)矩陣b,則第n+1個客戶端第h層與第n個客戶端第h層對應的參數(shù)矩陣b的相似度分數(shù)的計算公式為:
19、
20、計算出總共h層的layer_score(n+1,n,h)后,第n+1個客戶端和第n個客戶端間的相似度client_score的計算公式如下:
21、
22、第n個客戶端的相似度向量αn為:
23、αn=[client_score(n,1),client_score(n,2),…,client_score(n,n)];
24、其中,n=1,2,…,n,n為客戶端總數(shù);h=1,2,…,h,h為客戶端內(nèi)預訓練模型的總層數(shù)。
25、本發(fā)明通過計算各個客戶端之間的相似度向量,相似度向量能夠表示不同客戶端中數(shù)據(jù)集之間的相似程度,方便后續(xù)根據(jù)客戶端的相似度向量進行分組。
26、優(yōu)選的,所述服務器進行多粒度分組并計算組內(nèi)加權平均低秩矩陣,包括:
27、使用相似度向量和k-means算法將參數(shù)矩陣b進行多粒度分組,得到m個分組,將分組標記為i;對分組i中組內(nèi)所有參數(shù)矩陣根據(jù)客戶端對應的數(shù)據(jù)量進行加權平均,得到組內(nèi)加權平均低秩矩陣并發(fā)送至客戶端;其中,i=1,2,…,m,m為分組總數(shù);多粒度分組的方式包括:
28、客戶端粒度分組:對客戶端內(nèi)預訓練模型所有層的相似度向量求平均向量,得到客戶端粒度的相似度向量,然后再對客戶端粒度的相似度向量使用k-means算法分成m組;
29、模型塊粒度分組:對同一客戶端相同塊內(nèi)所有層的相似度向量求平均向量,得到模型塊粒度的相似度向量,然后再對模型塊粒度的相似度向量使用k-means算法分成m組;
30、模型層粒度分組:直接對模型層粒度的相似度向量使用k-means算法分成m組。
31、本發(fā)明根據(jù)相似度向量使用k-means算法進行多粒度分組,將相似程度接近的客戶端分入同一組中,有利于提高聯(lián)邦學習在non-iid場景下的性能,并且有效降低了客戶端數(shù)量龐大時使用混合專家方法傳輸?shù)膮?shù)量。
32、優(yōu)選的,所述客戶端進行低秩混合專家設置并加載若干組低秩矩陣得到專家lora,包括:
33、客戶端下載m個組內(nèi)加權平均低秩矩陣矩陣,客戶端本地的預訓練模型保持不變,將客戶端本地的參數(shù)矩陣a復制m個,并用下載的m個組內(nèi)加權平均低秩矩陣替換本地的參數(shù)矩陣b得到m個矩陣,依次將矩陣標記為第i個專家lora,得到m個專家lora。
34、優(yōu)選的,所述客戶端通過低秩混合專家微調專家門控參數(shù)并上傳至服務器,包括:
35、客戶端凍結預訓練模型和專家lora,凍結專家lora包括凍結參數(shù)矩陣a和參數(shù)矩陣對客戶端n的專家門控參數(shù)g進行微調,微調完成后將訓練好的專家門控參數(shù)g上傳至服務器;對于混合專家模型的專家門控而言,第i個專家lora的分配概率計算公式為:
36、
37、其中,x為專家門控的輸入,gi為第i個專家lora對應的專家門控參數(shù),e為自然常數(shù),softmax()函數(shù)用于計算專家門控單元的分配概率;
38、通過公式計算模型總輸出o(x)。
39、需要說明的是,客戶端在微調專家門控參數(shù)時只需微調少量輪次,否則客戶端模型容易出現(xiàn)過擬合。
40、優(yōu)選的,所述服務器對各客戶端對應的專家門控參數(shù)進行加權平均得到平均專家門控參數(shù),包括:
41、服務器根據(jù)客戶端的數(shù)據(jù)量大小對客戶端對應的專家門控參數(shù)進行加權平均,得到平均專家門控參數(shù)將平均專家門控參數(shù)和專家lora加載到服務器的全局模型中,全局模型根據(jù)平均專家門控參數(shù)和專家lora進行推理。
42、與現(xiàn)有技術相比,本發(fā)明的有益效果是:
43、1.本發(fā)明利用lora的非對稱性微調參數(shù)矩陣b,由于在使用lora進行模型微調的過程中參數(shù)矩陣a從輸入中提取特征,參數(shù)矩陣b根據(jù)特征輸出數(shù)據(jù),因此只微調參數(shù)矩陣b與同時微調參數(shù)矩陣a和參數(shù)矩陣b相比是同樣有效的,并且需要學習和傳輸?shù)膮?shù)量只有后者的一半;本發(fā)明中客戶端在訓練預訓練模型時需要凍結參數(shù)矩陣a,只訓練參數(shù)矩陣b;客戶端將訓練完成的參數(shù)矩陣b上傳至服務器,參數(shù)矩陣a無須上傳,有利于降低聯(lián)邦學習的通信開銷;此外,通過lora結合混合專家的訓練方式提升了模型在non-iid場景下的精度。
44、2.本發(fā)明通過計算各個客戶端之間的相似度向量,根據(jù)相似度向量使用k-means算法進行多粒度分組,相似度向量能夠表示不同客戶端中數(shù)據(jù)集之間的相似程度,將相似程度接近的客戶端分入同一組中,有利于提高聯(lián)邦學習在non-iid場景下的性能。