weak_to_strong/logger.py (67 lines of code) (raw):
import json
import os
from datetime import datetime
import wandb
def append_to_jsonl(path: str, data: dict):
with open(path, "a") as f:
f.write(json.dumps(data) + "\n")
class WandbLogger(object):
CURRENT = None
log_path = None
def __init__(
self,
**kwargs,
):
project = os.environ.get("WANDB_PROJECT")
self.use_wandb = project is not None
if self.use_wandb:
wandb.init(
config=kwargs,
project=project,
name=kwargs["name"].format(
**kwargs, datetime_now=datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
)
if "name" in kwargs
else None,
)
if "save_path" in kwargs:
self.log_path = os.path.join(kwargs["save_path"], "log.jsonl")
if not os.path.exists(kwargs["save_path"]):
os.makedirs(kwargs["save_path"])
self._log_dict = {}
def logkv(self, key, value):
self._log_dict[key] = value
def logkvs(self, d):
self._log_dict.update(d)
def dumpkvs(self):
if self.use_wandb:
wandb.log(self._log_dict)
if self.log_path is not None:
append_to_jsonl(self.log_path, self._log_dict)
self._log_dict = {}
def shutdown(self):
if self.use_wandb:
wandb.finish()
def is_configured():
return WandbLogger.CURRENT is not None
def get_current():
assert is_configured(), "WandbLogger is not configured"
return WandbLogger.CURRENT
def configure(**kwargs):
if is_configured():
WandbLogger.CURRENT.shutdown()
WandbLogger.CURRENT = WandbLogger(**kwargs)
return WandbLogger.CURRENT
def logkv(key, value):
assert is_configured(), "WandbLogger is not configured"
WandbLogger.CURRENT.logkv(key, value)
def logkvs(d):
assert is_configured(), "WandbLogger is not configured"
WandbLogger.CURRENT.logkvs(d)
def dumpkvs():
assert is_configured(), "WandbLogger is not configured"
WandbLogger.CURRENT.dumpkvs()
def shutdown():
assert is_configured(), "WandbLogger is not configured"
WandbLogger.CURRENT.shutdown()
WandbLogger.CURRENT = None