0


用 PyTorch 实现 LLM-JEPA:不预测 token,预测嵌入

这篇文章从头实现 LLM-JEPA: Large Language Models Meet Joint Embedding Predictive Architectures。需要说明的是,这里写的是一个简洁的最小化训练脚本,目标是了解 JEPA 的本质:对同一文本创建两个视图,预测被遮蔽片段的嵌入,用表示对齐损失来训练。

本文的目标是让你真正理解这套方法。代码会逐行讲解,每个函数的用途都会解释清楚,并和论文的核心直觉对应起来。每个代码块都会详细说明,方便你根据自己的实验需求进行修改。

代码

整个 LLM-JEPA 训练脚本放在一个文件里:

它接收原始文本然后创建两个视图:context 视图把某些片段替换成 [MASK],target 视图保留原始文本但只在被遮蔽位置做监督。Context 编码器是可训练的,负责预测 target 编码器在遮蔽位置的表示。Target 编码器则是 context 编码器的 EMA 副本,不参与梯度计算。损失函数用的是预测嵌入和目标嵌入之间的余弦距离。

运行示例:

 # 小型冒烟测试(无需下载,随机初始化)
python llm_jepa_train.py --smoke_test

# 使用 HF 模型骨干训练
python llm_jepa_train.py --model_name distilbert-base-uncased --steps 200 --batch_size 8

# 在自己的文本文件上训练
 python llm_jepa_train.py --model_name distilbert-base-uncased --text_file data.txt --steps 2000

这是一个简洁的参考实现,不是完整的仓库代码。编码器用的是 Transformers 库。

 import argparse  
import math  
import os  
import random  
from dataclasses import dataclass  
from typing import List, Tuple, Optional  

import torch  
import torch.nn as nn  
import torch.nn.functional as F  
from torch.utils.data import Dataset, DataLoader  

try:  
    from transformers import AutoTokenizer, AutoModel, AutoConfig  
except Exception:  
    AutoTokenizer = None  
    AutoModel = None  
    AutoConfig = None

# -----------------------------  
# Utilities  
# -----------------------------  
def set_seed(seed: int):  
    random.seed(seed)  
    torch.manual_seed(seed)  
    torch.cuda.manual_seed_all(seed)

def pick_device(device_str: str) -> torch.device:  
    if device_str == "auto":  
        return torch.device("cuda" if torch.cuda.is_available() else "cpu")  
    return torch.device(device_str)

# -----------------------------  
# Span masking (simple + effective)  
# -----------------------------  
def sample_span_mask(  
    seq_len: int,  
    mask_ratio: float,  
    mean_span_len: int,  
    special_positions: Optional[set] = None,  
) -> torch.BoolTensor:  
    """  
    Returns a boolean mask of length seq_len indicating which positions are masked.  
    We mask contiguous spans until we reach approximately mask_ratio of tokens.  
    """  
    if special_positions is None:  
        special_positions = set()  

    mask = torch.zeros(seq_len, dtype=torch.bool)  
    if seq_len <= 0:  
        return mask  

    target_to_mask = max(1, int(round(seq_len * mask_ratio)))  
    masked = 0  

    attempts = 0  
    max_attempts = seq_len * 4  

    while masked < target_to_mask and attempts < max_attempts:  
        attempts += 1  

        span_len = max(1, int(random.expovariate(1.0 / max(1, mean_span_len))))  
        span_len = min(span_len, seq_len)  

        start = random.randint(0, seq_len - 1)  
        end = min(seq_len, start + span_len)  

        span_positions = [i for i in range(start, end) if i not in special_positions]  
        if not span_positions:  
            continue  

        newly = 0  
        for i in span_positions:  
            if not mask[i]:  
                mask[i] = True  
                newly += 1  

        masked += newly  

    return mask

def apply_mask_to_input_ids(  
    input_ids: torch.LongTensor,  
    attention_mask: torch.LongTensor,  
    tokenizer,  
    mask_ratio: float,  
    mean_span_len: int,  
) -> Tuple[torch.LongTensor, torch.BoolTensor]:  
    """  
    Masks spans inside non-special, non-padding tokens.  
    Returns:  
      masked_input_ids: input ids with masked tokens replaced by [MASK]  
      pred_mask: boolean mask over positions where we apply JEPA loss  
    """  
    assert input_ids.dim() == 1  
    seq_len = int(attention_mask.sum().item())  

    # Identify special token positions (CLS, SEP, etc.) in the visible region  
    special_positions = set()  
    for i in range(seq_len):  
        tid = int(input_ids[i].item())  
        if tid in {  
            tokenizer.cls_token_id,  
            tokenizer.sep_token_id,  
            tokenizer.pad_token_id,  
        }:  
            special_positions.add(i)  

    pred_mask = sample_span_mask(  
        seq_len=seq_len,  
        mask_ratio=mask_ratio,  
        mean_span_len=mean_span_len,  
        special_positions=special_positions,  
    )  

    masked_input_ids = input_ids.clone()  
    mask_token_id = tokenizer.mask_token_id  
    if mask_token_id is None:  
        raise ValueError("Tokenizer has no mask_token_id. Use a model with [MASK].")  

    # Replace masked positions with [MASK]  
    masked_input_ids[:seq_len][pred_mask] = mask_token_id  

    # pred_mask should be full length (includes pads as False)  
    full_mask = torch.zeros_like(attention_mask, dtype=torch.bool)  
    full_mask[:seq_len] = pred_mask  

    return masked_input_ids, full_mask

# -----------------------------  
# Dataset  
# -----------------------------  
class TextLinesDataset(Dataset):  
    def __init__(self, texts: List[str]):  
        self.texts = [t.strip() for t in texts if t.strip()]  

    def __len__(self) -> int:  
        return len(self.texts)  

    def __getitem__(self, idx: int) -> str:  
        return self.texts[idx]

def load_texts_from_file(path: str, max_lines: Optional[int] = None) -> List[str]:  
    texts = []  
    with open(path, "r", encoding="utf-8") as f:  
        for i, line in enumerate(f):  
            if max_lines is not None and i >= max_lines:  
                break  
            texts.append(line.rstrip("\n"))  
    return texts

def default_tiny_corpus() -> List[str]:  
    return [  
        "The cat sat on the mat and looked at the window.",  
        "A quick brown fox jumps over the lazy dog.",  
        "Deep learning models can learn useful representations from raw data.",  
        "Rocket Learning builds AI tools for education in India.",  
        "Transformers use attention to mix information across tokens.",  
        "Self-supervised learning can reduce the need for labels.",  
        "JEPA trains models to predict embeddings, not tokens.",  
        "Bengaluru is a major tech hub in India.",  
        "A good system design balances simplicity and scalability.",  
        "Reading code carefully helps you understand how an idea is implemented.",  
    ]

@dataclass  
class Batch:  
    input_ids: torch.LongTensor          # [B, L]  
    attention_mask: torch.LongTensor     # [B, L]  
    masked_input_ids: torch.LongTensor   # [B, L]  
    pred_mask: torch.BoolTensor          # [B, L]  positions to compute loss on

def collate_jepa(  
    batch_texts: List[str],  
    tokenizer,  
    max_length: int,  
    mask_ratio: float,  
    mean_span_len: int,  
) -> Batch:  
    toks = tokenizer(  
        batch_texts,  
        padding=True,  
        truncation=True,  
        max_length=max_length,  
        return_tensors="pt",  
    )  
    input_ids = toks["input_ids"]              # [B, L]  
    attention_mask = toks["attention_mask"]    # [B, L]  

    masked_input_ids_list = []  
    pred_mask_list = []  

    for b in range(input_ids.size(0)):  
        mi, pm = apply_mask_to_input_ids(  
            input_ids[b],  
            attention_mask[b],  
            tokenizer,  
            mask_ratio=mask_ratio,  
            mean_span_len=mean_span_len,  
        )  
        masked_input_ids_list.append(mi)  
        pred_mask_list.append(pm)  

    masked_input_ids = torch.stack(masked_input_ids_list, dim=0)  
    pred_mask = torch.stack(pred_mask_list, dim=0)  

    return Batch(  
        input_ids=input_ids,  
        attention_mask=attention_mask,  
        masked_input_ids=masked_input_ids,  
        pred_mask=pred_mask,  
    )

# -----------------------------  
# Model: Encoder + Predictor + EMA target encoder  
# -----------------------------  
class PredictorMLP(nn.Module):  
    def __init__(self, dim: int, hidden_mult: int = 4, dropout: float = 0.0):  
        super().__init__()  
        hidden = dim * hidden_mult  
        self.net = nn.Sequential(  
            nn.Linear(dim, hidden),  
            nn.GELU(),  
            nn.Dropout(dropout),  
            nn.Linear(hidden, dim),  
        )  

    def forward(self, x: torch.Tensor) -> torch.Tensor:  
        return self.net(x)

class LLMJEPA(nn.Module):  
    def __init__(self, encoder: nn.Module, dim: int, ema_m: float = 0.99, pred_hidden_mult: int = 4):  
        super().__init__()  
        self.context_encoder = encoder  
        self.target_encoder = self._copy_encoder(encoder)  
        self.predictor = PredictorMLP(dim=dim, hidden_mult=pred_hidden_mult, dropout=0.0)  
        self.ema_m = ema_m  

        for p in self.target_encoder.parameters():  
            p.requires_grad = False  

    @staticmethod  
    def _copy_encoder(enc: nn.Module) -> nn.Module:  
        import copy  
        return copy.deepcopy(enc)  

    @torch.no_grad()  
    def ema_update(self):  
        m = self.ema_m  
        for p_ctx, p_tgt in zip(self.context_encoder.parameters(), self.target_encoder.parameters()):  
            p_tgt.data.mul_(m).add_(p_ctx.data, alpha=(1.0 - m))  

    def forward(  
        self,  
        masked_input_ids: torch.LongTensor,  
        input_ids: torch.LongTensor,  
        attention_mask: torch.LongTensor,  
        pred_mask: torch.BoolTensor,  
    ) -> torch.Tensor:  
        """  
        Returns JEPA loss (scalar).  
        We compute:  
          z_ctx = context_encoder(masked_input)  
          z_tgt = target_encoder(full input)  
          pred = predictor(z_ctx)  
          loss over positions in pred_mask  
        """  
        out_ctx = self.context_encoder(input_ids=masked_input_ids, attention_mask=attention_mask)  
        z_ctx = out_ctx.last_hidden_state  # [B, L, D]  

        with torch.no_grad():  
            out_tgt = self.target_encoder(input_ids=input_ids, attention_mask=attention_mask)  
            z_tgt = out_tgt.last_hidden_state  # [B, L, D]  

        pred = self.predictor(z_ctx)  # [B, L, D]  

        # Select masked positions  
        # pred_mask: [B, L] bool  
        masked_pred = pred[pred_mask]  # [N, D]  
        masked_tgt = z_tgt[pred_mask]  # [N, D]  

        if masked_pred.numel() == 0:  
            # Safety: if a batch ends up with no masked tokens, return zero loss  
            return pred.sum() * 0.0  

        masked_pred = F.normalize(masked_pred, dim=-1)  
        masked_tgt = F.normalize(masked_tgt, dim=-1)  

        # Cosine distance  
        loss = 1.0 - (masked_pred * masked_tgt).sum(dim=-1)  
        return loss.mean()

# -----------------------------  
# Training  
# -----------------------------  
def build_hf_encoder(model_name: str):  
    if AutoModel is None:  
        raise RuntimeError("transformers is not installed. pip install transformers")  

    config = AutoConfig.from_pretrained(model_name)  
    encoder = AutoModel.from_pretrained(model_name, config=config)  
    dim = int(config.hidden_size)  
    return encoder, dim

def build_random_encoder(vocab_size: int = 30522, dim: int = 256, layers: int = 4, heads: int = 4):  
    """  
    For smoke tests only: small Transformer encoder (random init).  
    Requires a tokenizer with vocab mapping for ids.  
    """  
    encoder_layer = nn.TransformerEncoderLayer(d_model=dim, nhead=heads, batch_first=True)  
    transformer = nn.TransformerEncoder(encoder_layer, num_layers=layers)  

    class TinyEncoder(nn.Module):  
        def __init__(self):  
            super().__init__()  
            self.emb = nn.Embedding(vocab_size, dim)  
            self.pos = nn.Embedding(512, dim)  
            self.enc = transformer  

        def forward(self, input_ids, attention_mask):  
            B, L = input_ids.shape  
            pos_ids = torch.arange(L, device=input_ids.device).unsqueeze(0).expand(B, L)  
            x = self.emb(input_ids) + self.pos(pos_ids)  

            # attention_mask: 1 for keep, 0 for pad  
            # transformer expects src_key_padding_mask: True for pad  
            pad_mask = attention_mask == 0  
            h = self.enc(x, src_key_padding_mask=pad_mask)  
            return type("Out", (), {"last_hidden_state": h})  

    return TinyEncoder(), dim

def save_checkpoint(path: str, model: LLMJEPA, optimizer: torch.optim.Optimizer, step: int):  
    os.makedirs(os.path.dirname(path), exist_ok=True)  
    torch.save(  
        {  
            "step": step,  
            "context_encoder": model.context_encoder.state_dict(),  
            "target_encoder": model.target_encoder.state_dict(),  
            "predictor": model.predictor.state_dict(),  
            "optimizer": optimizer.state_dict(),  
        },  
        path,  
    )

def main():  
    parser = argparse.ArgumentParser()  
    parser.add_argument("--model_name", type=str, default="distilbert-base-uncased", help="HF encoder backbone")  
    parser.add_argument("--text_file", type=str, default="", help="Path to a newline-separated text file")  
    parser.add_argument("--max_lines", type=int, default=50000)  
    parser.add_argument("--max_length", type=int, default=128)  
    parser.add_argument("--mask_ratio", type=float, default=0.3)  
    parser.add_argument("--mean_span_len", type=int, default=5)  
    parser.add_argument("--ema_m", type=float, default=0.99)  
    parser.add_argument("--pred_hidden_mult", type=int, default=4)  

    parser.add_argument("--batch_size", type=int, default=8)  
    parser.add_argument("--lr", type=float, default=2e-5)  
    parser.add_argument("--weight_decay", type=float, default=0.01)  
    parser.add_argument("--steps", type=int, default=500)  
    parser.add_argument("--warmup_steps", type=int, default=50)  
    parser.add_argument("--log_every", type=int, default=25)  
    parser.add_argument("--save_every", type=int, default=200)  
    parser.add_argument("--save_path", type=str, default="checkpoints/llm_jepa.pt")  

    parser.add_argument("--device", type=str, default="auto")  
    parser.add_argument("--seed", type=int, default=42)  
    parser.add_argument("--smoke_test", action="store_true", help="No downloads, tiny random encoder, tiny corpus")  
    args = parser.parse_args()  

    set_seed(args.seed)  
    device = pick_device(args.device)  

    if args.smoke_test:  
        if AutoTokenizer is None:  
            raise RuntimeError("transformers is required even for smoke_test (for tokenizer).")  
        tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")  
        # Ensure mask token exists  
        if tokenizer.mask_token_id is None:  
            raise ValueError("Tokenizer must support [MASK]. Use a masked LM tokenizer.")  

        texts = default_tiny_corpus()  
        ds = TextLinesDataset(texts)  

        encoder, dim = build_random_encoder(vocab_size=int(tokenizer.vocab_size), dim=256, layers=4, heads=4)  
        model = LLMJEPA(encoder=encoder, dim=dim, ema_m=0.95, pred_hidden_mult=2).to(device)  

        lr = 1e-4  
    else:  
        if AutoTokenizer is None:  
            raise RuntimeError("transformers is not installed. pip install transformers")  
        tokenizer = AutoTokenizer.from_pretrained(args.model_name)  
        if tokenizer.mask_token_id is None:  
            raise ValueError(  
                "This tokenizer has no [MASK]. Pick a masked-encoder model (BERT/DeBERTa/DistilBERT)."  
            )  

        if args.text_file:  
            texts = load_texts_from_file(args.text_file, max_lines=args.max_lines)  
        else:  
            texts = default_tiny_corpus()  

        ds = TextLinesDataset(texts)  

        encoder, dim = build_hf_encoder(args.model_name)  
        model = LLMJEPA(encoder=encoder, dim=dim, ema_m=args.ema_m, pred_hidden_mult=args.pred_hidden_mult).to(device)  

        lr = args.lr  

    # DataLoader  
    def _collate(batch_texts):  
        return collate_jepa(  
            batch_texts=batch_texts,  
            tokenizer=tokenizer,  
            max_length=args.max_length,  
            mask_ratio=args.mask_ratio,  
            mean_span_len=args.mean_span_len,  
        )  

    dl = DataLoader(ds, batch_size=args.batch_size, shuffle=True, drop_last=True, collate_fn=_collate)  

    # Optimizer  
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=args.weight_decay)  

    # Simple warmup + cosine schedule  
    def lr_at(step: int) -> float:  
        if step < args.warmup_steps:  
            return float(step + 1) / float(max(1, args.warmup_steps))  
        progress = (step - args.warmup_steps) / float(max(1, args.steps - args.warmup_steps))  
        progress = min(max(progress, 0.0), 1.0)  
        return 0.5 * (1.0 + math.cos(math.pi * progress))  

    model.train()  
    running = 0.0  
    step = 0  
    data_iter = iter(dl)  

    while step < args.steps:  
        try:  
            batch = next(data_iter)  
        except StopIteration:  
            data_iter = iter(dl)  
            batch = next(data_iter)  

        # Move to device  
        input_ids = batch.input_ids.to(device)  
        attention_mask = batch.attention_mask.to(device)  
        masked_input_ids = batch.masked_input_ids.to(device)  
        pred_mask = batch.pred_mask.to(device)  

        # LR schedule  
        scale = lr_at(step)  
        for pg in optimizer.param_groups:  
            pg["lr"] = lr * scale  

        loss = model(  
            masked_input_ids=masked_input_ids,  
            input_ids=input_ids,  
            attention_mask=attention_mask,  
            pred_mask=pred_mask,  
        )  

        optimizer.zero_grad(set_to_none=True)  
        loss.backward()  
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)  
        optimizer.step()  

        # EMA update after optimizer step  
        model.ema_update()  

        running += float(loss.item())  
        step += 1  

        if step % args.log_every == 0:  
            avg = running / float(args.log_every)  
            running = 0.0  
            print(f"step {step:6d} | loss {avg:.4f} | lr {optimizer.param_groups[0]['lr']:.6g}")  

        if step % args.save_every == 0:  
            save_checkpoint(args.save_path, model, optimizer, step)  
            print(f"saved checkpoint to {args.save_path} at step {step}")  

    save_checkpoint(args.save_path, model, optimizer, step)  
    print(f"training done. final checkpoint: {args.save_path}")

if __name__ == "__main__":  
     main()

这个脚本在训练什么

这是一个面向文本的 JEPA 风格表示预测器。

输入普通文本行,对每个样本创建两个视图。遮蔽视图(context view)是同一个句子,但某些 span 被替换成 `[MASK];原始视图(target view)保持原样,没有遮蔽。

训练流程是这样的:遮蔽视图过一个可训练的 context 编码器,原始视图过一个不可训练的 target 编码器,然后训练一个预测器,让 context 编码器的表示能预测 target 编码器的表示——但只在被遮蔽的位置上计算损失。Target 编码器通过 EMA 更新来保持稳定。

这种设计鼓励模型学习"填补语义"的表示,而不是预测具体的 token。

set_seed 函数

 defset_seed(seed: int):  
     random.seed(seed)  
     torch.manual_seed(seed)  
     torch.cuda.manual_seed_all(seed)

这个函数确保运行可复现。

random.seed(seed)

固定 Python 的随机操作(span 遮蔽会用到),

torch.manual_seed(seed)

固定 PyTorch 在 CPU 上的随机性,

torch.cuda.manual_seed_all(seed)

固定 CUDA 内核的随机性。

span 遮蔽和模型初始化都是随机的,不设种子的话每次跑结果都不一样。

pick_device 函数

 def pick_device(device_str: str) -> torch.device:  
     if device_str == "auto":  
         return torch.device("cuda" if torch.cuda.is_available() else "cpu")  
     return torch.device(device_str)

返回 PyTorch 设备对象。如果传

--device auto

,有 GPU 就用 GPU,没有就用 CPU。也可以直接指定

--device cpu

--device cuda

张量和模型必须在同一设备上,这是基本要求。

sample_span_mask 函数

 def sample_span_mask(seq_len, mask_ratio, mean_span_len, special_positions=None)

整个脚本里最重要的函数之一。

目标是创建一个布尔掩码,标记序列中哪些位置该被遮蔽。参数包括:seq_len 是真实 token 数量(不含 padding),mask_ratio 是遮蔽比例(比如 0.3),mean_span_len 是连续遮蔽 span 的平均长度,special_positions 是永远不该遮蔽的位置(CLS、SEP、PAD)。

内部逻辑是先创建一个全 False 的掩码,然后计算需要遮蔽多少 token:

 target_to_mask=max(1, int(round(seq_len*mask_ratio)))

即使序列很短也至少遮蔽 1 个。

接下来循环采样 span 直到凑够数。Span 长度从指数分布采样:

 span_len=max(1, int(random.expovariate(1.0/max(1, mean_span_len))))

这会产出很多短 span 和少量长 span,比较符合自然分布。随机选一个起始位置,过滤掉特殊 token,把剩下的位置标记为 True。

遮蔽策略对表示学习质量影响很大。Span 遮蔽能迫使模型从周围上下文推断缺失的语义。

apply_mask_to_input_ids 函数

 defapply_mask_to_input_ids(input_ids, attention_mask, tokenizer, mask_ratio, mean_span_len)

拿到一个样本的 token ids,输出两个东西:masked_input_ids 是把遮蔽位置换成 [MASK] 后的 ids,pred_mask 是标记哪些位置要算损失的布尔掩码。

先算可见序列长度:

seq_len = int(attention_mask.sum().item())

。attention_mask 里真实 token 是 1,padding 是 0。

然后识别特殊 token 位置,CLS 和 SEP 不能遮蔽,否则模型容易出问题。调用 sample_span_mask 采样遮蔽位置,把这些位置替换成 mask_token_id:

 masked_input_ids[:seq_len][pred_mask] =mask_token_id

返回的 pred_mask 是完整长度的,padding 位置都是 False。只在遮蔽位置算 JEPA 损失,其他位置忽略。

TextLinesDataset 类

 classTextLinesDataset(Dataset):  
     def__init__(self, texts):  
         self.texts= [t.strip() fortintextsift.strip()]

极简的数据集实现,存文本行列表,去掉空行和首尾空白。

__len__

返回行数,

__getitem__

返回单条文本。

load_texts_from_file 逐行读文件,可限制最大行数,传

--text_file

时用。default_tiny_corpus 提供内置测试数据集。

Batch 数据类

 @dataclass  
 classBatch:  
     input_ids  
     attention_mask  
     masked_input_ids  
     pred_mask

用 dataclass 比返回元组清晰多了,代码可读性好。

collate_jepa 函数

DataLoader 创建批次时调用的函数。输入是原始文本列表,先用 tokenizer 做分词、padding、截断:

 toks=tokenizer(batch_texts, padding=True, truncation=True, max_length=max_length, return_tensors="pt")

产出 input_ids 和 attention_mask。然后对每个样本调 apply_mask_to_input_ids 生成遮蔽版本和 pred_mask,最后堆叠成 [B, L] 张量返回 Batch。

DataLoader 是逐样本读的,但训练需要批次。批处理和遮蔽都在这里发生。

PredictorMLP 类

预测器头,结构简单:

 nn.Linear(dim, hidden)  
 nn.GELU()  
 nn.Dropout()  
 nn.Linear(hidden, dim)

把 context 表示映射到 target 表示空间,相当于一个学习出来的适配器,帮助对齐两边的嵌入。

LLMJEPA 模型类

主模型包装器,包含四个核心部件:context_encoder 是可训练的 Transformer 编码器,target_encoder 是它的深拷贝但不可训练,predictor 是 MLP,ema_m 是 EMA 动量因子。

_copy_encoder 用

copy.deepcopy

确保 target 和 context 初始状态一致。

ema_update 缓慢更新 target 编码器权重:

 p_tgt=m*p_tgt+ (1-m) *p_ctx

m=0.99 时 target 变化非常慢,这能稳定训练、降低表示坍塌风险。

forward 的流程:把遮蔽视图过 context 编码器(可训练),原始视图过 target 编码器(无梯度),predictor 处理 context 输出,然后只取遮蔽位置的向量:

 masked_pred=pred[pred_mask]  # [N, D]  
 masked_tgt=z_tgt[pred_mask]  # [N, D]

从 [B, L, D] 变成 [N, D],N 是遮蔽 token 总数。归一化后算余弦距离:

 loss=1- (masked_pred*masked_tgt).sum(dim=-1)  
 returnloss.mean()

归一化是因为余弦相似度只看向量方向,不看大小。

build_hf_encoder 函数

加载 Hugging Face 编码器,返回模型和隐藏维度(从 config.hidden_size 读)。

build_random_encoder 函数

冒烟测试专用,从头建一个小 Transformer 编码器,包括嵌入层、位置嵌入、编码器堆栈。注意这不是掩码语言模型,只是个编码器架构。返回对象带

.last_hidden_state

属性是为了匹配 HF 输出格式。

总结

这个实现刻意追求清晰而非完整,所以没有自定义注意力掩码、多视图数据集或混合目标。但是把它当参考实现用是非常合适的。原始 LLM-JEPA 论文做得更深入,把 JEPA 和 token 预测结合起来,还利用了文本-代码这样的自然配对视图。那些设计对下游任务表现很重要,但也增加了复杂度,容易让人看不清核心机制。

论文:

https://arxiv.org/abs/2509.14252

作者:azhar

“用 PyTorch 实现 LLM-JEPA:不预测 token,预测嵌入”的评论:

还没有评论