
import torch

from torchmetrics import Metric
from monai.losses import SSIMLoss


# loss computing makes use of the MONAI toolkit available at:
# https://github.com/Project-MONAI/MONAI


class SSIM(Metric):
    def __init__(self, spatial_dims):
        """
        Initialize SSIM metric.
        :param spatial_dims: number of input data dimensions
        """
        super().__init__(dist_sync_on_step=False)
        self.spatial_dims = spatial_dims
        self.data_range = torch.ones(1).cuda() if torch.cuda.is_available() else torch.ones(1)
        self.add_state("loss", default=torch.zeros(1), dist_reduce_fx="sum")
        self.add_state("steps", default=torch.zeros(1), dist_reduce_fx="sum")
        self.add_state("ssim", default=torch.zeros(1), dist_reduce_fx="sum")

    def update(self, prediction, target, loss):
        """
        Update step for SSIM metric.
        :param prediction: prediction
        :param target: target
        :param loss: computed loss
        """
        self.steps += 1
        self.loss += loss
        # compute ssim on reshaped (N, C, W, H, D) to evaluate as 3D data
        self.ssim += 1 - SSIMLoss(spatial_dims=self.spatial_dims)(
            torch.transpose(torch.transpose(target, 0, 3), 0, 1).unsqueeze(0),
            torch.transpose(torch.transpose(prediction, 0, 3), 0, 1).unsqueeze(0),
            self.data_range)

    def compute(self):
        """
        Compute step for SSIM metric and loss.
        :return: average SSIM, loss
        """
        return self.ssim / self.steps, self.loss / self.steps
