"""Basic Callback classes"""
import os
import pathlib
from typing import Dict
from torch.utils.tensorboard import SummaryWriter
from idrlnet.receivers import Receiver, Signal
from idrlnet.variable import Variables
__all__ = ["GradientReceiver", "SummaryReceiver", "HandleResultReceiver"]
[docs]class GradientReceiver(Receiver):
"""Register the receiver to monitor gradient norm on the Tensorboard."""
[docs] def receive_notify(self, solver: "Solver", message): # noqa
if not (Signal.TRAIN_PIPE_END in message):
return
for netnode in solver.netnodes:
if not netnode.require_no_grad:
model = netnode.net
total_norm = 0
for p in model.parameters():
param_norm = p.grad.data.norm(2)
total_norm += param_norm.item() ** 2
total_norm = total_norm ** (1.0 / 2)
assert isinstance(solver.receivers[0], SummaryWriter)
solver.summary_receiver.add_scalar(
"gradient/total_norm", total_norm, solver.global_step
)
[docs]class SummaryReceiver(SummaryWriter, Receiver):
"""The receiver will be automatically registered to control the Tensorboard."""
def __init__(self, *args, **kwargs):
SummaryWriter.__init__(self, *args, **kwargs)
[docs] def receive_notify(self, solver: "Solver", message: Dict): # noqa
if Signal.AFTER_COMPUTE_LOSS in message.keys():
loss_component = message[Signal.AFTER_COMPUTE_LOSS]
self.add_scalars("loss_overview", loss_component, solver.global_step)
for key, value in loss_component.items():
self.add_scalar(f"loss_component/{key}", value, solver.global_step)
if Signal.TRAIN_PIPE_END in message.keys():
for i, optimizer in enumerate(solver.optimizers):
self.add_scalar(
f"optimizer/lr_{i}",
optimizer.param_groups[0]["lr"],
solver.global_step,
)
[docs]class HandleResultReceiver(Receiver):
"""The receiver will be automatically registered to save results on training domains."""
def __init__(self, result_dir):
self.result_dir = result_dir
[docs] def receive_notify(self, solver: "Solver", message: Dict): # noqa
if Signal.SOLVE_END in message.keys():
samples = solver.sample_variables_from_domains()
in_var, _, lambda_out = solver.generate_in_out_dict(samples)
pred_out_sample = solver.forward_through_all_graph(
in_var, solver.outvar_dict_index
)
diff_out_sample = {key: Variables() for key in pred_out_sample}
results_path = pathlib.Path(self.result_dir)
results_path.mkdir(exist_ok=True, parents=True)
for key in samples:
for _key in samples[key]:
if _key not in pred_out_sample[key].keys():
pred_out_sample[key][_key] = samples[key][_key]
diff_out_sample[key][_key] = samples[key][_key]
else:
diff_out_sample[key][_key] = (
pred_out_sample[key][_key] - samples[key][_key]
)
samples[key].save(
os.path.join(results_path, f"{key}_true"), ["vtu", "np", "csv"]
)
pred_out_sample[key].save(
os.path.join(results_path, f"{key}_pred"), ["vtu", "np", "csv"]
)
diff_out_sample[key].save(
os.path.join(results_path, f"{key}_diff"), ["vtu", "np", "csv"]
)