import math, json, requests
from dataclasses import dataclass

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from datasets import load_dataset
from transformers import AutoTokenizer
from tqdm import tqdm



if torch.backends.mps.is_available():
    DEVICE = "mps"
elif torch.cuda.is_available():
    DEVICE = "cuda"
else:
    DEVICE = "cpu"

print(f"Using device: {DEVICE}")


## WEBHOOK UTILS START

#DISCORD_WEBHOOK_URL = "https://discord.com/api/webhooks/1429480946423300261/cLgVZMlaJ3cIW617j81-LypT5NaCKUzLAbhQlcR0dQUUg-Y8DFRQuOr2zyWcFc-yTMXU"
DISCORD_WEBHOOK_URL = "https://discord.com/api/webhooks/1429480946423300260/cLgVZMlaJ3cIW617j81-LypT5NaCKUzLAbhQlcR0dQUUg-Y8DFRQuOr2zyWcFc-yTMXU"

def send_discord_message(message: str):
    if DISCORD_WEBHOOK_URL is None: return
    try: requests.post(DISCORD_WEBHOOK_URL, json={"content": message}, timeout=5)
    except: pass

## WEBHOOK UTILS END


# model config (?)
@dataclass
class ModelConfig:
    vocab_size: int
    d_model: int = 576
    n_heads: int = 9
    n_layers: int = 12
    d_ff: int = 2304
    max_seq_len: int = 512
    dropout: float = 0.0

class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-6):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(dim))
        self.eps = eps
    def forward(self, x):
        rms = torch.sqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
        return self.weight * (x / rms)

## ROPE HELPERS START
def rotate_half(x):
    x1 = x[..., : x.size(-1) // 2]
    x2 = x[..., x.size(-1) // 2 :]
    return torch.cat([-x2, x1], dim=-1)

def apply_rope(q, k, cos, sin):
    cos = cos[None, None, :, :]
    sin = sin[None, None, :, :]
    return (
        (q * cos) + (rotate_half(q) * sin),
        (k * cos) + (rotate_half(k) * sin),
    )

def build_rope_cache(max_seq_len, head_dim, device):
    theta = 10000.0
    inv_freq = 1.0 / (theta ** (torch.arange(0, head_dim, 2, device=device).float() / head_dim))
    t = torch.arange(max_seq_len, device=device).float()
    freqs = torch.einsum("i,j->ij", t, inv_freq)
    cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1)
    sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1)
    return cos, sin
## ROPE HELPERS END

class MultiHeadAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.d_model = config.d_model
        self.n_heads = config.n_heads
        self.head_dim = self.d_model // self.n_heads


        self.qkv = nn.Linear(self.d_model, 3 * self.d_model, bias = False)
        self.o = nn.Linear(self.d_model, self.d_model, bias = False)
        self.drop = nn.Dropout(config.dropout)

    def forward(self, x, mask, cos, sin):
        b, t, _ = x.shape
        qkv = self.qkv(x).view(b, t, 3, self.n_heads, self.head_dim)
        q, k, v = qkv.permute(2, 0, 3, 1, 4)
        
        q, k = apply_rope(q, k, cos, sin)
        att = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
        att = att.masked_fill(~mask, float("-inf"))
        att = F.softmax(att, dim=-1)
        att = self.drop(att)

        out = torch.matmul(att, v)
        out = out.transpose(1, 2).reshape(b, t, self.d_model)
        out = self.drop(self.o(out))
        return out

class SwiGLU(nn.Module):
    def __init__(self, config):
        super().__init__()
        d = config.d_model
        ff = config.d_ff
        self.w1 = nn.Linear(d, ff, bias=False)
        self.w2 = nn.Linear(d, ff, bias=False)
        self.w3 = nn.Linear(ff, d, bias=False)
        self.drop = nn.Dropout(config.dropout)
    def forward(self, x):
        return self.drop(self.w3(F.silu(self.w1(x)) * self.w2(x)))

class TransformerBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.n1 = RMSNorm(config.d_model)
        self.n2 = RMSNorm(config.d_model)
        self.att = MultiHeadAttention(config)
        self.ff = SwiGLU(config)

    def forward(self, x, mask, cos, sin):
        x = x + self.att(self.n1(x), mask, cos, sin)
        x = x + self.ff(self.n2(x))
        return x


## FULL MODEL

class TransformerLM(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.embed = nn.Embedding(config.vocab_size, config.d_model)
        self.blocks = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layers)])
        self.norm = RMSNorm(config.d_model)
        self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias = False)
        self.lm_head.weight = self.embed.weight

        cos, sin = build_rope_cache(config.max_seq_len, config.d_model // config.n_heads, "cpu")
        self.register_buffer("cos", cos, persistent=False)
        self.register_buffer("sin", sin, persistent=False)

        mask = torch.tril(torch.ones(config.max_seq_len, config.max_seq_len, dtype=torch.bool))
        self.register_buffer("mask", mask.view(1, 1, config.max_seq_len, config.max_seq_len), persistent = False)

    def forward(self, ids, targets=None):
        b, t = ids.shape
        x = self.embed(ids)

        cos = self.cos[:t]
        sin = self.sin[:t]
        mask = self.mask[:, :, :t, :t]

        for blk in self.blocks:
            x = blk(x, mask, cos, sin)
        
        x = self.norm(x)
        logits = self.lm_head(x)

        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.reshape(-1))

        return logits, loss

## DATASET STUFF START

class SequenceDataset(Dataset):
    def __init__(self, token_ids, seq_len):
        total_len = (len(token_ids) // seq_len) * seq_len
        token_ids = token_ids[:total_len]
        self.data = token_ids.view(-1, seq_len)

    def __len__(self):
        return self.data.size(0)

    def __getitem__(self, idx):
        return self.data[idx]

def prepare_datasets(seq_len):
    print("preparing load datset")
    raw = load_dataset("tensonaut/EPSTEIN_FILES_20K")
    print('dataset done loading')
    
    # compute train / validation split
    print('compute splite')
    raw_text = raw["train"]["text"]
    train_split = int(len(raw_text) * 0.05)
    
    print('i think this is the bottleneck')
    train_text = "\n\n".join([i for i in raw_text[train_split:] if i is not None])
    val_text = "\n\n".join([i for i in raw_text[:train_split] if i is not None])
    print('🙏')

    tok = AutoTokenizer.from_pretrained("gpt2")
    if tok.pad_token is None:
        tok.pad_token = tok.eos_token
    print('bro')
    train_ids = torch.tensor(tok.encode(train_text, add_special_tokens=False))
    val_ids = torch.tensor(tok.encode(val_text, add_special_tokens=False))
    print('wa')
    train_ds = SequenceDataset(train_ids, seq_len)
    val_ds = SequenceDataset(val_ids, seq_len)

    print(f"Train seqs: {len(train_ds)}, Val seqs: {len(val_ds)}")

    return tok, train_ds, val_ds, len(train_ds), len(val_ds)

def train():
    BLOCK = 512
    SEQ = BLOCK + 1
    BATCH = 2
    EPOCHS = 1
    LR = 3e-4

    tok, train_ds, val_ds, len_train, len_val = prepare_datasets(SEQ)

    config = ModelConfig(
        vocab_size=tok.vocab_size,
        max_seq_len=BLOCK,
    )

    model = TransformerLM(config).to(DEVICE)
    print(f"Params: {sum(p.numel() for p in model.parameters())/1e6:.2f}M")

    opt = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=0.01)

    train_loader = DataLoader(train_ds, batch_size=BATCH, shuffle=True, drop_last=True)
    val_loader = DataLoader(val_ds, batch_size=BATCH, drop_last=True)

    for epoch in range(1, EPOCHS + 1):
        print(f"\n=== EPOCH {epoch} ===")
        model.train()
        total = 0
        steps = 0

        for i, batch in enumerate(tqdm(train_loader)):
            # print(f" working on batch {batch}")
            batch = batch.to(DEVICE)
            inp = batch[:, :-1]
            tgt = batch[:, 1:]

            logits, loss = model(inp, tgt)
            opt.zero_grad()
            # print("backprop start")
            loss.backward()
            # print("backprop end")
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            opt.step()

            total += loss.item()
            steps += 1

            send_discord_message(f"batch {i}/{len_train}")
            print(f"batch {i}/{len_train}")
        print(f"Train loss: {total/steps:.4f}")

    # validation
    model.eval()
    vtotal = 0
    vsteps = 0
    with torch.no_grad():
        for batch in val_loader:
            batch = batch.to(DEVICE)
            inp = batch[:, :-1]
            tgt = batch[:, 1:]
            _, vloss = model(inp, tgt)
            vtotal += vloss.item()
            vsteps += 1
        print(f"Val loss: {vtotal/vsteps:.4f}")

    # save the trained model to a file
    save_path = "transformer_lm.pt"
    torch.save(
        {
            "model_state_dict": model.state_dict(),
            "config": config.__dict__,
            "tokenizer_name": tok.name_or_path,
        },
        save_path,
    )
    print(f"Model saved to {save_path}")
    send_discord_message(f"Training complete. Model saved to {save_path}")

if __name__ == "__main__":
    train()
