
import os

from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint, ModelSummary, RichProgressBar
from pytorch_lightning.loggers import TensorBoardLogger

from data_loading.data_module import DataModule
from generator.generator import UGen
from utils.args import get_main_args
from utils.utils import make_empty_dir, set_cuda_devices, set_granularity, verify_ckpt_path

if __name__ == "__main__":
    args = get_main_args()
    set_granularity()  # Increase maximum fetch granularity of L2 to 128 bytes
    set_cuda_devices(args)
    if args.seed is not None:
        seed_everything(args.seed)
    data_module = DataModule(args)
    if args.exec_mode == "predict":
        data_module.setup()  # call setup for pytorch_lightning compatibility
    ckpt_path = verify_ckpt_path(args)

    model = UGen(args)
    callbacks = [RichProgressBar(), ModelSummary(max_depth=2)]
    logger = False
    if args.exec_mode == "train":
        if args.tb_logs:
            logger = TensorBoardLogger(
                save_dir=f"{args.results}/tb_logs",
                name=f"task={args.task}_dim={args.dim}_fold={args.fold}_precision={16 if args.amp else 32}",
                default_hp_metric=False,
                version=0,
            )
        if args.save_ckpt:
            callbacks.append(
                ModelCheckpoint(
                    dirpath=f"{args.ckpt_store_dir}/checkpoints/fold{args.fold}",
                    filename="{epoch}-{ssim:.2f}",
                    monitor="ssim",
                    mode="max",
                    save_last=True,
                )
            )

    trainer = Trainer(
        logger=logger,  # logger for experiment tracking
        default_root_dir=args.results,  # default path for logs and weights when no logger or ckpt callback is passed
        benchmark=True,
        deterministic=False,   # sets whether PyTorch operations must use deterministic algorithms
        max_epochs=args.epochs,
        precision=16 if args.amp else 32,
        gradient_clip_val=args.gradient_clip_val,
        enable_checkpointing=args.save_ckpt,
        callbacks=callbacks,
        num_sanity_val_steps=0,    # sanity check runs 0 validation batches before starting the training routine
        accelerator="gpu",
        devices=args.gpus,
        num_nodes=args.nodes,
        strategy="ddp" if args.gpus > 1 else None,
    )

    if args.exec_mode == "train":
        trainer.fit(model, datamodule=data_module, ckpt_path=ckpt_path)
    elif args.exec_mode == "predict":
        if args.save_preds:
            # define prediction directory
            ckpt_name = "_".join(args.ckpt_path.split("/")[-1].split(".")[:-1])
            dir_name = f"predictions_{ckpt_name}"
            dir_name += f"_task={model.args.task}_fold={model.args.fold}"
            save_dir = os.path.join(args.results, dir_name)
            model.save_dir = save_dir
            make_empty_dir(save_dir)

        model.args = args
        trainer.test(model, dataloaders=data_module.test_dataloader(), ckpt_path=ckpt_path)
