import argparse
import random
from pathlib import Path
from typing import Mapping, Optional

import numpy as np
import torch
import torchvision.transforms as transforms
from mlsdk import (
    Context,
    MNCoreMomentumSGD,
    MNDevice,
    set_buffer_name_in_optimizer,
    set_tensor_name_in_module,
    storage,
)
from torchvision import datasets

torch.manual_seed(0)
random.seed(0)
np.random.seed(0)


def mnist_loaders(batch_size, eval_batch_size):
    # MLSDK requires input to be passed as a dictionary
    def list_to_dict(batch):
        batch = torch.utils.data.default_collate(batch)
        return {"x": batch[0], "t": batch[1]}

    transform = transforms.Compose(
        [
            transforms.Pad(2),
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,)),
        ]
    )
    train_dataset = datasets.MNIST(
        "/tmp", train=True, transform=transform, download=True
    )
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        drop_last=True,
        collate_fn=list_to_dict,
    )
    eval_dataset = datasets.MNIST(
        "/tmp",
        train=False,
        transform=transform,
        download=True,
    )
    # In evaluation, using ``drop_last`` breaks its validity, so we instead use batch
    # size that divides the number of validation images (10,000) to correctly evaluate
    # the model.
    assert len(eval_dataset) % eval_batch_size == 0
    eval_loader = torch.utils.data.DataLoader(
        eval_dataset,
        batch_size=eval_batch_size,
        shuffle=False,
        collate_fn=list_to_dict,
    )
    return train_loader, eval_loader


class MNCoreClassifier(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear1 = torch.nn.Linear(1024, 256)
        self.linear2 = torch.nn.Linear(256, 10)

    def forward(self, x, t, **args):
        x_reshaped = x.reshape(x.size(0), -1)
        x1 = self.linear1(x_reshaped)
        x2 = torch.nn.functional.relu(x1)
        x3 = self.linear2(x2)
        loss = torch.nn.functional.cross_entropy(x3, t)
        # MLSDK requires output to be returned as a dictionary
        if self.training:
            return {"loss": loss}
        else:
            return {"y": x3, "loss": loss}


def main(outdir: str, option_json_path: Optional[Path], device_str: str) -> None:
    batch_size = 64
    eval_batch_size = 125

    device = MNDevice(device_str)
    context = Context(device)
    Context.switch_context(context)

    train_loader, eval_loader = mnist_loaders(batch_size, eval_batch_size)

    model_with_loss_fn = MNCoreClassifier()
    model_with_loss_fn.train()
    set_tensor_name_in_module(model_with_loss_fn, "model_with_loss_fn")
    for p in model_with_loss_fn.parameters():
        context.register_param(p)

    optimizer = MNCoreMomentumSGD(model_with_loss_fn.parameters(), 0.1, 0, 0.9, 1.0)
    set_buffer_name_in_optimizer(optimizer, "optimizer")
    context.register_optimizer_buffers(optimizer)

    def train_step(inp: Mapping[str, torch.Tensor]) -> Mapping[str, torch.Tensor]:
        x = inp["x"]
        t = inp["t"]
        optimizer.zero_grad()
        output = model_with_loss_fn(x, t)
        loss = output["loss"]
        loss.backward()
        optimizer.step()
        return {"loss": loss}

    compile_options = {}
    if option_json_path is not None:
        compile_options["option_json"] = str(option_json_path)

    sample = next(iter(train_loader))
    compiled_train_step = context.compile(
        train_step,
        sample,
        storage.path(outdir) / "train_step",
        options=compile_options,
    )

    for epoch in range(10):
        loss = 0.0
        for i, sample in enumerate(train_loader):
            curr_loss = compiled_train_step(sample)["loss"].item()
            loss += (curr_loss - loss) / (i + 1)
            if i % 100 == 0:
                print(f"epoch {epoch}, iter {i:4}, loss {loss}")
        print(f"epoch {epoch}, loss {loss}")

    context.synchronize()

    model_with_loss_fn.eval()

    def eval_step(inp: Mapping[str, torch.Tensor]) -> Mapping[str, torch.Tensor]:
        x = inp["x"]
        t = inp["t"]
        output = model_with_loss_fn(x, t)
        y = output["y"]
        _, predicted = torch.max(y, 1)
        correct = (predicted == t).sum()
        return {"correct": correct}

    sample = next(iter(eval_loader))
    compiled_eval_step = context.compile(
        eval_step,
        sample,
        storage.path(outdir) / "eval_step",
        options=compile_options,
    )
    correct = 0
    for sample in eval_loader:
        correct += compiled_eval_step(sample)["correct"].item()
    print(
        f"Correct: {correct} / {len(eval_loader.dataset)}. "
        f"Accuracy: {correct / len(eval_loader.dataset)}"
    )
    assert 0.95 < correct / len(eval_loader.dataset)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--outdir", type=str, default="/tmp/mlsdk_mnist")
    parser.add_argument("--option_json", type=Path, default=None)
    parser.add_argument("--device", type=str, default="mncore2:0")
    args = parser.parse_args()
    main(args.outdir, args.option_json, args.device)
