סדרת ממבה: איך הכל התחיל? 3 הסקירות הראשונות

Legendre Memory Units: Continuous-Time Representation in Recurrent Neural Networks
המאמר הראשון בסדרה שלנו מנסה לטפל בעיה הראשונה של RNNs כלומר באי יכולת של רשתות אלו לדחוס את הזכרון (קלט בחלון ההקשר) בצורה מספיק טובה. המאמר מציע גישה מקורית ומעניינת שמקורה במערכות דינמיות (Dynamic Systems) לבניית ייצוג הזכרון. נניח שיש לנו פונקציית קלט רציפה u(t) ואנו רוצים לבנות מערכת ש״זוכרת את הפונקציה זו״ כלומר בונה ייצוגה כך שיהיה לי לשחזרה באופן מדיוק. תזכרו שכדי לתאר קלט דיסקרטי כמו טקסט אנו צריכים רק לעשות דיסקרטיזציה או לדגום את הפונקציה הזו).
המאמר בונה מערכת דינמית המתוארת על ידי משוואה דיפרנציאלית לינארית (מערכת דינמית, משוואה 1 במאמר) כאשר (m(t הוא וקטור הזכרון ו- (u(t כאמור הקלט (כרגע חד ממדי). מתברר שעבור בחירה מסוימת של מטריצת A במשוואה של המערכת הדינמית ניתן לתאר את הקלט (בפרק זמן מסוים) על ידי שילוב של פונקציית הזכרון (m(t ופונקציות מתמטיות הנקראות פולינומים של Legendre (משוואה 3 במאמר). כלומר ניתן לתאר את כל מה שקרה מבחינת הקלט עד זמן מסוים על ידי פונקציה (m(t – וזה בדיוק מה שרצינו, נכון?
אולם הדאטה שלנו דיסקרטי (טוקנים נגיד) אז צריך לעשות דיסקרטיזציה (דגימה) לגישה הזו. כלומר במקום פונקציות רציפות תהיה לנו סדרת הקלט u_t ווקטור הזכרון m_t. גם מטריצות במערכת הדינמית שלנו צריכות לעבור דיסקרטיזציה (השערוך הרגיל של הנגזרת/גרדיאנט) ואז נקבל נוסחה רקורסיבית עבור m_t כפונקציה של u_t ו- m_t-1. ניתן לתאר את את הדגימות עד t=T על ידי נוסחת נסיגה הזו.
זהו זה – יש לנו רשת בסגנון RNN כאשר הזכרון ממודל על ידי דיסקרטיזציה של מערכת דינמית, המחשבת מקמדי של פולינומי Legendre ובאופן זה עבד לא רע אי שם ב 2020.
HiPPO: Recurrent Memory with Optimal Polynomial Projections
הגענו למאמר השני בסדרה – המאמר הזה מאוד חשוב כי הוא פיתח בסיס מתמטי מוצק המשמש כל המודלים מבוססים על מערכות דינמיות לינאריות כולל כמובן ממבה. המאמר הזה קצת (די הרבה) כבד מתמטית אך אנסה לעשות כמיטב יכולתי כדי לנסות להעביר לכם את המסר העיקרי שהוא מביא איתו.
בסקירה הקודמת דיברנו על איך ניתן לבנות וקטור זכרון (m(t בעלת יכולת לשחזר פונקצית קלט (u(x ל-x מאינטרוול [0, t]; כאן t מסמן גודל חלון הקשר (כלומר אורך הזכרון). פונקצית (m(t ממודלת על ידי מערכת דינמית לינארית ושילובה עם פולינומי Legendre משחזר לנו את הקלט u. נעיר שאנו עובדים עם הגרסאות הדיסקרטיות של המודלים האלו שהן בעצם נוסחת נסיגה עבור סדרת וקטורי הזכרון m_t.
המאמר המסוקר מנסח מסגרת מתמטית כללית עבור בעיית ייצוג הזכרון של פונקצית קלט (u(x בתחום [0, t]. הנה מתחיל הסיבוך: קודם כל פולינומי Legendre הם מקרה פרטי של פונקציות אורתוגונליות במרחב הילברט (יותר נכון מרחב פונקציונלי L של לבג – המקרה הפרטי של הילברט) המצויד בנוסף בפונקציית מידה mu. אוקיי, מה הדבר הזה אומר בעצם? ממש בגדול זה מרחב של פונקציות שהמכפלה הפנימית ביניהן מוגדרת בתור אינטגרל של מכפלתן תחת מידה mu (במקרה הפשוט ביותר מידה mu שווה ל 1 זהותית ואנו מקבלים אינטגרל Riemann רגיל של המכפלה אבל עבור mu מורכבים יותר כמו Riemann-Stieltjes). פונקציות אורתוגונלית במרחב החמוד הזה מוגדרות בתור אלו שהמכפלה הפנימית שלהן שווה ל 0 (תחת מידה mu). פולינומי Legenge הן אורתוגונליים תחת מידה mu השווה ל- 1/t ב- [0, t] ואפס בכל מקום אחר.
אז נניח שיש לנו N פונקציות אורתוגונליות l_i(x), i=1,…, N במרחבנו החמוד. ועכשיו המטרה היא לתאר את הקלט (u(x ב-[0, t] על ידי l_i(x), i=1,…, N. כלומר אנו רוצים לבנות סכום ממושקל (u*(x של l_i(x) עם מקדמים מסוימים (שימו לב שעבור t-ים שונים מקבלים וקטורי מקדמים שונים וכך שיש לנו כאן פונקציה וקטורית של המקדמים התלויה ב-t).
(u*(x צריך לקרב בצורה טובה את הקלט (u(x (כלומר למזער שגיאה ביניהן ב-[0, t]). והדיוק מחושב בתור אינטגרל של ההפרש הריבועי בין (u*(x ו- (u(x תחת מידה mu (כאמור היא שווה ל- 1/t ב- [0, t] עבור כל x ואפס בכל מקום אחר עבור פולינומי Legendre אבל כמובן קיימות עוד אפשרויות). איך נחשב מקדמים הממזערים את ההפרש הזה? לא כזה מסובך: מקדם i שווה למכפלה פנימית (=אינטגרל) בין פונקצייה מספר i לפונקצית קלט u תחת אותה מידה mu.
עכשיו איך כל זה קשור למערכות דינמיות לינאריות החמודות שלנו? מתברר כי מערכת דינמית לינארית שתיארנו בסקירה הקודמת עבור וקטור (m(t מתארת את המקדמים של ייצוג הקלט באמצעות N פולינומי Legendre אורתוגונליים תחת מידה mu שהגדרנו לפני. ו-N זה המימד של וקטור הזכרון (m(t תחת מידה mu הדורשת קרבה אחידה (=זכרון אחיד) בין u* ו- u ב- [0, t].
אם נגדיר מידה mu להיות פונקציה (exp(x-t עבור t נתון, מערכת דינמית לינארית אחרת תתאר לנו מקדמים של פולינומי Laguerre (אורתוגונליים תחת mu הזו). שימו לב שמידה זו מגדירה זיכרון הדועך מעריכית כלומר ככל שעובר הזמן מזמן הנוכחי t, הזכרון הולך ונהיה מעומעם יותר.
בנוסף המאמר גם מדבר על שיטות דיסקרטיזציה של מערכת דינמית זו וגם דן בקשר בינה לבין RNNs.
אוקיי, עכשיו סיכום במשפט אחד של המאמר הדי כבד הזה. המחברים בנו מסגרת מתמטית למידול בעיית הזכרון של פונקציית קלט שישמש אותנו מאחורי הקלעים לבניית מודלי attention כל הדרך לממבה.
Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention
אחרי הסקירה הקודמת הכבדה מאוד מחכה לנו היום סקירה קלילה (הסקירה הבאה הולכת להיות די כבדה). כמו שכבר אמרנו אחד החסרונות הבולטים של הטרנספורמר היא הסיבוכיות הריבועית שלו במונחי אורך הקלט (= מספר איברים בסדרת הקלט). הסיבוכיות הזו בא על ידי ביטוי גם במהלך האימון וגם במהלך ההיסק (inference). סיבוכיות ריבועית זאת כואבת במיוחד בזמן ההיסק כאשר אין לנו יכולת לחזות מספר טוקנים בו זמנית כי לחיזוי טוקן n אנו צריכים לדעת את ה-(n-1) הטוקנים הראשונים. האם ניתן להפוך את הטרנספורמר לסוג של RNN במהלך ההיסק כאשר כל הזיכרון על הטוקנים הקודמים נדחס לכמה וקטורים בודדים (וקטור זכרון ווקטור של המצב)?
הטרנספורמר המקורי אינו מאפשר אופן חישוב כזה כי הוא מכיל פעולה לא לינארית (softmax) בתוך מנגנון תשומת הלב שלו. ניתן לראות די בקלות שלא ניתן לעקוף את מגבלת הסיבוכיות הריבועית שלו ללא שינוי של אופן חישוב של תשומת הלב. המאמר המסוקר מציע להחליף את חישוב הסופטמקס במנגנון זה בחישוב לינארי (מכפלת מטריצות) המחושבות על ידי הפעלת פונקציה לא לינארית phi על וקטורי השאילתות Q ושל וקטורי המפתחות K. מי שעוד זוכר מה זה KT)Kernel Trick) מבין מה שנעשה כאן הוא KT בכיוון ההפוך.
כמובן שאנו מאבדים כאן מהעוצמה של המנגנון תשומת הלב הרגיל אבל זה יעזור לנו לפתור את סוגיית הסיבוכיות הריבועית בזמן ההיסק. למעשה המחברים מוכיחים (ראו את התמונה למעלה) כי ניתן לממש את המנגנון הזה לסדרתי בעל סיבוכיות לינארית במונחי אורך הקלט. כמובן בזמן האימון ניתן לחשב חיזוי של כמה טוקנים בו זמנית (לפי היכולת החישובית שעומדת לרשותנו) וליהנות מהיתרון של מנגנון תשומת הלב הרגיל.
כלומר יש לנו טרנספורמר (מוחלש כמובן) באימון ו- RNN בהיסק. בהמשך נראה ניתן לשפר את הגישה הזו עם SSMs) state-space models).