分享最新的技术实现与算法优化,每一行代码都追求极致的性能与优雅的实现。
熵估计
完整的RENT算法实现,包含三种熵最小化方法(EM-FT/EM-RL/EM-INF) 和自适应温度缩放方法。
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from transformers import AutoTokenizer, AutoModelForCausalLM
import mathclass EntropyMinimization:
def __init__(self, model_name='gpt2'):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForCausalLM.from_pretrained(model_name).to(self.device)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
def token_level_entropy(self, logits):
"""计算标记级熵"""
probs = F.softmax(logits, dim=-1)
log_probs = F.log_softmax(logits, dim=-1)
entropy = -torch.sum(probs * log_probs, dim=-1)
return entropy
def sequence_level_entropy(self, logits, attention_mask):
"""计算序列级熵"""
probs = F.softmax(logits, dim=-1)
log_probs = F.log_softmax(logits, dim=-1)
# 计算每个位置的熵
token_entropies = -torch.sum(probs * log_probs, dim=-1)
# 对有效标记的熵求平均
valid_tokens = attention_mask.sum(dim=1)
sequence_entropies = (token_entropies * attention_mask).sum(dim=1) / valid_tokens
return sequence_entropies
class EMFTTrainer(EntropyMinimization):
"""EM-FT: 基于标记级熵的微调"""
def __init__(self, model_name='gpt2',lr=1e-5):
super().__init__(model_name)
self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=lr)
def generate_training_data(self, prompts, num_samples=4):
self.model.eval()
training_data = {'input_ids': [], 'attention_mask': []}
with torch.no_grad():
for prompt in prompts:
inputs = self.tokenizer(prompt, return_tensors='pt', padding=True, truncation=True)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
for _ in range(num_samples):
outputs = self.model.generate(
**inputs,
max_length=inputs['input_ids'].shape[1] + 100,
do_sample=True,
temperature=0.8,
pad_token_id=self.tokenizer.eos_token_id
)
training_data['input_ids'].append(outputs.cpu())
training_data['attention_mask'].append(torch.ones_like(outputs).cpu())
return training_data
def compute_entropy_loss(self, batch):
input_ids = batch['input_ids'].to(self.device)
attention_mask = batch['attention_mask'].to(self.device)
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
logits = outputs.logits
# 计算每个位置的熵
token_entropies = self.token_level_entropy(logits)
# 只计算非填充位置的熵
valid_entropies = token_entropies[attention_mask.bool()]
# 返回平均熵作为损失
return valid_entropies.mean()
def train_step(self, batch):
self.model.train()
self.optimizer.zero_grad()
loss = self.compute_entropy_loss(batch)
loss.backward()
self.optimizer.step()
return loss.item()
def train(self, prompts, num_epochs=3, batch_size=2):
print("生成训练数据...")
training_data = self.generate_training_data(prompts)
input_ids = torch.cat(training_data[input_ids], dim=0)
attention_mask = torch.cat(training_data['attention_mask'], dim=0)
num_batches = math.ceil(input_ids.shape[0] / batch_size)
print(f"开始训练,总批次数: {num_batches}")
for epoch in range(num_epochs):
total_loss = 0
for i in range(num_batches):
start_idx = i * batch_size
end_idx = min((i + 1) * batch_size, input_ids.shape[0])
batch = {
"input_ids": input_ids[start_idx:end_idx],
"attention_mask": attention_mask[start_idx:end_idx]
}
loss = self.train_step(batch)
total_loss += loss
if i % 10 == 0:
print(f"Epoch {epoch+1}, Batch {i+1}/{num_batches}, Loss: {loss:.4f}")
avg_loss = total_loss / num_batches
print(f"Epoch {epoch+1} 完成, 平均损失: {avg_loss:.4f}")
class EMRLLoss(EntropyMinimization):
"""EM-RL: 基于强化学习的熵最小化损失函数"""
def __init__(self, model_name='gpt2', beta=0.001):
super().__init__(model_name)
self.beta = beta
def compute_reward(self, logits, attention_mask, reward_type='token'):
if reward_type == 'token':
token_entropies = self.token_level_entropy(logits)
valid_entropies = token_entropies[attention_mask.bool()]
reward = -valid_entropies.mean()
else:
sequence_entropies = self.sequence_level_entropy(logits, attention_mask)
reward = -sequence_entropies.mean()
return reward
def compute_kl_penalty(self, logits, ref_logits, attention_mask):
log_probs = F.log_softmax(logits, dim=-1)
ref_probs = F.softmax(ref_logits, dim=-1)
kl_div = F.kl_div(
input=log_probs,
target=ref_probs,
reduction='none'
).sum(dim=-1)
valid_kl = kl_div[attention_mask.bool()]
return valid_kl.mean()
def compute_loss(self, logits, ref_logits, attention_mask, reward_type='token'):
# 计算奖励(负熵)
reward = self.compute_reward(logits, attention_mask, reward_type)
# 计算KL惩罚
kl_penalty = self.compute_kl_penalty(logits, ref_logits, attention_mask)
# 总损失 = -奖励 + β * KL惩罚
total_loss = -reward + self.beta * kl_penalty
return total_loss, reward.item(), kl_penalty.item()
class EMINFInference(EntropyMinimization):
"""EM-INF: 推理时logits优化"""
def __init__(self, model_name='gpt2'):
super().__init__(model_name)
def optimize_logits(self, logits, delta=0.3, num_steps=15, lr=0.1):
"""优化logits以减少熵"""
# 将logits设为需要梯度
optimized_logits = logits.clone().to(self.device).requires_grad_(True)
optimizer = optim.Adam([optimized_logits], lr=lr)
for step in range(num_steps):
optimizer.zero_grad()
# 计算当前熵
current_entropy = self.token_level_entropy(optimized_logits.unsqueeze(0)).mean()
# 损失函数: 超过阈值时最小化熵
if current_entropy > delta:
loss = current_entropy
loss.backward()
optimizer.step()
else:
break
return optimized_logits.detach()
def adaptive_temperature_scaling(self, logits, alpha=0.5, delta=0.3):
"""自适应温度缩放方法(对比基线)"""
original_entropy = self.token_level_entropy(logits.unsqueeze(0)).mean()
target_entropy = max(alpha * original_entropy, delta)
# 二分查找最佳温度
low_temp, high_temp = 0.1, 10.0
tolerance = 1e-4
max_iters = 20
for _ in range(max_iters):
mid_temp = (low_temp + high_temp) / 2
scaled_logits = logits / mid_temp
current_entropy = self.token_level_entropy(scaled_logits.unsqueeze(0)).mean()
if abs(current_entropy - target_entropy) < tolerance:
break
if current_entropy > target_entropy:
low_temp = mid_temp
else:
high_temp = mid_temp
return logits / mid_temp
def generate_with_em_inf(self, prompt, method='logit_optimization', max_length=100, **kwargs):
"""使用EM-INF生成文本"""
self.model.eval()
inputs = self.tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"].to(self.device)
attention_mask = inputs["attention_mask"].to(self.device)
generated_ids = input_ids.clone()
with torch.no_grad():
for _ in range(max_length):
with torch.no_grad():
outputs = self.model(input_ids=generated_ids, attention_mask=attention_mask)
next_token_logits = outputs.logits[:, -1, :]
if method == 'logit_optimization':
# EM-INF: logit优化
with torch.enable_grad():
optimized_logits = self.optimize_logits(
next_token_logits,
delta=kwargs.get('delta', 0.3),
num_steps=kwargs.get('num_steps', 15)
)
elif method == "adaptive_temp":
# 自适应温度缩放(基线)
optimized_logits = self.adaptive_temperature_scaling(
next_token_logits,
alpha=kwargs.get('alpha', 0.5),
delta=kwargs.get('delta', 0.3)
)
else:
optimized_logits = next_token_logits
probs = F.softmax(optimized_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
generated_ids = torch.cat([generated_ids, next_token], dim=1)
attention_mask = torch.cat([
attention_mask,
torch.ones(1, 1, device=self.device)
], dim=1)
if next_token.item() == self.tokenizer.eos_token_id:
break
return self.tokenizer.decode(generated_ids[0], skip_special_tokens=True)
优化评分函数
ICM通过优化一个评分函数U(D)来评估标签集D的质量,该函数由互预测性和逻辑一致性组成。 使用模拟退火启发式搜索和一致性修复改进样本标签。
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from transformers import AutoTokenizer, AutoModelForCausalLM
import numpy as np
import random
from dataclasses import dataclass, field
from abc import ABC, abstractmethod
import logging
import json
from tqdm import tqdm
import time
import argparse
from collections import defaultdict, deque
import hashlib
import re
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
@dataclass
class ICMConfig:
alpha: float = 50.0 # 互预测性权重
initial_temp: float = 10.0 # 模拟退火初始温度
final_temp: float = 0.01 # 最低温度
cooling_rate: float = 0.99 # 温度衰减率
init_samples: int = 8 # 初始随机样本数
max_iterations: int = 1000
batch_size: int = 4
max_context_examples: int = 5 # 最大上下文示例数
context_window: int = 2048 # 模型上下文窗口
fix_steps: int = 10 # 一致性修复迭代次数
device: str = 'cuda' if torch.cuda.is_available() else 'cpu'
seed: int = 620
cache_size: int = 1000 # 概率缓存大小
priority_weight: float = 100.0 # 相关样本采样权重
def __post_init__(self):
if self.device == "cuda" and not torch.cuda.is_available():
logger.warning('cuda不可用, 使用cpu')
self.device = "cpu"
class ProbabilityCache:
"""缓存概率计算结果以提高效率"""
def __init__(self, max_size=1000):
self.max_size = max_size
self.cache = {}
self.access_order = deque()
def _get_key(self, x, labels, context):
context_str = '|'.join([f"{d['x']}:{d['y']}" for d in context])
labels_str = ','.join(sorted(labels))
content = F"{x}|{labels_str}|{context_str}"
return hashlib.md5(content.encode()).hexdigest()
def get(self, x, labels, context):
key = self._get_key(x, labels, context)
if key in self.cache:
self.access_order.remove(key)
self.access_order.appendleft(key)
return self.cache[key]
return None
def set(self, x, labels, context, probs):
key = self._get_key(x, labels, context)
if len(self.cache) >= self.max_size:
oldest_key = self.access_order.pop()
del self.cache[oldest_key]
self.cache[key] = probs
self.access_order.appendleft(key)
class ConsistencyChecker(ABC):
"""逻辑一致性检查抽象基类"""
@abstractmethod
def check_pair(self, x_i, y_i, x_j, y_j):
"""检查两标签是否一致(True=一致, False=冲突)"""
pass
def get_inconsistent_pairs(self, data):
inconsistent = []
# 构建问题到索引的映射
question_to_indices = defaultdict(list)
for i, item in enumerate(data):
q = self.extract_question(item['x'])
question_to_indices[q].append(i)
# 只检查同一问题的样本对
for indices in question_to_indices.values():
if len(indices) > 1:
for i_idx, i in enumerate(indices):
for j in indices[i_idx+1:]:
if not self.check_pair(data[i]['x'], data[i]['y'], data[j]['x'], data[j]['y']):
inconsistent.append((i, j))
return inconsistent
@abstractmethod
def get_consistent_labels(self, x_i, x_j):
"""返回所有逻辑可能的标签组合"""
pass
def extract_question(self, text):
"""提取问题部分"""
if "Question:" in text:
return text.split("Claim:")[0].strip()
return text.split("\n")[0] if "\n" in text else text
class MathConsistencyChecker(ConsistencyChecker):
"""数学验证任务一致性检查"""
def __init__(self):
self.answer_patterns = ['The answer is', '答案是', 'Answer:']
def extract_answer(self, text):
for pattern in self.answer_patterns:
if pattern in text:
start = text.find(pattern) + len(pattern)
answer_part = text[start:start+20]
numbers = re.findall(r'-?\d+\.?\d*', answer_part)
if numbers:
return numbers[0]
return None
def check_pair(self, x_i, y_i, x_j, y_j):
q_i, q_j = self.extract_question(x_i), self.extract_question(x_j)
if q_i != q_j:
return True
ans_i, ans_j = self.extract_answer(x_i), self.extract_answer(x_j)
if ans_i and ans_j and ans_i != ans_j:
if (y_i == 'True' and y_j == 'True') or (y_i == 'False' and y_j == 'False'):
return False
return True
def get_consistent_labels(self, x_i, x_j):
q_i, q_j = self.extract_question(x_i), self.extract_question(x_j)
ans_i, ans_j = self.extract_answer(x_i), self.extract_answer(x_j)
if q_i == q_j and ans_i and ans_j and ans_i != ans_j:
# 同一问题不同答案:不能同时为True
return [('True', 'False'), ('False', 'True'), ('False', 'False')]
return [('True', 'True'), ('True', 'False'), ('False', 'True'), ('False', 'False')]
class AlpacaConsistencyChecker(ConsistencyChecker):
def extract_responses(self, text):
if 'Response A:' in text and 'Response B:' in text:
parts = text.split('Response A:')[1].split('Response B:')
if len(parts) == 2:
return parts[0].strip(), parts[1].strip()
return None, None
def check_pair(self, x_i, y_i, x_j, y_j):
resp_a_i, resp_b_i = self.extract_responses(x_i)
resp_a_j, resp_b_j = self.extract_responses(x_j)
if resp_a_i and resp_b_i and resp_a_j and resp_b_j:
# 如果是相同的响应对但顺序相反
if (resp_a_i == resp_b_j and resp_b_i == resp_a_j):
if y_i == "True" and y_j == "True":
return False
return True
def get_consistent_labels(self, x_i, x_j):
resp_a_i, resp_b_i = self.extract_responses(x_i)
resp_a_j, resp_b_j = self.extract_responses(x_j)
if (resp_a_i and resp_b_i and resp_a_j and resp_b_j and
resp_a_i == resp_b_j and resp_b_i == resp_a_j):
# 反对称情况
return [('True', 'False'), ('False', 'True'), ('False', 'False')]
return [('True', 'True'), ('True', 'False'), ('False', 'True'), ('False', 'False')]
class ICMCore:
def __init__(self, model, tokenizer, config, consistency_checker):
self.model = model
self.tokenizer = tokenizer
self.config = config
self.checker = consistency_checker
self.cache = ProbabilityCache(config.cache_size)
for param in self.model.parameters():
param.requires_grad = False
self.model.to(self.config.device)
self.model.eval()
random.seed(self.config.seed)
np.random.seed(self.config.seed)
torch.manual_seed(self.config.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(config.seed)
logger.info(f"ICMCore initialized on {config.device}")
def build_prompt(self, x, context):
parts = []
for ex in context[-self.config.max_context_examples:]:
parts.append(f"Input: {ex['x']}\nLabel: {ex['y']}")
parts.append(f"Input: {x}\nLabel:")
return "\n\n".join(parts)
@torch.no_grad()
def get_label_prob_batch(self, x_list, labels, context_list):
"""批量计算条件概率"""
results = []
for i, (x, context) in enumerate(zip(x_list, context_list)):
cached = self.cache.get(x, labels, context)
if cached is not None:
results.append(cached)
continue
prompt = self.build_prompt(x, context)
prompt_ids = self.tokenizer.encode(prompt, return_tensors='pt')
probs = {}
for label in labels:
full_text = prompt + ' ' + label
inputs = self.tokenizer(full_text, return_tensors='pt', truncation=True, max_length=self.config.context_window, return_offsets_mapping=True).to(self.config.device)
prompt_len = len(prompt_ids[0])
label_start = None
offset_mapping = inputs['offset_mapping'][0]
for idx, (start, end) in enumerate(offset_mapping):
if start >= len(prompt) and label_start is None:
label_start = idx
break
if label_start is None:
probs[label] = 1.0 / len(labels) # 回退
continue
outputs = self.model(**{k: v for k, v in inputs.items() if k != 'offset_mapping'})
logits = outputs.logits
label_ids = self.tokenizer.encode(label, add_special_tokens=False)
log_prob = 0.0
valid_tokens = 0
for j, token_id in enumerate(label_ids):
pos = label_start + j
if pos < logits.shape[1]:
token_log_prob = torch.log_softmax(logits[0, pos], dim=-1)[token_id].item()
log_prob += token_log_prob
valid_tokens += 1
if valid_tokens > 0:
probs[label] = np.exp(log_prob / valid_tokens) # 平均token概率
else:
probs[label] = 1.0 / len(labels)
# 归一化
total = sum(probs.values())
if total > 0:
probs = {k: v/total for k, v in probs.items()}
else:
probs = {label: 1.0/len(labels) for label in labels}
self.cache.set(x, labels, context, probs)
results.append(probs)
return results
@torch.no_grad()
def get_label_prob(self, x, labels, context):
"""单个样本的条件概率计算"""
return self.get_label_prob_batch([x], labels, [context])[0]
def compute_mutual_predictability(self, data):
"""计算互预测性 P_θ(D)"""
if len(data) <= 1:
return 0.0
total_log_prob = 0.0
batch_size = min(self.config.batch_size, len(data))
for start_idx in tqdm(range(0, len(data), batch_size), desc='Mutual Pred', leave=False):
end_idx = min(start_idx + batch_size, len(data))
batch_data = data[start_idx: end_idx]
x_list = [item['x'] for item in batch_data]
context_list = []
for i, item in enumerate(batch_data):
context = [d for j, d in enumerate(data) if j != start_idx + i]
context_list.append(context)
all_labels = list(set(item['y'] for item in data))
all_probs = self.get_label_prob_batch(x_list, all_labels, context_list)
for i, probs in enumerate(all_probs):
true_labels = batch_data[i]['y']
prob = probs.get(true_labels, 1e-10)
total_log_prob += np.log(prob)
return total_log_prob / len(data)
def compute_consistency_penalty(self, data):
"""计算一致性惩罚 I(D)"""
return len(self.checker.get_inconsistent_pairs(data))
def compute_score(self, data):
"""计算U(D) = α·P(D) - I(D)"""
if not data:
return -np.inf
mutual = self.compute_mutual_predictability(data)
penalty = self.compute_consistency_penalty(data)
score = self.config.alpha * mutual - penalty
return score
def fix_inconsistencies(self, data):
if len(data) <= 1:
return data
data = [item.copy() for item in data]
improved = True
steps = 0
while improved and steps < self.config.fix_steps:
improved = False
pairs = self.checker.get_inconsistent_pairs(data)
if not pairs:
break
# 随机采样一些不一致对进行修复
sample_pairs = random.sample(pairs, min(10, len(pairs)))
for i, j in sample_pairs:
x_i, x_j = data[i]['x'], data[j]['x']
label_options = self.checker.get_consistent_labels(x_i, x_j)
best_score = -np.inf
best_labels = None
current_score = self.compute_score(data)
orig_y_i, orig_y_j = data[i]['y'], data[j]['y']
for y_i_new, y_j_new in label_options:
data[i]['y'], data[j]['y'] = y_i_new, y_j_new
score = self.compute_score(data)
if score > best_score:
best_score = score
best_labels = (y_i_new, y_j_new)
if best_labels and best_score > current_score:
data[i]['y'], data[j]['y'] = best_labels
improved = True
else:
data[i]['y'], data[j]['y'] = orig_y_i, orig_y_j
steps += 1
return data
def select_example(self, unlabeled_indices, labeled, unlabeled_data):
if not labeled or random.random() < 0.3:
idx = random.choice(list(unlabeled_indices))
return unlabeled_data[idx], idx
# 基于问题相似度采样
ref_item = random.choice(labeled)
ref_question = self.checker.extract_question(ref_item['x'])
scores = []
candidates = []
for idx in unlabeled_indices:
question = self.checker.extract_question(unlabeled_data[idx])
# 简单相似度:问题前缀匹配
similarity = 1.0 if question == ref_question else 0.1
scores.append(similarity)
candidates.append(idx)
# 加权采样
if sum(scores) > 0:
idx = random.choices(candidates, weights=scores, k=1)[0]
else:
idx = random.choice(candidates)
return unlabeled_data[idx], idx
def run(self, unlabeled, labels, max_iter=None):
"""通过模拟退火搜索优化标注, 最大化U(D)评分"""
max_iter = max_iter or (min(len(unlabeled), self.config.max_iterations))
unlabeled_indices = set(range(len(unlabeled)))
labeled = []
init_indices = random.sample(list(unlabeled_indices), min(self.config.init_samples, len(unlabeled)))
for idx in init_indices:
labeled.append({
'x': unlabeled[idx],
'y': random.choice(labels),
'idx': idx
})
unlabeled_indices.remove(idx)
# 初始一致性修复
labeled = self.fix_inconsistencies(labeled)
current_score = self.compute_score(labeled)
logger.info(f"初始化完成: {len(labeled)} 个样本, 评分={current_score:.4f}")
start_time = time.time()
progress_bar = tqdm(range(max_iter), desc="ICM优化进度")
for iteration in progress_bar:
if not unlabeled_indices:
break
temp = max(self.config.final_temp, self.config.initial_temp * (self.config.cooling_rate ** iteration))
# 选择新样本进行标注
x_new, idx_new = self.select_example(unlabeled_indices, labeled, unlabeled)
# 预测最可能的标签
probs = self.get_label_prob(x_new, labels, labeled)
y_new = max(probs, key=probs.get)
# 创建包含新标注的临时数据集
new_item = {'x': x_new, 'y': y_new, 'idx': idx_new}
temp_labeled = labeled + [new_item]
temp_labeled = self.fix_inconsistencies(temp_labeled) # 修复可能的不一致
new_score = self.compute_score(temp_labeled)
delta = new_score - current_score # 评分变化量
# 模拟退火接受决策
accept = False
if delta > 0:
accept = True
else:
accept_prob = np.exp(delta / temp) if temp > 0 else 0
accept = random.random() < accept_prob
if accept:
# 接受新标注
labeled = temp_labeled
current_score = new_score
unlabeled_indices.remove(idx_new)
if iteration % 50 == 0:
logger.info(f"迭代 {iteration}: 评分={current_score:.4f}, "
f"变化量={delta:.4f}, 温度={temp:.4f}")
progress_bar.set_postfix({
'score': f'{current_score:.3f}',
'labeled': len(labeled),
'temp': f'{temp:.3f}'
})
elapsed = time.time() - start_time
logger.info(f"优化完成: {len(labeled)} 个标注样本, 最终评分={current_score:.4f}, "
f"总耗时={elapsed:.2f}秒")
return labeled
奖励塑造
INTUITOR算法的PyTorch实现,使用GRPO框架,用自我确定性替换外部奖励,可在无标签数据上强化学习智能体的探索过程。
import torch
import torch.nn.functional as F
import torch.optim as optim
from transformers import AutoTokenizer, AutoModelForCausalLM
import numpy as np
from dataclasses import dataclass
from torch.utils.data import Dataset
@dataclass
class IntuitorConfig:
model_name: str = 'gpt2'
batch_size: int = 128
num_candidates: int = 7
kl_penalty: int = 0.005
lr: int = 3e-5
max_length: int = 1024
num_epochs: int = 3
clip_epsilon: float = 0.2
advantage_scale: float = 1.0
use_online_self_certainty: bool = True
class SelfCertaintyCalculator:
def __init__(self, vocab_size):
self.vocab_size = vocab_size
self.uniform_dist = torch.ones(vocab_size) / vocab_size
def compute_self_certainty(self, logits, attention_mask):
"""计算自我确定性分数"""
# 获取概率分布
probs = F.softmax(logits, dim=-1) # [batch_size, seq_len, vocab_size]
# 将均匀分布移到设备上
uniform = self.uniform_dist.to(logits.device).unsqueeze(0).unsqueeze(0) # [1, 1, vocab_size]
# 计算每个位置的KL散度:KL(uniform || p)
kl_divergence = F.kl_div(
input=torch.log(probs + 1e-8),
target=uniform.expand_as(probs),
reduction='none',
log_target=False
).sum(dim=-1) # [batch_size, seq_len]
kl_divergence = kl_divergence * attention_mask
# 计算序列平均自我确定性
seq_lengths = attention_mask.sum(dim=1)
self_certainty = kl_divergence.sum(dim=1) / (seq_lengths + 1e-8)
return self_certainty
def compute_token_level_self_certainty(self, logits):
"""计算token级别的自我确定性"""
probs = F.softmax(logits, dim=-1)
uniform = self.uniform_dist.to(logits.device).unsqueeze(0).unsqueeze(0)
token_self_certainty = F.kl_div(
input=torch.log(probs + 1e-8),
target=uniform.expand_as(probs),
reduction='none',
log_target=False
).sum(dim=-1)
return token_self_certainty
class GRPOTrainer:
def __init__(self, model, tokenizer, config):
self.model = model
self.tokenizer = tokenizer
self.config = config
self.self_certainty_calculator = SelfCertaintyCalculator(len(tokenizer))
# 参考模型(用于KL散度计算)
device = next(model.parameters()).device
self.reference_model = AutoModelForCausalLM.from_pretrained(self.config.model_name).to(device)
self.reference_model.eval()
for param in self.reference_model.parameters():
param.requires_grad = False
self.optimizer = optim.AdamW(self.model.parameters(), lr=self.config.lr)
def compute_advantages(self, self_certainty_scores):
advantages = (self_certainty_scores - self_certainty_scores.mean()) / (self_certainty_scores.std() + 1e-8)
return advantages * self.config.advantage_scale
def compute_kl_penalty(self, logits, ref_logits, attention_mask):
probs = F.softmax(logits, dim=-1)
ref_probs = F.softmax(ref_logits, dim=-1)
kl_divergence = F.kl_div(
input=torch.log(probs + 1e-8),
target=ref_probs,
reduction='none',
log_target=False
).sum(dim=-1)
kl_divergence = kl_divergence * attention_mask
seq_lengths = attention_mask.sum(dim=1)
avg_kl = kl_divergence.sum(dim=1) / (seq_lengths + 1e-8)
return avg_kl
def generate_candidates(self, input_ids, num_candidates):
self.model.eval()
device = next(self.model.parameters()).device
input_ids = input_ids.to(device)
batch_size = input_ids.shape[0]
all_candidates_ids = []
all_attention_mask = []
with torch.no_grad():
for i in range(batch_size):
expanded_input = input_ids[i:i+1].repeat(num_candidates, 1)
outputs = self.model.generate(
expanded_input,
max_new_tokens=self.config.max_length,
num_return_sequences=1,
do_sample=True,
temperature=0.7,
pad_token_id=self.tokenizer.eos_token_id,
eos_token_id=self.tokenizer.eos_token_id,
attention_mask=torch.ones_like(expanded_input)
)
input_len = expanded_input.shape[1]
generated_ids = outputs[:, input_len:]
full_sequence = torch.cat([expanded_input, generated_ids], dim=1)
attention_mask = torch.ones_like(full_sequence)
all_candidates_ids.append(full_sequence)
all_attention_mask.append(attention_mask)
candidates_ids = torch.cat(all_candidates_ids, dim=0)
candidates_mask = torch.cat(all_attention_mask, dim=0)
return candidates_ids, candidates_mask
def compute_gradients(self, input_ids, attention_mask, candidate_ids, candidate_masks):
self.model.train()
batch_size = input_ids.shape[0]
num_candidates = self.config.num_candidates
# 前向传播计算候选的logits
candidates_outputs = self.model(
candidate_ids,
attention_mask=candidate_masks,
output_hidden_states=False,
output_attentions=False,
return_dict=True
)
candidates_logits = candidates_outputs.logits
# 计算自我确定性奖励
self_certainty_scores = self.self_certainty_calculator.compute_self_certainty(candidates_logits, candidate_masks)
# 计算优势函数
advantages = []
for i in range(batch_size):
batch_scores = self_certainty_scores[i*num_candidates:(i+1)*num_candidates]
batch_advantages = self.compute_advantages(batch_scores)
advantages.append(batch_advantages)
advantages = torch.cat(advantages)
# 计算参考模型的logits
with torch.no_grad():
ref_outputs = self.reference_model(
candidate_ids,
attention_mask=candidate_masks,
output_hidden_states=False,
output_attentions=False,
return_dict=True
)
ref_logits = ref_outputs.logits
# 计算重要性采样比率
candidate_probs = F.log_softmax(candidates_logits, dim=-1)
with torch.no_grad():
ref_probs = F.log_softmax(ref_logits, dim=-1)
# 获取生成的token
input_lens = []
for i in range(batch_size * num_candidates):
input_len = torch.where(candidate_ids[i] == self.tokenizer.eos_token_id)[0]
if len(input_len) > 0:
input_lens.append(input_len[0].item() + 1)
else:
input_lens.append(input_ids.shape[1])
# 计算每个token的重要性采样比率
ratios = []
kl_penalties = []
for i in range(batch_size * num_candidates):
input_len = input_lens[i]
gen_len = candidate_ids.shape[1] - input_len
if gen_len <= 0:
ratios.append(torch.tensor(0.0).to(candidate_ids.device))
kl_penalties.append(torch.tensor(0.0).to(candidate_ids.device))
continue
# 获取生成部分的概率
gen_candidate_probs = candidate_probs[i, input_len:input_len+gen_len]
gen_ref_probs = ref_probs[i, input_len:input_len+gen_len]
gen_tokens = candidate_ids[i, input_len:input_len+gen_len]
# 计算重要性采样比率
candidate_token_probs = torch.gather(
gen_candidate_probs, -1, gen_tokens.unsqueeze(-1)
).squeeze(-1)
ref_token_probs = torch.gather(
gen_ref_probs, -1, gen_tokens.unsqueeze(-1)
).squeeze(-1)
ratio = torch.exp(candidate_token_probs - ref_token_probs)
ratios.append(ratio.mean())
# 计算KL散度惩罚
kl_penalty = F.kl_div(
candidate_probs[i, input_len:],
ref_probs[i, input_len:],
reduction='batchmean',
log_target=True
)
kl_penalties.append(kl_penalty)
ratios = torch.stack(ratios)
kl_penalties = torch.stack(kl_penalties)
# 计算裁剪的目标函数
advantages = advantages.to(ratios.device)
surrogate1 = ratios * advantages
surrogate2 = torch.clamp(ratios, 1 - self.config.clip_epsilon, 1 + self.config.clip_epsilon) * advantages
# GRPG目标函数
grpo_objective = torch.min(surrogate1, surrogate2) - self.config_kl_penalty * kl_penalties
loss = -grpo_objective.mean()
self.optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
self.optimizer.step()
metrics = {
'loss': loss.item(),
'self_certainty_mean': self_certainty_scores.mean().item(),
'self_certainty_std': self_certainty_scores.std().item(),
'advantages_mean': advantages.mean().item(),
'kl_penalty_mean': kl_penalties.mean().item(),
'grpo_objective': grpo_objective.mean().item()
}
return metrics
MCTS、CoAT、TIP、TPO
结合蒙特卡洛树搜索和偏好优化的测试时扩展框架。通过CoAT框架进行上下文感知自适应思考, 使用TIP惩罚思路切换,并利用TPO进行迭代优化
from transformers import AutoTokenizer, AutoModelForCausalLM
import matplotlib.pyplot as plt
import numpy as np
import re
import torch
import torch.nn.functional as F
from tqdm import tqdm
import random
# 解析函数
def parse_solution(response):
# 宽松匹配所有x=和y=的数值,取最后一次出现的结果
x_values = re.findall(r'x\s*[==]\s*(-?\d+\.?\d*)', response)
y_values = re.findall(r'y\s*[==]\s*(-?\d+\.?\d*)', response)
if not x_values or not y_values:
return None, None
try:
x = float(x_values[-1])
y = float(y_values[-1])
return x, y
except:
return None, None
tokenizer = AutoTokenizer.from_pretrained(
"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
pad_token="<|endoftext|>" # 显式设置pad token
)
model = AutoModelForCausalLM.from_pretrained(
"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
device_map="auto",
)
# 确保pad_token有效,若不存在则使用eos_token
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
device='cuda'
model.to(device)
#============================= 实现 CoAT ==================================
class CoATNode:
def __init__(self, parent=None, prompt='', context='', depth=0):
self.parent = parent
self.children = [] # 子节点列表
self.visit_count = 0 # 访问次数
self.total_value = 0 # 累计价值
self.prompt = prompt # 初始提示词
self.context = context # 当前节点生成内容
self.associative_memory = [] # 关联记忆
self.depth = depth # 当前节点深度
self.is_expanded = False # 是否已扩展
def uct_value(self, exploration_weight=1.414):
"""计算UCT值"""
if self.visit_count == 0:
return float('inf') # 未访问节点优先搜索
exploitation = self.total_value / self.visit_count
exploration = exploration_weight * np.sqrt(np.log(self.parent.visit_count) / (self.visit_count + 1e-6))
return exploitation + exploration
def best_child(self):
"""选择UCT值最高的子节点"""
return max(self.children, key=lambda x: x.uct_value())
# CoAT MCTS
class CoATFramework:
def __init__(self, model, tokenizer, max_iter=100, max_depth=5, num_simulations=50):
self.model = model
self.tokenizer = tokenizer
self.max_iter = max_iter
self.max_depth = max_depth
self.num_simulations = num_simulations
self.external_brain = self.init_external_brain()
def init_external_brain(self):
"""初始化外部数学库"""
return {
"消元法": "联立方程消去变量:方程1 + 方程2 → 3x = 9 → x=3",
"代入法": "从方程1解出y=8-x,代入方程2得 2x - (8-x) =1 → x=3",
"验证步骤": "将x=3代入原方程验证:3 + y=8 → y=5"
}
def retrieve_associative_memory(self, context):
"""动态关联记忆检索"""
keywords = ["消元", "代入", "验证", "解"]
for kw in keywords:
if kw in context:
return self.external_brain.get(kw, "")
return ""
def evaluate_node(self, node):
"""节点价值评估"""
# 生成质量
full_text = node.prompt + node.context
ppl = compute_perplexity(self.model, self.tokenizer, full_text)
gen_score = 1 / (ppl + 1e-6)
# 关联内容质量
am_score = 0.2 * len(node.associative_memory) # 每条关联内容+0.2分
return gen_score + am_score
def expand_node(self, node):
"""节点扩展, 生成候选内容并关联记忆"""
if node.is_expanded:
return
# 生成候选内容(基于当前上下文)
input_text = node.prompt + node.context
inputs = self.tokenizer(input_text, return_tensors="pt").to(device)
inputs["attention_mask"] = inputs.input_ids.ne(tokenizer.pad_token_id).int()
outputs = self.model.generate(
inputs.input_ids,
attention_mask=inputs.attention_mask,
max_new_tokens=100,
num_return_sequences=3, # 每个节点生成3个候选
temperature=0.7,
do_sample=True,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
)
# 创建子节点
for seq in outputs:
child_text = self.tokenizer.decode(seq, skip_special_tokens=True)
child = CoATNode(
parent=node,
prompt=node.prompt,
context=child_text[len(input_text):], # 仅保留新增内容
depth=node.depth+1
)
# 关联记忆搜索
child.associative_memory = self.retrieve_associative_memory(child.context)
node.children.append(child)
node.is_expanded = True
def simulate(self, node):
"""蒙特卡洛模拟: 随机游走到叶节点并评估"""
current_depth = node.depth
while current_depth < self.max_depth:
if not node.children:
self.expand_node(node)
if not node.children:
break # 无子节点可扩展
node = random.choice(node.children) # 随机选择子节点
current_depth += 1
return self.evaluate_node(node)
def backpropagate(self, node, value):
"""回溯更新节点价值"""
while node is not None:
node.visit_count += 1
node.total_value += value
node = node.parent
def mcts_search(self, root):
"""蒙特卡洛树搜索"""
for _ in range(self.max_iter):
# 选择阶段
node = root
while node.children:
node = node.best_child()
# 扩展阶段
if node.depth < self.max_depth and not node.is_expanded:
self.expand_node(node)
# 模拟阶段
total_sim_value = 0
for _ in range(self.num_simulations):
sim_value = self.simulate(node)
total_sim_value += sim_value
avg_sim_value = total_sim_value / self.num_simulations
# 回溯更新
self.backpropagate(node, avg_sim_value)
# 选择最优路径
best_node = root.best_child()
return best_node
# CoAT生成函数
def generate_with_coat(prompt, coat):
# 初始化搜索树
root = CoATNode(prompt=prompt, context="")
coat.expand_node(root) # 初始扩展
# MCTS搜索
best_node = coat.mcts_search(root)
# 构建最终响应
full_response = best_node.context
if best_node.associative_memory:
full_response += f"\n[关联知识] {best_node.associative_memory}"
# 回溯生成完整路径
path = []
current_node = best_node
while current_node.parent:
path.append(current_node.context)
current_node = current_node.parent
path.reverse()
full_response = "\n".join(path) + "\n" + full_response
return full_response
#============================= 实现 TIP Logits 处理器==================================
# 定义思路切换相关的触发词
switch_tokens = [
'另一种方法', 'alternatively', '或者', '换一种思路',
'但是', '另一方面', '然而'
]
# 通过分词器转换为 token id
switch_tokens_ids = tokenizer.convert_tokens_to_ids(switch_tokens)
from transformers import LogitsProcessor
class TIPLogitsProcessor(LogitsProcessor):
def __init__(self, switch_token_ids, alpha=3.0, beta=300):
self.switch_token_ids = switch_token_ids
self.alpha = alpha # 惩罚强度
self.beta = beta # 惩罚时间
self.current_thought_start = 0 # 当前思路的起始位置
def __call__(self, input_ids, scores):
# 检查是否触发新思路
last_token = input_ids[0][-1].item()
if last_token in self.switch_token_ids:
self.current_thought_start = input_ids.shape[-1] # 记录新思路的起始位置
# 计算当前处理 token 是否在惩罚窗口中
current_position = input_ids.shape[-1]
if current_position < self.current_thought_start + self.beta:
# 对切换 token 施加惩罚
for token_id in self.switch_token_ids:
scores[:, token_id] -= self.alpha
return scores
# 设置生成参数
def generate_with_cot(prompt):
inputs = tokenizer(prompt, return_tensors="pt").to(device)
inputs["attention_mask"] = inputs.input_ids.ne(tokenizer.pad_token_id).int()
# 添加 TIP Logits 处理器
logits_processor = [TIPLogitsProcessor(switch_tokens_ids, alpha=3.0, beta=300)]
outputs = model.generate(
inputs.input_ids,
attention_mask=inputs.attention_mask,
max_new_tokens=512,
temperature=0.85,
top_p=0.9,
repetition_penalty=1.2,
do_sample=True,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
logits_processor=logits_processor # 注入TIP处理器
)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
problem = "解方程组:\n方程1: x + y = 8\n方程2: 2x - y = 1"
# 修改后的CoT提示(增加改进空间)
cot_prompt = f"""
请逐步解决以下问题:
{problem}
分步推理要求:
1. 明确标注步骤编号
2. 展示完整的代数运算过程
3. 最终解用方框标出(如:x=3, y=5)
"""
# 困惑度计算函数
def compute_perplexity(model, tokenizer, text):
inputs = tokenizer(text, return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
log_probs = F.log_softmax(logits, dim=-1)
tokens = inputs["input_ids"]
nll = F.nll_loss(log_probs[:, :-1].contiguous().view(-1, log_probs.size(-1)),
tokens[:, 1:].contiguous().view(-1),
reduction='mean')
return torch.exp(nll).item()
response = generate_with_cot(cot_prompt)
ppl = compute_perplexity(model, tokenizer, cot_prompt)
print("============ 原始模型 =============")
print(f"输入:{cot_prompt}...\n输出:{response}...\n困惑度:{ppl:.2f}")
with open('pre_response.txt','w', encoding='utf-8') as f:
f.write(f"输入:{cot_prompt}\n输出:{response}\n困惑度:{ppl:.2f}")
# ============================ TPO =========================
num_iterations = 3 # TPO迭代次数
num_candidates = 5 # 每轮生成的候选响应数量
ground_truth = (3, 5) # 方程组的真实解
# 奖励函数:根据解的正确性得分
def reward_function(parsed_solution):
if parsed_solution is None:
return -1.0 # 无效解惩罚
x_pred, y_pred = parsed_solution
# 计算误差并归一化 [0,1]
max_error = 8 # 最大可能误差(如x=8,y=0时误差为5+8=13,但需根据问题调整)
error = (abs(x_pred - ground_truth[0]) + abs(y_pred - ground_truth[1])) / max_error
# 思路切换惩罚
switch_count = sum([1 for token in switch_tokens if token in response])
penalty = 0.05 * switch_count # 每次切换扣 0.05 分
return max(0.0, 1.0 - error - penalty)
def tpo_optimization(initial_prompt):
cache = [] # 存储(响应、奖励分)的缓存
coat = CoATFramework(model, tokenizer)
# 初始生成候选
candidates = [generate_with_coat(initial_prompt, coat) for _ in range(num_candidates)]
for resp in candidates:
x, y = parse_solution(resp)
score = reward_function((x, y))
cache.append((resp, score))
# TPO迭代
for _ in tqdm(range(num_iterations), desc='TPO Train'):
# 选择最优和最差响应
best_resp = max(cache, key=lambda x:x[1])[0]
worst_resp = min(cache, key=lambda x:x[1])[0]
# 生成文本反馈(改进建议)
feedback_prompt = f"""
以下是两个解方程组的示例:
**优秀示例**:
{best_resp}
**较差示例**:
{worst_resp}
请分析优秀示例的优点和较差示例的不足,并提出改进建议:
1. 步骤完整性(是否遗漏验证步骤)
2. 计算准确性(是否存在算术错误)
3. 表达清晰度(是否使用明确标记)
**强制修正要求**:
- 必须验证消元步骤:3x=9 → x=3
- 若出现矛盾结论,必须重新计算
"""
feedback_prompt += """
**注意**:反馈需满足以下要求:
- 分析必须具体,避免复述解题过程
- 改进建议不超过3条
- 强制使用LaTeX公式标注关键步骤
"""
feedback = generate_with_cot(feedback_prompt)
# 基于反馈生成新候选
new_candidates = [generate_with_cot(f"{initial_prompt}\n改进建议:{feedback}")
for _ in range(num_candidates)]
# 更新缓存
for resp in new_candidates:
x, y = parse_solution(resp)
score = reward_function((x, y))
cache.append((resp, score))
# 返回最高分响应
return max(cache, key=lambda x:x[1])[0], feedback
print("============ TPO模型 =============")
# 运行TPO优化
optimized_response, feedback = tpo_optimization(cot_prompt)
ppl = compute_perplexity(model, tokenizer, optimized_response)
print(f"输入:{feedback}...\n输出:{optimized_response}...\n困惑度:{ppl:.2f}")
with open('tpo_response.txt','w', encoding='utf-8') as f:
f.write(f"输入:{feedback}\n输出:{optimized_response}\n困惑度:{ppl:.2f}")
稀疏激活机制
基于Gumbel-Softmax和容量因子的动态Top-k路由机制, 实现专家网络的稀疏激活和负载均衡。
class DynamicTopkRouter(nn.Module):
def __init__(self, hidden_dim, num_experts, top_k=2,
capacity_factor=1.2):
super().__init__()
self.hidden_dim = hidden_dim
self.num_experts = num_experts
self.top_k = top_k
self.capacity_factor = capacity_factor
# 路由权重
self.router = nn.Linear(hidden_dim, num_experts, bias=False)
def forward(self, x, training=True):
batch_size, seq_len, hidden_dim = x.shape
# 计算路由分数
router_logits = self.router(x) # [B, L, E]
# Top-k选择
top_k_logits, top_k_indices = torch.topk(
router_logits, self.top_k, dim=-1
)
# Gumbel-Softmax实现稀疏路由
if training:
gumbel_noise = -torch.log(
-torch.log(torch.rand_like(top_k_logits) + 1e-9) + 1e-9
)
top_k_logits = top_k_logits + gumbel_noise
# 计算路由权重
routing_weights = F.softmax(top_k_logits, dim=-1)
# 计算容量限制
capacity = int(self.capacity_factor * seq_len / self.num_experts)
# 创建路由掩码
routing_mask = torch.zeros_like(router_logits, dtype=torch.bool)
routing_mask.scatter_(-1, top_k_indices, True)
# 应用容量限制
expert_counts = torch.zeros(self.num_experts, device=x.device)
for i in range(self.num_experts):
expert_mask = routing_mask[..., i]
if expert_counts[i] + expert_mask.sum() > capacity:
# 超出容量的token将被丢弃
excess = int(expert_counts[i] + expert_mask.sum() - capacity)
drop_indices = torch.where(expert_mask)[0][:excess]
expert_mask[drop_indices] = False
routing_mask[..., i] = expert_mask
return {
'routing_weights': routing_weights,
'routing_indices': top_k_indices,
'routing_mask': routing_mask,
'expert_counts': expert_counts
}
高效注意力机制实现
使用Triton实现的FlashAttention-2前向传播内核,通过分块计算和内存优化, 将注意力机制的内存复杂度从O(N²)降低到O(N)。
import torch
import triton
import triton.language as tl
@triton.jit
def flash_attention_v2(
# 张量指针
q_ptr, k_ptr, v_ptr, o_ptr,
# 元数据
seq_len, head_dim: tl.constexpr,
# 内存步幅
q_stride_m, q_stride_h, # Q的步幅 [seq_len, num_heads, head_dim]
k_stride_m, k_stride_h, # K的步幅
v_stride_m, v_stride_h, # V的步幅
o_stride_m, o_stride_h, # O的步幅
# 超参数
BLOCK_M: tl.constexpr, # Q块大小
BLOCK_N: tl.constexpr, # K/V块大小
NUM_HEADS: tl.constexpr, # 头数
IS_CAUSAL: tl.constexpr # 是否因果掩码
):
# 1. 计算程序ID与初始偏移量
pid_head = tl.program_id(0) # 头索引
pid_m = tl.program_id(1) # 每个头内处理Q块的索引
# 当前Q块的起始位置
start_m = pid_m * BLOCK_M
offs_m = start_m + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
# 2. 初始化累加器
m_i = tl.full((BLOCK_M, ), float('-inf'), dtype=tl.float32)
l_i = tl.zeros((BLOCK_M, ), dtype=tl.float32)
acc = tl.zeros((BLOCK_M, head_dim), dtype=tl.float32)
# 3. 加载Q块(使用向量化加载)
q_offset = pid_head * q_stride_h + offs_m[:, None] * q_stride_m
q = tl.load(q_ptr + q_offset + tl.arange(0, head_dim)[None, :] * q_stride_h,
mask=(offs_m[:, None] < seq_len) & (tl.arange(0, head_dim)[None, :] < head_dim),
other=0.0).to(tl.float32)
# 4. 主循环处理K/V块
for start_n in range(0, (seq_len + BLOCK_M - 1) // BLOCK_N * BLOCK_N, BLOCK_N):
# 4.1 计算当前K/V块的有限范围
valid_n = start_n + offs_n < seq_len
# 4.2 加载K块
k_offset = pid_head * k_stride_h + (start_n + offs_n)[:, None] * k_stride_m
k = tl.load(k_ptr + k_offset + tl.arange(0, head_dim)[None, :] * 1,
mask=valid_n[:, None] & (tl.arange(0, head_dim)[None, :] < head_dim),
other=0.0).to(tl.float32)
# 4.3 加载V块
v_offset = pid_head * v_stride_h + (start_n + offs_n)[:, None] * v_stride_m
v = tl.load(v_ptr + v_offset + tl.arange(0, head_dim)[None, :] * 1,
mask=valid_n[:, None] & (tl.arange(0, head_dim)[None, :] < head_dim),
other=0.0).to(tl.float32)
# 4.4 计算QK^T(启用Tensor Core加速)
s = tl.dot(q, k.T.to(q.dtype))
s = s * (1.0 / tl.sqrt(tl.cast(head_dim, tl.float32)))
# 创建序列长度掩码
mask_m = offs_m < seq_len # Q序列的有效位置
mask_n = offs_n < seq_len # K序列的有效位置
seq_mask = mask_m[:, None] & mask_n[None, :] # 组合成二维掩码
# 4.5 处理因果掩码
if IS_CAUSAL:
causal_mask = (offs_m[:, None]) >= (start_n + offs_n[None, :])
seq_mask = seq_mask & causal_mask # 合并两种掩码
s = tl.where(causal_mask, s, float('-inf'))
# 4.6 在线Softmax更新
# 计算当前块的行最大值
m_curr = tl.maximum(tl.max(s, axis=1), m_i)
# 计算指数和
alpha = tl.exp(m_i - m_curr) # 旧最大衰减因子
beta = tl.exp(s - m_curr[:, None]) # 当前块指数
l_curr = alpha * l_i + tl.sum(beta, axis=1)
# 更新累加器
p = beta / l_curr[:, None]
acc = acc * alpha[:, None] + tl.dot(p.to(v.dtype), v)
# 4.7 保存中间变量
m_i = m_curr
l_i = l_curr
# 最终归一化并存储结果
o = acc / l_i[:, None]
# 存储到全局变量
o_offset = pid_head * o_stride_h + offs_m[:, None] * o_stride_m
tl.store(o_ptr + o_offset + tl.arange(0, head_dim)[None, :] * 1,
o.to(o_ptr.dtype.element_ty),
mask=(offs_m[:, None] < seq_len) & (tl.arange(0, head_dim)[None, :] < head_dim))
def call_flash_attention_v2(q, k, v, is_causal=False):
assert q.dim() == 3, "Input should be [seq_len, num_heads, head_dim]"
seq_len, num_heads, head_dim = q.shape
o = torch.empty_like(q)
config = {
'BLOCK_M': 128,
'BLOCK_N': 64,
'num_warps': 8,
'num_stages': 3,
}
# 网格维度:每个头独立计算,每个头内划分Q块
grid = (num_heads, triton.cdiv(seq_len, config['BLOCK_M']))
flash_attention_v2[grid](
q, k, v, o,
seq_len, head_dim,
# 内存步幅计算(假设输入为连续张量)
q.stride(1), q.stride(0),
k.stride(1), k.stride(0),
v.stride(1), v.stride(0),
o.stride(1), o.stride(0),
NUM_HEADS=num_heads,
IS_CAUSAL=is_causal,
**config
)
return o
流式数据加载
用于大语言模型训练的流式数据迭代器,支持动态批处理、 样本打包和内存高效的数据加载。
class StreamingDataIterator:
def __init__(self, data_path, tokenizer, max_length=2048,
batch_size=8, buffer_size=10000):
self.data_path = data_path
self.tokenizer = tokenizer
self.max_length = max_length
self.batch_size = batch_size
self.buffer_size = buffer_size
# 初始化缓冲区
self.buffer = []
self.buffer_cursor = 0
# 加载数据生成器
self.data_generator = self._create_generator()
def _create_generator(self):
"""创建数据流生成器"""
def generator():
with open(self.data_path, 'r', encoding='utf-8') as f:
for line in f:
if line.strip():
yield json.loads(line.strip())
return generator()
def _tokenize_sample(self, sample):
"""将样本转换为token序列"""
text = sample['text']
tokens = self.tokenizer.encode(
text, truncation=True, max_length=self.max_length
)
return tokens
def _load_to_buffer(self):
"""加载数据到缓冲区"""
self.buffer = []
while len(self.buffer) < self.buffer_size:
try:
sample = next(self.data_generator)
tokens = self._tokenize_sample(sample)
self.buffer.append(tokens)
except StopIteration:
break
def __iter__(self):
return self
def __next__(self):
"""获取下一个批次"""
if self.buffer_cursor + self.batch_size > len(self.buffer):
self._load_to_buffer()
self.buffer_cursor = 0
if len(self.buffer) == 0:
raise StopIteration
# 获取批次
batch = self.buffer[
self.buffer_cursor:self.buffer_cursor + self.batch_size
]
self.buffer_cursor += self.batch_size
# 动态填充
max_len = max(len(seq) for seq in batch)
padded_batch = []
attention_masks = []
for seq in batch:
padding_length = max_len - len(seq)
padded_seq = seq + [self.tokenizer.pad_token_id] * padding_length
padded_batch.append(padded_seq)
attention_mask = [1] * len(seq) + [0] * padding_length
attention_masks.append(attention_mask)
return {
'input_ids': torch.tensor(padded_batch, dtype=torch.long),
'attention_mask': torch.tensor(attention_masks, dtype=torch.long)
}