סדרת ממבה: סקירות 4-6 (H3 ,S5 ,S4)
(S4)Efficiently Modeling Long Sequences with Structured State Spaces
לאט לאט הגענו למאמר הרביעי בסדרת סקירות בדרך לממבה. הפעם נסקור מאמר מ-2022 שיצא שנתיים אחרי 3 המאמרים הראשוניים שסקרנו בנושא המעניין הזה. כמובן במהלך תקופה זו יצאו כמה מאמרים מעניינים שפיתחו ארכיטקטורות מבוססות מערכות דינמיות לינאריות (ובשם כללי יותר Space-State Models- SSMs).
המאמר שנסקור לקח את הגישה הזו לגבהים חדשים והגיע לתוצאות די מרשימות עם דאטה בעל אורך הקשר ארוך (למשל עבור אות אודיו המכיל אלפי או אפילו עשרות אלפי דגימות בשנייה. אם יש לנו מטלה שדורשת התחשבות בכמה עשרות שניות של אודיו אז אנו צריכים אורך הקשר של מאות רבות של דגימות וזה די כבד עבור הטרנספורמר עם הסיבוכיות הריבועית שלו – במונחי אורך הקשר).
אוקיי, אז בואו ניזכר מהו היתרון הבולט של ארכיטקטורות מבוססות SSMs. מצד אחד בעת ההיסק (inference) של טוקן הם מונעים מאיתנו צורך להתחשב באופן מפורש בכל הדגימות הקודמות על ידי דחיסה של המידע בטוקנים הקודמים(=זיכרון) בווקטור זיכרון אחד, המתעדכן עם המערכת הדינמית הלינארית. מצד שני במהלך האימון (כשכל הטוקנים ידועים) הוא מאפשר חישוב בו זמני של כל הטוקנים הממוסכים.
דואליות עוצמתית זו התאפשרה על ידי ייצוגה של זיכרון בתור מערכת לינארית שניתן לבטא את הזיכרון המצטבר לכל טוקן כפעולה לינארית. כלומר ניתן לתאר את הפלט של עבור טוקן k על ידי הנוסחה באחת התמונות (הקטנה יותר). מטריצות בנוסחה הן הגרסאות המודסקרטות של המטריצות המופיעות בנוסחה של המערכת הדינמית המתארת את התקדמות הזכרון בזמן (טוקנים). ניתן לראות כי מה שיש לנו כאן זו רשת קונבולוציה (העלולה להיות מאוד ארוכה) שמאפשרת חישוב הייצוג של כל טוקן i.
קיבלנו את הארכיטקטורה הדואלית המתאימה גם לאימון וגם להיסק. אבל יש בעיה קטנה. עבור אורך הקשר גדול מספיק נדרשת כמות גדולה מאוד של זכרון. קודם כל אנו צריכים מטריצה A בגודל NxN (נגיד עבור N=64) עבור כל מימד של ייצוג הקלט (כי זה מה שהמערכת הדינמית שלנו ״צריכה לזכור״). אז חישוב קונבולוציה זו בצורה הישירה עבור מטריצה A כללית של המאמר של HiPPO (עבור מקרה של פולינומי Legendre שנקרא LegT תחת מכסה המנוע של המערכת הדינמית) הוא מאוד כבד ודורש הרבה זכרון.
אז מה ניתן לעשות? קודם כל אם מטריצה A היתה אלכסונית החישוב ודרישות הזכרון היו הרבה יותר צנועות. המחברים גם שמו לב כי conjugation של מטריצה A במערכת הדינמית (הכפלתה מימין ומשמאל במטריצה אוניטרית V) מוביל למערכת דינמית שקולה עם התוצאה Vx. הבעיה שמטריצה A מ-HiPPO לא ניתן לתאר בצורה V*LV כאשר L היא מטריצה אלכסונית, ו V היא מטריצה אוניטרית (נובע מכך ש A אינה קומוטטיבית עם A* כלומר לא נורמלית – זה השם אין מה לעשות).
אז הכל אבוד? מתברר שלא. מתברר ש A מ HiPPO ניתן לתאר בתור סכום של מטריצה נורמלית ומטריצה בעלת רנק נמוך (עבור LegT הרנק אפילו שווה ל-1 כלומר תוספת זו כי מכפלה חיצונית של שני וקטורים בעלי מימד Nx1). ואז המאמר מציע אלגוריתם די לא טריוויאלי עבוד חישוב של קרנל קונבולוציה ארוך המבוסס על 3 עקרונות מתמטיים:
- במקום לחשב A^l עבור כל l ניתן לחשב z-transform (מקוטע עד L) של A ואז לחשב בצורה די פשוטה את A^l על ידי הצבה של שורש שונים של 1 (המרוכבים) ב z-transform הזה.
- כאשר A הוא הפרש של מטריצה אלכסונית L ומטריצה בעלת רנק נמוך מאוד ניתן לחשב את z-transform הזה בצורה יעילה דרך זהות Woodbury שמסתכם בהיפוך של מטריצה אלכסונית.
- ניתן לבצע את כל החישובים העולים כאשר מפעילים בזהות Woodbury בצורה יעילה מאוד עם Cauchy Kernel; שזה בגדול מטריצה שנבנית בצורה מסוימת משני וקטורים
לבסוף, מבצעים את החישובים האלו עבור כל מימד של ייצוגי הטוקנים בנפרד ואז מערבבים עם שכבה לינארית (או כמה). מטריצות אלכסוניות L (למעשה וקטור), וקטורים B, C וגם P ו-Q שמכפלתם היא מטריצה בעלת נמוך מאומנות בנפרד עבור כל מימד של ייצוג הטוקנים.
זהו, יצא ארוך – הסקירה הסקירה תהיה קצרה יותר.
Simplified State Space Layers For Sequence Modeling(S5)
ממשיכים עם הסקירה החמישית כל הדרך לממבה. סקירה זו תהיה די קלילה כי היא בסך הכל מציעה שכלול לארכיטקטורת S4 שדיברנו עליה בהרחבה בסקירה הקודמת. למעשה S4 בנויה מ- H (מימד של ייצוג הטוקן) SSMs שכל אחד מהם מומש עם מערכת דינמית לינארית שדנו עליה בהרחבה בסקירות הקודמות. כל SSM מהווה בעצם זכרון עבור כל מימד של וקטור ייצוג הטוקן לאורך זמן. זמן כאן ציר הטוקנים שאותם אנחנו רוצים לזכור כדי לקבל החלטה מושכלת עבור הטוקן הנוכחי.
אם נביט בנוסחאות המתארות SSM ניתן לראות כי H מערכות SSM האלו אפשר לתאר כ-SSM אחד גדול המתואר על ידי מטריצה A בלוקית אלכסונית שבאלכסון שלה נמצאות מטריצות A_i, i=1,…H המתארות כל SSM. וקטורים B ו- C של ה- SSM הגדול הזה ניתן לבנות על ידי השרשור של וקטורי B_i ו- C_i של H המערכות SSM האלה.
כמובן שכל הסיפור הזה דורש לא מעט זכרון ולא מעט חישובים במיוחד כאשר H (מימד ייצוג הדאטה) הוא סדר גודל של כמה מאות או כמה אלפים. אז המאמר המסוקר מציע להשתמש באותה מטריצה A עבור המערכות הדינמיות המתארות זיכרון של כל מימד שך ייצוג הדאטה. גודל של מטריצה A נבחר הרבה יותר קטן מ- PH שזה גודל של מטריצה A עבור כל המימדים של ייצוג התוכן יחד (= גודל המטריצה הבלוקים האלכסונית). כמובן שבדרך זו נחסכים לנו גם הזיכרון וגם כמות החישובים הנדרשת גם בהיסק וגם באימון.
כמובן שהקטנה שכזו של מימד מטריצה A עלול לפגוע בביצועי המודל (כי אידיאלית זיכרון של מימדים שונים של ייצוג דאטה עשויים להכיל אופיינים שונים של זיכרון; נגיד, זיכרון ארוך וקצר טווח). המחברים בוחנים מספר דרכים לצמצום פגיעה זו על ידי עדכון חכם של A ועוד כמה טריקים נחמדים. המחברים למשל בוחרים אופצייה של מטריצה A בעלת מימד KP כאשר K הרבה יותר קטן מ-H.
בקיצור מאמר קליל וקל לקריאה…
Hungry Hungry Hippos: Towards Language Modeling with State Space Models(H3)
עד עכשיו ראינו מאמרים שמימשו את ארכיטקטורת SSM בתור רכיב הזכרון של המערכת. אף אחת מהמאמרים שסקרנו לא ניסה לשלב גישה זו(SSM) יחד עם מנגוננים אחרים שמוכרים לנו מעולם של עיבוד סדרות דאטה עם רשתות נוירונים. המאמר המסוקר משלב את גישת SSM, המיושמת באמצעות מערכות דינמיות לינאריות, עם מנגנון תשומת הלב הלינארי.
דיברנו על מנגנון attention הלינארי בסקירה השלישית של המאמר: Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention. המאמר הזה הציע להחליף את מנגנון תשומת הלב הרגיל עם softmax של הטרנספורמרים בחישוב לינארי: (f(k)*f(q כאשר * מסמן מכפלה פנימית ו- f היא פונקציה לא לינארית. המאמר מראה כי ניתן לתאר טרנספורמר עם מנגנון זה בתור RNN ולהימנע מסיבוכיות חישוב ריבועית הרגילה של הטרנספורמרים. כלומר אין צורך להתחשב בצורה מפורשת בכל פיסות הדאטה לפני טוקן i בשביל לחזות אותו אלא כל הזיכרון של הטוקנים הקודמים נדחס ושמור בשני וקטורים.
אוקיי, אבל למה צריך בעצם לשלב ארכיטקטורות מבוססת SSM עם מנגנונים אחרים? התשובה היא פשוטה – ארכיטקטורות אלה לא מספיק טובות לכמה משימות. למשל מחברי המאמר שמו לב כי במשימות כמו Induction Head שצריך לעקוב על טוקן שבא אחרי טוקן מסוים, ארכיטקטורה זו מפגינה ביצועים לא מרשימים במיוחד. כדי להתמודד עם סוגיה זו המחברים הציעו לשלב SSM עם מטריצות A מסוימות עם מנגנון תשומת הלב הלינארי.
אז איך כל הסיפור הזה עובד? בשלב הראשון מכפילים את ייצוגי הטוקנים במטריצות Q, K ו- V כמו בטרנספורמרים. בשלב השני מפעילים SSM על המפתח k (עבור כל הטוקנים) עם מטריצה A המדמה ״זיכרון של הטוקן הקודם״(בערך A_ij=1 כאשר i – j=1 ואפס אחרת). מבחינת מנגנון תשומת הלב הלינארי זה ״מקביל״ ל (f(k למרות ש f כאן ״די לינארית״.
בשלב השלישי לוקחים Q, V והתוצאה של השלב הקודם ל h חתיכות (= ״ראשים״ במנגנון ה-attention). לאחר מכן מכפילים כל חתיכה של V בחתיכה של התוצאה של השלב הקודם (עם k) ו״מעבירים״ את התוצאות דרך SSM עם מטריצה A אלכסונית. את התוצאה מכפילים ב-q, מאחדים את כל התוצאות ומכפילים במטריצה W_O כמו שמקובל בטרנספורמרים מרובי ראשים (multi-head transformers).
בנוסף המאמר מציע מנגנון הנקרא FlashConv לחישוב חיזוי הטוקנים באופן מקבילי במהלך האימון. כמו שאתם זוכרים הקרנל קונבולוציוני שם מאוד ארוך וחישובו יכול להיות יקר גם מבחינת הזיכרון וגם מבחינת הזמן אם נעשה בצורה נאיבית. המחברים משכללים את המנגנון כאשר העיקרון המוביל הוא ניצול מקסימלי של זיכרון SRAM המהיר שיש ב-GPUs תוך מזעור של הערבות דאטה לשם (זה איטי ובד״כ מהווה צוואר בקבוק) . הזיכרון הזה לא גדול ולא ניתן לדחוף שם יותר מדי אז נדרשות שיטות מתוחכמות המפרקות את חישוב הקונבולוציה לחלקים תוך ניצול תכונות של FFT ו- IFFT. נזכיר שהחישוב הקונבולוציה מבוצע בצורה: ((c(x) = iFFT(FFT(c)*FFT(x כאשר (c(x היא קונבולוציה על x עם קרנל c.