-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathlogger.py
33 lines (22 loc) · 944 Bytes
/
logger.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import logging
import os
import torch
import torch.distributed as dist
class DistributedAwareLogger(logging.Logger):
def __init__(self, name):
super().__init__(name)
self.log_on_all_ranks = os.environ.get("LOG_ALL_RANKS", "False") == "True"
def _log(self, level, msg, args, exc_info=None, extra=None, stack_info=False):
if dist.is_initialized():
if dist.get_rank() == 0 or self.log_on_all_ranks:
super()._log(level, msg, args, exc_info, extra, stack_info)
else:
super()._log(level, msg, args, exc_info, extra, stack_info)
logging.setLoggerClass(DistributedAwareLogger)
logging.basicConfig(level=logging.INFO)
def get_logger(name):
logger = logging.getLogger(name)
logger.setLevel(os.environ.get("LOG_LEVEL", "INFO").upper())
return logger
if os.environ.get("LOG_LEVEL", "INFO").upper() == "DEBUG":
torch.set_printoptions(profile="full")