close menu

מדריך: איך להקטין את צריכת הזכרון באימון מודלים עמוקים?

כשזה מגיע למודלים עמוקים: גדול יותר == טוב יותר!
אבל לצערנו הרב במקרים מסוימים: גדול מידי לא נכנס.

איך לאמן מודלים תוך שימוש בפחות זכרון

המדריך כתוב בטורצ'. ההעדפה האישית שלי היא טנזורפלו אבל עובדים עם מה שיש.. אז הקוד כאן בטורצ'.

צבירת גרדיאנטים – Gradient Accumulation

אמ;לק: במקום לנסות להכניס באצ'ים גדולים יותר לזיכרון: נעדכן את המודל רק כל כמה צעדים.

ממש פשוט: אנחנו בדרך כלל מאמנים באצ' אחר באצ': דוחפים את הבאצ'ים למודל כשבכל באצ' מספר דוגמאות. זו הדרך שמאמנים מודלים עמוקים בעולם המודרני והיא הוכיחה עצמה שוב ושוב.
הבעיה היא ש:
  • "להכפיל את מספר הדגימות" = "להכפיל את צריכת הזיכרון".

וזה לא טוב.

אז במקום לעשות את זה: בואו פשוט נקטין את הבאצ' אבל נעדכן את המודל רק כל כמה צעדים וככה נסמלץ את "הבאצ' הגדול יותר"
חשוב: כדי לסלץ באצ'ים גדולים יותר עד הסוף אנחנו צריכים נצטרך גם לחלק את הלוס ביחס המתאים לבאץ' הגדול יותר

[כי אם הינו מאמנים בבאצ' גדול – תרומת כל דוגמה היתה קטנה יותר כי יש יותר דוגמאות].

קוד:
for step, batch in enumerate(loader, 1):
outputs = model(inputs)
loss = loss_fn(outputs, targets)
if gradient_accumulation_steps > 1:
loss = loss / gradient_accumulation_steps
loss.backward()
if step % gradient_accumulation_steps == 0 or step == steps:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
optimizer.step()
model.zero_grad()

הקפאת משקולות

אמ;לק: במקום לאמן כל הפרמטרים במודל. בואו פשוט לא נאמן אותם.

לפעמים, פשוט אין לנו מספיק זיכרון לאימון כל הפרמטרים שלנו.
במצב כזה, יכול להיות שנאלץ פשוט לאמן רק חלק מהם.
איך עושים את זה בפועל?
כדי לעשות את זה בפועל טורצ' מספקת ממשק פשוט: נגדיר לפרמטרים שנרצה להקפיא: require_grad = False ואז הם לא יושפעו מהאימון.
קוד:
def freeze(module):
for parameter in module.parameters():
parameter.requires_grad = False
def get_freezed_parameters(module):
freezed_parameters = []
for name, parameter in module.named_parameters():
if not parameter.requires_grad:
freezed_parameters.append(name)
return freezed_parameters
עכשיו נוכל פשוט להקפיא שכבות במודל המאומן שלנו למשל ככה:
freeze(model.embeddings)
או
freeze(model.encoder.layer[:2])

ולמי שלא מכיר: Automatic Mixed Precision

אמ;לק: Nvidia מספקת לנו "קסמים עם typeים" ואנחנו מקבלים בתמורה אימון מהיר יותר עם צריכת זכרון קטנה יותר.

אז Automatic Mixed Precision (AMP) [1] היא שיטה לאימון מהיר במיוחד של מודלים עמוקים שעל הדרך גם צורכת פחות זיכרון.
  • הרעיון הוא להשתמש בPrecision נמוך יותר לחישוב גרדיאנטים אבל לעדכן את המודל בPrecision המלא. כדי לעבור בניהם בלי לפגוע בביצועים המאמר המקורי מציע שיטת Scaling לערכים.
איך משתמשים בזה?
בטורצ' יש ספריה מובנית (torch.cuda.amp) לעבודה עם Mixed Precision והשימוש בה נראה כך:
from torch.cuda.amp import autocast, GradScaler
scaler = GradScaler()
for step, batch in enumerate(loader, 1):
# Lower precision calculation
with autocast(enabled=True):
outputs = model(inputs)
loss = loss_fn(outputs, targets)
# Scaling
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
scaler.step(optimizer)
scaler.update()

אופטימיזציה ב8-ביט

אמ;לק: על ידי שימוש ברעיון דומה לרעיון למעלה נוכל לחשב את כל האופטימייזר ב8-ביט ובכך נחסוך זכרון.

אם נשמור את מצב האופטימיזר בPrecision נמוך נוכל לחסוך הרבה זיכרון שכן האופטימייזר במהלך האימון יכול להגיע למצב שהוא מחזיק את כל המודל כמה פעמים בזכרון.
  • במאמר הראשון בנושא [2] פרסמו הכותבים (מטא) סדרת ניסוים גדולה בה הם בדקו את יציבות האימון ב8-ביט על גבי טווח היפרפרמטרים רחב והראו כי האימון מגיע לאותן התוצאות פחות או יותר ויציב באותה המידה.
  • כדי להשתמש באופטימייזר הזה במציאות הכותבים שחררו חבילה המכילה את כל האלגוריתמים "bitsandbytes". בחבילה זו ישנה גרסת 8 ביט של האופטימיזר הפופולרי: AdamW שפועלת ב-8-ביט: AdamW8bit.

יש גם הרבה אחרים..

החבילה כולה נמצאת כאן: https://github.com/facebookresearch/bitsandbytes
איך משתמשים בזה?
import bitsandbytes as bnb
bnb_optimizer = bnb.optim.AdamW8bit(params=model_parameters, lr=2e-5, weight_decay=0.0)

שמירת גרדיאנטים – Gradient Checkpoint

אמ;לק: במהלך הריצה קדימה ואחורה על גבי המודל לצורך חישוב הנגזרות אין לנו צורך להחזיק את כל הנגזרות של כל המודל כל הזמן, על ידי שחרור חכם של זכרון ושמירה יותר מתוחכמת לאורך הריצה נוכל לחסוך הרבה זכרון.

  • במאמר [3] מוצג טריק מתוחכם לחישוב גרדיאנטים המצריך מאיתנו להחזיק חלקים קטנים יותר מהמודל בזכרון בכל שלב בו אני מתקדמים לעומק המודל לצורך חישוב הנגזרות. כשהמודל גדול שיטה זו חוסכת דרמטית בצריכת הזכרון תוך כדי עלות חישובית נוספת יחסית קטנה.
איך משתמשים בזה?
from torch.utils.checkpoint import checkpoint
model.gradient_checkpointing_enable()
רפרנסים:
  1. [1] – הטריק של https://arxiv.org/abs/1710.03740 :Mixed Precision
  2. [2] – המאמר על אופטימיזציה ב8-ביט: https://arxiv.org/abs/2110.02861
  3. [3] – המאמר על Gradient Checkpointing כאן: https://arxiv.org/abs/1604.06174
עוד בנושא: