ReZero -אלגוריתם לאימון מהיר של רשתות עמוקות במיוחד

אני כמעט תמיד מתעצבן כשיש עבודה שטוענת שהיא "מגדירה את ה-Resnet מחדש". בדרך כלל מדובר באיזשהי אקטיבציה חדשה (מישהו שמע מ-Mish?) אבל לרוב יש לעבודות האלה אחת משלוש בעיות:
- החוקרים ניסו לאמן רק על משימה אחת (בדרך כלל קלסיפיקציה של תמונות)
- יש איזשהו טריידאוף שהוא לא תמיד ברור (האימון נהיה מהיר יותר, אבל התוצאות פחות טובות)
- אין קוד פתוח.
הבעיה השלישית היא כמובן הכי חמורה, כי כדי שאני אנסה להטמיע מאמר בתוך פרוייקט שאני עובד עליו כדאי שזה יהיה משהו קל להטמעה. בעיה מספר אחת גם חמורה כי אני רוצה לדעת שגם אם אני כבר השקעתי את הזמן להשתמש בטריק אז שהסיכויים גבוהים שזה באמת יעזור.
אז עם הפתיח הזה, בואו נדבר על:
ReZero is All You Need: Fast Convergence at Large Depth
Bachlechner, B. Majumder, H. Mao, G. Cottrell, J. McAuley (UC San Diego, 2020)
הקדמה:
מאז ראשית הימים (2015?) ידוע שיש שני רצונות מתנגשים לכל מי שמתעסק ברשתות. מצד אחד יש רצון לעשות את הרשת עמוקה יותר, כי (לא נתווכח על זה עכשיו) רשתות עמוקות מגיעות לביצועים גבוהים יותר. מנגד, ככל שהרשתות נהיות עמוקות יותר קשה יותר לאמן אותן, בגלל שלל בעיות. למרות שהיו יודעים לפתור חלק מהבעיות עדיין מתקיים כמעט תמיד הכלל שרשת עמוקה יותר = אימון יותר ארוך.
כדי לפתוח את הבעיה הזאת, הוצעו לאורך השנים מספר רב של פתרונות, ובכללי יש שלושה דרכים לטפל בבעיות האימון של רשתות עמוקות:
א. Initialization – כשהתחלתי לעבוד עם רשתות ב-2015, אלה היו 99 אחוז מהבעיות שהיו צצות תמיד. זה תחום שהיו בו פעם "כמה תשובות" והיום כולנו משתמשים פשוט ב-Relu בתור אקטיבציה ולכן ב-Xavier initialization שנקרא קצת שונה בכל חבילת קוד. קשה לזכור עד כמה זה עזר לאימון להתכנס, אבל תנסו להוריד את זה ותראו שזה הופך אימון של רשתות עם מעל 10-15 שכבות לכמעט בלתי אפשרי.
ב. Per-layer mormalization – שיטה שנכנסה טיפה מאוחר יותר ובעצם נותנת לנו לשלוט בממוצע ובסטיית התקן שכל שכבה ברשת מייצרת. זה גם הופך את האימון ליותר מהיר, וגם גורם לחלקים עמוקים יותר ברשת להתאמן בלי בעיה. רובנו משתמשים ב-BatchNorm, בעוד שחלקנו כבר עברו ל-LayerNorm או ל-GroupNorm.
ג. Residual connection – הרעיון שהוצג במקור על ידי על ידי החוקרים מ-Microsoft ב-2015 הוא בבסיסו רעיון מאוד פשוט: נשים הרבה מאוד שכבות ואם חלק מהן מתחרפנות במהלך האימון, פשוט ניתן לרשת מכניזם "לדלג" מעליהן ואולי אחר כך ששאר הרשת תתייצב אז השכבות האלה יחזרו לתת עבודה. אין ספק שזה אחד הרעיונות שיש בהם את השימוש הרב ביותר בתעשייה ולמאמר המקורי יש כמעט 50K ציטוטים (אחד המאמרים הכי מצוטטים בתחום) ואולי מליון מימושים אונליין. באופן טבעי, בגלל היתרון של הקשרים האלה, נעשתה עבודה רבה בנסיון לממש פתרונות טובים יותר, אך לדעתי האישית מעטים מהם באמת מעניינים.
מה שמשותף לכל שלושת הפתרונות האלה: כולם קלים למימוש ואפשר בקלות יחסית להשתמש בהם במגוון של ארכיטקטורות.
ועכשיו, נגיע ל-ReZero (Residual with Zero initialization):
אמל"ק:
הטריק הוא לעשות שני דברים: דבר ראשון להכפיל את את ה-Output של כל שכבה במשקולות נלמדות (קצת כמו שעושים ב-BN רק בלי ה-Bias) והאחרי ההכפלה הזאת לחבר לתוצאה של ה-Residual connection. דבר שני, לדאוג שהערך של המשקולות הנלמדות האלה יהיה אפס ב-Initialization.
מה הרעיון מאחורי זה?
הרעיון שעומד בסיס הקונספט הזה מערב קצת מטריצות ואלגברה לינארית ובעיקר הרבה נפנופי ידיים (לדעתי הטיעון המתמטי לא מרגיש שלם, אבל אני גם לא הבן-אדם לשפוט את זה). עם זאת, המטרה היא תמיד לנסות לשמר תוך כדי האימון תכונה שנקראת Dynamical Isometry. באופן כללי זה אומר שאם אנחנו מסתכלים על מטריצת היעקוביאן (Jacobian matrix) של כל שכבה, אנחנו נרצה מאוד שההתפלגות של כל הערכים תהיה קרובה ל-1. ישנן סיבות רצות לכך שנרצה שההתפלגות הזאת תשמר, אבל באופן כללי הרציונל הוא שזה עוזר למידע להגיע לעומק הרשת ב-Forward pass ועוזר לגרדיאנטים לחזור אחורה ב-Backward pass. למעשה, כל הרעיונות שציינו קודם (Initialization, Residual connections ו-BatchNormalization) בעצם תורמים לשימור ההתפלגות הזאת, גם אם לא בצורה ישירה. אם אתם מתעיינים בתיאורה שעומדת מאחורי הנושא, יש מאמר מבריק של Google brain מ-2018 שמשחק עם הרעיון הזה וגם מנסה לקבע את ההתפלגות הזאת באימון. כדי להדגים את האפקטיביות של העקרון הם מרכיבים רשת עם 10K שכבות ומראים איך הם עדיין מצליחים לגרום לה להתאמן, הישג מרשים במיוחד.
במאמר (Rezero) הכותבים מסבירים למה זה בעייתי לשמר את התכונה הזאת ברשתות שמכילות ReLU או מנגנונים של Attention והם גם מסבירים למה לדעתם הרשת מקיימת את ההתפלגות הזאת. לפחות ב-Initialization זה טריויאלי שהרשת מקיימת את התכונה הזאת (כי ה-Residual יחד עם שכבה שמוציאה רק אפסים זה בעצם Identity) אבל לא ראיתי התייחסות במאמר למה בהכרח ה-Dynamical Isometry נשמר גם בזמן ריצה.
עזוב שנייה את המתמטיקה, זה נשמע כמו טריק פשוט, זה בכלל עובד?
אז כבר אמרתי, כשמראים שיטה כזאת רק על סוג אחד של בעיה זה בדרך כלל לא משכנע אותי ואני לא טורח לבדוק, אבל הפעם הכותבים של המאמר עשו כמה ניסויים שהופכים את המאמר למשכנע:
א. Resnet-18 שמסווגת Cifar-10 שהגיעה לאותו Accuracy ב-30 אחוז פחות זמן. דוגמא שמטרתה היא פשוט להראות שזה עובד.
ב. ניסו לאמן מגוון של Transformers עם כל שיטות הנורמליזציה (Post-Norm, Pre-Norm, ו-GPT-Norm יש פירוט של מה כל אלה אומרים בפוסט של ים) והראו שאותה הרשת, בתוספת הטריק שלהם מגיעה לאותם ביצועים בחצי מזמן האימון. בנוסף לזה, הם הצליחו גם לאמן Transformer עם 128 שכבות, בלי שום נורמליזציה ושום טריקים, שזה משהו שעד עכשיו היה מאוד מאתגר לעשות.
ג. סתם בשביל הכיף, הם אמנו גם את הרשת עם ה-10K שכבות ממקודם.
החשיבות של המאמר בסוך היא (לדעתי) שהטריק הזה גורם למודלים ממש גדולים ועמוקים, להיות יותר נגישים וקלים לאימון. הקונספט של Rezero קל מאוד למימוש ופשוט מאוד להוסיף אותו למודלים שאתם עובדים עליהם בחברות או בתחומי המחקר שלכם. במאמר הם גם מציינים שאפשר לשלב את השכבה של ה-Rezero עם שכבות נורמליזציה שונות כמו LayerNorm או BatchNorm כדי לשפר ביצועים ושהמנגנון שלהם לא מפריע בזה.
המאמר:
https://arxiv.org/pdf/2003.04887.pdf
כרגיל, כותב על מאמרים רק עם קוד פתוח (דווקא לא ברמה גבוהה מידי, אבל לפחות משהו):
https://github.com/majumderb/rezero
אני אשמח לשמוע את דעתכם\מחשבותכם\תגובתכם.