"""
MyModel class and its child classes.

The MyModel class enables flattening of the model parameters for tracking.
MLP and LeNet are example models. Add your own PyTorch model by inheriting
from MyModel and organizing it into the pytorch lightning style.
"""
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
import torchmetrics
from torch import nn
from torch.optim import SGD, Adagrad, Adam, RMSprop


class MyModel(pl.LightningModule):
    """
    MyModel class that enables flattening of the model parameters.
    """
    def __init__(
        self, 
        optimizer, 
        learning_rate, 
        momentum=0.9,
        custom_optimizer=None, 
        gpus=0, 
        num_classes=10,
        model_type="mlp",
    ):
        """
        Initialize a new MyModel.

        Arguments:
            optimizer: optimizer to use, such as "adam", "sgd", "sgd+momentum".
            learning_rate: Learning rate to use.
            momentum: Momentum to use for SGD optimizer. Defaults to 0.9.
            custom_optimizer (optional): Custom optimizer to use. Defaults to None.
            gpus (optional): GPUs to use for training. Defaults to 0.
            num_classes: Number of classes
            model_type: Type of model to use. Defaults to "mlp".
        """
        super().__init__()
        self.optimizer = optimizer
        self.learning_rate = learning_rate
        self.momentum = momentum
        self.custom_optimizer = custom_optimizer
        self.gpus = gpus
        self.optim_path = []
        self.accuracy = torchmetrics.Accuracy(
            task="multiclass", num_classes=num_classes
        )
        self.training_step_outputs = []

    def configure_optimizers(self):
        # TODO: Add support for custom optimizers (e.g., SignSGD)
        if self.optimizer == "adam":
            return Adam(self.parameters(), self.learning_rate)
        elif self.optimizer == "sgd":
            return SGD(self.parameters(), self.learning_rate)
        elif self.optimizer == "sgd+momentum":
            return SGD(self.parameters(), self.learning_rate, momentum=self.momentum)
        else:
            raise Exception(
                f"custom_optimizer supplied is not supported: {self.custom_optimizer}"
            )

    def get_flattened_params(self):
        """
        Get flattened and concatenated parameters of the model.
        """
        params = self._get_params()
        # Create an empty tensor on the same device as the model parameters
        device = next(self.parameters()).device
        flattened_params = torch.empty(0, device=device)
        for _, param in params.items():
            flattened_params = torch.cat((flattened_params, torch.flatten(param)))
        return flattened_params

    def init_from_flattened_params(self, flattened_params):
        """Set all model parameters from the flattened form."""
        if not isinstance(flattened_params, torch.Tensor):
            raise AttributeError(
                "Argument to init_from_flattened_params() must be torch.Tensor"
            )
        shapes = self._get_param_shapes()
        state_dict = self._unflatten_to_state_dict(flattened_params, shapes)
        self.load_state_dict(state_dict, strict=True)

    def _get_param_shapes(self):
        shapes = []
        for name, param in self.named_parameters():
            shapes.append((name, param.shape, param.numel()))
        return shapes

    def _get_params(self):
        params = {}
        for name, param in self.named_parameters():
            params[name] = param.data
        return params

    def _unflatten_to_state_dict(self, flattened_params, shapes):
        state_dict = {}
        counter = 0
        for shape in shapes:
            name, tsize, tnum = shape
            param = flattened_params[counter : counter + tnum].reshape(tsize)
            state_dict[name] = torch.nn.Parameter(param)
            counter += tnum
        assert counter == len(flattened_params), "counter must reach the end of flattened parameters"
        return state_dict


class MLP(MyModel):
    """
    A Multilayer Perceptron (MLP) model.

    Default is 1 hidden layer with dimension 50.
    """
    def __init__(
        self,
        input_dim,
        num_classes,
        learning_rate,
        num_hidden_layers=1,
        hidden_dim=50,
        optimizer="adam",
        custom_optimizer=None,
        gpus=0,
    ):
        """
        Initialize a dense MLP model.

        Arguments:
            input_dim: Number of input dimensions.
            num_classes: Number of classes or output dimensions.
            learning_rate: The learning rate to use.
            num_hidden_layers (optional): Number of hidden layers. Defaults to 1.
            hidden_dim (optional): Number of dimensions in each hidden layer. Defaults to 50.
            optimizer (optional): The optimizer to use. Defaults to "adam".
            custom_optimizer (optional): The custom optimizer to use. Defaults to None.
            gpus (optional): GPUs to use if available. Defaults to 0.
        """
        super().__init__(
            optimizer=optimizer,
            learning_rate=learning_rate,
            custom_optimizer=custom_optimizer,
            gpus=gpus,
            num_classes=num_classes,
        )
        if num_hidden_layers == 0:
            self.layers = nn.Linear(input_dim, num_classes)
        else:
            self.layers = nn.Sequential(nn.Linear(input_dim, hidden_dim), nn.ReLU())
            n_layers = 2
            for _ in range(num_hidden_layers - 1):
                self.layers.add_module(
                    name=f"{n_layers}", module=nn.Linear(hidden_dim, hidden_dim)
                )
                self.layers.add_module(name=f"{n_layers+1}", module=nn.ReLU())
                n_layers += 2

            self.layers.add_module(
                name=f"{n_layers}", module=nn.Linear(hidden_dim, num_classes)
            )

    def forward(self, x, apply_softmax=False):
        pred = self.layers(x)
        if apply_softmax:
            pred = F.softmax(pred, dim=1)
        return pred

    def loss_fn(self, pred, y):
        return F.cross_entropy(pred, y)

    def training_step(self, batch, batch_idx):
        """
        Training step for a batch of data.

        The model computes the loss and save it along with the flattened model params.
        """
        X, y = batch

        # Early guard: inputs should be finite
        if not torch.isfinite(X).all():
            num_bad = (~torch.isfinite(X)).sum().item()
            raise RuntimeError(
                f"Non-finite inputs detected in MLP.training_step at batch_idx={batch_idx}: "
                f"{num_bad} entries are inf/NaN"
            )

        pred = self(X)

        # Early guard: model outputs should be finite
        if not torch.isfinite(pred).all():
            num_bad = (~torch.isfinite(pred)).sum().item()
            raise RuntimeError(
                f"Non-finite model outputs detected in MLP.training_step at batch_idx={batch_idx}: "
                f"{num_bad} entries are inf/NaN"
            )
        # Get model weights flattened here to append to optim_path later
        flattened_params = self.get_flattened_params()
        loss = self.loss_fn(pred, y)

        # Debug: flag non-finite loss or parameters early (for MLP)
        if not torch.isfinite(loss):
            raise RuntimeError(
                f"Non-finite loss detected in MLP.training_step at batch_idx={batch_idx}: {loss.item()}"
            )
        if not torch.isfinite(flattened_params).all():
            num_bad = (~torch.isfinite(flattened_params)).sum().item()
            raise RuntimeError(
                f"Non-finite parameters detected in MLP.training_step at batch_idx={batch_idx}: {num_bad} entries are inf/NaN"
            )

        preds = pred.max(dim=1)[1]  # class
        accuracy = self.accuracy(preds, y)

        self.log(
            "train_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True
        )

        self.log(
            "train_acc",
            accuracy,
            on_step=False,
            on_epoch=True,
            prog_bar=True,
            logger=True,
        )

        self.training_step_outputs.append(
            {"loss": loss, "accuracy": accuracy, "flattened_params": flattened_params}
        )

        return {"loss": loss, "accuracy": accuracy, "flattened_params": flattened_params}

    def on_train_epoch_end(self):
        """
        Only save the last step in each epoch.

        Arguments:
            training_step_outputs: all the steps in this epoch.
        """
        self.optim_path.append(self.training_step_outputs[-1])


class LeNet(MyModel):
    """
    The LeNet-5 model (LeCun et al., 1998).
    """
    def __init__(
        self,
        learning_rate,
        num_classes,
        optimizer="adam",
        custom_optimizer=None,
        gpus=0,
    ):
        """
        Initialize a LeNet model.

        Arguments:
            learning_rate: Learning rate to use.
            num_classes: Number of classes
            optimizer (optional): optimizer to use. Defaults to "adam".
            custom_optimizer (optional): custom optimizer to use. Defaults to None.
            gpus (optional): Number of GPUs for training if available. Defaults to 0.
        """
        super().__init__(
            optimizer=optimizer,
            learning_rate=learning_rate,
            custom_optimizer=custom_optimizer,
            gpus=gpus,
            num_classes=num_classes,
        )

        self.relu = nn.ReLU()
        self.pool = nn.AvgPool2d(kernel_size=(2, 2), stride=(2, 2))

        self.conv1 = nn.Conv2d(
            in_channels=1,
            out_channels=6,
            kernel_size=(5, 5),
            stride=(1, 1),
            padding=(0, 0),
        )

        self.conv2 = nn.Conv2d(
            in_channels=6,
            out_channels=16,
            kernel_size=(5, 5),
            stride=(1, 1),
            padding=(0, 0),
        )

        self.conv3 = nn.Conv2d(
            in_channels=16,
            out_channels=120,
            kernel_size=(5, 5),
            stride=(1, 1),
            padding=(0, 0),
        )

        self.fc1 = nn.Linear(120, 84)
        self.fc2 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.pool(x)
        x = self.relu(self.conv2(x))
        x = self.pool(x)
        x = self.relu(self.conv3(x))  # (n, 120, 1, 1) -> (n, 120)
        x = x.reshape(x.shape[0], -1)
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x

    def loss_fn(self, pred, y):
        return F.cross_entropy(pred, y)

    def training_step(self, batch, batch_idx):
        """
        Training step for a batch of data.

        The model computes the loss and save it along with the flattened model params.
        """
        X, y = batch

        # Early guard: inputs should be finite
        if not torch.isfinite(X).all():
            num_bad = (~torch.isfinite(X)).sum().item()
            raise RuntimeError(
                f"Non-finite inputs detected in LeNet.training_step at batch_idx={batch_idx}: "
                f"{num_bad} entries are inf/NaN"
            )

        pred = self(X)

        # Early guard: model outputs should be finite
        if not torch.isfinite(pred).all():
            num_bad = (~torch.isfinite(pred)).sum().item()
            raise RuntimeError(
                f"Non-finite model outputs detected in LeNet.training_step at batch_idx={batch_idx}: "
                f"{num_bad} entries are inf/NaN"
            )
        # Get model weights flattened here to append to optim_path later
        flattened_params = self.get_flattened_params()
        loss = self.loss_fn(pred, y)

        # Debug: flag non-finite loss or parameters early (for LeNet)
        if not torch.isfinite(loss):
            raise RuntimeError(
                f"Non-finite loss detected in LeNet.training_step at batch_idx={batch_idx}: {loss.item()}"
            )
        if not torch.isfinite(flattened_params).all():
            num_bad = (~torch.isfinite(flattened_params)).sum().item()
            raise RuntimeError(
                f"Non-finite parameters detected in LeNet.training_step at batch_idx={batch_idx}: {num_bad} entries are inf/NaN"
            )

        preds = pred.max(dim=1)[1]  # class
        accuracy = self.accuracy(preds, y)

        self.log(
            "train_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True
        )

        self.log(
            "train_acc",
            accuracy,
            on_step=False,
            on_epoch=True,
            prog_bar=True,
            logger=True,
        )

        self.training_step_outputs.append(
            {"loss": loss, "accuracy": accuracy, "flattened_params": flattened_params}
        )

        return {"loss": loss, "accuracy": accuracy, "flattened_params": flattened_params}

    def on_train_epoch_end(self):
        self.optim_path.append(self.training_step_outputs[-1])
