最近代码展示

分享最新的技术实现与算法优化,每一行代码都追求极致的性能与优雅的实现。

RENT

RENT熵最小化

熵估计

完整的RENT算法实现,包含三种熵最小化方法(EM-FT/EM-RL/EM-INF) 和自适应温度缩放方法。

Python / PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from transformers import AutoTokenizer, AutoModelForCausalLM
import mathclass EntropyMinimization:

def __init__(self, model_name='gpt2'):
    self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    self.tokenizer = AutoTokenizer.from_pretrained(model_name)
    self.model = AutoModelForCausalLM.from_pretrained(model_name).to(self.device)

    if self.tokenizer.pad_token is None:
        self.tokenizer.pad_token = self.tokenizer.eos_token
    
def token_level_entropy(self, logits):
    """计算标记级熵"""
    probs = F.softmax(logits, dim=-1)
    log_probs = F.log_softmax(logits, dim=-1)
    entropy = -torch.sum(probs * log_probs, dim=-1)
    return entropy

def sequence_level_entropy(self, logits, attention_mask):
    """计算序列级熵"""
    probs = F.softmax(logits, dim=-1)
    log_probs = F.log_softmax(logits, dim=-1)
    
    # 计算每个位置的熵
    token_entropies = -torch.sum(probs * log_probs, dim=-1)

    # 对有效标记的熵求平均
    valid_tokens = attention_mask.sum(dim=1)
    sequence_entropies = (token_entropies * attention_mask).sum(dim=1) / valid_tokens

    return sequence_entropies

class EMFTTrainer(EntropyMinimization):
"""EM-FT: 基于标记级熵的微调"""
def __init__(self, model_name='gpt2',lr=1e-5):
    super().__init__(model_name)
    self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=lr)

def generate_training_data(self, prompts, num_samples=4):
    self.model.eval()
    training_data = {'input_ids': [], 'attention_mask': []}

    with torch.no_grad():
        for prompt in prompts:
            inputs = self.tokenizer(prompt, return_tensors='pt', padding=True, truncation=True)
            inputs = {k: v.to(self.device) for k, v in inputs.items()}

            for _ in range(num_samples):
                outputs = self.model.generate(
                    **inputs,
                    max_length=inputs['input_ids'].shape[1] + 100,
                    do_sample=True,
                    temperature=0.8,
                    pad_token_id=self.tokenizer.eos_token_id
                )

                training_data['input_ids'].append(outputs.cpu())
                training_data['attention_mask'].append(torch.ones_like(outputs).cpu())

    return training_data

def compute_entropy_loss(self, batch):
    input_ids = batch['input_ids'].to(self.device)
    attention_mask = batch['attention_mask'].to(self.device)

    outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
    logits = outputs.logits

    # 计算每个位置的熵
    token_entropies = self.token_level_entropy(logits)

    # 只计算非填充位置的熵
    valid_entropies = token_entropies[attention_mask.bool()]

    # 返回平均熵作为损失
    return valid_entropies.mean()

def train_step(self, batch):
    self.model.train()
    self.optimizer.zero_grad()

    loss = self.compute_entropy_loss(batch)
    loss.backward()
    self.optimizer.step()

    return loss.item()

def train(self, prompts, num_epochs=3, batch_size=2):
    print("生成训练数据...")
    training_data = self.generate_training_data(prompts)
    
    input_ids = torch.cat(training_data[input_ids], dim=0)
    attention_mask = torch.cat(training_data['attention_mask'], dim=0)

    num_batches = math.ceil(input_ids.shape[0] / batch_size)

    print(f"开始训练,总批次数: {num_batches}")

    for epoch in range(num_epochs):
        total_loss = 0
        
        for i in range(num_batches):
            start_idx = i * batch_size
            end_idx = min((i + 1) * batch_size, input_ids.shape[0])

            batch = {
                "input_ids": input_ids[start_idx:end_idx],
                "attention_mask": attention_mask[start_idx:end_idx]
            }

            loss = self.train_step(batch)
            total_loss += loss

            if i % 10 == 0:
                print(f"Epoch {epoch+1}, Batch {i+1}/{num_batches}, Loss: {loss:.4f}")
        
        avg_loss = total_loss / num_batches
        print(f"Epoch {epoch+1} 完成, 平均损失: {avg_loss:.4f}")

class EMRLLoss(EntropyMinimization):
"""EM-RL: 基于强化学习的熵最小化损失函数"""
def __init__(self, model_name='gpt2', beta=0.001):
    super().__init__(model_name)
    self.beta = beta

def compute_reward(self, logits, attention_mask, reward_type='token'):
    if reward_type == 'token':
        token_entropies = self.token_level_entropy(logits)
        valid_entropies = token_entropies[attention_mask.bool()]
        reward = -valid_entropies.mean()
    else:
        sequence_entropies = self.sequence_level_entropy(logits, attention_mask)
        reward = -sequence_entropies.mean()
    return reward

def compute_kl_penalty(self, logits, ref_logits, attention_mask):
    log_probs = F.log_softmax(logits, dim=-1)
    ref_probs = F.softmax(ref_logits, dim=-1)
    
    kl_div = F.kl_div(
        input=log_probs,
        target=ref_probs,
        reduction='none'
    ).sum(dim=-1)
    
    valid_kl = kl_div[attention_mask.bool()]
    return valid_kl.mean()

def compute_loss(self, logits, ref_logits, attention_mask, reward_type='token'):
    # 计算奖励(负熵)
    reward = self.compute_reward(logits, attention_mask, reward_type)

    # 计算KL惩罚
    kl_penalty = self.compute_kl_penalty(logits, ref_logits, attention_mask)

    # 总损失 = -奖励 + β * KL惩罚
    total_loss = -reward + self.beta * kl_penalty

    return total_loss, reward.item(), kl_penalty.item()

class EMINFInference(EntropyMinimization):
"""EM-INF: 推理时logits优化"""
def __init__(self, model_name='gpt2'):
    super().__init__(model_name)

def optimize_logits(self, logits, delta=0.3, num_steps=15, lr=0.1):
    """优化logits以减少熵"""
    # 将logits设为需要梯度
    optimized_logits = logits.clone().to(self.device).requires_grad_(True)
    optimizer = optim.Adam([optimized_logits], lr=lr)

    for step in range(num_steps):
        optimizer.zero_grad()

        # 计算当前熵
        current_entropy = self.token_level_entropy(optimized_logits.unsqueeze(0)).mean()

        # 损失函数: 超过阈值时最小化熵
        if current_entropy > delta:
            loss = current_entropy
            loss.backward()
            optimizer.step()
        else:
            break

    return optimized_logits.detach()

def adaptive_temperature_scaling(self, logits, alpha=0.5, delta=0.3):
    """自适应温度缩放方法(对比基线)"""
    original_entropy = self.token_level_entropy(logits.unsqueeze(0)).mean()
    target_entropy = max(alpha * original_entropy, delta)

    # 二分查找最佳温度
    low_temp, high_temp = 0.1, 10.0
    tolerance = 1e-4
    max_iters = 20

    for _ in range(max_iters):
        mid_temp = (low_temp + high_temp) / 2
        scaled_logits = logits / mid_temp
        current_entropy = self.token_level_entropy(scaled_logits.unsqueeze(0)).mean()
        
        if abs(current_entropy - target_entropy) < tolerance:
            break
        
        if current_entropy > target_entropy:
            low_temp = mid_temp
        else:
            high_temp = mid_temp
    
    return logits / mid_temp

def generate_with_em_inf(self, prompt, method='logit_optimization', max_length=100, **kwargs):
    """使用EM-INF生成文本"""
    self.model.eval()

    inputs = self.tokenizer(prompt, return_tensors="pt")
    input_ids = inputs["input_ids"].to(self.device)
    attention_mask = inputs["attention_mask"].to(self.device)
    
    generated_ids = input_ids.clone()

    with torch.no_grad():
        for _ in range(max_length):
            with torch.no_grad():
                outputs = self.model(input_ids=generated_ids, attention_mask=attention_mask)
                next_token_logits = outputs.logits[:, -1, :]
            
            if method == 'logit_optimization':
                # EM-INF: logit优化
                with torch.enable_grad():
                    optimized_logits = self.optimize_logits(
                        next_token_logits, 
                        delta=kwargs.get('delta', 0.3),
                        num_steps=kwargs.get('num_steps', 15)
                    )
            elif method == "adaptive_temp":
                # 自适应温度缩放(基线)
                optimized_logits = self.adaptive_temperature_scaling(
                    next_token_logits,
                    alpha=kwargs.get('alpha', 0.5),
                    delta=kwargs.get('delta', 0.3)
                )
            else:
                optimized_logits = next_token_logits
            
            probs = F.softmax(optimized_logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            
            generated_ids = torch.cat([generated_ids, next_token], dim=1)
            attention_mask = torch.cat([
                attention_mask, 
                torch.ones(1, 1, device=self.device)
            ], dim=1)
            
            if next_token.item() == self.tokenizer.eos_token_id:
                break
    
    return self.tokenizer.decode(generated_ids[0], skip_special_tokens=True)
                        
RENT 无监督强化学习 LLM PyTorch
更新于 2025年11月16日 查看完整代码 →
ICM

ICM无监督引出

优化评分函数

ICM通过优化一个评分函数U(D)来评估标签集D的质量,该函数由互预测性和逻辑一致性组成。 使用模拟退火启发式搜索和一致性修复改进样本标签。

Python / PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from transformers import AutoTokenizer, AutoModelForCausalLM
import numpy as np
import random
from dataclasses import dataclass, field
from abc import ABC, abstractmethod
import logging
import json
from tqdm import tqdm
import time
import argparse
from collections import defaultdict, deque
import hashlib
import re

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)


@dataclass
class ICMConfig:
    alpha: float = 50.0  # 互预测性权重
    initial_temp: float = 10.0  # 模拟退火初始温度
    final_temp: float = 0.01  # 最低温度
    cooling_rate: float = 0.99  # 温度衰减率
    init_samples: int = 8  # 初始随机样本数
    max_iterations: int = 1000
    batch_size: int = 4
    max_context_examples: int = 5  # 最大上下文示例数
    context_window: int = 2048  # 模型上下文窗口
    fix_steps: int = 10  # 一致性修复迭代次数
    device: str = 'cuda' if torch.cuda.is_available() else 'cpu'
    seed: int = 620
    cache_size: int = 1000  # 概率缓存大小
    priority_weight: float = 100.0  # 相关样本采样权重

    def __post_init__(self):
        if self.device == "cuda" and not torch.cuda.is_available():
            logger.warning('cuda不可用, 使用cpu')
            self.device = "cpu"

class ProbabilityCache:
    """缓存概率计算结果以提高效率"""
    def __init__(self, max_size=1000):
        self.max_size = max_size
        self.cache = {}
        self.access_order = deque()

    def _get_key(self, x, labels, context):
        context_str = '|'.join([f"{d['x']}:{d['y']}" for d in context])
        labels_str = ','.join(sorted(labels))
        content = F"{x}|{labels_str}|{context_str}"
        return hashlib.md5(content.encode()).hexdigest()
    
    def get(self, x, labels, context):
        key = self._get_key(x, labels, context)
        if key in self.cache:
            self.access_order.remove(key)
            self.access_order.appendleft(key)
            return self.cache[key]
        return None

    def set(self, x, labels, context, probs):
        key = self._get_key(x, labels, context)
        if len(self.cache) >= self.max_size:
            oldest_key = self.access_order.pop()
            del self.cache[oldest_key]
        self.cache[key] = probs
        self.access_order.appendleft(key)

class ConsistencyChecker(ABC):
    """逻辑一致性检查抽象基类"""
    @abstractmethod
    def check_pair(self, x_i, y_i, x_j, y_j):
        """检查两标签是否一致(True=一致, False=冲突)"""
        pass

    def get_inconsistent_pairs(self, data):
        inconsistent = []

        # 构建问题到索引的映射
        question_to_indices = defaultdict(list)
        for i, item in enumerate(data):
            q = self.extract_question(item['x'])
            question_to_indices[q].append(i)

        # 只检查同一问题的样本对
        for indices in question_to_indices.values():
            if len(indices) > 1:
                for i_idx, i in enumerate(indices):
                    for j in indices[i_idx+1:]:
                        if not self.check_pair(data[i]['x'], data[i]['y'], data[j]['x'], data[j]['y']):
                            inconsistent.append((i, j))
        
        return inconsistent
    
    @abstractmethod
    def get_consistent_labels(self, x_i, x_j):
        """返回所有逻辑可能的标签组合"""
        pass

    def extract_question(self, text):
        """提取问题部分"""
        if "Question:" in text:
            return text.split("Claim:")[0].strip()
        return text.split("\n")[0] if "\n" in text else text

class MathConsistencyChecker(ConsistencyChecker):
    """数学验证任务一致性检查"""
    def __init__(self):
        self.answer_patterns = ['The answer is', '答案是', 'Answer:']

    def extract_answer(self, text):
        for pattern in self.answer_patterns:
            if pattern in text:
                start = text.find(pattern) + len(pattern)
                answer_part = text[start:start+20]
                numbers = re.findall(r'-?\d+\.?\d*', answer_part)
                if numbers:
                    return numbers[0]
        return None
    
    def check_pair(self, x_i, y_i, x_j, y_j):
        q_i, q_j = self.extract_question(x_i), self.extract_question(x_j)
        if q_i != q_j:
            return True
        
        ans_i, ans_j = self.extract_answer(x_i), self.extract_answer(x_j)
        if ans_i and ans_j and ans_i != ans_j:
            if (y_i == 'True' and y_j == 'True') or (y_i == 'False' and y_j == 'False'):
                return False
        return True
    
    def get_consistent_labels(self, x_i, x_j):
        q_i, q_j = self.extract_question(x_i), self.extract_question(x_j)
        ans_i, ans_j = self.extract_answer(x_i), self.extract_answer(x_j)
        
        if q_i == q_j and ans_i and ans_j and ans_i != ans_j:
            # 同一问题不同答案:不能同时为True
            return [('True', 'False'), ('False', 'True'), ('False', 'False')]
        return [('True', 'True'), ('True', 'False'), ('False', 'True'), ('False', 'False')]

class AlpacaConsistencyChecker(ConsistencyChecker):
    def extract_responses(self, text):
        if 'Response A:' in text and 'Response B:' in text:
            parts = text.split('Response A:')[1].split('Response B:')
            if len(parts) == 2:
                return parts[0].strip(), parts[1].strip()
        return None, None
    
    def check_pair(self, x_i, y_i, x_j, y_j):
        resp_a_i, resp_b_i = self.extract_responses(x_i)
        resp_a_j, resp_b_j = self.extract_responses(x_j)

        if resp_a_i and resp_b_i and resp_a_j and resp_b_j:
            # 如果是相同的响应对但顺序相反
            if (resp_a_i == resp_b_j and resp_b_i == resp_a_j):
                if y_i == "True" and y_j == "True":
                    return False
        return True
    
    def get_consistent_labels(self, x_i, x_j):
        resp_a_i, resp_b_i = self.extract_responses(x_i)
        resp_a_j, resp_b_j = self.extract_responses(x_j)
        
        if (resp_a_i and resp_b_i and resp_a_j and resp_b_j and
            resp_a_i == resp_b_j and resp_b_i == resp_a_j):
            # 反对称情况
            return [('True', 'False'), ('False', 'True'), ('False', 'False')]
        return [('True', 'True'), ('True', 'False'), ('False', 'True'), ('False', 'False')]

class ICMCore:
    def __init__(self, model, tokenizer, config, consistency_checker):
        self.model = model
        self.tokenizer = tokenizer
        self.config = config
        self.checker = consistency_checker
        self.cache = ProbabilityCache(config.cache_size)

        for param in self.model.parameters():
            param.requires_grad = False
        self.model.to(self.config.device)
        self.model.eval()

        random.seed(self.config.seed)
        np.random.seed(self.config.seed)
        torch.manual_seed(self.config.seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(config.seed)
        
        logger.info(f"ICMCore initialized on {config.device}")

    def build_prompt(self, x, context):
        parts = []
        for ex in context[-self.config.max_context_examples:]:
            parts.append(f"Input: {ex['x']}\nLabel: {ex['y']}")
        parts.append(f"Input: {x}\nLabel:")
        return "\n\n".join(parts)
    
    @torch.no_grad()
    def get_label_prob_batch(self, x_list, labels, context_list):
        """批量计算条件概率"""
        results = []

        for i, (x, context) in enumerate(zip(x_list, context_list)):
            cached = self.cache.get(x, labels, context)
            if cached is not None:
                results.append(cached)
                continue

            prompt = self.build_prompt(x, context)
            prompt_ids = self.tokenizer.encode(prompt, return_tensors='pt')

            probs = {}
            for label in labels:
                full_text = prompt + ' ' + label
                inputs = self.tokenizer(full_text, return_tensors='pt', truncation=True, max_length=self.config.context_window, return_offsets_mapping=True).to(self.config.device)

                prompt_len = len(prompt_ids[0])
                label_start = None
                offset_mapping = inputs['offset_mapping'][0]

                for idx, (start, end) in enumerate(offset_mapping):
                    if start >= len(prompt) and label_start is None:
                        label_start = idx
                        break
                
                if label_start is None:
                    probs[label] = 1.0 / len(labels)  # 回退
                    continue

                outputs = self.model(**{k: v for k, v in inputs.items() if k != 'offset_mapping'})
                logits = outputs.logits

                label_ids = self.tokenizer.encode(label, add_special_tokens=False)
                log_prob = 0.0
                valid_tokens = 0

                for j, token_id in enumerate(label_ids):
                    pos = label_start + j
                    if pos < logits.shape[1]:
                        token_log_prob = torch.log_softmax(logits[0, pos], dim=-1)[token_id].item()
                        log_prob += token_log_prob
                        valid_tokens += 1

                if valid_tokens > 0:
                    probs[label] = np.exp(log_prob / valid_tokens)  # 平均token概率
                else:
                    probs[label] = 1.0 / len(labels)

            # 归一化
            total = sum(probs.values())
            if total > 0:
                probs = {k: v/total for k, v in probs.items()}
            else:
                probs = {label: 1.0/len(labels) for label in labels}
            
            self.cache.set(x, labels, context, probs)
            results.append(probs)
        
        return results
    
    @torch.no_grad()
    def get_label_prob(self, x, labels, context):
        """单个样本的条件概率计算"""
        return self.get_label_prob_batch([x], labels, [context])[0]

    def compute_mutual_predictability(self, data):
        """计算互预测性 P_θ(D)"""
        if len(data) <= 1:
            return 0.0
        
        total_log_prob = 0.0
        batch_size = min(self.config.batch_size, len(data))

        for start_idx in tqdm(range(0, len(data), batch_size), desc='Mutual Pred', leave=False):
            end_idx = min(start_idx + batch_size, len(data))
            batch_data = data[start_idx: end_idx]

            x_list = [item['x'] for item in batch_data]
            context_list = []

            for i, item in enumerate(batch_data):
                context = [d for j, d in enumerate(data) if j != start_idx + i]
                context_list.append(context)

            all_labels = list(set(item['y'] for item in data))
            all_probs = self.get_label_prob_batch(x_list, all_labels, context_list)
            
            for i, probs in enumerate(all_probs):
                true_labels = batch_data[i]['y']
                prob = probs.get(true_labels, 1e-10)
                total_log_prob += np.log(prob)

        return total_log_prob / len(data)
    
    def compute_consistency_penalty(self, data):
        """计算一致性惩罚 I(D)"""
        return len(self.checker.get_inconsistent_pairs(data))

    def compute_score(self, data):
        """计算U(D) = α·P(D) - I(D)"""
        if not data:
            return -np.inf
            
        mutual = self.compute_mutual_predictability(data)
        penalty = self.compute_consistency_penalty(data)
        score = self.config.alpha * mutual - penalty

        return score

    def fix_inconsistencies(self, data):
        if len(data) <= 1:
            return data
        
        data = [item.copy() for item in data]
        improved = True
        steps = 0

        while improved and steps < self.config.fix_steps:
            improved = False
            pairs = self.checker.get_inconsistent_pairs(data)
            
            if not pairs:
                break

            # 随机采样一些不一致对进行修复
            sample_pairs = random.sample(pairs, min(10, len(pairs)))

            for i, j in sample_pairs:
                x_i, x_j = data[i]['x'], data[j]['x']
                label_options = self.checker.get_consistent_labels(x_i, x_j)

                best_score = -np.inf
                best_labels = None
                current_score = self.compute_score(data)

                orig_y_i, orig_y_j = data[i]['y'], data[j]['y']

                for y_i_new, y_j_new in label_options:
                    data[i]['y'], data[j]['y'] = y_i_new, y_j_new
                    score = self.compute_score(data)

                    if score > best_score:
                        best_score = score
                        best_labels = (y_i_new, y_j_new)
                
                if best_labels and best_score > current_score:
                    data[i]['y'], data[j]['y'] = best_labels
                    improved = True
                else:
                    data[i]['y'], data[j]['y'] = orig_y_i, orig_y_j
            
            steps += 1

        return data
    
    def select_example(self, unlabeled_indices, labeled, unlabeled_data):
        if not labeled or random.random() < 0.3:
            idx = random.choice(list(unlabeled_indices))
            return unlabeled_data[idx], idx
        
        # 基于问题相似度采样
        ref_item = random.choice(labeled)
        ref_question = self.checker.extract_question(ref_item['x'])
        
        scores = []
        candidates = []

        for idx in unlabeled_indices:
            question = self.checker.extract_question(unlabeled_data[idx])
            # 简单相似度:问题前缀匹配
            similarity = 1.0 if question == ref_question else 0.1
            scores.append(similarity)
            candidates.append(idx)

        # 加权采样
        if sum(scores) > 0:
            idx = random.choices(candidates, weights=scores, k=1)[0]
        else:
            idx = random.choice(candidates)
            
        return unlabeled_data[idx], idx
    
    def run(self, unlabeled, labels, max_iter=None):
        """通过模拟退火搜索优化标注, 最大化U(D)评分"""
        max_iter = max_iter or (min(len(unlabeled), self.config.max_iterations))

        unlabeled_indices = set(range(len(unlabeled)))
        labeled = []

        init_indices = random.sample(list(unlabeled_indices), min(self.config.init_samples, len(unlabeled)))

        for idx in init_indices:
            labeled.append({
                'x': unlabeled[idx], 
                'y': random.choice(labels), 
                'idx': idx
            })
            unlabeled_indices.remove(idx)

        # 初始一致性修复
        labeled = self.fix_inconsistencies(labeled)
        current_score = self.compute_score(labeled)
        
        logger.info(f"初始化完成: {len(labeled)} 个样本, 评分={current_score:.4f}")
        
        start_time = time.time()
        progress_bar = tqdm(range(max_iter), desc="ICM优化进度")
        
        for iteration in progress_bar:
            if not unlabeled_indices:
                break

            temp = max(self.config.final_temp, self.config.initial_temp * (self.config.cooling_rate ** iteration))

            # 选择新样本进行标注
            x_new, idx_new = self.select_example(unlabeled_indices, labeled, unlabeled)
            
            # 预测最可能的标签
            probs = self.get_label_prob(x_new, labels, labeled)
            y_new = max(probs, key=probs.get)
            
            # 创建包含新标注的临时数据集
            new_item = {'x': x_new, 'y': y_new, 'idx': idx_new}
            temp_labeled = labeled + [new_item]
            temp_labeled = self.fix_inconsistencies(temp_labeled)  # 修复可能的不一致
            new_score = self.compute_score(temp_labeled)
            
            delta = new_score - current_score  # 评分变化量

            # 模拟退火接受决策
            accept = False
            if delta > 0:
                accept = True
            else:
                accept_prob = np.exp(delta / temp) if temp > 0 else 0
                accept = random.random() < accept_prob
            
            if accept:
                # 接受新标注
                labeled = temp_labeled
                current_score = new_score
                unlabeled_indices.remove(idx_new)
                
                if iteration % 50 == 0:
                    logger.info(f"迭代 {iteration}: 评分={current_score:.4f}, "
                                f"变化量={delta:.4f}, 温度={temp:.4f}")
            
            progress_bar.set_postfix({
                'score': f'{current_score:.3f}',
                'labeled': len(labeled),
                'temp': f'{temp:.3f}'
            })
        
        elapsed = time.time() - start_time
        logger.info(f"优化完成: {len(labeled)} 个标注样本, 最终评分={current_score:.4f}, "
                    f"总耗时={elapsed:.2f}秒")
        
        return labeled
ICM 无监督强化学习 LLM PyTorch
更新于 2025年11月17日 查看完整代码 →
IN

INTUITOR

奖励塑造

INTUITOR算法的PyTorch实现,使用GRPO框架,用自我确定性替换外部奖励,可在无标签数据上强化学习智能体的探索过程。

Python / PyTorch
import torch
import torch.nn.functional as F
import torch.optim as optim
from transformers import AutoTokenizer, AutoModelForCausalLM
import numpy as np
from dataclasses import dataclass
from torch.utils.data import Dataset

@dataclass
class IntuitorConfig:
    model_name: str = 'gpt2'
    batch_size: int = 128
    num_candidates: int = 7
    kl_penalty: int = 0.005
    lr: int = 3e-5
    max_length: int = 1024
    num_epochs: int = 3
    clip_epsilon: float = 0.2
    advantage_scale: float = 1.0
    use_online_self_certainty: bool = True

class SelfCertaintyCalculator:
    def __init__(self, vocab_size):
        self.vocab_size = vocab_size
        self.uniform_dist = torch.ones(vocab_size) / vocab_size

    def compute_self_certainty(self, logits, attention_mask):
        """计算自我确定性分数"""
        # 获取概率分布
        probs = F.softmax(logits, dim=-1)  # [batch_size, seq_len, vocab_size]

        # 将均匀分布移到设备上
        uniform = self.uniform_dist.to(logits.device).unsqueeze(0).unsqueeze(0)  # [1, 1, vocab_size]

        # 计算每个位置的KL散度:KL(uniform || p)
        kl_divergence = F.kl_div(
            input=torch.log(probs + 1e-8),
            target=uniform.expand_as(probs),
            reduction='none',
            log_target=False
        ).sum(dim=-1)  # [batch_size, seq_len]

        kl_divergence = kl_divergence * attention_mask

        # 计算序列平均自我确定性
        seq_lengths = attention_mask.sum(dim=1)
        self_certainty = kl_divergence.sum(dim=1) / (seq_lengths + 1e-8)

        return self_certainty 

    def compute_token_level_self_certainty(self, logits):
        """计算token级别的自我确定性"""
        probs = F.softmax(logits, dim=-1)
        uniform = self.uniform_dist.to(logits.device).unsqueeze(0).unsqueeze(0)

        token_self_certainty = F.kl_div(
            input=torch.log(probs + 1e-8),
            target=uniform.expand_as(probs),
            reduction='none',
            log_target=False
        ).sum(dim=-1)

        return token_self_certainty
    
class GRPOTrainer:
    def __init__(self, model, tokenizer, config):
        self.model = model
        self.tokenizer = tokenizer
        self.config = config
        self.self_certainty_calculator = SelfCertaintyCalculator(len(tokenizer))

        # 参考模型(用于KL散度计算)
        device = next(model.parameters()).device
        self.reference_model = AutoModelForCausalLM.from_pretrained(self.config.model_name).to(device)
        self.reference_model.eval()

        for param in self.reference_model.parameters():
            param.requires_grad = False
        
        self.optimizer = optim.AdamW(self.model.parameters(), lr=self.config.lr)

    def compute_advantages(self, self_certainty_scores):
        advantages = (self_certainty_scores - self_certainty_scores.mean()) / (self_certainty_scores.std() + 1e-8)
        return advantages * self.config.advantage_scale
    
    def compute_kl_penalty(self, logits, ref_logits, attention_mask):
        probs = F.softmax(logits, dim=-1)
        ref_probs = F.softmax(ref_logits, dim=-1)

        kl_divergence = F.kl_div(
            input=torch.log(probs + 1e-8),
            target=ref_probs,
            reduction='none',
            log_target=False
        ).sum(dim=-1)

        kl_divergence = kl_divergence * attention_mask

        seq_lengths = attention_mask.sum(dim=1)
        avg_kl = kl_divergence.sum(dim=1) / (seq_lengths + 1e-8)

        return avg_kl
    
    def generate_candidates(self, input_ids, num_candidates):
        self.model.eval()

        device = next(self.model.parameters()).device
        input_ids = input_ids.to(device)

        batch_size = input_ids.shape[0]
        all_candidates_ids = []
        all_attention_mask = []

        with torch.no_grad():
            for i in range(batch_size):
                expanded_input = input_ids[i:i+1].repeat(num_candidates, 1)

                outputs = self.model.generate(
                    expanded_input,
                    max_new_tokens=self.config.max_length,
                    num_return_sequences=1,
                    do_sample=True,
                    temperature=0.7,
                    pad_token_id=self.tokenizer.eos_token_id,
                    eos_token_id=self.tokenizer.eos_token_id,
                    attention_mask=torch.ones_like(expanded_input)
                )

                input_len = expanded_input.shape[1]
                generated_ids = outputs[:, input_len:]

                full_sequence = torch.cat([expanded_input, generated_ids], dim=1)
                attention_mask = torch.ones_like(full_sequence)

                all_candidates_ids.append(full_sequence)
                all_attention_mask.append(attention_mask)

        candidates_ids = torch.cat(all_candidates_ids, dim=0)
        candidates_mask = torch.cat(all_attention_mask, dim=0)

        return candidates_ids, candidates_mask
    
    def compute_gradients(self, input_ids, attention_mask, candidate_ids, candidate_masks):
        self.model.train()
        batch_size = input_ids.shape[0]
        num_candidates = self.config.num_candidates

        # 前向传播计算候选的logits
        candidates_outputs = self.model(
            candidate_ids,
            attention_mask=candidate_masks,
            output_hidden_states=False,
            output_attentions=False,
            return_dict=True
        )
        candidates_logits = candidates_outputs.logits

        # 计算自我确定性奖励
        self_certainty_scores = self.self_certainty_calculator.compute_self_certainty(candidates_logits, candidate_masks)

        # 计算优势函数
        advantages = []
        for i in range(batch_size):
            batch_scores = self_certainty_scores[i*num_candidates:(i+1)*num_candidates]
            batch_advantages = self.compute_advantages(batch_scores)
            advantages.append(batch_advantages)
        advantages = torch.cat(advantages)

        # 计算参考模型的logits
        with torch.no_grad():
            ref_outputs = self.reference_model(
                candidate_ids,
                attention_mask=candidate_masks,
                output_hidden_states=False,
                output_attentions=False,
                return_dict=True
            )
            ref_logits = ref_outputs.logits

        # 计算重要性采样比率
        candidate_probs = F.log_softmax(candidates_logits, dim=-1)
        with torch.no_grad():
            ref_probs = F.log_softmax(ref_logits, dim=-1)

        # 获取生成的token
        input_lens = []
        for i in range(batch_size * num_candidates):
            input_len = torch.where(candidate_ids[i] == self.tokenizer.eos_token_id)[0]
            if len(input_len) > 0:
                input_lens.append(input_len[0].item() + 1)
            else:
                input_lens.append(input_ids.shape[1])

        # 计算每个token的重要性采样比率
        ratios = []
        kl_penalties = []

        for i in range(batch_size * num_candidates):
            input_len = input_lens[i]
            gen_len = candidate_ids.shape[1] - input_len

            if gen_len <= 0:
                ratios.append(torch.tensor(0.0).to(candidate_ids.device))
                kl_penalties.append(torch.tensor(0.0).to(candidate_ids.device))
                continue
        
            # 获取生成部分的概率
            gen_candidate_probs = candidate_probs[i, input_len:input_len+gen_len]
            gen_ref_probs = ref_probs[i, input_len:input_len+gen_len]
            gen_tokens = candidate_ids[i, input_len:input_len+gen_len]
            
            # 计算重要性采样比率
            candidate_token_probs = torch.gather(
                gen_candidate_probs, -1, gen_tokens.unsqueeze(-1)
            ).squeeze(-1)
            ref_token_probs = torch.gather(
                gen_ref_probs, -1, gen_tokens.unsqueeze(-1)
            ).squeeze(-1)

            ratio = torch.exp(candidate_token_probs - ref_token_probs)
            ratios.append(ratio.mean())

            # 计算KL散度惩罚
            kl_penalty = F.kl_div(
                candidate_probs[i, input_len:],
                ref_probs[i, input_len:],
                reduction='batchmean',
                log_target=True
            )
            kl_penalties.append(kl_penalty)

        ratios = torch.stack(ratios)
        kl_penalties = torch.stack(kl_penalties)

        # 计算裁剪的目标函数
        advantages = advantages.to(ratios.device)
        surrogate1 = ratios * advantages
        surrogate2 = torch.clamp(ratios, 1 - self.config.clip_epsilon, 1 + self.config.clip_epsilon) * advantages

        # GRPG目标函数
        grpo_objective = torch.min(surrogate1, surrogate2) - self.config_kl_penalty * kl_penalties

        loss = -grpo_objective.mean()

        self.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
        self.optimizer.step()

        metrics = {
            'loss': loss.item(),
            'self_certainty_mean': self_certainty_scores.mean().item(),
            'self_certainty_std': self_certainty_scores.std().item(),
            'advantages_mean': advantages.mean().item(),
            'kl_penalty_mean': kl_penalties.mean().item(),
            'grpo_objective': grpo_objective.mean().item()
        }
        
        return metrics
INTUITOR 奖励塑造 自监督强化学习
更新于 2025年11月12日 查看完整代码 →
CM

CoAT-MCTS

MCTS、CoAT、TIP、TPO

结合蒙特卡洛树搜索和偏好优化的测试时扩展框架。通过CoAT框架进行上下文感知自适应思考, 使用TIP惩罚思路切换,并利用TPO进行迭代优化

Python / PyTorch
from transformers import AutoTokenizer, AutoModelForCausalLM
import matplotlib.pyplot as plt
import numpy as np
import re
import torch
import torch.nn.functional as F
from tqdm import tqdm
import random

# 解析函数
def parse_solution(response):
    # 宽松匹配所有x=和y=的数值,取最后一次出现的结果
    x_values = re.findall(r'x\s*[==]\s*(-?\d+\.?\d*)', response)
    y_values = re.findall(r'y\s*[==]\s*(-?\d+\.?\d*)', response)
    
    if not x_values or not y_values:
        return None, None
    
    try:
        x = float(x_values[-1])
        y = float(y_values[-1])
        return x, y
    except:
        return None, None



tokenizer = AutoTokenizer.from_pretrained(
    "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
    pad_token="<|endoftext|>"  # 显式设置pad token
)
model = AutoModelForCausalLM.from_pretrained(
    "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
    device_map="auto",
)

# 确保pad_token有效,若不存在则使用eos_token
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
device='cuda'
model.to(device)
#============================= 实现 CoAT ==================================
class CoATNode:
    def __init__(self, parent=None, prompt='', context='', depth=0):
        self.parent = parent
        self.children = []  # 子节点列表
        self.visit_count = 0  # 访问次数
        self.total_value = 0  # 累计价值
        self.prompt = prompt  # 初始提示词
        self.context = context  # 当前节点生成内容
        self.associative_memory = []  # 关联记忆
        self.depth = depth  # 当前节点深度
        self.is_expanded = False  # 是否已扩展
    
    def uct_value(self, exploration_weight=1.414):
        """计算UCT值"""
        if self.visit_count == 0:
            return float('inf') # 未访问节点优先搜索
        exploitation = self.total_value / self.visit_count
        exploration = exploration_weight * np.sqrt(np.log(self.parent.visit_count) / (self.visit_count + 1e-6))
        return exploitation + exploration
        
    def best_child(self):
        """选择UCT值最高的子节点"""
        return max(self.children, key=lambda x: x.uct_value())
    
# CoAT MCTS
class CoATFramework:
    def __init__(self, model, tokenizer, max_iter=100, max_depth=5, num_simulations=50):
        self.model = model
        self.tokenizer = tokenizer
        self.max_iter = max_iter
        self.max_depth = max_depth
        self.num_simulations = num_simulations
        self.external_brain = self.init_external_brain()
    
    def init_external_brain(self):
        """初始化外部数学库"""
        return {
            "消元法": "联立方程消去变量:方程1 + 方程2 → 3x = 9 → x=3",
            "代入法": "从方程1解出y=8-x,代入方程2得 2x - (8-x) =1 → x=3",
            "验证步骤": "将x=3代入原方程验证:3 + y=8 → y=5"
        }
    
    def retrieve_associative_memory(self, context):
        """动态关联记忆检索"""
        keywords = ["消元", "代入", "验证", "解"]
        for kw in keywords:
            if kw in context:
                return self.external_brain.get(kw, "")
        return ""
    
    def evaluate_node(self, node):
        """节点价值评估"""
        # 生成质量
        full_text = node.prompt + node.context
        ppl = compute_perplexity(self.model, self.tokenizer, full_text)
        gen_score = 1 / (ppl + 1e-6)

        # 关联内容质量
        am_score = 0.2 * len(node.associative_memory) # 每条关联内容+0.2分
        return gen_score + am_score
    
    def expand_node(self, node):
        """节点扩展, 生成候选内容并关联记忆"""
        if node.is_expanded:
            return
        # 生成候选内容(基于当前上下文)
        input_text = node.prompt + node.context
        inputs = self.tokenizer(input_text, return_tensors="pt").to(device)
        inputs["attention_mask"] = inputs.input_ids.ne(tokenizer.pad_token_id).int()
        outputs = self.model.generate(
            inputs.input_ids,
            attention_mask=inputs.attention_mask,
            max_new_tokens=100,
            num_return_sequences=3, # 每个节点生成3个候选
            temperature=0.7,
            do_sample=True,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
        )
        # 创建子节点
        for seq in outputs:
            child_text = self.tokenizer.decode(seq, skip_special_tokens=True)
            child = CoATNode(
                parent=node,
                prompt=node.prompt,
                context=child_text[len(input_text):], # 仅保留新增内容
                depth=node.depth+1
            )
            # 关联记忆搜索
            child.associative_memory = self.retrieve_associative_memory(child.context)
            node.children.append(child)
        node.is_expanded = True

    def simulate(self, node):
        """蒙特卡洛模拟: 随机游走到叶节点并评估"""
        current_depth = node.depth
        while current_depth < self.max_depth:
            if not node.children:
                self.expand_node(node)
            if not node.children:
                break # 无子节点可扩展
            node = random.choice(node.children) # 随机选择子节点
            current_depth += 1
        return self.evaluate_node(node)
    
    def backpropagate(self, node, value):
        """回溯更新节点价值"""
        while node is not None:
            node.visit_count += 1
            node.total_value += value
            node = node.parent
        
    def mcts_search(self, root):
        """蒙特卡洛树搜索"""
        for _ in range(self.max_iter):
            # 选择阶段
            node = root
            while node.children:
                node = node.best_child()
            
            # 扩展阶段
            if node.depth < self.max_depth and not node.is_expanded:
                self.expand_node(node)

            # 模拟阶段
            total_sim_value = 0
            for _ in range(self.num_simulations):
                sim_value = self.simulate(node)
                total_sim_value += sim_value
            avg_sim_value = total_sim_value / self.num_simulations

            # 回溯更新
            self.backpropagate(node, avg_sim_value)

        # 选择最优路径
        best_node = root.best_child()
        return best_node
    
# CoAT生成函数
def generate_with_coat(prompt, coat):
    # 初始化搜索树
    root = CoATNode(prompt=prompt, context="")
    coat.expand_node(root) # 初始扩展

    # MCTS搜索
    best_node = coat.mcts_search(root)

    # 构建最终响应
    full_response = best_node.context
    if best_node.associative_memory:
        full_response += f"\n[关联知识] {best_node.associative_memory}"
    
    # 回溯生成完整路径
    path = []
    current_node = best_node
    while current_node.parent:
        path.append(current_node.context)
        current_node = current_node.parent
    path.reverse()
    full_response = "\n".join(path) + "\n" + full_response
    
    return full_response


#============================= 实现 TIP Logits 处理器==================================
# 定义思路切换相关的触发词
switch_tokens = [
    '另一种方法', 'alternatively', '或者', '换一种思路',
    '但是', '另一方面', '然而'
]

# 通过分词器转换为 token id
switch_tokens_ids = tokenizer.convert_tokens_to_ids(switch_tokens)
from transformers import LogitsProcessor

class TIPLogitsProcessor(LogitsProcessor):
    def __init__(self, switch_token_ids, alpha=3.0, beta=300):
        self.switch_token_ids = switch_token_ids
        self.alpha = alpha # 惩罚强度
        self.beta = beta # 惩罚时间
        self.current_thought_start = 0 # 当前思路的起始位置

    def __call__(self, input_ids, scores):
        # 检查是否触发新思路
        last_token = input_ids[0][-1].item()
        if last_token in self.switch_token_ids:
            self.current_thought_start = input_ids.shape[-1] # 记录新思路的起始位置
        
        # 计算当前处理 token 是否在惩罚窗口中
        current_position = input_ids.shape[-1]
        if current_position < self.current_thought_start + self.beta:
            # 对切换 token 施加惩罚
            for token_id in self.switch_token_ids:
                scores[:, token_id] -= self.alpha
        
        return scores


# 设置生成参数
def generate_with_cot(prompt):
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    inputs["attention_mask"] = inputs.input_ids.ne(tokenizer.pad_token_id).int()

    # 添加 TIP Logits 处理器
    logits_processor = [TIPLogitsProcessor(switch_tokens_ids, alpha=3.0, beta=300)]

    outputs = model.generate(
        inputs.input_ids,
        attention_mask=inputs.attention_mask,
        max_new_tokens=512,
        temperature=0.85, 
        top_p=0.9,
        repetition_penalty=1.2,
        do_sample=True,
        pad_token_id=tokenizer.pad_token_id,
        eos_token_id=tokenizer.eos_token_id,
        logits_processor=logits_processor # 注入TIP处理器
    )
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

problem = "解方程组:\n方程1: x + y = 8\n方程2: 2x - y = 1"

# 修改后的CoT提示(增加改进空间)
cot_prompt = f"""
请逐步解决以下问题:

{problem}

分步推理要求:
1. 明确标注步骤编号
2. 展示完整的代数运算过程
3. 最终解用方框标出(如:x=3, y=5)
"""


# 困惑度计算函数
def compute_perplexity(model, tokenizer, text):
    inputs = tokenizer(text, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = model(**inputs)
    logits = outputs.logits
    log_probs = F.log_softmax(logits, dim=-1)
    tokens = inputs["input_ids"]
    nll = F.nll_loss(log_probs[:, :-1].contiguous().view(-1, log_probs.size(-1)),
                        tokens[:, 1:].contiguous().view(-1),
                        reduction='mean')
    return torch.exp(nll).item()


response = generate_with_cot(cot_prompt)
ppl = compute_perplexity(model, tokenizer, cot_prompt)
print("============ 原始模型 =============")
print(f"输入:{cot_prompt}...\n输出:{response}...\n困惑度:{ppl:.2f}")

with open('pre_response.txt','w', encoding='utf-8') as f:
    f.write(f"输入:{cot_prompt}\n输出:{response}\n困惑度:{ppl:.2f}")


# ============================ TPO =========================
num_iterations = 3  # TPO迭代次数
num_candidates = 5  # 每轮生成的候选响应数量
ground_truth = (3, 5) # 方程组的真实解

# 奖励函数:根据解的正确性得分
def reward_function(parsed_solution):
    if parsed_solution is None:
        return -1.0 # 无效解惩罚
    x_pred, y_pred = parsed_solution
    # 计算误差并归一化 [0,1]
    max_error = 8  # 最大可能误差(如x=8,y=0时误差为5+8=13,但需根据问题调整)
    error = (abs(x_pred - ground_truth[0]) + abs(y_pred - ground_truth[1])) / max_error
    # 思路切换惩罚
    switch_count = sum([1 for token in switch_tokens if token in response])
    penalty = 0.05 * switch_count # 每次切换扣 0.05 分
    return max(0.0, 1.0 - error - penalty)

def tpo_optimization(initial_prompt):
    cache = [] # 存储(响应、奖励分)的缓存

    coat = CoATFramework(model, tokenizer)
    # 初始生成候选
    candidates = [generate_with_coat(initial_prompt, coat) for _ in range(num_candidates)]
    for resp in candidates:
        x, y = parse_solution(resp)
        score = reward_function((x, y))
        cache.append((resp, score))

    # TPO迭代
    for _ in tqdm(range(num_iterations), desc='TPO Train'):
        # 选择最优和最差响应
        best_resp = max(cache, key=lambda x:x[1])[0]
        worst_resp = min(cache, key=lambda x:x[1])[0]

        # 生成文本反馈(改进建议)
        feedback_prompt = f"""
        以下是两个解方程组的示例:

        **优秀示例**:
        {best_resp}

        **较差示例**:
        {worst_resp}

        请分析优秀示例的优点和较差示例的不足,并提出改进建议:
        1. 步骤完整性(是否遗漏验证步骤)
        2. 计算准确性(是否存在算术错误)
        3. 表达清晰度(是否使用明确标记)

        **强制修正要求**:
        - 必须验证消元步骤:3x=9 → x=3
        - 若出现矛盾结论,必须重新计算
        """
        feedback_prompt += """
        **注意**:反馈需满足以下要求:
        - 分析必须具体,避免复述解题过程
        - 改进建议不超过3条
        - 强制使用LaTeX公式标注关键步骤
        """
        feedback = generate_with_cot(feedback_prompt)

        # 基于反馈生成新候选
        new_candidates = [generate_with_cot(f"{initial_prompt}\n改进建议:{feedback}") 
                            for _ in range(num_candidates)]
        
        # 更新缓存
        for resp in new_candidates:
            x, y = parse_solution(resp)
            score = reward_function((x, y))
            cache.append((resp, score))

    # 返回最高分响应
    return max(cache, key=lambda x:x[1])[0], feedback

print("============ TPO模型 =============")
# 运行TPO优化
optimized_response, feedback = tpo_optimization(cot_prompt)
ppl = compute_perplexity(model, tokenizer, optimized_response)
print(f"输入:{feedback}...\n输出:{optimized_response}...\n困惑度:{ppl:.2f}")

with open('tpo_response.txt','w', encoding='utf-8') as f:
    f.write(f"输入:{feedback}\n输出:{optimized_response}\n困惑度:{ppl:.2f}")
                        
TTS MCTS CoAT TIP TPO
更新于 2025年4月 查看完整代码 →
MoE

MoE动态Top-k路由

稀疏激活机制

基于Gumbel-Softmax和容量因子的动态Top-k路由机制, 实现专家网络的稀疏激活和负载均衡。

Python / PyTorch
class DynamicTopkRouter(nn.Module):
    def __init__(self, hidden_dim, num_experts, top_k=2, 
                 capacity_factor=1.2):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.num_experts = num_experts
        self.top_k = top_k
        self.capacity_factor = capacity_factor
        
        # 路由权重
        self.router = nn.Linear(hidden_dim, num_experts, bias=False)
        
    def forward(self, x, training=True):
        batch_size, seq_len, hidden_dim = x.shape
        
        # 计算路由分数
        router_logits = self.router(x)  # [B, L, E]
        
        # Top-k选择
        top_k_logits, top_k_indices = torch.topk(
            router_logits, self.top_k, dim=-1
        )
        
        # Gumbel-Softmax实现稀疏路由
        if training:
            gumbel_noise = -torch.log(
                -torch.log(torch.rand_like(top_k_logits) + 1e-9) + 1e-9
            )
            top_k_logits = top_k_logits + gumbel_noise
        
        # 计算路由权重
        routing_weights = F.softmax(top_k_logits, dim=-1)
        
        # 计算容量限制
        capacity = int(self.capacity_factor * seq_len / self.num_experts)
        
        # 创建路由掩码
        routing_mask = torch.zeros_like(router_logits, dtype=torch.bool)
        routing_mask.scatter_(-1, top_k_indices, True)
        
        # 应用容量限制
        expert_counts = torch.zeros(self.num_experts, device=x.device)
        for i in range(self.num_experts):
            expert_mask = routing_mask[..., i]
            if expert_counts[i] + expert_mask.sum() > capacity:
                # 超出容量的token将被丢弃
                excess = int(expert_counts[i] + expert_mask.sum() - capacity)
                drop_indices = torch.where(expert_mask)[0][:excess]
                expert_mask[drop_indices] = False
                routing_mask[..., i] = expert_mask
        
        return {
            'routing_weights': routing_weights,
            'routing_indices': top_k_indices,
            'routing_mask': routing_mask,
            'expert_counts': expert_counts
        }
MoE 动态路由 Gumbel-Softmax 负载均衡
更新于 2025年4月 查看完整代码 →
FA

Triton FlashAttention-2 前向传播

高效注意力机制实现

使用Triton实现的FlashAttention-2前向传播内核,通过分块计算和内存优化, 将注意力机制的内存复杂度从O(N²)降低到O(N)。

Python / Triton
import torch
import triton
import triton.language as tl

@triton.jit
def flash_attention_v2(
    # 张量指针
    q_ptr, k_ptr, v_ptr, o_ptr,
    # 元数据
    seq_len, head_dim: tl.constexpr,
    # 内存步幅
    q_stride_m, q_stride_h, # Q的步幅 [seq_len, num_heads, head_dim]
    k_stride_m, k_stride_h, # K的步幅
    v_stride_m, v_stride_h, # V的步幅
    o_stride_m, o_stride_h, # O的步幅
    # 超参数
    BLOCK_M: tl.constexpr, # Q块大小
    BLOCK_N: tl.constexpr, # K/V块大小
    NUM_HEADS: tl.constexpr, # 头数
    IS_CAUSAL: tl.constexpr # 是否因果掩码
):
    # 1. 计算程序ID与初始偏移量
    pid_head = tl.program_id(0) # 头索引
    pid_m = tl.program_id(1) # 每个头内处理Q块的索引

    # 当前Q块的起始位置
    start_m = pid_m * BLOCK_M
    offs_m = start_m + tl.arange(0, BLOCK_M)
    offs_n = tl.arange(0, BLOCK_N)

    # 2. 初始化累加器
    m_i = tl.full((BLOCK_M, ), float('-inf'), dtype=tl.float32)
    l_i = tl.zeros((BLOCK_M, ), dtype=tl.float32)
    acc = tl.zeros((BLOCK_M, head_dim), dtype=tl.float32)

    # 3. 加载Q块(使用向量化加载)
    q_offset = pid_head * q_stride_h + offs_m[:, None] * q_stride_m
    q = tl.load(q_ptr + q_offset + tl.arange(0, head_dim)[None, :] * q_stride_h,
                mask=(offs_m[:, None] < seq_len) & (tl.arange(0, head_dim)[None, :] < head_dim),
                other=0.0).to(tl.float32)
    
    # 4. 主循环处理K/V块
    for start_n in range(0, (seq_len + BLOCK_M - 1) // BLOCK_N * BLOCK_N, BLOCK_N):
        # 4.1 计算当前K/V块的有限范围
        valid_n = start_n + offs_n < seq_len
        # 4.2 加载K块
        k_offset = pid_head * k_stride_h + (start_n + offs_n)[:, None] * k_stride_m
        k = tl.load(k_ptr + k_offset + tl.arange(0, head_dim)[None, :] * 1,
                    mask=valid_n[:, None] & (tl.arange(0, head_dim)[None, :] < head_dim), 
                    other=0.0).to(tl.float32)
        # 4.3 加载V块
        v_offset = pid_head * v_stride_h + (start_n + offs_n)[:, None] * v_stride_m
        v = tl.load(v_ptr + v_offset + tl.arange(0, head_dim)[None, :] * 1,
                    mask=valid_n[:, None] & (tl.arange(0, head_dim)[None, :] < head_dim),
                    other=0.0).to(tl.float32)
        
        # 4.4 计算QK^T(启用Tensor Core加速)
        s = tl.dot(q, k.T.to(q.dtype))
        s = s * (1.0 / tl.sqrt(tl.cast(head_dim, tl.float32)))

        # 创建序列长度掩码
        mask_m = offs_m < seq_len  # Q序列的有效位置
        mask_n = offs_n < seq_len  # K序列的有效位置
        seq_mask = mask_m[:, None] & mask_n[None, :]  # 组合成二维掩码

        # 4.5 处理因果掩码
        if IS_CAUSAL:
            causal_mask = (offs_m[:, None]) >= (start_n + offs_n[None, :])
            seq_mask = seq_mask & causal_mask  # 合并两种掩码
            
        s = tl.where(causal_mask, s, float('-inf'))

        # 4.6 在线Softmax更新
        # 计算当前块的行最大值
        m_curr = tl.maximum(tl.max(s, axis=1), m_i)
        # 计算指数和
        alpha = tl.exp(m_i - m_curr) # 旧最大衰减因子
        beta = tl.exp(s - m_curr[:, None]) # 当前块指数
        l_curr = alpha * l_i + tl.sum(beta, axis=1)
        # 更新累加器
        p = beta / l_curr[:, None]
        acc = acc * alpha[:, None] + tl.dot(p.to(v.dtype), v)

        # 4.7 保存中间变量
        m_i = m_curr
        l_i = l_curr

    # 最终归一化并存储结果
    o = acc / l_i[:, None]
    # 存储到全局变量
    o_offset = pid_head * o_stride_h + offs_m[:, None] * o_stride_m
    tl.store(o_ptr + o_offset + tl.arange(0, head_dim)[None, :] * 1,
                o.to(o_ptr.dtype.element_ty),
                mask=(offs_m[:, None] < seq_len) & (tl.arange(0, head_dim)[None, :] < head_dim))

def call_flash_attention_v2(q, k, v, is_causal=False):
    assert q.dim() == 3, "Input should be [seq_len, num_heads, head_dim]"
    seq_len, num_heads, head_dim = q.shape
    o = torch.empty_like(q)
    
    config = {
        'BLOCK_M': 128,
        'BLOCK_N': 64,
        'num_warps': 8,
        'num_stages': 3,
    }
    
    # 网格维度:每个头独立计算,每个头内划分Q块
    grid = (num_heads, triton.cdiv(seq_len, config['BLOCK_M']))
    
    flash_attention_v2[grid](
        q, k, v, o,
        seq_len, head_dim,
        # 内存步幅计算(假设输入为连续张量)
        q.stride(1), q.stride(0),
        k.stride(1), k.stride(0),
        v.stride(1), v.stride(0),
        o.stride(1), o.stride(0),
        NUM_HEADS=num_heads,
        IS_CAUSAL=is_causal,
        **config
    )
    return o
Triton FlashAttention GPU优化 内存优化
更新于 2025年9月 查看完整代码 →
DI

StreamingDataIterator

流式数据加载

用于大语言模型训练的流式数据迭代器,支持动态批处理、 样本打包和内存高效的数据加载。

Python
class StreamingDataIterator:
    def __init__(self, data_path, tokenizer, max_length=2048,
                 batch_size=8, buffer_size=10000):
        self.data_path = data_path
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.batch_size = batch_size
        self.buffer_size = buffer_size
        
        # 初始化缓冲区
        self.buffer = []
        self.buffer_cursor = 0
        
        # 加载数据生成器
        self.data_generator = self._create_generator()
        
    def _create_generator(self):
        """创建数据流生成器"""
        def generator():
            with open(self.data_path, 'r', encoding='utf-8') as f:
                for line in f:
                    if line.strip():
                        yield json.loads(line.strip())
        return generator()
    
    def _tokenize_sample(self, sample):
        """将样本转换为token序列"""
        text = sample['text']
        tokens = self.tokenizer.encode(
            text, truncation=True, max_length=self.max_length
        )
        return tokens
    
    def _load_to_buffer(self):
        """加载数据到缓冲区"""
        self.buffer = []
        while len(self.buffer) < self.buffer_size:
            try:
                sample = next(self.data_generator)
                tokens = self._tokenize_sample(sample)
                self.buffer.append(tokens)
            except StopIteration:
                break
    
    def __iter__(self):
        return self
    
    def __next__(self):
        """获取下一个批次"""
        if self.buffer_cursor + self.batch_size > len(self.buffer):
            self._load_to_buffer()
            self.buffer_cursor = 0
            
            if len(self.buffer) == 0:
                raise StopIteration
        
        # 获取批次
        batch = self.buffer[
            self.buffer_cursor:self.buffer_cursor + self.batch_size
        ]
        self.buffer_cursor += self.batch_size
        
        # 动态填充
        max_len = max(len(seq) for seq in batch)
        padded_batch = []
        attention_masks = []
        
        for seq in batch:
            padding_length = max_len - len(seq)
            padded_seq = seq + [self.tokenizer.pad_token_id] * padding_length
            padded_batch.append(padded_seq)
            
            attention_mask = [1] * len(seq) + [0] * padding_length
            attention_masks.append(attention_mask)
        
        return {
            'input_ids': torch.tensor(padded_batch, dtype=torch.long),
            'attention_mask': torch.tensor(attention_masks, dtype=torch.long)
        }
数据流 LLM训练 内存优化 动态批处理
更新于 2025年4月20日 查看完整代码 →