[{"content":"==================\nthis is probably a nothing post, but as i\u0026rsquo;m trying to solve a problem: wanting to make a robot lawyer go brrr - i often have sidequests.\nthis sidequest just happened to be making my grpo implementation faster. to do that, i needed to improve upon a base grpo implementation and basically iterate until it\u0026rsquo;s faster (until i lose momentum, get bored, or both). pretty simple really.\nthe Second Brain Algorithm^tm is as follows:\n read code/paper ask gpt5-pro for any clarifying questions if something doesn\u0026rsquo;t make sense in my brain think about ways it could be faster ask gpt5-pro to implement each one so i can test them run through bugs until each iteration works go back and do some reading on things/tricks that make kernels faster ask gpt5-pro to implement the tricks repeat until i have a branch of the \u0026ldquo;faster tree\u0026rdquo; that shows promise (eg, the branch of the 20+ ideas that hasn\u0026rsquo;t been pruned from applying effort yet). it\u0026rsquo;s really just mental breath first search  i went down a few rabbitholes trying my best to beat cublas matmuls and said forget it, i\u0026rsquo;m sure if i wanted to spend a long time on this i could have, but i don\u0026rsquo;t want to lose momentum, so lei decided to use cublas then write a softmax kernel and go from there.\nbasically, i got heavily inspired by liger-kernel\u0026rsquo;s chunking/streaming and flash attention\u0026rsquo;s accumulator tricks and i wanted to see if i can combine them in some way to make things faster. things ended up working out pretty well in my experimentation and i wanted to write down what i did before i move onto the next sidequest. here is the gist:\n\u0026ldquo;regular\u0026rdquo; grpo:\n generate a few answers, score them (that score is your advantage). compare the new policy to the old one with a ratio (think \u0026ldquo;did we get more confident or less?\u0026quot;). clip that ratio so we don\u0026rsquo;t overreact to noisy scores. optionally add a KL penalty to stay close to a reference model. average the per‑token loss and backprop.  two big problems:\n big matmul lots of dead tokens  so how do we fix/speed that up? FlashGRPO located here: https://github.com/KayneWest/flashgrpo - the basic idea is this:\n1) pack only real tokens (no full V matrix materialization): pack the non‑masked tokens into a tight matrix so we don\u0026rsquo;t waste compute on padding. vanilla grpo typically forms full logits $[N,V]$ (or full log-softmax) per step and then gather\u0026rsquo;s target log‑probs, which costs $O(N \\cdot V)$ memory traffic. FlashGRPO instead:\n packs tokens to keep only valid rows $N \\le B \\cdot T$ (mask + ignore index). streams the vocabulary in tiles of width $C$ (where $C$ is auto‑tuned) and runs a cublas gemm $X[N,K] \\times W_{v_0:v_1}^T[C,K] \\to [N,C]$. consumes each tile immediately with a triton reducer that performs online log-sum-exp (tracks per-row $(m,s)$) and gathers the target logit $z$ if it falls in the current tile. no global $[N,V]$ tensor is ever written.  this keeps temporary memory to $O(N \\cdot C)$ instead of $O(N \\cdot V)$ while preserving full‑vocab accuracy.\n2) online LSE + softcap for numerics. the triton reducer maintains $(m, s)$ so that $\\mathrm{LSE} = m + \\log s$ is exact over all tiles. optional softcap clamps logits to $\\pm \\mathrm{softcap}$ during both forward and backward passes. temperature is applied in‑place during reduction to avoid extra gemms.\n3) overlapped execution (eg dual streams). cublas gemm for tile $t+1$ runs concurrently with triton reduction for tile $t$ using ping‑pong buffers (concurrent read/write operations in two memory regions that alternate between being read from and written to; improving throughput/overlapping operations) and cuda events. this hides reduction latency and keeps the gemm saturated.\n4) custom backward that never forms $p(\\cdot)$ explicitly. instead of backpropagating through a full log-softmax, FlashGRPO reconstructs per‑tile gradients:\n$$ \\frac{\\partial \\ell}{\\partial \\mathrm{logits}{i,v}} = \\underbrace{\\frac{\\partial \\ell}{\\partial \\log p_i}}{\\mathrm{grad_row}} \\cdot \\left( \\mathbb{1}[v = y_i] \\cdot g_t - p_{i,v} \\cdot g_{i,v} \\right) $$\nthen contracts locally where these contractions per tile are streamed across $V$, so no $[N,V]$ probability matrix is ever realized. gradients are accumulated in fp32 and cast back to parameter dtype at the end.\n5) ratio/clip gradient is selective and stable. FlashGRPO computes token‑wise $\\log p - \\log p_{\\mathrm{old}}$ in fp32 and clamps it to $[-10,10]$ before exponentiating. the gradient flows only through the chosen branch.\n6) bf16/fp16 friendly with fp32 critical paths. inputs can be bf16/fp16; the reduction/LSE and accumulators are fp32. chunk size is auto‑selected given $N$, $V$, dtype‑bytes, and a MB budget.\nTL;DR:\ntriton kernels reduce_chunk_online_lse (forward LSE + target gather), finalize_grpo_1d (ratio/clip + advantage), and the backward pair bwd_dx_from_tile / bwd_dwdb_from_tile handle all row‑packing, softcap gating, and fp32 accumulation; scheduling uses ping‑pong buffers with CUDA events; chunk sizing is chosen via choose_chunk_size given a memory budget.\nInstallation git clone https://github.com/KayneWest/flashgrpo.git cd flashgrpo pip install -e . Basic Usage import torch from flashgrpo import GRPOCuBLASxTritonPacked # Initialize GRPO operator grpo_op = GRPOCuBLASxTritonPacked( temperature=1.0, # Temperature scaling epsilon_low=0.2, # Lower clipping bound  epsilon_high=0.2, # Upper clipping bound delta=None, # Optional delta clipping softcap=30.0, # Logit clamping chunk_size=4096, # Fixed chunk size (auto if None) max_temp_mb=768, # Memory budget for auto chunking beta=0.04, # KL regularization coefficient ) # Your model forward pass with torch.cuda.amp.autocast(False): outputs = model(input_ids=batch[\u0026#34;input_ids\u0026#34;], attention_mask=batch[\u0026#34;attention_mask\u0026#34;], output_hidden_states=True) # Extract components H = outputs.hidden_states[-1][:, :-1, :].to(torch.bfloat16) # [B,T,K] W = model.get_output_embeddings().weight # [V,K]  targets = batch[\u0026#34;input_ids\u0026#34;][:, 1:] # [B,T] old_logps = batch[\u0026#34;old_logps\u0026#34;][:, :-1] # [B,T] completion_mask = batch[\u0026#34;completion_mask\u0026#34;][:, :-1] # [B,T] advantages = batch[\u0026#34;advantages\u0026#34;] # [B] ref_logps = get_reference_logps(batch) # [B,T] # Compute GRPO loss loss = grpo_op( H, W, targets, old_logps, completion_mask, advantages, ref_per_token_logps=ref_logps ) loss.backward() optimizer.step() tada \u0026ndash; sidequest completed and back to creating a robot thomas jefferson\n","date":"2025-10-09","permalink":"/posts/flashgrpo/","title":"FlashGRPO"},{"content":"==================\nlets get wild  i write with zero autocorrect. literally freeballing. as i\u0026rsquo;ve said before, if the shoggoth will replicate me, may it do so in all my glory.  on to the show\u0026hellip;\nI started reading the Anthropic blog called \u0026ldquo;Circuit Tracing: Revealing Computational Graphs in Language Models\u0026rdquo; and thought it would be cool to replicate it on gpt2-small.\nfor those who didn\u0026rsquo;t read it, the abstract is as follows:\nWe introduce a method to uncover mechanisms underlying behaviors of language models. We produce graph descriptions of the model’s computation on prompts of interest by tracing individual computational steps in a “replacement model”. This replacement model substitutes a more interpretable component (here, a “cross-layer transcoder”) for parts of the underlying model (here, the multi-layer perceptrons) that it is trained to approximate. We develop a suite of visualization and validation tools we use to investigate these “attribution graphs” supporting simple behaviors of an 18-layer language model, and lay the groundwork for a companion paper applying these methods to a frontier model, Claude 3.5 Haiku.\nbuilding a replacement model import sys import os import json import glob import math import random from typing import List, Tuple, Dict, Optional import transformer_lens import transformer_lens.HookedTransformer as HookedTransformer import transformer_lens.HookedTransformerConfig as HookedTransformerConfig import numpy as np import torch import torch.nn as nn import torch.autograd as autograd from torch.utils.data import Dataset, DataLoader import torch.nn.functional as F from transformers import AutoTokenizer from datasets import load_dataset create some data with transformer lens because the paper deals with extracting activations from the model, rather than doing this at inference time, we\u0026rsquo;ll shortcut this by just running the model\u0026rsquo;s forward pass and storing the activations in a dataset.\ndef sample_fineweb_data(n_samples=5000): \u0026#34;\u0026#34;\u0026#34; Streams the FineWeb dataset and yields up to n_samples items. \u0026#34;\u0026#34;\u0026#34; ds_stream = load_dataset( \u0026#34;HuggingFaceFW/fineweb\u0026#34;, name=\u0026#34;CC-MAIN-2024-10\u0026#34;, split=\u0026#34;train\u0026#34;, streaming=True ) # ds_stream is an iterable dataset. We\u0026#39;ll just take the first N examples data_iter = iter(ds_stream) for i in range(n_samples): try: item = next(data_iter) yield item except StopIteration: break model_name = \u0026#34;gpt2-small\u0026#34; model = HookedTransformer.from_pretrained( model_name, device=\u0026#34;cuda\u0026#34; if torch.cuda.is_available() else \u0026#34;cpu\u0026#34; ) resid_cache = {} mlp_out_cache = {} def hook_resid_pre(activation, hook): resid_cache[hook.name] = activation.detach().cpu() def hook_mlp_out(activation, hook): mlp_out_cache[hook.name] = activation.detach().cpu() # Register hooks for each layer for layer_idx in range(model.cfg.n_layers): # e.g. \u0026#39;blocks.0.hook_resid_pre\u0026#39; resid_name = f\u0026#34;blocks.{layer_idx}.hook_resid_pre\u0026#34; mlp_name = f\u0026#34;blocks.{layer_idx}.hook_mlp_out\u0026#34; model.add_hook(resid_name, hook_resid_pre, \u0026#34;fwd\u0026#34;) model.add_hook(mlp_name, hook_mlp_out, \u0026#34;fwd\u0026#34;) tokenizer = AutoTokenizer.from_pretrained(\u0026#34;gpt2\u0026#34;) # GPT-2 typically doesn\u0026#39;t have a pad token, we\u0026#39;ll just allow a new token or do a truncation approach tokenizer.pad_token = tokenizer.eos_token device = \u0026#34;cuda\u0026#34; if torch.cuda.is_available() else \u0026#34;cpu\u0026#34; def collect_hidden_states_and_save( text_list, batch_idx, output_dir=\u0026#34;clt_data\u0026#34;, max_seq_len=128, ): \u0026#34;\u0026#34;\u0026#34; text_list: list of raw strings for a single batch batch_idx: which batch index we are on We\u0026#39;ll tokenize, run forward pass, store the hidden states in .npz or something. \u0026#34;\u0026#34;\u0026#34; enc = tokenizer( text_list, padding=True, truncation=True, max_length=max_seq_len, return_tensors=\u0026#34;pt\u0026#34; ) enc = {k: v.to(device) for k, v in enc.items()} # Clear old caches resid_cache.clear() mlp_out_cache.clear() _ = model(enc[\u0026#34;input_ids\u0026#34;]) # triggers hooks # Now gather the hidden states from resid_cache, mlp_out_cache # They will have shape: [batch_size, seq_len, d_model] # For each layer we have e.g. resid_cache[\u0026#34;blocks.0.hook_resid_pre\u0026#34;] # We can store them in an npz layer_data = {} for layer_idx in range(model.cfg.n_layers): rname = f\u0026#34;blocks.{layer_idx}.hook_resid_pre\u0026#34; mname = f\u0026#34;blocks.{layer_idx}.hook_mlp_out\u0026#34; # shape: [batch, seq, d_model] resid_arr = resid_cache[rname].numpy() mlp_arr = mlp_out_cache[mname].numpy() layer_data[f\u0026#34;resid_{layer_idx}\u0026#34;] = resid_arr layer_data[f\u0026#34;mlp_{layer_idx}\u0026#34;] = mlp_arr # Save to a single file os.makedirs(output_dir, exist_ok=True) out_path = os.path.join(output_dir, f\u0026#34;batch_{batch_idx}.npz\u0026#34;) np.savez_compressed(out_path, **layer_data) print(f\u0026#34;Saved {out_path} with shape: {resid_arr.shape}\u0026#34;) def main_collect_5k(output_dir=\u0026#34;clt_data\u0026#34;): model_name = \u0026#34;gpt2-small\u0026#34; model = HookedTransformer.from_pretrained( model_name, device=\u0026#34;cuda\u0026#34; if torch.cuda.is_available() else \u0026#34;cpu\u0026#34; ) resid_cache = {} mlp_out_cache = {} def hook_resid_pre(activation, hook): resid_cache[hook.name] = activation.detach().cpu() def hook_mlp_out(activation, hook): mlp_out_cache[hook.name] = activation.detach().cpu() # Register hooks for each layer for layer_idx in range(model.cfg.n_layers): # e.g. \u0026#39;blocks.0.hook_resid_pre\u0026#39; resid_name = f\u0026#34;blocks.{layer_idx}.hook_resid_pre\u0026#34; mlp_name = f\u0026#34;blocks.{layer_idx}.hook_mlp_out\u0026#34; model.add_hook(resid_name, hook_resid_pre, \u0026#34;fwd\u0026#34;) model.add_hook(mlp_name, hook_mlp_out, \u0026#34;fwd\u0026#34;) tokenizer = AutoTokenizer.from_pretrained(\u0026#34;gpt2\u0026#34;) # GPT-2 typically doesn\u0026#39;t have a pad token, we\u0026#39;ll just allow a new token or do a truncation approach tokenizer.pad_token = tokenizer.eos_token device = \u0026#34;cuda\u0026#34; if torch.cuda.is_available() else \u0026#34;cpu\u0026#34; batch_size = 16 buffer = [] batch_index = 0 from tqdm import tqdm for doc_idx, record in tqdm(enumerate(sample_fineweb_data(n_samples=5000))): text = record[\u0026#34;text\u0026#34;] # or whichever field is correct # Add to buffer buffer.append(text) if len(buffer) \u0026gt;= batch_size: collect_hidden_states_and_save(buffer, batch_index, output_dir) buffer = [] batch_index += 1 # leftover if buffer: collect_hidden_states_and_save(buffer, batch_index, output_dir) # let\u0026#39;s collect 5k samples main_collect_5k() create a dataset iterator def normalize_vector(vec, eps=1e-9): # vec shape: [d_model] or [batch, d_model] # We\u0026#39;ll do L2 norm along the last dim norm = vec.norm(dim=-1, keepdim=True).clamp_min(eps) return vec / norm class CLTHiddenStateDataset(Dataset): def __init__(self, data_dir=\u0026#34;clt_data\u0026#34;, layer_count=12): self.files = sorted(glob.glob(f\u0026#34;{data_dir}/batch_*.npz\u0026#34;)) # We\u0026#39;ll store (file_index, idx_in_file) in an index self.index = [] self.layer_count = layer_count # We\u0026#39;ll parse each file once to see how many items it has for fi, path in enumerate(self.files): try: with np.load(path) as npz: # e.g. shape of resid_0 is [batch, seq, d_model] shape = npz[\u0026#34;resid_0\u0026#34;].shape # shape = (b, s, d_model) num_positions = shape[0] * shape[1] # b*s for i in range(num_positions): self.index.append((fi, i)) except Exception as e: print(f\u0026#34;Error loading {path}: {e}\u0026#34;) def __len__(self): return len(self.index) def __getitem__(self, idx): file_idx, pos_idx = self.index[idx] path = self.files[file_idx] with np.load(path) as npz: # We\u0026#39;ll reconstruct which (batch, seq) =\u0026gt; (pos_idx // seq_len, pos_idx % seq_len). # But we need shape info: shape = npz[\u0026#34;resid_0\u0026#34;].shape # [b, s, d] b, s, d = shape local_b = pos_idx // s local_s = pos_idx % s # Now gather resid and mlp for each layer # We\u0026#39;ll build lists of shape [d_model] resid_list = [] mlp_list = [] for layer_idx in range(self.layer_count): rname = f\u0026#34;resid_{layer_idx}\u0026#34; mname = f\u0026#34;mlp_{layer_idx}\u0026#34; # shape of r: [b, s, d] r = npz[rname]# [local_b, local_s] # shape [d] m = npz[mname]# [local_b, local_s] # shape [d] # normalize r = normalize_vector(torch.from_numpy(r).float()) m = normalize_vector(torch.from_numpy(m).float()) resid_list.append(r) mlp_list.append(m) return resid_list, mlp_list cross-layer transcoder class RectangleFunction(autograd.Function): @staticmethod def forward(ctx, x): ctx.save_for_backward(x) return ((x \u0026gt; -0.5) \u0026amp; (x \u0026lt; 0.5)).float() @staticmethod def backward(ctx, grad_output): (x,) = ctx.saved_tensors grad_input = grad_output.clone() grad_input[(x \u0026lt;= -0.5) | (x \u0026gt;= 0.5)] = 0 return grad_input class JumpReLUFunction(autograd.Function): @staticmethod def forward(ctx, x, log_threshold, bandwidth): ctx.save_for_backward(x, log_threshold, torch.tensor(bandwidth)) threshold = torch.exp(log_threshold) return x * (x \u0026gt; threshold).float() @staticmethod def backward(ctx, grad_output): x, log_threshold, bandwidth_tensor = ctx.saved_tensors bandwidth = bandwidth_tensor.item() threshold = torch.exp(log_threshold) x_grad = (x \u0026gt; threshold).float() * grad_output threshold_grad = ( -(threshold / bandwidth) * RectangleFunction.apply((x - threshold) / bandwidth) * grad_output ) return x_grad, threshold_grad, None # None for bandwidth class JumpReLU(nn.Module): def __init__(self, feature_size, bandwidth, device=\u0026#39;cpu\u0026#39;): super(JumpReLU, self).__init__() self.log_threshold = nn.Parameter(torch.zeros(feature_size, device=device)) self.bandwidth = bandwidth def forward(self, x): return JumpReLUFunction.apply(x, self.log_threshold, self.bandwidth) class StepFunction(autograd.Function): @staticmethod def forward(ctx, x, log_threshold, bandwidth): ctx.save_for_backward(x, log_threshold, torch.tensor(bandwidth)) threshold = torch.exp(log_threshold) return (x \u0026gt; threshold).float() @staticmethod def backward(ctx, grad_output): x, log_threshold, bandwidth_tensor = ctx.saved_tensors bandwidth = bandwidth_tensor.item() threshold = torch.exp(log_threshold) x_grad = torch.zeros_like(x) threshold_grad = ( -(1.0 / bandwidth) * RectangleFunction.apply((x - threshold) / bandwidth) * grad_output ) return x_grad, threshold_grad, None # None for bandwidth def uniform_init(tensor, limit): with torch.no_grad(): tensor.uniform_(-limit, limit) class CrossLayerTranscoder(nn.Module): def __init__(self, layer_feature_list: List[int], d_model: int, bandwidth: float = 1.0, device=\u0026#34;cpu\u0026#34;): \u0026#34;\u0026#34;\u0026#34; layer_feature_list: e.g. [128, 256, 128, ...], length = n_layers specifying how many features for each layer d_model: hidden dimension of the underlying Transformer bandwidth: controls JumpReLU partial derivatives device: \u0026#39;cpu\u0026#39; or \u0026#39;cuda\u0026#39; \u0026#34;\u0026#34;\u0026#34; super().__init__() self.n_layers = len(layer_feature_list) self.layer_feature_list = layer_feature_list self.d_model = d_model self.bandwidth = bandwidth self.device = device # (1) ENCODERS: # We\u0026#39;ll store a separate W_enc[i] for each layer i, shape: [layer_feature_list[i], d_model]. # We\u0026#39;ll place them in a nn.ParameterList so we can do \u0026#34;W_enc[i]\u0026#34; in forward code. self.W_enc = nn.ParameterList() for i, feat_count in enumerate(layer_feature_list): limit = 1.0 / math.sqrt(feat_count) param = nn.Parameter(torch.empty(feat_count, d_model, device=device)) uniform_init(param, limit) self.W_enc.append(param) # (2) DECODERS: # For each (src -\u0026gt; tgt) with src\u0026lt;=tgt, # we define a matrix [layer_feature_list[src], d_model]. # We\u0026#39;ll store them in a single nn.ParameterList, but keep track in index_map. self.W_dec = nn.ParameterList() self.index_map = [] dec_limit = 1.0 / math.sqrt(self.n_layers * d_model) idx_counter = 0 for src_layer in range(self.n_layers): row = [] src_feat = layer_feature_list[src_layer] for tgt_layer in range(self.n_layers): if tgt_layer \u0026gt;= src_layer: dec_param = nn.Parameter(torch.empty(src_feat, d_model, device=device)) uniform_init(dec_param, dec_limit) self.W_dec.append(dec_param) row.append(idx_counter) idx_counter += 1 else: row.append(None) self.index_map.append(row) # (3) JumpReLUs: # If each layer has a different #features, we can either store # one JumpReLU for each layer, or do something simpler.  # For demonstration, we\u0026#39;ll store one per layer: from torch import autograd self.jumps = nn.ModuleList() for i, feat_count in enumerate(layer_feature_list): self.jumps.append(JumpReLU(feat_count, bandwidth, device=device)) def forward(self, resid_streams: List[torch.Tensor]) -\u0026gt; List[torch.Tensor]: \u0026#34;\u0026#34;\u0026#34; resid_streams[i]: shape [batch, seq, d_model], for layer i Returns: list of length n_layers, each [batch, seq, d_model], the reconstruction for each layer\u0026#39;s MLP out \u0026#34;\u0026#34;\u0026#34; batch_size, seq_len, _ = resid_streams[0].shape all_activations = [] # a^ℓ for each layer # 1) ENCODING for i in range(self.n_layers): x = resid_streams[i] # [batch, seq, d_model] W_enc_mat = self.W_enc[i] # [layer_feature_list[i], d_model] # =\u0026gt; a_pre shape [batch, seq, layer_feature_list[i]] a_pre = torch.einsum(\u0026#34;bsd,nd-\u0026gt;bsn\u0026#34;, x, W_enc_mat) # jump layer i a_post = self.jumps[i](a_pre) all_activations.append(a_post) # 2) DECODING # y^ℓ_hat = sum_{ℓ\u0026#39;\u0026lt;=ℓ} W_dec^(ℓ\u0026#39;-\u0026gt;ℓ) * a^(ℓ\u0026#39;) mlp_recon = [] for tgt_layer in range(self.n_layers): recon = torch.zeros(batch_size, seq_len, self.d_model, device=self.device) for src_layer in range(tgt_layer+1): dec_idx = self.index_map[src_layer][tgt_layer] W_dec_mat = self.W_dec[dec_idx] # shape [layer_feature_list[src_layer], d_model] a_src = all_activations[src_layer] # [batch, seq, layer_feature_list[src_layer]] # =\u0026gt; [batch, seq, d_model] recon_part = torch.einsum(\u0026#34;bsn,nd-\u0026gt;bsd\u0026#34;, a_src, W_dec_mat) recon += recon_part mlp_recon.append(recon) return mlp_recon def forward_with_preacts(self, resid_streams: List[torch.Tensor] ) -\u0026gt; Tuple[List[torch.Tensor], List[torch.Tensor]]: \u0026#34;\u0026#34;\u0026#34; Same as forward(), but also returning the raw pre-activations a_pre for each layer. \u0026#34;\u0026#34;\u0026#34; batch_size, seq_len, _ = resid_streams[0].shape all_a_pre = [] all_a_post = [] for i in range(self.n_layers): x = resid_streams[i] W_enc_mat = self.W_enc[i] a_pre = torch.einsum(\u0026#34;bsd,nd-\u0026gt;bsn\u0026#34;, x, W_enc_mat) a_post = self.jumps[i](a_pre) all_a_pre.append(a_pre) all_a_post.append(a_post) mlp_recon = [] for tgt_layer in range(self.n_layers): recon = torch.zeros(batch_size, seq_len, self.d_model, device=self.device) for src_layer in range(tgt_layer+1): dec_idx = self.index_map[src_layer][tgt_layer] W_dec_mat = self.W_dec[dec_idx] a_src = all_a_post[src_layer] recon_part = torch.einsum(\u0026#34;bsn,nd-\u0026gt;bsd\u0026#34;, a_src, W_dec_mat) recon += recon_part mlp_recon.append(recon) return mlp_recon, all_a_pre @classmethod def from_hookedtransformer(cls, hmodel, layer_feature_list: List[int], bandwidth=1.0, device=\u0026#34;cpu\u0026#34;): \u0026#34;\u0026#34;\u0026#34; hmodel: a HookedTransformer (from transformer_lens) layer_feature_list: e.g. [128, 256, ...], length = hmodel.cfg.n_layers \u0026#34;\u0026#34;\u0026#34; L = hmodel.cfg.n_layers d_model = hmodel.cfg.d_model if len(layer_feature_list) != L: raise ValueError(f\u0026#34;layer_feature_list must have length {L}, got {len(layer_feature_list)}\u0026#34;) return cls( layer_feature_list=layer_feature_list, d_model=d_model, bandwidth=bandwidth, device=device ) so a cross-layer transcoder is a module that takes in a list of residual streams and returns a list of reconstructed residual streams. as the paper says, it\u0026rsquo;s goal is to reconstruct the activations of the MLPs in the model.\nto re-explain an already well explained concept (from the blog):\n each feature reads from the residual stream of the layer it\u0026rsquo;s in using a linear encoder a given layer\u0026rsquo;s features helps reconstruct the residual stream of all layers below it using a linear decoder features are trained jointly thus the output of an MLP is reconstructed from the features of all the layers below it  training the cross-layer transcoder to train the CLT, there are two loss functions:\n a reconstruction loss summed over all layers a sparsity penalty summed over all layers  the reconstruction loss is obvious, we want to reconstruct the MLP outputs from the features, and we use MSE to do this.\nthe sparsity penalty is a bit more complex, but the idea is we want to encourage the model to use as few features as possible to reconstruct the MLP outputs\n this promotes interpretability: without these constraints, the model can many features to reconstruct the MLP outputs, which makes it harder attribute specific behavior to a single feature. networks exibit polysemanticity, where a single feature can have multiple meanings. this sparsity constraint allows us to monosemanticity, where a single feature has a single meaning. reducing noise in the eventual attribution graphs: having too many active features can lead to incredibly dense attribution graphs, which are harder to analyze.  def advanced_sparsity_loss( clt: CrossLayerTranscoder, preacts_list: List[torch.Tensor], c: float = 1.0, lambda_spars: float = 1e-3 ): \u0026#34;\u0026#34;\u0026#34; Example: sum_{src_layer, i} tanh( c * ||W_dec_{i}|| * mean(a_pre) ) We\u0026#39;ll do a simplified version, ignoring multi-step or partial layering. \u0026#34;\u0026#34;\u0026#34; device = preacts_list[0].device L = clt.n_layers penalty = torch.zeros((), device=device) # for each layer i for src_layer in range(L): a_pre = preacts_list[src_layer] # shape [batch, seq, n_feat_this_layer] # average activation across batch, seq =\u0026gt; [n_feat_this_layer] a_mean = a_pre.mean(dim=(0,1)) # for each tgt_layer \u0026gt;= src_layer for tgt_layer in range(src_layer, L): dec_idx = clt.index_map[src_layer][tgt_layer] W_dec_mat = clt.W_dec[dec_idx] # shape [layer_feature_list[src_layer], d_model] dec_norm = W_dec_mat.norm(dim=1) # shape [layer_feature_list[src_layer]] raw_vals = c * dec_norm * a_mean penalty_layer = torch.tanh(raw_vals).sum() penalty += penalty_layer return lambda_spars * penalty def get_sparsity_scale(current_step, total_steps, lambda_final): # linear ramp from 0 to lambda_final across total_steps scale = min(1.0, current_step / (total_steps - 1)) return scale * lambda_final def sum_of_relu_neg(preacts_list): \u0026#34;\u0026#34;\u0026#34; preacts_list: a list of Tensors, each shape [batch, seq, n_features], containing the pre-activation values for each layer. Returns a scalar that is the sum over all layers, batches, and feature dims of ReLU(-preactivation). \u0026#34;\u0026#34;\u0026#34; total_loss = torch.tensor(0.0, device=preacts_list[0].device) for a_pre in preacts_list: # shape: [batch, seq, n_features] negvals = torch.relu(-a_pre) # ReLU(-x) = max(0, -x) total_loss += negvals.sum() return total_loss def train_clt( clt, dataloader, num_epochs, lambda_spars_final=1e-3, preact_loss_coef=3e-6, total_steps=None ): \u0026#34;\u0026#34;\u0026#34; clt: CrossLayerTranscoder module with JumpReLU threshold=0.03, uniform init, etc. dataloader: yields (resid_list, mlp_list) already normalized \u0026#34;\u0026#34;\u0026#34; optimizer = torch.optim.Adam(clt.parameters(), lr=1e-4) if total_steps is None: total_steps = num_epochs * len(dataloader) global_step = 0 for epoch in range(num_epochs): print(f\u0026#34;Epoch {epoch} of {num_epochs}\u0026#34;) for batch_idx, (resid_batch_list, mlp_batch_list) in enumerate(dataloader): print(f\u0026#34;Batch {batch_idx} of {len(dataloader)}\u0026#34;) batch_size = resid_batch_list[0].shape[0] # to GPU # remember to ensure these are normalized per token for l in range(len(resid_batch_list)): # need this for weird batching error resid_batch_list[l] = resid_batch_list[l].to(clt.device)[0] mlp_batch_list[l] = mlp_batch_list[l].to(clt.device)[0] # forward pass =\u0026gt; recon_list recon_list, preactivations = clt.forward_with_preacts(resid_batch_list) # ^ maybe you define forward_with_preacts() to return # the list of preactivation a^ell and the final recon. L_preact = 3e-6 * sum_of_relu_neg(preactivations) # MSE mse = torch.tensor(0.0, device=clt.device) for i in range(len(recon_list)): diff = recon_list[i] - mlp_batch_list[i] mse += diff.pow(2).mean() # advanced sparsity current_lambda_spars = get_sparsity_scale( global_step, total_steps, lambda_spars_final ) L_spars = advanced_sparsity_loss( clt, preactivations, # the list of shape [batch, seq, n_features] for each layer c=1.0, lambda_spars=current_lambda_spars ) # preact loss # preactivations is a list of shape [batch, n_features] for each layer total_loss = mse + L_spars + L_preact optimizer.zero_grad() total_loss.backward() optimizer.step() global_step += 1 print(f\u0026#34;Epoch {epoch} done. last total_loss={total_loss.item():.4f} MSE={mse.item():.4f}\u0026#34;) dataset = CLTHiddenStateDataset(\u0026#34;clt_data\u0026#34;) dataloader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=4) model_name = \u0026#34;gpt2-small\u0026#34; hmodel = HookedTransformer.from_pretrained(model_name) layer_feature_list = [128]*hmodel.cfg.n_layers # or something custom clt = CrossLayerTranscoder.from_hookedtransformer( hmodel, layer_feature_list, bandwidth=1.0, device=\u0026#34;cuda\u0026#34; ) train_clt(clt, dataloader, 2) replacement modeling basically, a replacement model is the swapping out of the MLPs in the model with the CLT while also providing an error correction term to account for the error in the reconstruction.\nthe way anthropic frames this is that this basically re-writes the underlying model with sparser, more interpretable units.\nthe way anthropic describes how we should view the local replacement model is such:\nit\u0026rsquo;s a fully connected neural network spanning across tokens that allow us to do interpretability at research\n effectively, it\u0026rsquo;s a graph. it\u0026rsquo;s the union of CLT features active at every token position. Its weights are the summed interactions over all the linear paths from one feature to another  def get_layer_index_from_hook(hook_name: str): \u0026#34;\u0026#34;\u0026#34; A helper that extracts the integer layer index from a name like \u0026#39;blocks.3.hook_mlp_out\u0026#39; or \u0026#39;blocks.3.hook_resid_mid\u0026#39;, etc. Returns None if not recognized. \u0026#34;\u0026#34;\u0026#34; # e.g. \u0026#34;blocks.3.hook_mlp_out\u0026#34; -\u0026gt; 3 # naive approach: if \u0026#34;blocks.\u0026#34; not in hook_name: return None try: after_blocks = hook_name.split(\u0026#34;blocks.\u0026#34;)[1] i_str = after_blocks.split(\u0026#34;.\u0026#34;)[0] # e.g. \u0026#34;3\u0026#34; return int(i_str) except: return None class LocalReplacementModel(nn.Module): \u0026#34;\u0026#34;\u0026#34; A local replacement model that: 1) Uses `base_transformer` for the general structure, 2) Freezes LN outputs \u0026amp; attention patterns with pre-recorded data, 3) Replaces each MLP with CLT + error corrections, 4) Records the cross-layer-transcoder\u0026#39;s feature activations in self.clt_activations. \u0026#34;\u0026#34;\u0026#34; def __init__( self, base_transformer: HookedTransformer, clt: \u0026#34;CrossLayerTranscoder\u0026#34;, error_corrections: Dict[str, torch.Tensor], ln_scales: Dict[str, torch.Tensor], attn_patterns: Dict[str, torch.Tensor], device=\u0026#34;cuda\u0026#34; ): super().__init__() self.base = base_transformer self.clt = clt self.error_corrections = error_corrections # e.g. { \u0026#34;blocks.3.hook_mlp_out\u0026#34;: tensor(...) } self.ln_scales = ln_scales # e.g. { \u0026#34;blocks.3.ln1.hook_normalized\u0026#34;: tensor(...) } self.attn_patterns = attn_patterns # e.g. { \u0026#34;blocks.3.attn.hook_pattern\u0026#34;: tensor(...) } self.device = device # We\u0026#39;ll store the MLP inputs for each layer (the \u0026#34;hook_mlp_in\u0026#34; values) self.mlp_inputs = [None]*clt.n_layers # We\u0026#39;ll also store the cross-layer feature activations for each layer # (shape [batch, seq, n_features_of_layer]) after we run the MLP-out hook self.clt_activations = [None]*clt.n_layers def forward(self, tokens: torch.Tensor) -\u0026gt; torch.Tensor: \u0026#34;\u0026#34;\u0026#34; 1. Freeze LN by injecting recorded LN outputs. 2. Freeze QK by injecting stored attention patterns. 3. Replace MLP with CLT + error corrections. 4. Store CLT feature activations in self.clt_activations. \u0026#34;\u0026#34;\u0026#34; # ===== Define hooking functions ===== def freeze_ln_hook(activation, hook): \u0026#34;\u0026#34;\u0026#34; Hook on e.g. \u0026#34;blocks.0.ln1.hook_normalized\u0026#34;: We replace the LN output with our pre-saved LN output (from an earlier pass). \u0026#34;\u0026#34;\u0026#34; name = hook.name if name in self.ln_scales: return self.ln_scales[name].to(self.device) return activation def freeze_attn_pattern_hook(activation, hook): \u0026#34;\u0026#34;\u0026#34; Hook on e.g. \u0026#34;blocks.0.attn.hook_pattern\u0026#34;: We replace the QK-softmax pattern with pre-saved patterns. \u0026#34;\u0026#34;\u0026#34; name = hook.name if name in self.attn_patterns: return self.attn_patterns[name].to(self.device) return activation def mlp_in_hook(activation, hook): \u0026#34;\u0026#34;\u0026#34; E.g. \u0026#34;blocks.3.hook_mlp_in\u0026#34; =\u0026gt; store for layer 3 \u0026#34;\u0026#34;\u0026#34; layer_idx = get_layer_index_from_hook(hook.name) if layer_idx is not None: self.mlp_inputs[layer_idx] = activation.detach().clone() return activation def mlp_out_hook(activation, hook): \u0026#34;\u0026#34;\u0026#34; E.g. \u0026#34;blocks.3.hook_mlp_out\u0026#34; =\u0026gt; we do: 1) feed self.mlp_inputs[L] into the CLT for layer L, 2) add error correction, 3) also fill self.clt_activations. \u0026#34;\u0026#34;\u0026#34; layer_idx = get_layer_index_from_hook(hook.name) if layer_idx is None: return activation x = self.mlp_inputs[layer_idx] if x is None: return activation # fallback # We\u0026#39;ll feed \u0026#34;dummy_resids\u0026#34; so that only layer_idx\u0026#39;s input is nonzero dummy_resids = [] for i in range(self.clt.n_layers): if i == layer_idx: dummy_resids.append(x) else: dummy_resids.append(torch.zeros_like(x)) # We want the \u0026#34;pre-activations\u0026#34; a_pre from each layer, # so we use forward_with_preacts: mlp_recon_list, all_a_post = self.clt.forward_with_preacts(dummy_resids) # mlp_recon_list =\u0026gt; list of length n_layers, each shape [batch, seq, d_model] # all_a_post =\u0026gt; list of length n_layers, each shape [batch, seq, n_feat(i)] # The reconstructed MLP-out for this layer recon_layer = mlp_recon_list[layer_idx] # Let\u0026#39;s store the cross-layer feature activations for *all* layers # (each a_post is shape [batch, seq, #features for that layer]) for i, feat_acts in enumerate(all_a_post): self.clt_activations[i] = feat_acts.detach().clone() # Then add the \u0026#34;error correction,\u0026#34; if any err_key = hook.name # e.g. \u0026#34;blocks.3.hook_mlp_out\u0026#34; if err_key in self.error_corrections: e = self.error_corrections[err_key].to(self.device) recon_layer = recon_layer + e return recon_layer # ===== Attach the hooks ===== hooks = [] # Freeze LN for i in range(self.base.cfg.n_layers): ln_name = f\u0026#34;blocks.{i}.ln1.hook_normalized\u0026#34; if ln_name in self.base.hook_dict: h_ln = self.base.add_hook(ln_name, freeze_ln_hook, \u0026#34;fwd\u0026#34;) hooks.append(h_ln) # Freeze attention pattern for i in range(self.base.cfg.n_layers): attn_name = f\u0026#34;blocks.{i}.attn.hook_pattern\u0026#34; if attn_name in self.base.hook_dict: h_attn = self.base.add_hook(attn_name, freeze_attn_pattern_hook, \u0026#34;fwd\u0026#34;) hooks.append(h_attn) # Intercept MLP in/out for i in range(self.clt.n_layers): in_name = f\u0026#34;blocks.{i}.hook_mlp_in\u0026#34; out_name = f\u0026#34;blocks.{i}.hook_mlp_out\u0026#34; if in_name in self.base.hook_dict: hi = self.base.add_hook(in_name, mlp_in_hook, \u0026#34;fwd\u0026#34;) hooks.append(hi) if out_name in self.base.hook_dict: ho = self.base.add_hook(out_name, mlp_out_hook, \u0026#34;fwd\u0026#34;) hooks.append(ho) # ===== Run forward pass ===== logits = self.base(tokens) # remove hooks after finishing for h in hooks: if h is not None: h.remove() return logits def build_local_replacement_model_with_cache( base_model: HookedTransformer, clt: CrossLayerTranscoder, prompt: str, device=\u0026#34;cuda\u0026#34; ): \u0026#34;\u0026#34;\u0026#34; 1) forward pass on base_model, store LN outputs, attention patterns, MLP in/out 2) compute error corrections 3) return a LocalReplacementModel that re-uses LN \u0026amp; attn and MLP is replaced with CLT+error \u0026#34;\u0026#34;\u0026#34; layer_count = base_model.cfg.n_layers tokens = base_model.to_tokens(prompt, prepend_bos=True).to(device) # -- Use a filter that picks up the hook names we want. -- # We want LN outputs, attn pattern, mlp_in, mlp_out # For LN: \u0026#34;blocks.{i}.ln1.hook_normalized\u0026#34; # For attn: \u0026#34;blocks.{i}.attn.hook_pattern\u0026#34; # For MLP in/out: \u0026#34;blocks.{i}.hook_mlp_in\u0026#34; / \u0026#34;blocks.{i}.hook_mlp_out\u0026#34; def activation_filter(name: str): # Return True if we want to store this hook in the cache # Return False otherwise if \u0026#34;.ln1.hook_normalized\u0026#34; in name: return True if \u0026#34;.attn.hook_pattern\u0026#34; in name: return True if \u0026#34;.hook_mlp_in\u0026#34; in name: return True if \u0026#34;.hook_mlp_out\u0026#34; in name: return True return False # run_with_cache returns (logits, cache). # cache is a HookedTransformerCache containing the stored activations logits, cache = base_model.run_with_cache( tokens, return_type=\u0026#34;logits\u0026#34;, # or \u0026#34;none\u0026#34; if you don\u0026#39;t need final logits names_filter=activation_filter ) ln_scales = {} attn_patterns = {} mlp_in_cache = {} mlp_out_cache = {} # read from cache for each layer\u0026#39;s LN out, attn pattern, MLP in/out for i in range(layer_count): ln_key = f\u0026#34;blocks.{i}.ln1.hook_normalized\u0026#34; attn_key = f\u0026#34;blocks.{i}.attn.hook_pattern\u0026#34; in_key = f\u0026#34;blocks.{i}.hook_mlp_in\u0026#34; out_key = f\u0026#34;blocks.{i}.hook_mlp_out\u0026#34; if ln_key in cache: ln_scales[ln_key] = cache[ln_key].detach().clone() if attn_key in cache: attn_patterns[attn_key] = cache[attn_key].detach().clone() if in_key in cache: mlp_in_cache[in_key] = cache[in_key].detach().clone() if out_key in cache: mlp_out_cache[out_key] = cache[out_key].detach().clone() # 2) Build error corrections by comparing the model\u0026#39;s MLP out to the CLT recon error_corrections = {} for i in range(clt.n_layers): layer_in_name = f\u0026#34;blocks.{i}.hook_mlp_in\u0026#34; layer_out_name = f\u0026#34;blocks.{i}.hook_mlp_out\u0026#34; # Only do this if we actually have MLP in/out for that layer if layer_in_name in mlp_in_cache and layer_out_name in mlp_out_cache: mlp_in = mlp_in_cache[layer_in_name].to(device) mlp_out_true = mlp_out_cache[layer_out_name].to(device) # minimal forward in CLT for just this layer dummy_inputs = [torch.zeros_like(mlp_in) for _ in range(clt.n_layers)] dummy_inputs[i] = mlp_in clt_outputs = clt(dummy_inputs) # list of shape [n_layers] clt_layer_out = clt_outputs[i] # [batch, seq, d_model] diff = mlp_out_true - clt_layer_out error_corrections[layer_out_name] = diff.detach().clone() # 3) Build the final \u0026#34;frozen\u0026#34; local model local_model = LocalReplacementModel( base_transformer=base_model, clt=clt, error_corrections=error_corrections, ln_scales=ln_scales, attn_patterns=attn_patterns, device=device ) return local_model Attribution Graphs these are representations of how different computational components (specifically the interpretable features) contribute to the final output of the model.\ngraphs contain 4 nodes:\n output: the final output tokens of the model input: the input embedding tokens to the model intermediate: clt features at each prompt token error: the remaining output unexplained by the CLT  edges represent linear attributions; they originate from the input node and end at the output node and indicate a direct, linear influence from one node to another. the idea is that the activation of any feature (or token) is decomposed into the sum of its incoming contributions from these edges, allowing a clear, linear causal interpretation.\nthe graph is constructed by backward attribution via jacobians.\nso we want to decompose the output into the sum of contributions from earlier components. because the influence is spread over many different paths, using backward jacobians allows one to quantify the sensitivity of the target\u0026rsquo;s pre-activation to changes in the source\u0026rsquo;s activation.\n\u0026lt; explain math/gradients \u0026gt;\nfor every target t in our graph, we inject its corresponding input vector into the residual stream at the appropriate layer/token position. we then perform a backward pass on the underlying model with the following modifications: 1) stop-gradients are inserted into the non-linear parts of the model and 2) frozen attention patterns are used so that the only way the input can affect the output is through the linear paths in the graph. for any source node, its contribution is computed as a dot product between its decoder vector and the gradient signal that flows from the target to the source. lastly, the result is scaled by the activation of the source node.\nby computing these backward jacobians, we effectively obtain a set of linear influence weights that hell us how much each feature contributes to the target. by doing this, you can effectively map out a circuit (in the mech interp sense).\nanthropic also describes a graph pruning algorithm that they describe with the following pseudocode:\nfunction compute_normalized_adjacency_matrix(graph): # Convert graph to adjacency matrix A # A[j, i] = weight from i to j (note the transposition) A = convert_graph_to_adjacency_matrix(graph) A = absolute_value(A) # Normalize each row to sum to 1 row_sums = sum(A, axis=1) row_sums = maximum(row_sums, 1e-8) # Avoid division by zero A = diagonal_matrix(1/row_sums) @ A return A function prune_nodes_by_indirect_influence(graph, threshold): A = compute_normalized_adjacency_matrix(graph) # Calculate the indirect influence matrix: B = (I - A)^-1 - I # This is a more efficient way to compute A + A^2 + A^3 … B = inverse(identity_matrix(size=A.shape[0]) - A) - identity_matrix(size=A.shape[0]) # Get weights for logit nodes. # This is 0 if a node is a non-logit node and equal to the probability for logit nodes logit_weights = get_logit_weights(graph) # Calculate influence on logit nodes for each node influence_on_logits = matrix_multiply(B, logit_weights) # Sort nodes by influence sorted_node_indices = argsort(influence_on_logits, descending=True) # Calculate cumulative influence cumulative_influence = cumulative_sum( influence_on_logits[sorted_node_indices]) / sum(influence_on_logits) # Keep nodes with cumulative influence up to threshold nodes_to_keep = cumulative_influence \u0026lt;= threshold # Create new graph with only kept nodes and their edges return create_subgraph(graph, nodes_to_keep) # Edge pruning by thresholded influence function prune_edges_by_thresholded_influence(graph, threshold): # Get normalized adjacency matrix A = compute_normalized_adjacency_matrix(graph) # Calculate influence matrix (as before) B = estimate_indirect_influence(A) # Get logit node weights (as before) logit_weights = get_logit_weights(graph) # Calculate node scores (influence on logits) node_score = matrix_multiply(B, logit_weights) # Edge score is weighted by the logit influence of the target node edge_score = A * node_score[:, None] # Calculate edges to keep based on thresholded cumulative score sorted_edges = sort(edge_score.flatten(), descending=True) cumulative_score = cumulative_sum(sorted_edges) / sum(sorted_edges) threshold_index = index_where(cumulative_score \u0026gt;= threshold) edge_mask = edge_score \u0026gt;= sorted_edges[threshold_index] # Create new graph with pruned adjacency matrix pruned_adjacency = A * edge_mask return create_subgraph_from_adjacency(graph, pruned_adjacency) my implementation of the attribution graph code is as follows:\nclass NodeType: EMBEDDING = \u0026#34;embedding\u0026#34; FEATURE = \u0026#34;feature\u0026#34; ERROR = \u0026#34;error\u0026#34; LOGIT = \u0026#34;logit\u0026#34; class AttributionNode: \u0026#34;\u0026#34;\u0026#34; A node in the attribution graph, e.g. a single feature in some layer, or an embedding node, or a logit node, etc. \u0026#34;\u0026#34;\u0026#34; def __init__( self, node_type: str, name: str, layer_idx: int = None, context_pos: int = None, logit_index: int = None ): self.node_type = node_type self.name = name # Set layer_idx to a sentinel if not provided if layer_idx is None: if node_type == NodeType.EMBEDDING: self.layer_idx = -1 elif node_type == NodeType.LOGIT: self.layer_idx = 9999 else: self.layer_idx = 0 else: self.layer_idx = layer_idx self.context_pos = context_pos self.logit_index = logit_index # Storing these as placeholders self.activation = None self.output_vector = None self.input_vector = None # Must be an integer for feature nodes: self.feature_index = None # we must fill this in ourselves # We\u0026#39;ll store an integer ID in the graph self.id = None def __repr__(self): return f\u0026#34;Node\u0026lt;{self.node_type}:{self.name}:{self.id}\u0026gt;\u0026#34; class AttributionGraph: \u0026#34;\u0026#34;\u0026#34; A container for all nodes and edges in the attribution graph. \u0026#34;\u0026#34;\u0026#34; def __init__(self): self.nodes = [] self.edges = {} # adjacency: edges[u] = list of (v, weight) def add_node(self, node: AttributionNode): node.id = len(self.nodes) self.nodes.append(node) self.edges[node.id] = [] return node.id def add_edge(self, src_id: int, tgt_id: int, weight: float): self.edges[src_id].append((tgt_id, weight)) def get_num_nodes(self): return len(self.nodes) def build_graph_nodes( local_model, # e.g. a LocalReplacementModel prompt: str, top_k=3, feature_threshold=0.01 ): \u0026#34;\u0026#34;\u0026#34; 1) Forward pass =\u0026gt; local_model.clt_activations is populated 2) Build an AttributionGraph 3) Add logit nodes for top_k final predictions 4) Add feature nodes for each cross-layer feature whose activation \u0026gt; threshold \u0026#34;\u0026#34;\u0026#34; G = AttributionGraph() tokens = local_model.base.to_tokens(prompt, prepend_bos=True).to(local_model.device) with torch.no_grad(): logits = local_model(tokens) # fills local_model.clt_activations # Add top-k logit nodes final_logits = logits[0, -1, :] probs = F.softmax(final_logits, dim=-1) top_vals, top_inds = torch.topk(final_logits, top_k) logit_node_ids = [] for rank in range(top_k): tok_id = top_inds[rank].item() p = probs[tok_id].item() node = AttributionNode( node_type=NodeType.LOGIT, name=f\u0026#34;logit_{tok_id} (p={p:.3f})\u0026#34;, layer_idx=None, # sets layer_idx=9999 context_pos=tokens.shape[1]-1, # last position logit_index=tok_id ) nid = G.add_node(node) logit_node_ids.append(nid) # Create feature nodes from local_model.clt_activations feature_nodes = {} for i in range(local_model.clt.n_layers): acts_i = local_model.clt_activations[i] # shape [1, seq, n_feat_i] if acts_i is None: continue acts_i = acts_i[0] # shape [seq, n_feat_i] if batch=1 seq_len, n_feat_i = acts_i.shape for pos in range(seq_len): for feat_j in range(n_feat_i): val = acts_i[pos, feat_j].item() if val \u0026gt; feature_threshold: # Build a feature node =\u0026gt; layer i, position pos, feature index feat_j node = AttributionNode( node_type=NodeType.FEATURE, name=f\u0026#34;feat_L{i}_f{feat_j}_pos{pos}\u0026#34;, layer_idx=i, context_pos=pos, logit_index=None ) node.activation = val # Key fix: assign the integer feature index node.feature_index = feat_j # \u0026lt;--- CRUCIAL node_id = G.add_node(node) feature_nodes[(i, pos, feat_j)] = node_id cache_dict = { \u0026#34;logit_node_ids\u0026#34;: logit_node_ids, \u0026#34;feature_nodes\u0026#34;: feature_nodes } return G, cache_dict def compute_logit_injection_vector(local_model, vocab_idx): \u0026#34;\u0026#34;\u0026#34; Suppose you do \u0026#39;logits = W_U * final_resid\u0026#39;. Then the gradient for (logit[vocab_idx] - mean(logit)) w.r.t. final_resid is: injection_vec = W_U[vocab_idx] - average(W_U) or something along these lines. \u0026#34;\u0026#34;\u0026#34; W_U = local_model.base.W_U # e.g. unembedding weights shape [d_model, vocab_size] d_model, vocab_size = W_U.shape # minimal example: w_target = W_U[:, vocab_idx] # shape [d_model] w_mean = W_U.mean(dim=1) # shape [d_model] injection_vec = (w_target - w_mean) return injection_vec def run_backward_with_injection( local_model: LocalReplacementModel, injection_vec: torch.Tensor, layer_idx: int, real_tokens: torch.Tensor, token_pos: int = None, freeze_ln=True, freeze_attn=True, ): \u0026#34;\u0026#34;\u0026#34; Runs a backward pass in \u0026#39;local_model\u0026#39; (a LocalReplacementModel) with an \u0026#39;injection_vec\u0026#39; placed into the chosen layer (layer_idx) residual stream. We now pass \u0026#39;real_tokens\u0026#39; so that the dimension of the residual matches the dimension we used in the normal forward pass. Args: local_model: The local replacement model (with LN \u0026amp; attn patterns frozen). injection_vec: a Tensor of shape [d_model] or [batch, seq, d_model], specifying the \u0026#34;gradient signal\u0026#34; we want to inject. layer_idx: which layer\u0026#39;s residual we inject into. real_tokens: the actual token IDs used in the normal forward pass (must have the same shape as originally used by local_model). token_pos: which token position to inject into. If None, we broadcast across all positions. freeze_ln: whether to freeze LN denominators in the backward pass. freeze_attn: whether to freeze QK patterns in the backward pass. Returns: grad_dict: a dictionary of shape {(layer_idx, pos): residual_grad_vector} containing the gradient w.r.t. each layer\u0026#39;s residual stream after injection. \u0026#34;\u0026#34;\u0026#34; # (A) clamp layer_idx if needed nL = local_model.base.cfg.n_layers if layer_idx \u0026gt;= nL: # Perhaps for logit nodes, set it to nL - 1 layer_idx = nL - 1 # 1) Zero out old gradients local_model.zero_grad(set_to_none=True) # 2) Optionally freeze LN or QK params so they do not accumulate gradient # This ensures that no gradient accumulates in LN or QK parameters.  # Just be mindful that some models might have LN named differently (ln_f or ln_final).  if freeze_ln or freeze_attn: for name, param in local_model.base.named_parameters(): if freeze_ln and (\u0026#34;ln\u0026#34; in name): param.requires_grad_(False) if freeze_attn and (\u0026#34;.W_Q\u0026#34; in name or \u0026#34;.W_K\u0026#34; in name): param.requires_grad_(False) # 3) Hook the chosen layer\u0026#39;s residual via a forward pass with the *real tokens* storage = {} def store_resid_hook(resid, hook): # resid shape is [batch, seq, d_model] storage[\u0026#34;resid\u0026#34;] = resid return resid # By default for GPT-2 style, the \u0026#34;pre-MLP residual\u0026#34; is often \u0026#34;blocks.{layer_idx}.hook_resid_mid\u0026#34; # or \u0026#34;blocks.{layer_idx}.hook_resid_pre\u0026#34;. For GPT-2 from TransformerLens, it\u0026#39;s typically: # \u0026#34;blocks.{layer_idx}.hook_resid_mid\u0026#34; # but the user\u0026#39;s code might vary. We\u0026#39;ll try \u0026#34;hook_resid_mid\u0026#34; here: resid_name = f\u0026#34;blocks.{layer_idx}.hook_resid_mid\u0026#34; if resid_name not in local_model.base.hook_dict: # fallback or raise error # e.g. check if \u0026#34;hook_resid_post\u0026#34; is present alt_name = f\u0026#34;blocks.{layer_idx}.hook_resid_post\u0026#34; if alt_name in local_model.base.hook_dict: resid_name = alt_name else: raise KeyError(f\u0026#34;No valid mid/post resid hook for layer {layer_idx}\u0026#34;) # register the forward hook handle = local_model.base.add_hook(resid_name, store_resid_hook, \u0026#34;fwd\u0026#34;) with torch.enable_grad(): # Perform the forward pass with the REAL tokens so shapes match _ = local_model(real_tokens) if handle is not None: handle.remove() if \u0026#34;resid\u0026#34; not in storage: raise ValueError( f\u0026#34;Could not find residual for layer {layer_idx} - check your \u0026#34; f\u0026#34;hook name (is it \u0026#39;hook_resid_mid\u0026#39; or \u0026#39;hook_resid_pre\u0026#39;?)\u0026#34; ) resid_L = storage[\u0026#34;resid\u0026#34;] # shape [batch, seq, d_model] # We want gradient to flow from the dummy loss -\u0026gt; resid_L # 4) Construct the dummy loss # If injection_vec is [d_model], we broadcast across [batch, seq, d_model]. shape = resid_L.shape if injection_vec.dim() == 1 and injection_vec.shape[0] == shape[-1]: expanded_injection = injection_vec.view(1, 1, -1).expand_as(resid_L) elif injection_vec.shape == shape: expanded_injection = injection_vec else: raise ValueError(f\u0026#34;injection_vec shape mismatch: got {injection_vec.shape}, \u0026#34; f\u0026#34;but resid_L is {shape}.\u0026#34;) if token_pos is not None: # zero out gradient for all positions except \u0026#39;token_pos\u0026#39; mask = torch.zeros_like(resid_L) mask[:, token_pos, :] = 1.0 expanded_injection = expanded_injection * mask dummy_loss = (resid_L * expanded_injection).sum() dummy_loss.backward(retain_graph=True) # 5) gather gradient in a dictionary grad_dict = {} # For demonstration, we only store the gradient for exactly (layer_idx, token_pos) # If you want all positions, you can store them all. # We\u0026#39;ll do a single key =\u0026gt; shape [batch, seq, d_model] if resid_L.grad is not None: grad_dict[(layer_idx, \u0026#34;all\u0026#34;)] = resid_L.grad.detach().cpu() else: grad_dict[(layer_idx, \u0026#34;all\u0026#34;)] = None return grad_dict def is_upstream_layer(src_node, tgt_node): \u0026#34;\u0026#34;\u0026#34; Example helper: returns True if src_node is strictly earlier in (layer, position) than tgt_node. You can define your own logic for \u0026#34;which nodes can feed into which\u0026#34; in your graph. \u0026#34;\u0026#34;\u0026#34; # e.g. if src_node.layer_idx \u0026lt; tgt_node.layer_idx, or if equal layer but earlier tokenpos # etc. We\u0026#39;ll just do a dummy check: return (src_node.layer_idx \u0026lt; tgt_node.layer_idx) or ( src_node.layer_idx == tgt_node.layer_idx and src_node.context_pos \u0026lt;= tgt_node.context_pos ) def get_token_embedding(token_id, local_model): \u0026#34;\u0026#34;\u0026#34; Suppose your model has a wte (word token embed) of shape [vocab_size, d_model]. Then the embedding is wte[token_id]. \u0026#34;\u0026#34;\u0026#34; return local_model.base.embed[token_id] # for example def compute_direct_edges_for_node( G, node_id: int, local_model: nn.Module, tokens: torch.Tensor, freeze_ln=True, freeze_attn=True, epsilon=1e-8, ): \u0026#34;\u0026#34;\u0026#34; For a given node in the attribution graph, we compute the direct edges from *all* upstream nodes in G to `node_id`. We do a custom backward pass in the local replacement model, with LN denominators and QK patterns frozen (stop-grad). Pseudocode steps: 1) Build the injection vector for the `node_id`—this depends on whether it\u0026#39;s a Feature node, Logit node, etc. 2) Insert that injection vector into the residual stream at the correct layer (or final-layer residual for a logit). 3) Run a custom backward pass that accumulates `.grad` in the model\u0026#39;s residual streams (and zero in LN denominators, etc.). 4) For each source node in G, compute the direct edge weight using either: w = source_activation * sum_{k} [ W_dec^{(source)}^T * grad_for_layer_k * W_enc^{(target)} ] or for embeddings / error nodes, a direct dot-product with the residual grad. 5) Add edge to G if abs(weight) \u0026gt; threshold. \u0026#34;\u0026#34;\u0026#34; node = G.nodes[node_id] node_type = node.node_type # --------------------------------------------------------------------- # 1) Build the injection vector for the \u0026#34;target node\u0026#34; # --------------------------------------------------------------------- if node_type == NodeType.LOGIT: # For a logit node, commonly we do \u0026#34;gradient w.r.t. (logit_tok - mean_logit)\u0026#34;. target_logit_idx = node.logit_index # e.g. ID in vocab # Suppose we have the final-layer residual dimension = d_model # We make an injection vector of shape [d_model], or [1, d_model] injection_vec = compute_logit_injection_vector(local_model, target_logit_idx) elif node_type == NodeType.FEATURE: # For a feature node, we typically want to inject that feature\u0026#39;s input vector # into the residual stream at the layer it reads from. That is: # v_in^L = W_enc[featureID] (the encoder weights) # For cross-layer transcoders, each feature has an encoder layer. We assume # node stores e.g. node.layer_idx, node.feature_idx, etc. layer_idx = node.layer_idx feat_idx = node.feature_idx # e.g. index in that layer\u0026#39;s set of features injection_vec = local_model.clt.W_enc[layer_idx][feat_idx].detach().clone() elif node_type == NodeType.ERROR: # For error nodes, we might define injection_vec as the error node\u0026#39;s # \u0026#34;output vector\u0026#34; in the residual stream. # e.g. v_out = MLP_out_true - CLT_out. We\u0026#39;ll just pretend we have it stored: injection_vec = node.output_vector.detach().clone() elif node_type == NodeType.EMBEDDING: # For an embedding node, we can do something like: token_embed = get_token_embedding(node.token_id, local_model) injection_vec = token_embed.detach().clone() else: raise ValueError(f\u0026#34;Unknown node_type = {node_type}\u0026#34;) # We\u0026#39;ll reshape to [d_model] or [1,d_model] if needed injection_vec = injection_vec.detach().view(-1) injection_vec.requires_grad_(False) # --------------------------------------------------------------------- # 2) Insert the injection vector into the residual stream # We\u0026#39;ll do a custom backward pass that sets MLP out, LN denominators, QK patterns # to no_grad or zero_grad, etc. # --------------------------------------------------------------------- # For example, if the node is a LOGIT node, we put injection_vec at the final-layer residual # If it\u0026#39;s a FEATURE in layer L, we put the injection_vec in the residual of layer L # We\u0026#39;ll store them in a \u0026#39;tensor\u0026#39; that the model sees as the backward pass \u0026#34;incoming gradient\u0026#34;. # We\u0026#39;ll define a helper function. In a real code, you\u0026#39;d define a specialized # \u0026#34;run_backward_with_injection\u0026#34; that does your custom hooking logic: grad_dict = run_backward_with_injection( local_model=local_model, injection_vec=injection_vec, layer_idx=node.layer_idx if node_type == NodeType.FEATURE else local_model.base.cfg.n_layers, # ^ if logit or error or embedding, we might set this to final layer real_tokens=tokens, freeze_ln=freeze_ln, freeze_attn=freeze_attn, ) # grad_dict: Suppose it returns a dict mapping {(layer_idx, token_pos): residual_grad_tensor} # or something similar for each layer. You may also want the raw MLP-input grads in # your local replacement model or the partial derivatives w.r.t. CLT decoders, etc. # --------------------------------------------------------------------- # 3) For each source node, compute direct edge weight # --------------------------------------------------------------------- # We\u0026#39;ll show how to do it for a \u0026#34;feature\u0026#34; source node (the typical case). # We\u0026#39;ll also handle \u0026#34;embedding\u0026#34; and \u0026#34;error\u0026#34; node. (They differ because # features have to multiply by their activation in this prompt.) # We\u0026#39;ll do a loop over all nodes in G. In practice, you might only want to # handle nodes at earlier layers, or earlier token positions, etc. for src_id, src_node in enumerate(G.nodes): if src_id == node_id: continue if not is_upstream_layer(src_node, node): # you might have logic to skip obviously irrelevant nodes continue w = 0.0 if src_node.node_type == NodeType.FEATURE: # The formula from the text is roughly: # # A_{s -\u0026gt; t} = a_s * sum_{ℓ in [src_layer..(tgt_layer-1)]} [ # W_dec(src_feat)ᵀ * grad[ℓ] * W_enc(tgt_feat) # ] # # But we already have \u0026#34;grad[ℓ] * W_enc(tgt_feat)\u0026#34; from the injection pass, # plus we can do \u0026#34;W_dec(src_feat)ᵀ . that\u0026#34; ... # We\u0026#39;ll do it more explicitly in code: src_activation = src_node.activation # a_s from the prompt s_layer = src_node.layer_idx s_feat_idx = src_node.feature_idx # Grab decoders from local_model.clt for that feature # For each layer in [s_layer..(node.layer_idx-1)], we can do something: # out_vec = local_model.clt.W_dec[ index_map[s_layer][ℓ] ][ s_feat_idx ] # partial = torch.dot( out_vec, grad_dict[ℓ, node.pos] ) # shape both [d_model] # Summation of partial across the relevant layers. sum_val = 0.0 # We\u0026#39;ll define target_layer = node.layer_idx if node is feature # or local_model.base.cfg.n_layers if node is logit end_layer = (node.layer_idx if node_type == NodeType.FEATURE else local_model.base.cfg.n_layers) for mid_layer in range(s_layer, end_layer): dec_idx = local_model.clt.index_map[s_layer][mid_layer] out_vec = local_model.clt.W_dec[dec_idx][s_feat_idx] # shape [d_model] # Suppose grad_dict[(mid_layer, node.context_pos)] is the gradient for that token/resid # We\u0026#39;ll do a dot product: if (mid_layer, src_node.context_pos) not in grad_dict: continue # might not exist grad_vec = grad_dict[(mid_layer, src_node.context_pos)] partial = torch.dot(out_vec, grad_vec) sum_val += partial.item() # Now multiply by a_s: w = src_activation * sum_val elif src_node.node_type == NodeType.EMBEDDING: # Then the direct edge is basically  # w = (Emb_src)^T grad[ src_layer, context_pos ] # Embeddings read into layer 0 (or wherever your model does). # Or you might store them differently.  # If your local model lumps them into the \u0026#34;resid at layer 0,\u0026#34; you can do: emb_vec = src_node.embedding_vector layer_for_emb = 0 # typically we treat the embedding as \u0026#34;layer 0\u0026#34;  pos = src_node.context_pos if (layer_for_emb, pos) in grad_dict: grad_vec = grad_dict[(layer_for_emb, pos)] w = torch.dot(emb_vec, grad_vec).item() else: w = 0.0 elif src_node.node_type == NodeType.ERROR: # Error nodes have no upstream inputs. By definition in the text, # they \u0026#34;pop out of nowhere.\u0026#34; So typically we do NOT add edges in from anything. # Conversely, if you\u0026#39;re computing edges from error to *this node*, you do: # w = error_vec^T gradVec # for the relevant layer. Example: err_out_vec = src_node.output_vector # shape [d_model] err_layer = src_node.layer_idx pos = src_node.context_pos if (err_layer, pos) in grad_dict: grad_vec = grad_dict[(err_layer, pos)] w = torch.dot(err_out_vec, grad_vec).item() else: w = 0.0 # If abs(w) is big enough, we add an edge if abs(w) \u0026gt; 1e-2: # TODO: change this threshold to 1e-4 G.add_edge(src_id, node_id, w) # done. G is updated with edges from source_node -\u0026gt; node_id. # You\u0026#39;d repeat this function for each node in the graph for which # you want to discover incoming edges. return G # done def build_attribution_graph_for_prompt(local_model, prompt_tokens): # 1) build the nodes G, cache = build_graph_nodes(prompt_tokens, local_model) # 2) for each \u0026#34;target node\u0026#34; in logit_node_ids, do partial backward for lnid in cache[\u0026#34;logit_node_ids\u0026#34;]: compute_direct_edges_for_node(G, lnid, local_model, prompt_tokens) # 3) for each feature node, if you want edges going into it, do partial backward too for (layer, pos, feat_i), node_id in cache[\u0026#34;feature_nodes\u0026#34;].items(): compute_direct_edges_for_node(G, node_id, local_model, prompt_tokens) # 4) done building adjacency return G def build_adjacency_matrix(G: AttributionGraph): \u0026#34;\u0026#34;\u0026#34; Returns adjacency matrix A of shape [N, N], where A[j,i] = sum of edges from i-\u0026gt;j This is typically used for the pruning step. \u0026#34;\u0026#34;\u0026#34; N = G.get_num_nodes() A = np.zeros((N,N), dtype=np.float32) for i in range(N): for (j, w) in G.edges[i]: A[j,i] += w return A def prune_graph( G: AttributionGraph, logit_nodes: List[int], threshold_nodes=0.8, threshold_edges=0.98 ): \u0026#34;\u0026#34;\u0026#34; 1) Build adjacency matrix 2) compute B = (I - A)^-1 - I 3) For each node, sum up influence on logit nodes 4) keep top X% =\u0026gt; subgraph 5) then re-build adjacency, do edge-level pruning similarly. \u0026#34;\u0026#34;\u0026#34; # 1) adjacency A = build_adjacency_matrix(G) N = A.shape[0] # 2) compute B # we might want to clamp negative edges or do absolute value. The paper does \u0026#34;abs + row-normalize\u0026#34; # For simplicity, do exactly as the paper says: we\u0026#39;ll do a quick absolute + row normalization: A_abs = np.abs(A) row_sums = A_abs.sum(axis=1, keepdims=True) row_sums[row_sums \u0026lt; 1e-10] = 1.0 A_norm = A_abs / row_sums # Then B = (I - A_norm)^-1 - I I = np.eye(N, dtype=np.float32) M = (I - A_norm) M_inv = np.linalg.inv(M) B = M_inv - I # shape [N, N] # 3) logit influence # We have logit_nodes; we can weight them by their probability, or just sum them # Suppose we do a simple sum logit_mask = np.zeros((N,), dtype=np.float32) for lnid in logit_nodes: logit_mask[lnid] = 1.0 # node_influence = B * logit_mask # shape = [N,N], [N], =\u0026gt; we do a matmul node_influence = B @ logit_mask # shape [N] # 4) sort nodes by influence order = np.argsort(-node_influence) # descending csum = np.cumsum(node_influence[order]) total = csum[-1] if csum[-1]\u0026gt;0 else 1.0 cutoff_val = threshold_nodes*total keep_mask = np.zeros((N,), dtype=bool) for idx in range(len(order)): if csum[idx] \u0026lt;= cutoff_val: keep_mask[order[idx]] = True else: break # We now form a new subgraph with only those nodes. (Edges from or to \u0026#34;kept\u0026#34; nodes.) # Then we do a second pass for edge pruning using the same logic but with B for edges. # For brevity, we\u0026#39;ll just do node-level here.  # In practice, you\u0026#39;d want to do the 2-step approach from the text.  new_G = AttributionGraph() old_to_new = {} for i in range(N): if keep_mask[i]: new_i = new_G.add_node(G.nodes[i]) old_to_new[i] = new_i for i in range(N): if keep_mask[i]: for (j, w) in G.edges[i]: if keep_mask[j]: new_G.add_edge(old_to_new[i], old_to_new[j], w) return new_G def compute_all_direct_edges_layer( local_model, tokens, layer_idx, target_nodes, source_nodes, threshold=1e-4, freeze_ln=True, freeze_attn=True, ): \u0026#34;\u0026#34;\u0026#34; - Gathers the injection vectors for many \u0026#39;target_nodes\u0026#39; that read from \u0026#39;layer_idx\u0026#39;. - Runs single_pass_grad_for_targets exactly ONCE, returning a 4D array of partial derivatives. - For each source_node in `source_nodes`, does a dot product to get the direct edge from source -\u0026gt; each target. - Returns a list of (source_id, target_id, weight). NOTE: This function is a simplified example. It only demonstrates: - FEATURE nodes as targets, or LOGIT nodes as targets, for building injection vectors. - EMBEDDING or FEATURE nodes as sources (we do a simple dot product). - You can adapt it for your code (e.g. ERROR nodes, multi-layer partial sums, etc.). \u0026#34;\u0026#34;\u0026#34; # 1) Build the injection matrix for the target_nodes injection_list = [] # We\u0026#39;ll keep track of which target_node corresponds to each row # so we can output edges in the correct order target_index_map = [] d_model = local_model.base.cfg.d_model for i, tnode in enumerate(target_nodes): if tnode.node_type == \u0026#34;feature\u0026#34;: # e.g. local_model.clt.W_enc[layer_idx][feat_idx] feat_idx = tnode.feature_index inj_vec = local_model.clt.W_enc[layer_idx][feat_idx].detach().clone() elif tnode.node_type == \u0026#34;logit\u0026#34;: # logit injection inj_vec = compute_logit_injection_vector(local_model, tnode.logit_index).detach().clone() else: # If you had embeddings or error nodes as targets, you\u0026#39;d define them here raise ValueError(f\u0026#34;Unsupported target node type {tnode.node_type}\u0026#34;) injection_list.append(inj_vec) target_index_map.append(i) # shape = [n_targets, d_model] injection_mat = torch.stack(injection_list, dim=0) # 2) Single pass =\u0026gt; gradient wrt this layer\u0026#39;s residual grad_wrt_resid = single_pass_grad_for_targets( local_model=local_model, tokens=tokens, layer_idx=layer_idx, target_injection_vectors=injection_mat, freeze_ln=freeze_ln, freeze_attn=freeze_attn, ) # shape [batch, seq, d_model, n_targets] # We\u0026#39;ll assume batch=1 for simplicity # (If batch\u0026gt;1, you might do a sum over all batch elements or handle them differently.) grad_wrt_resid = grad_wrt_resid[0] # shape [seq, d_model, n_targets] seq_len, d_model, n_targets = grad_wrt_resid.shape # 3) Build edges from source_nodes to target_nodes edges = [] # list of (src_id, tgt_id, weight) for src_node in source_nodes: src_id = src_node.id if src_node.node_type == \u0026#34;embedding\u0026#34;: # Then the embedding vector is local_model.base.embed[idx], for token idx # We\u0026#39;ll do a dot product with grad_wrt_resid at the correct context_pos emb_idx = src_node.logit_index # or however you store it # Actually, we might store it as src_node.output_vector or something # For demonstration: token_id = src_node.logit_index emb_vec = local_model.base.embed[token_id] # shape [d_model] # Or if you stored an actual .embedding_vector in the node, use that emb_vec = emb_vec.detach().clone() pos = src_node.context_pos if pos \u0026lt; 0 or pos \u0026gt;= seq_len: # skip if out of range continue # shape [d_model, n_targets] grad_slice = grad_wrt_resid[pos] # shape [d_model, n_targets] # w = dot(emb_vec, grad_slice[:, i]) for each target i # =\u0026gt; shape [n_targets] w_vals = torch.einsum(\u0026#34;d, dn-\u0026gt;n\u0026#34;, emb_vec, grad_slice) # w_vals is shape [n_targets] # we can apply threshold, or store them all for t_i, wval in enumerate(w_vals): val = wval.item() if abs(val) \u0026gt;= threshold: tgt_id = target_nodes[t_i].id edges.append((src_id, tgt_id, val)) elif src_node.node_type == \u0026#34;feature\u0026#34;: # Suppose we do a simpler, single-layer approach: # direct edge ~ src_activation * dot(W_dec[src_feat], grad_slice) # We\u0026#39;ll find the correct position, assume it\u0026#39;s src_node.context_pos pos = src_node.context_pos s_feat_idx = src_node.feature_index # The src activation is typically stored in node.activation src_activation = src_node.activation # The \u0026#34;decoder\u0026#34; vector is local_model.clt.W_dec[index_map[src_layer][layer_idx]][s_feat_idx], etc. # For simplicity, let\u0026#39;s assume src_node.layer_idx == layer_idx # in real code, you might do a loop from src_layer..layer_idx-1 if pos \u0026lt; 0 or pos \u0026gt;= seq_len: continue # shape [d_model, n_targets] grad_slice = grad_wrt_resid[pos] # shape [d_model, n_targets] # The decoder vector for that feature -\u0026gt; layer_idx dec_idx = local_model.clt.index_map[src_node.layer_idx][layer_idx] if dec_idx is None: # skip if e.g. src_node.layer_idx\u0026gt;layer_idx continue out_vec = local_model.clt.W_dec[dec_idx][s_feat_idx] # shape [d_model] # Then do dot(out_vec, grad_slice) =\u0026gt; shape [n_targets] partial_vals = torch.einsum(\u0026#34;d, dn-\u0026gt;n\u0026#34;, out_vec, grad_slice) # Multiply by the activation w_vals = src_activation * partial_vals # shape [n_targets] for t_i, wval in enumerate(w_vals): val = wval.item() if abs(val) \u0026gt;= threshold: tgt_id = target_nodes[t_i].id edges.append((src_id, tgt_id, val)) else: # If you had other node types (ERROR, etc.), do similarly pass return edges def build_direct_edges_single_pass_debug( G: AttributionGraph, local_model: LocalReplacementModel, tokens: torch.Tensor, threshold=1e-4, ): \u0026#34;\u0026#34;\u0026#34; Debug version: We add print statements to confirm the shape of each injection vector for each target node. We also print the final stacked shape, to confirm it\u0026#39;s [n_targets, d_model]. \u0026#34;\u0026#34;\u0026#34; n_layers = local_model.base.cfg.n_layers d_model = local_model.base.cfg.d_model print(f\u0026#34;[DEBUG] n_layers={n_layers}, d_model={d_model}\u0026#34;) # Bucket target nodes by layer layer_to_targets = {L: [] for L in range(n_layers+1)} for node in G.nodes: if node.node_type == NodeType.LOGIT: layer_to_targets[n_layers].append(node) elif node.node_type == NodeType.FEATURE: layer_to_targets[node.layer_idx].append(node) else: pass # e.g. embedding/error =\u0026gt; sources only # One pass per layer for layer_i in range(n_layers+1): target_nodes = layer_to_targets[layer_i] if len(target_nodes) == 0: continue print(f\u0026#34;\\n[DEBUG] Layer={layer_i}, #targets={len(target_nodes)}\u0026#34;) # 2A) Build injection_mat =\u0026gt; shape [n_targets, d_model] injection_list = [] for idx, tnode in enumerate(target_nodes): if tnode.node_type == NodeType.FEATURE: feat_idx = tnode.feature_index w_enc_matrix = local_model.clt.W_enc[layer_i] # shape [num_features, d_model], e.g. [128,768] print(f\u0026#34; [DEBUG] Target node {idx}: FEATURE layer={layer_i}, feat_idx={feat_idx}\u0026#34;) print(f\u0026#34; w_enc_matrix.shape = {tuple(w_enc_matrix.shape)}\u0026#34;) # The injection vector must be shape [d_model], not bigger inj_vec = w_enc_matrix[feat_idx] # shape =\u0026gt; [768] print(f\u0026#34; inj_vec.shape = {tuple(inj_vec.shape)}\u0026#34;) injection_list.append(inj_vec.detach().clone()) elif tnode.node_type == NodeType.LOGIT: print(f\u0026#34; [DEBUG] Target node {idx}: LOGIT layer={layer_i}, logit_index={tnode.logit_index}\u0026#34;) inj_vec = compute_logit_injection_vector(local_model, tnode.logit_index) print(f\u0026#34; inj_vec.shape = {tuple(inj_vec.shape)}\u0026#34;) injection_list.append(inj_vec.detach().clone()) else: print(f\u0026#34; [DEBUG] Skipping target node {idx}: node_type={tnode.node_type}\u0026#34;) continue if len(injection_list) == 0: print(f\u0026#34; [DEBUG] No valid feature/logit targets for layer={layer_i}, skipping.\u0026#34;) continue # shape =\u0026gt; [n_targets, d_model] injection_mat = torch.stack(injection_list, dim=0) print(f\u0026#34;[DEBUG] injection_mat final shape =\u0026gt; {tuple(injection_mat.shape)}\u0026#34;) # If layer_i == n_layers =\u0026gt; clamp to n_layers-1 actual_layer = min(layer_i, n_layers - 1) print(f\u0026#34;[DEBUG] actual_layer = {actual_layer}\u0026#34;) # 2B) single backward pass =\u0026gt; gradient wrt resid at actual_layer grad_wrt_resid = single_pass_grad_for_targets_debug( local_model=local_model, tokens=tokens, layer_idx=actual_layer, target_injection_vectors=injection_mat, freeze_ln=True, freeze_attn=True, ) # =\u0026gt; shape [batch=1, seq, d_model, n_targets] grad_wrt_resid = grad_wrt_resid[0] # =\u0026gt; [seq, d_model, n_targets] seq_len, _, n_targets = grad_wrt_resid.shape # 2C) For each source node... for src_id, src_node in enumerate(G.nodes): if not is_upstream_layer_layer(src_node, layer_i): continue for t_idx, tgt_node in enumerate(target_nodes): pos = tgt_node.context_pos if (pos \u0026lt; 0) or (pos \u0026gt;= seq_len): continue grad_slice = grad_wrt_resid[pos, :, t_idx] # =\u0026gt; shape [d_model] w_val = 0.0 if src_node.node_type == NodeType.EMBEDDING: emb_vec = src_node.output_vector # shape [d_model] w_val = float(torch.dot(emb_vec, grad_slice)) elif src_node.node_type == NodeType.FEATURE: a_s = src_node.activation s_feat_idx = src_node.feature_index s_layer = src_node.layer_idx target_index = min(layer_i, n_layers - 1) dec_idx = local_model.clt.index_map[s_layer][target_index] if dec_idx is not None: out_vec = local_model.clt.W_dec[dec_idx][s_feat_idx] # [d_model] partial = float(torch.dot(out_vec, grad_slice)) w_val = a_s * partial if abs(w_val) \u0026gt;= threshold: G.add_edge(src_id, tgt_node.id, w_val) return G def single_pass_grad_for_targets_debug( local_model, tokens, layer_idx, target_injection_vectors, freeze_ln=True, freeze_attn=True, ): \u0026#34;\u0026#34;\u0026#34; Same as single_pass_grad_for_targets but with debug prints about shapes. \u0026#34;\u0026#34;\u0026#34; d_model = local_model.base.cfg.d_model print(f\u0026#34;\\n[DEBUG single_pass_grad_for_targets] =\u0026gt; layer_idx={layer_idx}\u0026#34;) print(f\u0026#34; target_injection_vectors.shape = {list(target_injection_vectors.shape)}\u0026#34;) local_model.zero_grad(set_to_none=True) # Optionally freeze LN or QK if freeze_ln or freeze_attn: for name, param in local_model.base.named_parameters(): if freeze_ln and (\u0026#34;ln\u0026#34; in name): param.requires_grad_(False) if freeze_attn and (\u0026#34;.W_Q\u0026#34; in name or \u0026#34;.W_K\u0026#34; in name): param.requires_grad_(False) # Hook storage = {} def store_resid_hook(resid, hook): storage[\u0026#34;resid\u0026#34;] = resid return resid # E.g. blocks.{layer_idx}.hook_resid_mid resid_name = f\u0026#34;blocks.{layer_idx}.hook_resid_mid\u0026#34; if resid_name not in local_model.base.hook_dict: # fallback resid_name = f\u0026#34;blocks.{layer_idx}.hook_resid_post\u0026#34; handle = local_model.base.add_hook(resid_name, store_resid_hook, \u0026#34;fwd\u0026#34;) with torch.enable_grad(): _ = local_model(tokens) if handle is not None: handle.remove() if \u0026#34;resid\u0026#34; not in storage: raise ValueError(f\u0026#34;Could not find residual for layer {layer_idx}: {resid_name}\u0026#34;) resid_L = storage[\u0026#34;resid\u0026#34;] print(f\u0026#34; resid_L.shape = {tuple(resid_L.shape)}\u0026#34;) # Now we verify the injection_vec shape is [n_targets, d_model]. n_targets = target_injection_vectors.shape[0] if target_injection_vectors.ndim != 2 or target_injection_vectors.shape[1] != d_model: raise RuntimeError( f\u0026#34;[DEBUG ERROR] Expected injection_vectors shape [n_targets, {d_model}], \u0026#34; f\u0026#34;but got {list(target_injection_vectors.shape)}\u0026#34; ) # 4) build partial losses injection_mat = target_injection_vectors.view(1, 1, d_model, n_targets) expanded_resid = resid_L.unsqueeze(-1) # =\u0026gt; [batch, seq, d_model, 1] product = expanded_resid * injection_mat # =\u0026gt; [batch, seq, d_model, n_targets] partial_losses = product.sum(dim=[0,1,2]) # =\u0026gt; shape [n_targets] # 3) Provide grad_outputs =\u0026gt; identity, so we get full Jacobian outs = partial_losses # shape [n_targets] jac_list = [] for i in range(n_targets): unit_vec = torch.zeros_like(outs) # shape [n_targets] unit_vec[i] = 1.0 # Now partial derivative of outs·unit_vec =\u0026gt; outs[i] grad_i = torch.autograd.grad( outputs=outs, # [n_targets] inputs=resid_L, # [batch, seq, d_model] grad_outputs=unit_vec, # also [n_targets] retain_graph=True, )[0] # =\u0026gt; shape [batch, seq, d_model] jac_list.append(grad_i.unsqueeze(-1)) # =\u0026gt; [batch, seq, d_model, 1] print(f\u0026#34;index={i}, grad_i.shape={tuple(grad_i.shape)}\u0026#34;) # Finally cat them =\u0026gt; [batch, seq, d_model, n_targets] grad_wrt_resid = torch.cat(jac_list, dim=-1) print(f\u0026#34; =\u0026gt; grad_wrt_resid.shape = {tuple(grad_wrt_resid.shape)}\\n\u0026#34;) return grad_wrt_resid putting it together so far #### collect 5k samples main_collect_5k() #### train the CLT dataset = CLTHiddenStateDataset(\u0026#34;clt_data\u0026#34;) dataloader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=4) model_name = \u0026#34;gpt2-small\u0026#34; hmodel = HookedTransformer.from_pretrained(model_name) layer_feature_list = [128]*hmodel.cfg.n_layers # or something custom cross_layer_transcoder = CrossLayerTranscoder.from_hookedtransformer( hmodel, layer_feature_list, bandwidth=1.0, device=\u0026#34;cuda\u0026#34; ) train_clt(cross_layer_transcoder, dataloader, 2) #### build the local replacement model prompt = \u0026#34;Hello, I\u0026#39;d like you to help me plan my wedding please.\u0026#34; local_replacement_model = build_local_replacement_model_with_cache(hmodel, cross_layer_transcoder, prompt, device=\u0026#34;cuda\u0026#34;) #### Let\u0026#39;s compare outputs of the original model vs. the local replacement with torch.no_grad(): tokens = hmodel.to_tokens(prompt, prepend_bos=True).to(\u0026#34;cuda\u0026#34;) orig_logits = hmodel(tokens) rep_logits = local_replacement_model(tokens) print(\u0026#34;Are the logits close?\u0026#34;, torch.allclose(orig_logits, rep_logits, atol=1e-5)) tokens = hmodel.to_tokens(prompt, prepend_bos=True) #### 1) Build node set G, cache = build_graph_nodes(local_replacement_model, prompt, top_k=3) #### 2) compute direct edges for each logit node or MLP node build_direct_edges_single_pass_debug(G, local_replacement_model, tokens, threshold=1e-4) #### 3) prune pruned_G = prune_graph(G, cache[\u0026#34;logit_node_ids\u0026#34;], threshold_nodes=0.8, threshold_edges=0.98) there ya go. some breakdown of anthropic\u0026rsquo;s paper with the assistance of some of my dearest friends: sonnet 3.7 thinking, o1-pro, gpt-4o, o3-mini-high, and long time homie, sonnet 3.5.\ni\u0026rsquo;ll do a part 2 on visualizing/labeling/patching/interventions/etc. ","date":"2025-04-12","permalink":"/posts/circuit_tracing/","title":"Replicating The Circuit Kings"},{"content":"==================\nthis is sort of a strange blog bc it\u0026rsquo;s more actually just a stream of consciousness into a random exploration I had.\nfrom what I understand, one of the complaints of individuals working at big labs is that there\u0026rsquo;s less \u0026lsquo;curiosity\u0026rsquo; driven research, and more bet driven research. in this paradigm, you have a bet on a direction, and you\u0026rsquo;re trying to make it work, so you go explore that direction instead of just wandering down a path that interests you. when you\u0026rsquo;re betting bagillions of dollars, this makes sense, but for little old me, I sometimes just like to wander.\nin order for ideas/intuition to fully develop in my brain (which hopefully then compounds into better future ideas), I usually need time to take dumb ideas to their inevitable failure or the intuition doesn\u0026rsquo;t properly develop; even if the idea itself is dumber than a box of rocks. the cool thing about these random walks, (and all research these days), is that o1/sonnet/r1/grok/etc have all sort of made this process much easier to do. eg, instead of painstakingly coding up a new idea and then losing interest completely by the end of it, you just sort of \u0026ldquo;vibe research\u0026rdquo;. you control the generative path, but the llm sort of implement the pieces of the idea that you want it to and you see where it goes. so the idea sort of starts on one end and sometimes ends at another completely unrelated end.\nideagen i was thinking about how to get smaller and smaller models while being as fast as humanly possible without something like distillation; rather just focusing on the architecture. something that potentially can be used in conjunction with like a memory unit for an agent (we\u0026rsquo;ll get there in another blog) and is super fast at inference time. the idea sort of stemmed from 1) how to make attention faster and 2) if you\u0026rsquo;re ever to own you\u0026rsquo;re own personal AI running on your own hardware, with some kind of human/alien like memory, you\u0026rsquo;d want to never have to reset a conversation/agentic flow, it would just sort of have this infinite memory reservoir of all your interactions with it; so basically it comes down to 1) what is Noam Shazeer doing architecturally, and 2) how can we get the attention to be faster.\nasking that/those question(s) leads me to some form of Multi Query Attention (MQA) and some form of compression (kv cache MLA from DeepSeek for example). so let\u0026rsquo;s make MQA faster. ok, so what\u0026rsquo;s a small didly on how we can make MQA faster in the easiest way possible? i dunno, top of mind - make it low-rank or something. sure, why not. let\u0026rsquo;s see what kind of trouble we can get ourselves into. logically this will reduce the expressiveness, but why the hell not just try it for fun.\nwhen you think about it, it\u0026rsquo;s basically lora without the lora. it\u0026rsquo;ll be building the entire attention layer from a low-rank parameterization instead of factorizing frozen weights. almost like training a lora from scratch.\nhasn\u0026rsquo;t this been done before? probably, but who cares. i\u0026rsquo;m not even going to look into it because i just want to see where this line of thinking naturally goes. let\u0026rsquo;s see what happens.\nmqlra: multi-query low-rank attention MQA basically means you share the same query/key across multiple heads. it\u0026rsquo;s comparison to standard mha is below:\n# Create separate learnable weight matrices per head for Q, K, V W_Q = torch.nn.ParameterList([ torch.nn.Parameter(torch.randn(d_model, d_k)) for _ in range(n_heads) ]) W_K = torch.nn.ParameterList([ torch.nn.Parameter(torch.randn(d_model, d_k)) for _ in range(n_heads) ]) W_V = torch.nn.ParameterList([ torch.nn.Parameter(torch.randn(d_model, d_k)) for _ in range(n_heads) ]) def scaled_dot_product_attention(Q, K, V): \u0026#34;\u0026#34;\u0026#34; Q: (B, T, d_k) K: (B, T, d_k) V: (B, T, d_k) \u0026#34;\u0026#34;\u0026#34; # batch matrix multiply =\u0026gt; (B, T, T) scores = torch.bmm(Q, K.transpose(1, 2)) / math.sqrt(d_k) attn_weights = F.softmax(scores, dim=-1) # (B, T, T) # multiply by V =\u0026gt; (B, T, d_k) output = torch.bmm(attn_weights, V) return output, attn_weights def standard_multi_head_attention(x, W_Q_list, W_K_list, W_V_list): \u0026#34;\u0026#34;\u0026#34; x: (B, T, d_model) Returns: (B, T, n_heads * d_k) \u0026#34;\u0026#34;\u0026#34; head_outputs = [] for h in range(n_heads): Q = x @ W_Q_list[h] # (B, T, d_k) K = x @ W_K_list[h] # (B, T, d_k) V = x @ W_V_list[h] # (B, T, d_k) out, _ = scaled_dot_product_attention(Q, K, V) head_outputs.append(out) # Concatenate along the last dimension return torch.cat(head_outputs, dim=-1) # (B, T, n_heads*d_k) out_std = standard_multi_head_attention(x, W_Q, W_K, W_V) print(\u0026#34;Standard Multi-Head output shape:\u0026#34;, out_std.shape) # Create per-head Q, but only ONE K and V WQ_mq = torch.nn.ParameterList([ torch.nn.Parameter(torch.randn(d_model, d_k)) for _ in range(n_heads) ]) WK_mq = torch.nn.Parameter(torch.randn(d_model, d_k)) # Shared across heads WV_mq = torch.nn.Parameter(torch.randn(d_model, d_k)) # Shared across heads def multi_query_attention(x, WQ_list, WK, WV): \u0026#34;\u0026#34;\u0026#34; x: (B, T, d_model) Returns: (B, T, n_heads * d_k) -- same shape as standard multi-head \u0026#34;\u0026#34;\u0026#34; K = x @ WK # (B, T, d_k), shared V = x @ WV # (B, T, d_k), shared head_outputs = [] for h in range(n_heads): Q_h = x @ WQ_list[h] # (B, T, d_k) out_h, _ = scaled_dot_product_attention(Q_h, K, V) head_outputs.append(out_h) return torch.cat(head_outputs, dim=-1) out_mq = multi_query_attention(x, WQ_mq, WK_mq, WV_mq) print(\u0026#34;Multi-Query Attention output shape:\u0026#34;, out_mq.shape) ok, so now let\u0026rsquo;s take that same idea and just make these matrices low-rank.\nclass MQLRA(nn.Module): \u0026#34;\u0026#34;\u0026#34; Simplified Multi-Query Attention + Low-Rank factorization. - We have `n_heads` separate Q transformations, each factorized into (W1Q_h, W2Q_h). - We have a *single* K factorization: (W1K, W2K) - We have a *single* V factorization: (W1V, W2V) - This implementation uses standard PyTorch operations for readability. Input shape: X: [B, L, D_in] We\u0026#39;ll produce an attention output shape of: [B, L, D_out * n_heads] (by concatenating the n_heads outputs). \u0026#34;\u0026#34;\u0026#34; def __init__(self, D_in, D_out, n_heads, rank, causal=False, scale=None): super().__init__() self.D_in = D_in self.D_out = D_out self.n_heads = n_heads self.rank = rank self.causal = causal if scale is not None: self.scale = float(scale) else: self.scale = 1.0 / math.sqrt(D_out) # Q: separate for each head # Each head: W1Q =\u0026gt; [D_in, rank], W2Q =\u0026gt; [rank, D_out] self.W1Q_heads = nn.ParameterList([ nn.Parameter(torch.randn(D_in, rank) * 0.02) for _ in range(n_heads) ]) self.W2Q_heads = nn.ParameterList([ nn.Parameter(torch.randn(rank, D_out) * 0.02) for _ in range(n_heads) ]) # K: single self.W1K = nn.Parameter(torch.randn(D_in, rank) * 0.02) self.W2K = nn.Parameter(torch.randn(rank, D_out) * 0.02) # V: single self.W1V = nn.Parameter(torch.randn(D_in, rank) * 0.02) self.W2V = nn.Parameter(torch.randn(rank, D_out) * 0.02) def forward(self, X, attn_bias=None): \u0026#34;\u0026#34;\u0026#34; X =\u0026gt; [B, L, D_in] attn_bias =\u0026gt; optional, shape broadcastible to [B, n_heads, L, L] if needed Returns: [B, L, n_heads*D_out] \u0026#34;\u0026#34;\u0026#34; batch_size, seq_len, _ = X.shape device = X.device dtype = X.dtype # Compute K and V once (shared across all heads) # Low-rank projections for K, V partial_k = X @ self.W1K # [B, L, rank] K = partial_k @ self.W2K # [B, L, D_out] partial_v = X @ self.W1V # [B, L, rank] V = partial_v @ self.W2V # [B, L, D_out] # Prepare output tensor to accumulate results from all heads out = torch.zeros(batch_size, seq_len, self.n_heads * self.D_out, device=device, dtype=dtype) # Process each head separately for h in range(self.n_heads): # Low-rank projection for Q (head-specific) partial_q = X @ self.W1Q_heads[h] # [B, L, rank] Q = partial_q @ self.W2Q_heads[h] # [B, L, D_out] # Compute attention scores # [B, L, D_out] @ [B, D_out, L] -\u0026gt; [B, L, L] scores = torch.bmm(Q, K.transpose(1, 2)) * self.scale # Apply causal mask if needed if self.causal: # Create causal mask and apply it mask = torch.triu(torch.ones(seq_len, seq_len, device=device), diagonal=1).bool() scores.masked_fill_(mask, -float(\u0026#39;inf\u0026#39;)) # Apply attention bias if provided if attn_bias is not None: scores = scores + attn_bias # Softmax and apply attention attn_weights = torch.softmax(scores, dim=-1) # [B, L, L] head_output = torch.bmm(attn_weights, V) # [B, L, D_out] # Add to output tensor at the correct position out[:, :, h * self.D_out:(h + 1) * self.D_out] = head_output return out for a super easy implementation, we show how this would fit into Keller Jordan\u0026rsquo;s modded nanogpt which means we basically take our MQLRA and format it the way Keller formats his CausalSelfAttention class and then simply drop it into the script\n# ----------------------------------------------------------------------------- # Custom operators: FP8 matmul by @YouJiacheng @torch.library.custom_op(\u0026#34;nanogpt::mm\u0026#34;, mutates_args=()) def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -\u0026gt; tuple[Tensor, Tensor, Tensor]: @torch.compile def impl(x: Tensor, w: Tensor): assert x.is_contiguous() and w.is_contiguous() x_f8 = x.div(x_s).to(torch.float8_e4m3fn) w_f8 = w.div(w_s).to(torch.float8_e4m3fn) out = torch._scaled_mm( x_f8, w_f8.T, out_dtype=torch.bfloat16, scale_a=x.new_tensor(x_s, dtype=torch.float32), scale_b=x.new_tensor(w_s, dtype=torch.float32), use_fast_accum=True, ) return out, x_f8, w_f8 return impl(x, w) @mm_op.register_fake def _(x: Tensor, w: Tensor, *_): assert x.ndim == w.ndim == 2 assert x.shape[1] == w.shape[1] assert x.device == w.device assert x.is_contiguous() and w.is_contiguous() return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn) @torch.library.custom_op(\u0026#34;nanogpt::mm_backward\u0026#34;, mutates_args=()) def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -\u0026gt; tuple[Tensor, Tensor]: @torch.compile def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor): assert grad.is_contiguous() x_inv_s = grad.new_tensor(x_s, dtype=torch.float32) w_inv_s = grad.new_tensor(w_s, dtype=torch.float32) grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32) grad_f8 = grad.div(grad_s).to(torch.float8_e5m2) grad_x = torch._scaled_mm( grad_f8, w_f8.T.contiguous().T, out_dtype=torch.bfloat16, scale_a=grad_inv_s, scale_b=w_inv_s, use_fast_accum=False, ) # faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768) grad_w = torch._scaled_mm( x_f8.T.contiguous(), grad_f8.T.contiguous().T, out_dtype=torch.float32, scale_a=x_inv_s, scale_b=grad_inv_s, use_fast_accum=False, ).T return grad_x, grad_w return impl(g, x_f8, w_f8) @mm_backward_op.register_fake def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_): return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32) def backward(ctx, grad_out: Tensor, *_): x_f8, w_f8 = ctx.saved_tensors x_s, w_s, grad_s = ctx.scales grad_x, grad_w = torch.ops.nanogpt.mm_backward( grad_out, x_f8, w_f8, x_s, w_s, grad_s ) return grad_x, grad_w, None, None, None def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output): *_, x_s, w_s, grad_s = inputs _, x_f8, w_f8 = output ctx.save_for_backward(x_f8, w_f8) ctx.scales = x_s, w_s, grad_s ctx.set_materialize_grads(False) mm_op.register_autograd(backward, setup_context=setup_context) # skipping the incredibly based Muon optimizer for brevity # ----------------------------------------------------------------------------- # PyTorch nn.Module definitions for the model def norm(x: Tensor): return F.rms_norm(x, (x.size(-1),)) class CastedLinear(nn.Linear): def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0): super().__init__(in_features, out_features, bias=False) self.use_fp8 = use_fp8 self.x_s = x_s self.w_s = w_s self.grad_s = grad_s def reset_parameters(self) -\u0026gt; None: std = 0.5 * (self.in_features ** -0.5) # 0.5 is a bit better than the default 1/sqrt(3) bound = (3 ** 0.5) * std with torch.no_grad(): self.weight.uniform_(-bound, bound) def forward(self, x: Tensor): if self.use_fp8 and self.training: _x = x.flatten(0, -2) out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0] return out.reshape(*x.shape[:-1], -1) else: return F.linear(x, self.weight.type_as(x)) class Rotary(nn.Module): def __init__(self, dim: int, max_seq_len: int): super().__init__() # half-truncate RoPE by @YouJiacheng (w/ base freq tuning) angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=dim//4, dtype=torch.float32) angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(dim//4)]) t = torch.arange(max_seq_len, dtype=torch.float32) theta = torch.einsum(\u0026#34;i,j -\u0026gt; ij\u0026#34;, t, angular_freq) self.cos = nn.Buffer(theta.cos(), persistent=False) self.sin = nn.Buffer(theta.sin(), persistent=False) def forward(self, x_BTHD: Tensor): assert self.cos.size(0) \u0026gt;= x_BTHD.size(-3) cos, sin = self.cos[None, :x_BTHD.size(-3), None, :], self.sin[None, :x_BTHD.size(-3), None, :] x1, x2 = x_BTHD.to(dtype=torch.float32).chunk(2, dim=-1) y1 = x1 * cos + x2 * sin y2 = x1 * (-sin) + x2 * cos return torch.cat((y1, y2), 3).type_as(x_BTHD) class CausalSelfAttention(nn.Module): def __init__(self, dim: int, num_heads: int, max_seq_len: int, head_dim=128): super().__init__() self.num_heads = num_heads self.head_dim = head_dim hdim = num_heads * head_dim std = 0.5 * (dim ** -0.5) bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng # merged QKV weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng # https://x.com/hi_tysam/status/1879699187107033311 self.qkv_w = nn.Parameter(torch.empty(3, hdim, dim).uniform_(-bound, bound)) self.lambdas = nn.Parameter(torch.tensor([0.5, 0.5])) self.rotary = Rotary(head_dim, max_seq_len) self.c_proj = CastedLinear(hdim, dim) self.c_proj.weight.detach().zero_() # zero init suggested by @Grad62304977 def forward(self, x: Tensor, ve: Tensor | None, block_mask: BlockMask): B, T = x.size(0), x.size(1) # batch size, sequence length assert B == 1, \u0026#34;Must use batch size = 1 for FlexAttention\u0026#34; q, k, v = F.linear(x, self.qkv_w.flatten(end_dim=1).type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2) q, k = norm(q), norm(k) # QK norm @Grad62304977 q, k = self.rotary(q), self.rotary(k) if ve is not None: v = self.lambdas[0] * v + self.lambdas[1] * ve.view_as(v) # @KoszarskyB \u0026amp; @Grad62304977 else: # skip mid-layers token value embeddings by @YouJiacheng v = self.lambdas[0] * v # scale the attention logits by given constant, instead of the default head_dim**-0.5, by @leloykun # inspired by learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283 y = flex_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), block_mask=block_mask, scale=0.12).transpose(1, 2) y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side y = self.c_proj(y) return y class MLP(nn.Module): def __init__(self, dim: int): super().__init__() hdim = 4 * dim self.c_fc = CastedLinear(dim, hdim) self.c_proj = CastedLinear(hdim, dim) self.c_proj.weight.detach().zero_() # zero init suggested by @Grad62304977 def forward(self, x: Tensor): x = self.c_fc(x) x = F.relu(x).square() # https://arxiv.org/abs/2109.08668v2; ~1-2% better than GELU; suggested by @SKYLINEZ007 and @Grad62304977 x = self.c_proj(x) return x class Block(nn.Module): def __init__(self, dim: int, num_heads: int, max_seq_len: int, layer_idx: int): super().__init__() # skip attention of blocks.7 (the 8th layer) by @YouJiacheng self.attn = CausalSelfAttention(dim, num_heads, max_seq_len) if layer_idx != 7 else None self.mlp = MLP(dim) self.lambdas = nn.Parameter(torch.tensor([1., 0.])) def forward(self, x: Tensor, ve: Tensor | None, x0: Tensor, block_mask: BlockMask): x = self.lambdas[0] * x + self.lambdas[1] * x0 if self.attn is not None: x = x + self.attn(norm(x), ve, block_mask) x = x + self.mlp(norm(x)) return x # ----------------------------------------------------------------------------- # The main model def next_multiple_of_n(v: float | int, *, n: int): return next(x for x in range(n, int(v) + 1 + n, n) if x \u0026gt;= v) class GPT(nn.Module): def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: int, max_seq_len: int): super().__init__() self.embed = nn.Embedding(vocab_size, model_dim) # token value embeddings by @KoszarskyB - inspired by @Grad62304977\u0026#39;s value residual implementation following https://arxiv.org/abs/2410.17897 # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) self.blocks = nn.ModuleList([Block(model_dim, num_heads, max_seq_len, i) for i in range(num_layers)]) # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. # suggested to me by @Grad62304977. this originates from Karpathy\u0026#39;s experiments. self.lm_head = CastedLinear(model_dim, next_multiple_of_n(vocab_size, n=128), use_fp8=True, x_s=(model_dim**0.5)/448, w_s=24/448, grad_s=1/448) self.lm_head.weight.detach().zero_() # @Grad62304977 # Add learnable skip connection weights for decoder layers assert num_layers % 2 == 0 self.skip_weights = nn.Parameter(torch.ones(num_layers//2)) def create_blockmasks(self, input_seq: Tensor, sliding_window_num_blocks: Tensor): BLOCK_SIZE = 128 docs = (input_seq == 50256).cumsum(0) def document_causal(b, h, q_idx, kv_idx): causal_mask = q_idx \u0026gt;= kv_idx document_mask = docs[q_idx] == docs[kv_idx] return causal_mask \u0026amp; document_mask def dense_to_ordered(dense_blockmask: Tensor): num_blocks = dense_blockmask.sum(dim=-1, dtype=torch.int32) indices = dense_blockmask.argsort(dim=-1, descending=False, stable=True).flip(-1).to(torch.int32) return num_blocks[None, None].contiguous(), indices[None, None].contiguous() # manual block mask creation by @YouJiacheng assert len(input_seq) % BLOCK_SIZE == 0 NUM_BLOCKS = len(input_seq) // BLOCK_SIZE block_idx = torch.arange(NUM_BLOCKS, dtype=torch.int32, device=\u0026#34;cuda\u0026#34;) causal_blockmask_any = block_idx[:, None] \u0026gt;= block_idx causal_blockmask_all = block_idx[:, None] \u0026gt; block_idx docs_low = docs.view(-1, BLOCK_SIZE)[:, 0].contiguous() docs_high = docs.view(-1, BLOCK_SIZE)[:, -1].contiguous() document_blockmask_any = (docs_low[:, None] \u0026lt;= docs_high) \u0026amp; (docs_high[:, None] \u0026gt;= docs_low) document_blockmask_all = (docs_low[:, None] == docs_high) \u0026amp; (docs_high[:, None] == docs_low) blockmask_any = causal_blockmask_any \u0026amp; document_blockmask_any blockmask_all = causal_blockmask_all \u0026amp; document_blockmask_all partial_kv_num_blocks, partial_kv_indices = dense_to_ordered(blockmask_any \u0026amp; ~blockmask_all) full_kv_num_blocks, full_kv_indices = dense_to_ordered(blockmask_all) def build_bm(window_size_blocks: Tensor) -\u0026gt; BlockMask: return BlockMask.from_kv_blocks( torch.clamp_max(partial_kv_num_blocks, torch.clamp_min(window_size_blocks - full_kv_num_blocks, 1)), partial_kv_indices, torch.clamp_max(full_kv_num_blocks, window_size_blocks - 1), full_kv_indices, BLOCK_SIZE=BLOCK_SIZE, mask_mod=document_causal, ) # Long-short SWA block masks by @leloykun \u0026amp; @YouJiacheng, adapated from suggestion by @Grad62304977, following Gemma 2 paper return build_bm(sliding_window_num_blocks), build_bm(sliding_window_num_blocks // 2) def forward(self, input_seq: Tensor, target_seq: Tensor, sliding_window_num_blocks: Tensor): assert input_seq.ndim == 1 ve = [value_embed(input_seq) for value_embed in self.value_embeds] # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun\u0026#39;s U-net structure ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] assert len(ve) == len(self.blocks) long_bm, short_bm = self.create_blockmasks(input_seq, sliding_window_num_blocks) block_masks = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, long_bm] assert len(block_masks) == len(self.blocks) x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977 # U-net design by @brendanh0gan skip_connections = [] n = len(self.skip_weights) for i in range(len(self.blocks)): if i \u0026gt;= n: x = x + self.skip_weights[i - n] * skip_connections.pop() x = self.blocks[i](x, ve[i], x0, block_masks[i]) if i \u0026lt; n: skip_connections.append(x) x = norm(x) logits = self.lm_head(x).float() # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) logits = 30 * torch.sigmoid(logits / (7.5 * x.size(-1)**0.5)) loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction=\u0026#39;sum\u0026#39; if self.training else \u0026#39;mean\u0026#39;) return loss class MQLRAFlexAttention(nn.Module): \u0026#34;\u0026#34;\u0026#34; Multi-Query Low-Rank Attention + FlexAttention that acts as a drop-in replacement for CausalSelfAttention. - We replicate the same structure: * a rotary embedding * a final CastedLinear (c_proj) * a lambdas parameter for combining \u0026#39;ve\u0026#39; (value embeddings) * multi-head query, single K/V - We still do QK norm, then apply rotary, then call flex_attention. - The difference: we use (W1Q_h, W2Q_h) for Q per head, plus single (W1K,W2K) and (W1V,W2V) for K, V. \u0026#34;\u0026#34;\u0026#34; def __init__( self, dim: int, num_heads: int, max_seq_len: int, head_dim=128, rank=16, # low-rank dimension scale=0.12, # default scale from your code ): \u0026#34;\u0026#34;\u0026#34; Args: dim: model dimension (input and output). num_heads: number of heads. max_seq_len: for rotary embedding. head_dim: dimension per head (like in CausalSelfAttention). rank: low-rank dimension for Q, K, V factorization. scale: scale factor for attention logits, e.g. 0.12 in your code. \u0026#34;\u0026#34;\u0026#34; super().__init__() self.num_heads = num_heads self.head_dim = head_dim self.rank = rank self.scale = scale # total hidden dimension after merging heads hdim = num_heads * head_dim # optional initialization approach std = 0.5 * (dim ** -0.5) bound = (3 ** 0.5) * std # We keep lambdas for mixing in \u0026#39;ve\u0026#39; self.lambdas = nn.Parameter(torch.tensor([0.5, 0.5])) # Rotary embedding self.rotary = Rotary(head_dim, max_seq_len) # final projection after attention self.c_proj = CastedLinear(hdim, dim) self.c_proj.weight.detach().zero_() # Q: separate (W1Q, W2Q) for each head # shape: W1Q =\u0026gt; [dim, rank], W2Q =\u0026gt; [rank, head_dim] # We\u0026#39;ll store them in ParameterList to keep them distinct per head. self.W1Q_heads = nn.ParameterList([ nn.Parameter(torch.empty(dim, rank).uniform_(-bound, bound)) for _ in range(num_heads) ]) self.W2Q_heads = nn.ParameterList([ nn.Parameter(torch.empty(rank, head_dim).uniform_(-bound, bound)) for _ in range(num_heads) ]) # K: single (W1K, W2K) self.W1K = nn.Parameter(torch.empty(dim, rank).uniform_(-bound, bound)) self.W2K = nn.Parameter(torch.empty(rank, head_dim).uniform_(-bound, bound)) # V: single (W1V, W2V) self.W1V = nn.Parameter(torch.empty(dim, rank).uniform_(-bound, bound)) self.W2V = nn.Parameter(torch.empty(rank, head_dim).uniform_(-bound, bound)) def forward(self, x: torch.Tensor, ve: torch.Tensor | None, block_mask: BlockMask): \u0026#34;\u0026#34;\u0026#34; x: [B, T, dim] ve: optional value-embedding, shape must broadcast with V block_mask: for flex_attention Returns: [B, T, dim] after c_proj \u0026#34;\u0026#34;\u0026#34; B, T, dim = x.shape # If your flex_attention kernel only works for B=1: assert B == 1, \u0026#34;Must use batch size = 1 for FlexAttention\u0026#34; # 1) Compute Q, K, V in multi-query style # Q =\u0026gt; [B, T, num_heads, head_dim] # single K =\u0026gt; broadcast to [B, T, num_heads, head_dim] # single V =\u0026gt; broadcast similarly # Q (separate per head) Q_list = [] for h in range(self.num_heads): partial_q = x.matmul(self.W1Q_heads[h]) # [B, T, rank] q_h = partial_q.matmul(self.W2Q_heads[h]) # [B, T, head_dim] Q_list.append(q_h) # stack =\u0026gt; [B, T, num_heads, head_dim] q = torch.stack(Q_list, dim=2) # single K partial_k = x.matmul(self.W1K) # [B, T, rank] K_ = partial_k.matmul(self.W2K) # [B, T, head_dim] # expand to match num_heads =\u0026gt; [B, T, num_heads, head_dim] k = K_.unsqueeze(2).expand(B, T, self.num_heads, self.head_dim) # single V partial_v = x.matmul(self.W1V) # [B, T, rank] v_ = partial_v.matmul(self.W2V) # [B, T, head_dim] v = v_.unsqueeze(2).expand(B, T, self.num_heads, self.head_dim) # 2) norm Q, K + rotary # same approach as CausalSelfAttention: norm(q), norm(k), then rotary q = norm(q) k = norm(k) q = self.rotary(q) k = self.rotary(k) # 3) incorporate \u0026#39;ve\u0026#39; into v if provided if ve is not None: # ve =\u0026gt; shape something like [B, T, num_heads, head_dim] or broadcastable # in CausalSelfAttention, they do `v = lambdas[0]*v + lambdas[1]*ve.view_as(v)` # We do the same: v = self.lambdas[0] * v + self.lambdas[1] * ve.view_as(v) else: v = self.lambdas[0] * v # 4) call flex_attention # flex_attention expects shape [B, n_heads, T, head_dim] # so we do transpose(1,2) q_t = q.transpose(1, 2) # =\u0026gt; [B, num_heads, T, head_dim] k_t = k.transpose(1, 2) v_t = v.transpose(1, 2) # we have scale = self.scale, block_mask # from the snippet: y = flex_attention(q_t, k_t, v_t, block_mask, scale=0.12) y_t = flex_attention(q_t, k_t, v_t, block_mask=block_mask, scale=self.scale) # y_t =\u0026gt; [B, num_heads, T, head_dim] # transpose back =\u0026gt; [B, T, num_heads, head_dim] y = y_t.transpose(1, 2).contiguous() # 5) reshape =\u0026gt; [B, T, (num_heads * head_dim)] y = y.view(B, T, self.num_heads * self.head_dim) # final projection =\u0026gt; [B, T, dim] y = self.c_proj(y) return y # Manually add gradient checkpointing to Block class MQLRAFlexBlock(nn.Module): def __init__(self, dim: int, num_heads: int, max_seq_len: int, layer_idx: int, rank: int): super().__init__() # skip attention of blocks.7 (the 8th layer) by @YouJiacheng self.attn = MQLRAFlexAttention(dim, num_heads, max_seq_len, rank=rank) if layer_idx != 7 else None self.mlp = MLP(dim) self.lambdas = nn.Parameter(torch.tensor([1., 0.])) def forward(self, x: Tensor, ve: Tensor | None, x0: Tensor, block_mask: BlockMask): # Use checkpoint instead of direct execution return checkpoint_block(self._forward, x, ve, x0, block_mask) def _forward(self, x: Tensor, ve: Tensor | None, x0: Tensor, block_mask: BlockMask): x = self.lambdas[0] * x + self.lambdas[1] * x0 if self.attn is not None: x = x + self.attn(norm(x), ve, block_mask) x = x + self.mlp(norm(x)) return x class ModifiedGPTWithMQLRAFlex(nn.Module): \u0026#34;\u0026#34;\u0026#34;Modified GPT model that uses MqaLraFlashAttention with configurable rank\u0026#34;\u0026#34;\u0026#34; def __init__(self, vocab_size: int, num_layers: int, num_heads: int, model_dim: int, max_seq_len: int, rank: int): super().__init__() self.embed = nn.Embedding(vocab_size, model_dim) # token value embeddings by @KoszarskyB - inspired by @Grad62304977\u0026#39;s value residual implementation following https://arxiv.org/abs/2410.17897 # value embedding code simplification inspired by @ragulpr https://github.com/KellerJordan/modded-nanogpt/pull/78 self.value_embeds = nn.ModuleList([nn.Embedding(vocab_size, model_dim) for _ in range(3)]) self.blocks = nn.ModuleList([MQLRAFlexBlock(model_dim, num_heads, max_seq_len, i, rank) for i in range(num_layers)]) # there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. # suggested to me by @Grad62304977. this originates from Karpathy\u0026#39;s experiments. self.lm_head = CastedLinear(model_dim, next_multiple_of_n(vocab_size, n=128), use_fp8=True, x_s=(model_dim**0.5)/448, w_s=24/448, grad_s=1/448) self.lm_head.weight.detach().zero_() # @Grad62304977 # Add learnable skip connection weights for decoder layers assert num_layers % 2 == 0 self.skip_weights = nn.Parameter(torch.ones(num_layers//2)) self._last_loss_value = None def create_blockmasks(self, input_seq: Tensor, sliding_window_num_blocks: Tensor): BLOCK_SIZE = 64 docs = (input_seq == 50256).cumsum(0) def document_causal(b, h, q_idx, kv_idx): causal_mask = q_idx \u0026gt;= kv_idx document_mask = docs[q_idx] == docs[kv_idx] return causal_mask \u0026amp; document_mask def dense_to_ordered(dense_blockmask: Tensor): num_blocks = dense_blockmask.sum(dim=-1, dtype=torch.int32) indices = dense_blockmask.argsort(dim=-1, descending=False, stable=True).flip(-1).to(torch.int32) return num_blocks[None, None].contiguous(), indices[None, None].contiguous() # manual block mask creation by @YouJiacheng assert len(input_seq) % BLOCK_SIZE == 0 NUM_BLOCKS = len(input_seq) // BLOCK_SIZE block_idx = torch.arange(NUM_BLOCKS, dtype=torch.int32, device=\u0026#34;cuda\u0026#34;) causal_blockmask_any = block_idx[:, None] \u0026gt;= block_idx causal_blockmask_all = block_idx[:, None] \u0026gt; block_idx docs_low = docs.view(-1, BLOCK_SIZE)[:, 0].contiguous() docs_high = docs.view(-1, BLOCK_SIZE)[:, -1].contiguous() document_blockmask_any = (docs_low[:, None] \u0026lt;= docs_high) \u0026amp; (docs_high[:, None] \u0026gt;= docs_low) document_blockmask_all = (docs_low[:, None] == docs_high) \u0026amp; (docs_high[:, None] == docs_low) blockmask_any = causal_blockmask_any \u0026amp; document_blockmask_any blockmask_all = causal_blockmask_all \u0026amp; document_blockmask_all partial_kv_num_blocks, partial_kv_indices = dense_to_ordered(blockmask_any \u0026amp; ~blockmask_all) full_kv_num_blocks, full_kv_indices = dense_to_ordered(blockmask_all) def build_bm(window_size_blocks: Tensor) -\u0026gt; BlockMask: return BlockMask.from_kv_blocks( torch.clamp_max(partial_kv_num_blocks, torch.clamp_min(window_size_blocks - full_kv_num_blocks, 1)), partial_kv_indices, torch.clamp_max(full_kv_num_blocks, window_size_blocks - 1), full_kv_indices, BLOCK_SIZE=BLOCK_SIZE, mask_mod=document_causal, ) # Long-short SWA block masks by @leloykun \u0026amp; @YouJiacheng, adapated from suggestion by @Grad62304977, following Gemma 2 paper return build_bm(sliding_window_num_blocks), build_bm(sliding_window_num_blocks // 2) def forward(self, input_seq: Tensor, target_seq: Tensor, sliding_window_num_blocks: Tensor): assert input_seq.ndim == 1 ve = [value_embed(input_seq) for value_embed in self.value_embeds] # 012 ... 012 structure on token value embeddings by @YouJiacheng, improved on @leloykun\u0026#39;s U-net structure ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] assert len(ve) == len(self.blocks) long_bm, short_bm = self.create_blockmasks(input_seq, sliding_window_num_blocks) block_masks = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, long_bm] assert len(block_masks) == len(self.blocks) x = x0 = norm(self.embed(input_seq)[None]) # use of norm here by @Grad62304977 # U-net design by @brendanh0gan skip_connections = [] n = len(self.skip_weights) for i in range(len(self.blocks)): if i \u0026gt;= n: x = x + self.skip_weights[i - n] * skip_connections.pop() x = self.blocks[i](x, ve[i], x0, block_masks[i]) if i \u0026lt; n: skip_connections.append(x) x = norm(x) logits = self.lm_head(x).float() # @Grad62304977 added tanh softcapping following Gemma 2 paper, @KoszarskyB reduced it from 30 to 15, @YouJiacheng shifted it by +15 (2*sigmoid(2*x)=tanh(x)+1) logits = 30 * torch.sigmoid(logits / (7.5 * x.size(-1)**0.5)) loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target_seq, reduction=\u0026#39;sum\u0026#39; if self.training else \u0026#39;mean\u0026#39;) # If we\u0026#39;re in training mode, track the loss components if self.training and dist.get_rank() == 0: # Store the loss components for analysis self._last_loss_value = float(loss.item()) return loss def generate(self, idx, max_new_tokens, temperature=0.8, top_k=40): \u0026#34;\u0026#34;\u0026#34;Generate text from the model\u0026#34;\u0026#34;\u0026#34; self.eval() with torch.no_grad(): for _ in range(max_new_tokens): # Crop idx to manageable context if needed idx_cond = idx if idx.size(1) \u0026lt;= 1024 else idx[:, -1024:] # Get logits for next token prediction logits, _ = self._forward_generation(idx_cond) logits = logits[:, -1, :] / temperature # Focus on last position # Apply top-k filtering if top_k is not None: v, _ = torch.topk(logits, min(top_k, logits.size(-1))) logits[logits \u0026lt; v[:, [-1]]] = -float(\u0026#39;inf\u0026#39;) # Sample from distribution probs = F.softmax(logits, dim=-1) idx_next = torch.multinomial(probs, num_samples=1) # Append to sequence idx = torch.cat((idx, idx_next), dim=1) return idx def _forward_generation(self, input_seq): \u0026#34;\u0026#34;\u0026#34;Handle generation case where we only need logits\u0026#34;\u0026#34;\u0026#34; # Make sure input is properly formatted if input_seq.ndim == 2: # [B, T] batch_size = input_seq.size(0) if batch_size \u0026gt; 1: # Only use the first batch item for simplicity in generation input_seq = input_seq[0] else: input_seq = input_seq.view(-1) # Use a reasonable window size for generation sliding_window_num_blocks = torch.tensor(8, dtype=torch.int32, device=input_seq.device) # Ensure input is padded to BLOCK_SIZE multiple BLOCK_SIZE = 64 # Same as in create_blockmasks original_len = len(input_seq) remainder = original_len % BLOCK_SIZE if remainder != 0: pad_length = BLOCK_SIZE - remainder padding = torch.zeros(pad_length, dtype=input_seq.dtype, device=input_seq.device) input_seq = torch.cat([input_seq, padding]) # Get token embeddings for the padded sequence x = x0 = norm(self.embed(input_seq)[None]) # use of norm here matches forward() # Create block masks for attention with padded sequence long_bm, short_bm = self.create_blockmasks(input_seq, sliding_window_num_blocks) block_masks = [long_bm, short_bm, short_bm, short_bm, long_bm, short_bm, short_bm, long_bm, short_bm, short_bm, short_bm, long_bm] # Get value embeddings for the padded sequence ve = [value_embed(input_seq) for value_embed in self.value_embeds] ve = [ve[0], ve[1], ve[2]] + [None] * (len(self.blocks) - 6) + [ve[0], ve[1], ve[2]] # Process through blocks like in forward method skip_connections = [] n = len(self.skip_weights) for i in range(len(self.blocks)): if i \u0026gt;= n: x = x + self.skip_weights[i - n] * skip_connections.pop() # Critical fix: Only pass value embeddings to blocks that expect them current_ve = ve[i] if i \u0026lt; len(ve) and ve[i] is not None else None current_mask = block_masks[i] if i \u0026lt; len(block_masks) else None # Disable gradient checkpointing during generation if hasattr(self.blocks[i], \u0026#39;use_checkpoint\u0026#39;): original_checkpoint = self.blocks[i].use_checkpoint self.blocks[i].use_checkpoint = False x = self.blocks[i](x, current_ve, x0, current_mask) self.blocks[i].use_checkpoint = original_checkpoint else: x = self.blocks[i](x, current_ve, x0, current_mask) if i \u0026lt; n: skip_connections.append(x) # Final norm and project to logits x = norm(x) logits = self.lm_head(x).float() # Return logits for generation and the hidden state return logits, x flash/flex attention i was already deep down the rabbit hole, so i figured why not attempt to drop in flash/flex attention on top of the attention matmuls and see if we can get an even larger speed boost\n# https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_attn_triton.py # assume we have the flash attention triton definition: # *Experimental* implementation of FlashAttention in Triton. # Tested with triton==2.0.0.dev20221202. #... #... # flash_attn_func = FlashAttnFunc.apply # from flash_attn_triton import flash_attn_func class MQLRAFlashAttention(nn.Module): \u0026#34;\u0026#34;\u0026#34; Multi-Query Attention + Low-Rank factorization + FlashAttention. - We have `n_heads` separate Q transformations, each factorized into (W1Q_h, W2Q_h). - We have a *single* K factorization: (W1K, W2K) - We have a *single* V factorization: (W1V, W2V) - Then we call the standard Triton FlashAttention kernel on the resulting Q, K, V. Input shape: X: [B, L, D_in] We\u0026#39;ll produce an attention output shape of: [B, L, D_out * n_heads] (by concatenating the n_heads outputs). We store parameters for Q in `self.W1Q_heads[i], self.W2Q_heads[i]`, and single K, V in `self.W1K, W2K, W1V, W2V`. \u0026#34;\u0026#34;\u0026#34; def __init__(self, D_in, D_out, n_heads, rank, causal=False, scale=None): super().__init__() self.D_in = D_in self.D_out = D_out self.n_heads = n_heads self.rank = rank self.causal = causal if scale is not None: self.scale = float(scale) else: self.scale = 1.0 / math.sqrt(D_out) # Q: separate for each head # Each head: W1Q =\u0026gt; [D_in, rank], W2Q =\u0026gt; [rank, D_out] self.W1Q_heads = nn.ParameterList([ nn.Parameter(torch.randn(D_in, rank) * 0.02) for _ in range(n_heads) ]) self.W2Q_heads = nn.ParameterList([ nn.Parameter(torch.randn(rank, D_out) * 0.02) for _ in range(n_heads) ]) # K: single self.W1K = nn.Parameter(torch.randn(D_in, rank) * 0.02) self.W2K = nn.Parameter(torch.randn(rank, D_out) * 0.02) # V: single self.W1V = nn.Parameter(torch.randn(D_in, rank) * 0.02) self.W2V = nn.Parameter(torch.randn(rank, D_out) * 0.02) def forward(self, X, attn_bias=None): \u0026#34;\u0026#34;\u0026#34; X =\u0026gt; [B, L, D_in] attn_bias =\u0026gt; optional, shape broadcastible to [B, n_heads, L, L] if needed We\u0026#39;ll produce out =\u0026gt; [B, L, n_heads*D_out]. Implementation steps: 1) Build Q (n_heads) =\u0026gt; shape [B, L, n_heads, D_out] 2) Build K =\u0026gt; shape [B, L, D_out], replicate to [B, L, 1, D_out] 3) Build V =\u0026gt; shape [B, L, D_out], replicate to [B, L, 1, D_out] 4) Call flash-attn forward =\u0026gt; out =\u0026gt; shape [B, L, n_heads, D_out] 5) Reshape to [B, L, n_heads*D_out]. \u0026#34;\u0026#34;\u0026#34; B, L, _ = X.shape # Step 1) Q =\u0026gt; [B, L, n_heads, D_out] # We\u0026#39;ll compute each head\u0026#39;s Q in Python. Then stack along \u0026#34;n_heads\u0026#34;. # partial_Q_h = X @ W1Q_heads[h] =\u0026gt; [B,L,rank] # Q_h = partial_Q_h @ W2Q_heads[h] =\u0026gt; [B,L,D_out] Q_list = [] for h in range(self.n_heads): partial_q = X.matmul(self.W1Q_heads[h]) # [B, L, rank] q_h = partial_q.matmul(self.W2Q_heads[h]) # [B, L, D_out] Q_list.append(q_h) # stack =\u0026gt; [B, L, n_heads, D_out] Q = torch.stack(Q_list, dim=2) # Step 2) K =\u0026gt; [B,L,D_out], then expand to [B,L,1,D_out] partial_k = X.matmul(self.W1K) # [B,L,rank] K_ = partial_k.matmul(self.W2K) # [B,L,D_out] # For multi-query: we have a single K =\u0026gt; we can just unsqueeze dim=2 =\u0026gt; n_heads=1 # But if we want flash to see n_heads == self.n_heads, we replicate: # shape =\u0026gt; [B, L, n_heads, D_out] # but logically the data is the same. We\u0026#39;ll do a .unsqueeze(2).expand(...) # or we can do a .repeat_interleave, but that costs memory. Alternatively we can do # a trick: we pass n_heads=1 to flash, but that won\u0026#39;t match Q\u0026#39;s shape. So we must replicate: # This is the simplest approach: K = K_.unsqueeze(2).expand(B, L, self.n_heads, self.D_out) # Step 3) V =\u0026gt; same approach partial_v = X.matmul(self.W1V) # [B,L,rank] V_ = partial_v.matmul(self.W2V) # [B,L,D_out] V = V_.unsqueeze(2).expand(B, L, self.n_heads, self.D_out) # Step 4) flash-attn =\u0026gt; out =\u0026gt; [B, L, n_heads, D_out] # The flash_attn code expects Q =\u0026gt; [B, L, n_heads, headdim] # so we pass Q,K,V in that shape. We can pass an optional bias if it matches shape [B,n_heads,L,L]. Q = Q.contiguous() K = K.contiguous() V = V.contiguous() print(\u0026#34;Q shape:\u0026#34;, Q.shape, \u0026#34;stride:\u0026#34;, Q.stride()) print(\u0026#34;K shape:\u0026#34;, K.shape, \u0026#34;stride:\u0026#34;, K.stride()) print(\u0026#34;V shape:\u0026#34;, V.shape, \u0026#34;stride:\u0026#34;, V.stride()) out = flash_attn_func(Q, K, V, attn_bias, self.causal, self.scale) # out =\u0026gt; [B, L, nheads, D_out] # Step 5) reshape =\u0026gt; [B,L,n_heads*D_out] out = out.reshape(B, L, self.n_heads*self.D_out) return out class MQLRAFlexAttention(nn.Module): \u0026#34;\u0026#34;\u0026#34; Multi-Query Attention + Low-Rank factorization + FlashAttention. - We have `n_heads` separate Q transformations, each factorized into (W1Q_h, W2Q_h). - We have a *single* K factorization: (W1K, W2K) - We have a *single* V factorization: (W1V, W2V) - Then we call the standard Triton FlashAttention kernel on the resulting Q, K, V. Input shape: X: [B, L, D_in] We\u0026#39;ll produce an attention output shape of: [B, L, D_out * n_heads] (by concatenating the n_heads outputs). We store parameters for Q in `self.W1Q_heads[i], self.W2Q_heads[i]`, and single K, V in `self.W1K, W2K, W1V, W2V`. \u0026#34;\u0026#34;\u0026#34; def __init__(self, D_in, D_out, n_heads, rank, causal=False, scale=None): super().__init__() self.D_in = D_in self.D_out = D_out self.n_heads = n_heads self.rank = rank self.causal = causal if scale is not None: self.scale = float(scale) else: self.scale = 1.0 / math.sqrt(D_out) # Q: separate for each head # Each head: W1Q =\u0026gt; [D_in, rank], W2Q =\u0026gt; [rank, D_out] self.W1Q_heads = nn.ParameterList([ nn.Parameter(torch.randn(D_in, rank) * 0.02) for _ in range(n_heads) ]) self.W2Q_heads = nn.ParameterList([ nn.Parameter(torch.randn(rank, D_out) * 0.02) for _ in range(n_heads) ]) # K: single self.W1K = nn.Parameter(torch.randn(D_in, rank) * 0.02) self.W2K = nn.Parameter(torch.randn(rank, D_out) * 0.02) # V: single self.W1V = nn.Parameter(torch.randn(D_in, rank) * 0.02) self.W2V = nn.Parameter(torch.randn(rank, D_out) * 0.02) def forward(self, X, attn_bias=None): \u0026#34;\u0026#34;\u0026#34; X =\u0026gt; [B, L, D_in] attn_bias =\u0026gt; optional, shape broadcastible to [B, n_heads, L, L] if needed We\u0026#39;ll produce out =\u0026gt; [B, L, n_heads*D_out]. Implementation steps: 1) Build Q (n_heads) =\u0026gt; shape [B, L, n_heads, D_out] 2) Build K =\u0026gt; shape [B, L, D_out], replicate to [B, L, 1, D_out] 3) Build V =\u0026gt; shape [B, L, D_out], replicate to [B, L, 1, D_out] 4) Call flash-attn forward =\u0026gt; out =\u0026gt; shape [B, L, n_heads, D_out] 5) Reshape to [B, L, n_heads*D_out]. \u0026#34;\u0026#34;\u0026#34; B, L, _ = X.shape # Step 1) Q =\u0026gt; [B, L, n_heads, D_out] # We\u0026#39;ll compute each head\u0026#39;s Q in Python. Then stack along \u0026#34;n_heads\u0026#34;. # partial_Q_h = X @ W1Q_heads[h] =\u0026gt; [B,L,rank] # Q_h = partial_Q_h @ W2Q_heads[h] =\u0026gt; [B,L,D_out] Q_list = [] for h in range(self.n_heads): partial_q = X.matmul(self.W1Q_heads[h]) # [B, L, rank] q_h = partial_q.matmul(self.W2Q_heads[h]) # [B, L, D_out] Q_list.append(q_h) # stack =\u0026gt; [B, L, n_heads, D_out] Q = torch.stack(Q_list, dim=2) # Step 2) K =\u0026gt; [B,L,D_out], then expand to [B,L,1,D_out] partial_k = X.matmul(self.W1K) # [B,L,rank] K_ = partial_k.matmul(self.W2K) # [B,L,D_out] # For multi-query: we have a single K =\u0026gt; we can just unsqueeze dim=2 =\u0026gt; n_heads=1 # But if we want flash to see n_heads == self.n_heads, we replicate: # shape =\u0026gt; [B, L, n_heads, D_out] # but logically the data is the same. We\u0026#39;ll do a .unsqueeze(2).expand(...) # or we can do a .repeat_interleave, but that costs memory. Alternatively we can do # a trick: we pass n_heads=1 to flash, but that won\u0026#39;t match Q\u0026#39;s shape. So we must replicate: # This is the simplest approach: K = K_.unsqueeze(2).expand(B, L, self.n_heads, self.D_out) # Step 3) V =\u0026gt; same approach partial_v = X.matmul(self.W1V) # [B,L,rank] V_ = partial_v.matmul(self.W2V) # [B,L,D_out] V = V_.unsqueeze(2).expand(B, L, self.n_heads, self.D_out) # Step 4) flash-attn =\u0026gt; out =\u0026gt; [B, L, n_heads, D_out] # The flash_attn code expects Q =\u0026gt; [B, L, n_heads, headdim] # so we pass Q,K,V in that shape. We can pass an optional bias if it matches shape [B,n_heads,L,L]. Q = Q.contiguous() K = K.contiguous() V = V.contiguous() print(\u0026#34;Q shape:\u0026#34;, Q.shape, \u0026#34;stride:\u0026#34;, Q.stride()) print(\u0026#34;K shape:\u0026#34;, K.shape, \u0026#34;stride:\u0026#34;, K.stride()) print(\u0026#34;V shape:\u0026#34;, V.shape, \u0026#34;stride:\u0026#34;, V.stride()) out = flash_attn_func(Q, K, V, attn_bias, self.causal, self.scale) # out =\u0026gt; [B, L, nheads, D_out] # Step 5) reshape =\u0026gt; [B,L,n_heads*D_out] out = out.reshape(B, L, self.n_heads*self.D_out) return out class MqaLraFlexAttention(nn.Module): \u0026#34;\u0026#34;\u0026#34; Multi-Query Attention + Low-Rank factorization + FlexAttention. - We have `n_heads` separate Q transformations, each factorized into (W1Q_h, W2Q_h). - We have a *single* K factorization: (W1K, W2K) - We have a *single* V factorization: (W1V, W2V) - Then we call the Triton-based `flex_attention` kernel on the resulting Q, K, V, similar to how `CausalSelfAttention` uses flex_attention. Input shape: X: [B, L, D_in] Output shape: [B, L, D_out * n_heads] \u0026#34;\u0026#34;\u0026#34; def __init__(self, D_in, D_out, n_heads, rank, causal=False, scale=None): super().__init__() self.D_in = D_in self.D_out = D_out self.n_heads = n_heads self.rank = rank self.causal = causal # If scale is not provided, use 1/sqrt(D_out) by default self.scale = float(scale) if scale is not None else (1.0 / math.sqrt(D_out)) # Q: separate for each head # Each head: W1Q =\u0026gt; [D_in, rank], W2Q =\u0026gt; [rank, D_out] self.W1Q_heads = nn.ParameterList( nn.Parameter(torch.randn(D_in, rank) * 0.02) for _ in range(n_heads) ) self.W2Q_heads = nn.ParameterList( nn.Parameter(torch.randn(rank, D_out) * 0.02) for _ in range(n_heads) ) # K: single self.W1K = nn.Parameter(torch.randn(D_in, rank) * 0.02) self.W2K = nn.Parameter(torch.randn(rank, D_out) * 0.02) # V: single self.W1V = nn.Parameter(torch.randn(D_in, rank) * 0.02) self.W2V = nn.Parameter(torch.randn(rank, D_out) * 0.02) def forward(self, X, block_mask=None, attn_bias=None): \u0026#34;\u0026#34;\u0026#34; X: [B, L, D_in] block_mask: the `BlockMask` (if needed by flex_attention). attn_bias: optional, shape broadcastible to [B, n_heads, L, L] if needed. Returns: [B, L, n_heads * D_out] \u0026#34;\u0026#34;\u0026#34; B, L, _ = X.shape # If your flex_attention kernel only supports B=1, you may want: # assert B == 1, \u0026#34;We only support batch_size=1 for flex_attention.\u0026#34; # 1) Compute Q per head =\u0026gt; [B, L, n_heads, D_out] Q_list = [] for h in range(self.n_heads): partial_q = X @ self.W1Q_heads[h] # [B, L, rank] q_h = partial_q @ self.W2Q_heads[h] # [B, L, D_out] Q_list.append(q_h) Q = torch.stack(Q_list, dim=2) # =\u0026gt; [B, L, n_heads, D_out] # 2) Compute K =\u0026gt; [B, L, D_out] partial_k = X @ self.W1K # [B, L, rank] K_ = partial_k @ self.W2K # [B, L, D_out] # Expand K =\u0026gt; [B, L, n_heads, D_out] for multi-query K = K_.unsqueeze(2).expand(B, L, self.n_heads, self.D_out) # 3) Compute V =\u0026gt; [B, L, D_out] partial_v = X @ self.W1V # [B, L, rank] V_ = partial_v @ self.W2V # [B, L, D_out] V = V_.unsqueeze(2).expand(B, L, self.n_heads, self.D_out) # 4) Call flex_attention # flex_attention expects [B, n_heads, L, head_dim], so we transpose(1,2). # Then we get back [B, n_heads, L, head_dim], which we transpose back to [B, L, n_heads, head_dim]. Q_t = Q.transpose(1, 2) # =\u0026gt; [B, n_heads, L, D_out] K_t = K.transpose(1, 2) # =\u0026gt; [B, n_heads, L, D_out] V_t = V.transpose(1, 2) # =\u0026gt; [B, n_heads, L, D_out] # flex attention here out_t = flex_attention(Q_t, K_t, V_t, block_mask=block_mask, scale=self.scale) # out_t =\u0026gt; [B, n_heads, L, D_out] # 5) Transpose back =\u0026gt; [B, L, n_heads, D_out], then flatten out = out_t.transpose(1, 2) # =\u0026gt; [B, L, n_heads, D_out] out = out.reshape(B, L, self.n_heads * self.D_out) return out one more change: i figued one simple change would be to fuse the k and v matrices into a single matrix. this is a bit more efficient, and it\u0026rsquo;s a bit faster on training (less calls to GPU)\nwhile in some kind of forward pass we could have the for loop parallelized for the original implementation. either way, i wanted to see what would happen.\nclass FusedMQLRAFlexAttention(nn.Module): \u0026#34;\u0026#34;\u0026#34; Multi-Query Low-Rank Attention + FlexAttention with fused K and V matrices for improved computational efficiency. This version maintains the low-rank factorization of Q matrices for each head, but uses fused matrices for K and V to reduce computation during the forward pass. The Q matrices still use the (W1Q_h, W2Q_h) factorization per head, but K and V use direct projection matrices. \u0026#34;\u0026#34;\u0026#34; def __init__( self, dim: int, num_heads: int, max_seq_len: int, head_dim=128, rank=16, # low-rank dimension used for initialization and Q factorization scale=0.12, # default scale from your code ): super().__init__() self.num_heads = num_heads self.head_dim = head_dim self.rank = rank self.scale = scale # total hidden dimension after merging heads hdim = num_heads * head_dim # initialization approach std = 0.5 * (dim ** -0.5) bound = (3 ** 0.5) * std # We keep lambdas for mixing in \u0026#39;ve\u0026#39; self.lambdas = nn.Parameter(torch.tensor([0.5, 0.5])) # Rotary embedding self.rotary = Rotary(head_dim, max_seq_len) # final projection after attention self.c_proj = CastedLinear(hdim, dim) self.c_proj.weight.detach().zero_() # Q: separate (W1Q, W2Q) for each head (keep low-rank factorization for Q) self.W1Q_heads = nn.ParameterList([ nn.Parameter(torch.empty(dim, rank).uniform_(-bound, bound)) for _ in range(num_heads) ]) self.W2Q_heads = nn.ParameterList([ nn.Parameter(torch.empty(rank, head_dim).uniform_(-bound, bound)) for _ in range(num_heads) ]) # Fused K: instead of W1K and W2K, use a single matrix initialized as their product W1K_init = torch.empty(dim, rank).uniform_(-bound, bound) W2K_init = torch.empty(rank, head_dim).uniform_(-bound, bound) self.K_matrix = nn.Parameter(W1K_init @ W2K_init) # Fused V: instead of W1V and W2V, use a single matrix initialized as their product W1V_init = torch.empty(dim, rank).uniform_(-bound, bound) W2V_init = torch.empty(rank, head_dim).uniform_(-bound, bound) self.V_matrix = nn.Parameter(W1V_init @ W2V_init) def forward(self, x: torch.Tensor, ve: torch.Tensor | None, block_mask: BlockMask): \u0026#34;\u0026#34;\u0026#34; x: [B, T, dim] ve: optional value-embedding, shape must broadcast with V block_mask: for flex_attention Returns: [B, T, dim] after c_proj \u0026#34;\u0026#34;\u0026#34; B, T, dim = x.shape # If your flex_attention kernel only works for B=1: assert B == 1, \u0026#34;Must use batch size = 1 for FlexAttention\u0026#34; # 1) Compute Q, K, V in multi-query style # Q =\u0026gt; [B, T, num_heads, head_dim] # single K =\u0026gt; broadcast to [B, T, num_heads, head_dim] # single V =\u0026gt; broadcast similarly # Q (separate per head with low-rank factorization) Q_list = [] for h in range(self.num_heads): partial_q = x.matmul(self.W1Q_heads[h]) # [B, T, rank] q_h = partial_q.matmul(self.W2Q_heads[h]) # [B, T, head_dim] Q_list.append(q_h) # stack =\u0026gt; [B, T, num_heads, head_dim] q = torch.stack(Q_list, dim=2) # K using fused matrix - one matmul instead of two K_ = x.matmul(self.K_matrix) # [B, T, head_dim] # expand to match num_heads =\u0026gt; [B, T, num_heads, head_dim] k = K_.unsqueeze(2).expand(B, T, self.num_heads, self.head_dim) # V using fused matrix - one matmul instead of two v_ = x.matmul(self.V_matrix) # [B, T, head_dim] v = v_.unsqueeze(2).expand(B, T, self.num_heads, self.head_dim) # 2) norm Q, K + rotary q = norm(q) k = norm(k) q = self.rotary(q) k = self.rotary(k) # 3) incorporate \u0026#39;ve\u0026#39; into v if provided if ve is not None: v = self.lambdas[0] * v + self.lambdas[1] * ve.view_as(v) else: v = self.lambdas[0] * v # 4) call flex_attention q_t = q.transpose(1, 2) # =\u0026gt; [B, num_heads, T, head_dim] k_t = k.transpose(1, 2) v_t = v.transpose(1, 2) y_t = flex_attention(q_t, k_t, v_t, block_mask=block_mask, scale=self.scale) # y_t =\u0026gt; [B, num_heads, T, head_dim] # transpose back =\u0026gt; [B, T, num_heads, head_dim] y = y_t.transpose(1, 2).contiguous() # 5) reshape =\u0026gt; [B, T, (num_heads * head_dim)] y = y.view(B, T, self.num_heads * self.head_dim) # final projection =\u0026gt; [B, T, dim] y = self.c_proj(y) return y class FusedMQLRAFlexBlock(nn.Module): def __init__(self, dim: int, num_heads: int, max_seq_len: int, layer_idx: int, rank: int): super().__init__() # skip attention of blocks.7 (the 8th layer) by @YouJiacheng self.attn = FusedMQLRAFlexAttention(dim, num_heads, max_seq_len, rank=rank) if layer_idx != 7 else None self.mlp = MLP(dim) self.lambdas = nn.Parameter(torch.tensor([1., 0.])) def forward(self, x: Tensor, ve: Tensor | None, x0: Tensor, block_mask: BlockMask): # Use checkpoint instead of direct execution return checkpoint_block(self._forward, x, ve, x0, block_mask) def _forward(self, x: Tensor, ve: Tensor | None, x0: Tensor, block_mask: BlockMask): x = self.lambdas[0] * x + self.lambdas[1] * x0 if self.attn is not None: x = x + self.attn(norm(x), ve, block_mask) x = x + self.mlp(norm(x)) return x quick and dirty training attempt w/ fused MQLRFlexAttention. basically, in order to roughly get the same performance as the original modded nanogpt, we need to train for many more steps since 1) i don\u0026rsquo;t have an H100 and 1a) we can train with float8, so basically we have to shrink the sequence length + increase the number of iterations (which sucks, but whatever, we\u0026rsquo;re just throwing boredom speghetti at the wall here)\nyou can sort of train with the fused kernel, then run inference with the original kernel, or not, or not not.\nnothing wild/ground breaking, just a bit of stream of consciousness.\n","date":"2025-03-15","permalink":"/posts/mqlra/","title":"Multi Query Low Rank Attention"},{"content":"==================\nThe Constitution of Constitutional AI for the Constitution of your Constitution A lawyer\u0026rsquo;s duty is to represent their client, free of bias. Their interpretation of law is rendered as an objective function of the end goal of their current client - even if that means it\u0026rsquo;s abhorrent to their own foundational values, morals, and ethics.\nA major problem with models like o1/gpt-4o/claude et al and their \u0026lsquo;the bias of the internet\u0026rsquo; is that their moral constitution is not always going to be aligned with the current objective of a lawyer representing a client. This duty and it\u0026rsquo;s relationship to law itself is an incredibly human concept that - unless it\u0026rsquo;s possible behind the doors of big labs - just isn\u0026rsquo;t possible without being open-sourced.\nThe following of rules is one thing but the grey area of what a rule is and what\u0026rsquo;s best in the given moment is human.\nI could go on and on, but here\u0026rsquo;s a practical example using 2A as a guidepost:\n we\u0026rsquo;ll find out the attention heads that activate during pro/anti 2A prompts vs. a control we\u0026rsquo;ll ablate these heads and see if the model still has 2A knowledge we\u0026rsquo;ll show a silly example of how to apply this in an application  Second Amendment Test Let\u0026rsquo;s see if we can trigger the model\u0026rsquo;s 2A knowledge using Heller and Miller. (pro 2A Heller would later overturn anti 2A Miller) - totally don\u0026rsquo;t have to read the following text, but will include both for posterity.\nDistrict of Columbia v. Heller, 554 U.S. 570 (2008)  We start therefore with a strong presumption that the Second Amendment right is exercised individually and belongs to all Americans. b. “Keep and bear Arms.” We move now from the holder of the right—“the people”—to the substance of the right: “to keep and bear Arms.” Before addressing the verbs “keep” and “bear,” we interpret their object: “Arms.” The 18th-century meaning is no different from the meaning today. The 1773 edition of Samuel Johnson’s dictionary defined “arms” as “weapons of offence, or armour of defence.” 1 Dictionary of the English Language 107 (4th ed.) (hereinafter Johnson). Timothy Cunningham’s important 1771 legal dictionary defined “arms” as “any thing that a man wears for his defence, or takes into his hands, or useth in wrath to cast at or strike another.” 1 A New and Complete Law Dictionary (1771); see also N. Webster, American Dictionary of the English Language (1828) (reprinted 1989) (hereinafter Webster) (similar). The term was applied, then as now, to weapons that were not specifically designed for military use and were not employed in a military capacity. For instance, Cunningham’s legal dictionary gave as an example of usage: “Servants and labourers shall use bows and arrows on Sundays, \u0026amp;c. and not bear other arms.” See also, e.g., An Act for the trial of Negroes, 1797 Del. Laws ch. XLIII, §6, p. 104, in 1 First Laws of the State of Delaware 102, 104 (J. Cushing ed. 1981 (pt. 1)); see generally State v. Duke, 42 Tex. 455, 458 (1874) (citing decisions of state courts construing “arms”). Although one founding-era thesaurus limited “arms” (as opposed to “weapons”) to “instruments of offence generally made use of in war,” even that source stated that all firearms constituted “arms.” 1 J. Trusler, The Distinction Between Words Esteemed Synonymous in the English Language 37 (1794) (emphasis added). The phrase “keep arms” was not prevalent in the written documents of the founding period that we have found, but there are a few examples, all of which favor viewing the right to “keep Arms” as an individual right unconnected with militia service. William Blackstone, for example, wrote that Catholics convicted of not attending service in the Church of England suffered certain penalties, one of which was that they were not permitted to “keep arms in their houses.” 4 Commentaries on the Laws of England 55 (1769) (hereinafter Blackstone); see also 1 W. \u0026amp; M., c. 15, §4, in 3 Eng. Stat. at Large 422 (1689) (“[N]o Papist … shall or may have or keep in his House … any Arms … ”); 1 Hawkins, Treatise on the Pleas of the Crown 26 (1771) (similar). Petitioners point to militia laws of the founding period that required militia members to “keep” arms in connection with militia service, and they conclude from this that the phrase “keep Arms” has a militia-related connotation. See Brief for Petitioners 16–17 (citing laws of Delaware, New Jersey, and Virginia). This is rather like saying that, since there are many statutes that authorize aggrieved employees to “file complaints” with federal agencies, the phrase “file complaints” has an employment-related connotation. “Keep arms” was simply a common way of referring to possessing arms, for militiamen and everyone else.[Footnote 7] At the time of the founding, as now, to “bear” meant to “carry.” See Johnson 161; Webster; T. Sheridan, A Complete Dictionary of the English Language (1796); 2 Oxford English Dictionary 20 (2d ed. 1989) (hereinafter Oxford). When used with “arms,” however, the term has a meaning that refers to carrying for a particular purpose—confrontation. In Muscarello v. United States, 524 U. S. 125 (1998), in the course of analyzing the meaning of “carries a firearm” in a federal criminal statute, Justice Ginsburg wrote that “[s]urely a most familiar meaning is, as the Constitution’s Second Amendment … indicate[s]: ‘wear, bear, or carry … upon the person or in the clothing or in a pocket, for the purpose … of being armed and ready for offensive or defensive action in a case of conflict with another person.’ ” Id., at 143 (dissenting opinion) (quoting Black’s Law Dictionary 214 (6th ed. 1998)). We think that Justice Ginsburg accurately captured the natural meaning of “bear arms.” Although the phrase implies that the carrying of the weapon is for the purpose of “offensive or defensive action,” it in no way connotes participation in a structured military organization. From our review of founding-era sources, we conclude that this natural meaning was also the meaning that “bear arms” had in the 18th century. In numerous instances, “bear arms” was unambiguously used to refer to the carrying of weapons outside of an organized militia. The most prominent examples are those most relevant to the Second Amendment: Nine state constitutional provisions written in the 18th century or the first two decades of the 19th, which enshrined a right of citizens to “bear arms in defense of themselves and the state” or “bear arms in defense of himself and the state.” [Footnote 8] It is clear from those formulations that “bear arms” did not refer only to carrying a weapon in an organized military unit. Justice James Wilson interpreted the Pennsylvania Constitution’s arms-bearing right, for example, as a recognition of the natural right of defense “of one’s person or house”—what he called the law of “self preservation.” 2 Collected Works of James Wilson 1142, and n. x (K. Hall \u0026amp; M. Hall eds. 2007) (citing Pa. Const., Art. IX, §21 (1790)); see also T. Walker, Introduction to American Law 198 (1837) (“Thus the right of self-defence [is] guaranteed by the [Ohio] constitution”); see also id., at 157 (equating Second Amendment with that provision of the Ohio Constitution). That was also the interpretation of those state constitutional provisions adopted by pre-Civil War state courts.[Footnote 9] These provisions demonstrate—again, in the most analogous linguistic context—that “bear arms” was not limited to the carrying of arms in a militia. The phrase “bear Arms” also had at the time of the founding an idiomatic meaning that was significantly different from its natural meaning: “to serve as a soldier, do military service, fight” or “to wage war.” See Linguists’ Brief 18; post, at 11 (Stevens, J., dissenting). But it unequivocally bore that idiomatic meaning only when followed by the preposition “against,” which was in turn followed by the target of the hostilities. See 2 Oxford 21. (That is how, for example, our Declaration of Independence ¶28, used the phrase: “He has constrained our fellow Citizens taken Captive on the high Seas to bear Arms against their Country … .”) Every example given by petitioners’ amici for the idiomatic meaning of “bear arms” from the founding period either includes the preposition “against” or is not clearly idiomatic. See Linguists’ Brief 18–23. Without the preposition, “bear arms” normally meant (as it continues to mean today) what Justice Ginsburg’s opinion in Muscarello said.` \u0026hellip;\n United States v. Miller, 307 U.S. 174 (1939) (was essentially overturned by Heller above)  under the Harrison Narcotic Act [Footnote 2] \u0026ndash; United States v. Jin Fuey Moy (1916), 241 U. S. 394, United States v. Doremus (1919), 249 U. S. 86, 249 U. S. 94; Linder v. United States (1925), 268 U. S. 5; Alston v. United States (1927), 274 U. S. 289; Nigro v. United States (1928), 276 U. S. 332 \u0026ndash; the objection that the Act usurps police power reserved to the States is plainly untenable. In the absence of any evidence tending to show that possession or use of a \u0026ldquo;shotgun having a barrel of less than eighteen inches in length\u0026rdquo; at this time has some reasonable relationship to the preservation or efficiency of a well regulated militia, we cannot say that the Second Amendment guarantees the right to keep and bear such an instrument. Certainly it is not within judicial notice that this weapon is any part of the ordinary military equipment, or that its use could contribute to the common defense. Aymette v. State, 2 Humphreys (Tenn.) 154, 158. The Constitution, as originally adopted, granted to the Congress power \u0026ldquo;To provide for calling forth the Militia to execute the Laws of the Union, suppress Insurrections and repel Invasions; To provide for organizing, arming, and disciplining, the Militia, and for governing such Part of them as may be employed in the Service of the United States, reserving to the States respectively, the Appointment of the Officers, and the Authority of training the Militia according to the discipline prescribed by Congress.\u0026rdquo; With obvious purpose to assure the continuation and render possible the effectiveness of such forces, the declaration and guarantee of the Second Amendment were made. It must be interpreted and applied with that end in view. √The Militia which the States were expected to maintain and train is set in contrast with Troops which they were forbidden to keep without the consent of Congress. The sentiment of the time strongly disfavored standing armies; the common view was that adequate defense of country and laws could be secured through the Militia \u0026ndash; civilians primarily, soldiers on occasion. The signification attributed to the term Militia appears from the debates in the Convention, the history and legislation of Colonies and States, and the writings of approved commentators. These show plainly enough that the Militia comprised all males physically capable of acting in concert for the common defense. \u0026ldquo;A body of citizens enrolled for military discipline.\u0026rdquo; And further, that ordinarily, when called for service these men were expected to appear bearing arms supplied by themselves and of the kind in common use at the time. Blackstone\u0026rsquo;s Commentaries, Vol. 2, Ch. 13, p. 409 points out \u0026ldquo;that king Alfred first settled a national militia in this kingdom,\u0026rdquo; and traces the subsequent development and use of such forces. Adam Smith\u0026rsquo;s Wealth of Nations, Book V, Ch. 1, contains an extended account of the Militia. It is there said: \u0026ldquo;Men of republican principles have been jealous of a standing army as dangerous to liberty.\u0026rdquo; \u0026ldquo;In a militia, the character of the labourer, artificer, or tradesman, predominates over that of the soldier: in a standing army, that of the soldier predominates over every other character, and in this distinction seems to consist the essential difference between those two different species of military force.\u0026rdquo; \u0026quot;The American Colonies In The 17th Century,\u0026rdquo; Osgood, Vol. 1, ch. XIII, affirms in reference to the early system of defense in New England \u0026ldquo;In all the colonies, as in England, the militia system was based on the principle of the assize of arms. This implied the general obligation of all adult male inhabitants to possess arms, and, with certain exceptions, to cooperate in the work of defence.\u0026rdquo; \u0026ldquo;The possession of arms also implied the possession of ammunition, and the authorities paid quite as much attention to the latter as to the former.\u0026rdquo; \u0026ldquo;A year later [1632] it was ordered that any single man who had not furnished himself with arms might be put out to service, and this became a permanent part of the legislation of the colony [Massachusetts].\u0026rdquo; Also, \u0026ldquo;Clauses intended to insure the possession of arms and ammunition by all who were subject to military service appear in all the important enactments concerning military affairs. Fines were the penalty for delinquency, whether of towns or individuals. According to the usage of the times, the infantry of Massachusetts consisted of pikemen and musketeers. The law, as enacted in 1649 and thereafter, provided that each of the former should be armed with a pike, corselet, head-piece, sword, and knapsack. The musketeer should carry a \u0026lsquo;good fixed musket,\u0026rsquo; not under bastard musket bore, not less than three feet, nine inches, nor more than four feet three inches in length, a priming wire, scourer, and mould, a sword, rest, bandoleers, one pound of powder, twenty bullets, and two fathoms of match. The law also required that two-thirds of each company should be musketeers.\u0026ldquo;`\n (the control) TITLE GUARANTY \u0026amp; TRUST CO. OF SCRANTON, PA. v. CRANE CO. , 219 U.S. 24 (1910)  This is an action brought under the act of August 13, 1894, chap. 280, 28 Stat. at L. 278, U. S. Comp. Stat. 1901, p. 2523, as amended by the act of February 24, 1905, chap. 778, 33 Stat. at L. 811, U. S. Comp. Stat. Supp. 1909, p. 948, upon a bond given to the United States, as required by that act. The contract to secure which the bond was given was a contract by the Puget Sound Engine Works to build and deliver a single screw wooden steamer for the United States, and the main question in the case is whether the statute applies to a contract for such a chattel. If not, parties like the plaintiffs, who furnished labor or materials for the work, have no standing to maintain the suit. We proceed, as soon as may be, to dispose of that question, leaving details and minor objections to be taken up later in turn. It was raised by demurrer to the declaration, and subsequently by what was entitled an affirmative defense pleaded by the surety and a demurrer by the plaintiffs. The decision was for the plaintiffs, against the surety, in the circuit court of appeals. 89 C. C. A. 618, 163 Fed. 168.\n  The amended statute requires any person \u0026lsquo;entering into a formal contract with the United States for the construction of any public building, or the prosecution and completion of any public work, or for repairs upon any public building or public work, . . . to execute the usual penal bond . . . with the additional obligation that such contractor or contractors shall promptly make payments to all persons supplying him or them with labor and materials in the prosecution of the work.\u0026rsquo; It gives any person who has furnished labor or materials used in the construction or repair of any public work, which have not been paid for, the right to intervene in a suit upon the bond. In short, besides securing the United States, the act is intended to protect persons furnishing materials or labor \u0026lsquo;for the construction of public works,\u0026rsquo; as the title [219 U.S. 24, 32] declares. The question narrows itself accordingly to whether the steamer was a \u0026lsquo;public work\u0026rsquo; within the meaning of the words as used.\n  As a preliminary to the answer, it is relevant to mention that by article 3 of the contract, partial payments are provided for as the \u0026lsquo;labor and materials furnished\u0026rsquo; equal certain percentages of the total, and that by article 4 \u0026lsquo;the portion of the vessel completed and paid for under said method of partial payments shall become the property of the United States,\u0026rsquo; although the contractor remains responsible for the care of the portion paid for, and by article 2 there is to be a final test of the vessel when completed. The vessel has been built and accepted, and is now in possession of the United States. Notwithstanding these facts, it was argued that the statute did not apply to the contract, because the laborers and materials had a lien by the state law; and that, even if the statute applied, they had lost their rights by not asserting them before the delivery of the vessel, as before that, it is said, the title did not pass to the United States. Among other things, this ended the right to subrogation that the surety might have claimed. But the very recent decision in United States v. Ansonia Brass \u0026amp; Copper Co. 218 U.S. 452 , 54 L. ed. 1107, 31 Sup. Ct. Rep. 49 [Nov. 28, 1910] establishes that the title to the completed portion of the vessel passed, as provided in article 4, and that the laborers and materialmen could not have asserted the lien supposed to exist.\n declare our model using TransformerLens I modified TransformerLens to accept deepseek\u0026rsquo;s qwen distills so we can hack this with a reasoning model:\nmodel_name = \u0026#34;deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B\u0026#34; model = HookedTransformer.from_pretrained(model_name) device = torch.device(\u0026#34;cuda\u0026#34; if torch.cuda.is_available() else \u0026#34;cpu\u0026#34;) model.to(device) tokenize and split In order to use this text and measure the affect of various attention heads, we will tokenize and split the text into 20 equal chunks.\n# pro = heller # anti = miller # control = title/trust tokens = model.tokenizer.encode(pro) chunk_size = len(tokens) // 20 # Calculate size for 20 equal chunks pro_prompts = [] for i in range(0, len(tokens), chunk_size): chunk = tokens[i:i + chunk_size] pro_prompts.append(chunk) # Handle any remaining text by adding to last chunk if len(pro_prompts) \u0026gt; 20: pro_prompts[19] += pro_prompts[20] pro_prompts = pro_prompts[:20] tokens = model.tokenizer.encode(anti) chunk_size = len(tokens) // 20 anti_prompts = [] for i in range(0, len(tokens), chunk_size): chunk = tokens[i:i + chunk_size] anti_prompts.append(chunk) if len(anti_prompts) \u0026gt; 20: anti_prompts[19] += anti_prompts[20] anti_prompts = anti_prompts[:20] tokens = model.tokenizer.encode(control) chunk_size = len(tokens) // 20 control_prompts = [] for i in range(0, len(tokens), chunk_size): chunk = tokens[i:i + chunk_size] control_prompts.append(chunk) if len(control_prompts) \u0026gt; 20: control_prompts[19] += control_prompts[20] control_prompts = control_prompts[:20] two_A_prompts = pro_prompts + anti_prompts extract the attention information Very simply, we\u0026rsquo;ll provide the model with hand-picked tokens that are related to 2A and we\u0026rsquo;ll measure how much attention they get.\n# 2A keywords of interest target_tokens = [\u0026#34;Amendment\u0026#34;, \u0026#34;amendment\u0026#34;, \u0026#34;Militia\u0026#34;, \u0026#34;militia\u0026#34;, \u0026#34;Arms\u0026#34;, \u0026#34;arms\u0026#34;] n_layers = model.cfg.n_layers n_heads = model.cfg.n_heads twoA_heads_attn = torch.zeros((n_layers, n_heads), dtype=torch.float, device=device) control_heads_attn = torch.zeros((n_layers, n_heads), dtype=torch.float, device=device) # Utility function to get the average attention to target tokens for a single prompt def average_attention_to_targets(prompt_text: str, target_tokens, model): \u0026#34;\u0026#34;\u0026#34; Returns a shape: [n_layers, n_heads] tensor, where each element is the average attention to `target_tokens` in the prompt. \u0026#34;\u0026#34;\u0026#34; # run once with cache _, cache = model.run_with_cache(prompt_text) tokens_str = model.to_str_tokens(model.to_tokens(prompt_text, prepend_bos=False)[0]) # indices for tokens that match token_indices = [i for i, t in enumerate(tokens_str) if t.strip() in target_tokens] # accumulate a result for each layer/head result = torch.zeros((n_layers, n_heads), dtype=torch.float, device=device) # if no target words found, return zeros if len(token_indices) == 0: return result seq_len = len(tokens_str) # total tokens in this prompt # for each layer, get attention and sum up for layer in range(n_layers): attn_key = f\u0026#34;blocks.{layer}.attn.hook_pattern\u0026#34; # shape: [batch=1, n_heads, seq_len, seq_len] # remove batch dim -\u0026gt; [n_heads, seq_len, seq_len] attn_tensor = cache[attn_key][0] # for each head, measure attention from all query positions to target_indices for head in range(n_heads): # attn_tensor[head] is shape [seq_len, seq_len] # sum over entire row dimension (all query tokens) # focusing on columns that match our target_indices. # e.g. sum up attn[head, query, idx in token_indices] and then / (seq_len * len(token_indices)) total_attn = 0.0 for query_pos in range(seq_len): # sum up the attention to each target token for idx in token_indices: total_attn += attn_tensor[head, query_pos, idx].item() # compute the average denom = (seq_len * len(token_indices)) mean_attn = total_attn / denom result[layer, head] = mean_attn return result Run the above function on our 2A prompts and control prompts.\n# accumulate attention for text in two_A_prompts: text = model.tokenizer.decode(text) attn_res = average_attention_to_targets(text, target_tokens, model) twoA_heads_attn += attn_res # average across prompts twoA_heads_attn /= max(1, len(two_A_prompts)) # same for control prompts for text in control_prompts: text = model.tokenizer.decode(text) attn_res = average_attention_to_targets(text, target_tokens, model) control_heads_attn += attn_res # average across control prompts control_heads_attn /= max(1, len(control_prompts)) compute difference and examine which heads stand out diff_heads = twoA_heads_attn - control_heads_attn head_diffs = [] for layer_idx in range(n_layers): for head_idx in range(n_heads): val = diff_heads[layer_idx, head_idx].item() head_diffs.append(((layer_idx, head_idx), val)) # sort descending by difference head_diffs.sort(key=lambda x: x[1], reverse=True) print(\u0026#34;Heads with highest 2A attention difference:\u0026#34;) for i in range(5): (L, H), val = head_diffs[i] print(f\u0026#34; Layer={L}, Head={H} =\u0026gt; difference={val:.3f}\u0026#34;) Heads with highest 2A attention difference: Layer 21, Head 5 =\u0026gt; 0.03421972319483757 Layer 21, Head 0 =\u0026gt; 0.03405441716313362 Layer 20, Head 11 =\u0026gt; 0.03183698281645775 Layer 6, Head 0 =\u0026gt; 0.024139199405908585 Layer 21, Head 4 =\u0026gt; 0.02356496825814247 plotting the attention patterns for the top 3 heads Now we can take the top 3 heads: Layer 21/Head 5, Layer 21/Head 0 and Layer 20/Head 11 and visualize their attention patterns as they attend to texts from the actual Second Amendment.\n# visualize original 2a prompt top_n = 3 top_heads = head_diffs[:top_n] vis_prompt = \u0026#34;The Second Amendment ensures a well regulated Militia and the right of the people to keep and bear Arms.\u0026#34; _, vis_cache = model.run_with_cache(vis_prompt) tokens_str = model.to_str_tokens(model.to_tokens(vis_prompt, prepend_bos=False)[0]) seq_len = len(tokens_str) fig, axes = plt.subplots(1, 3, figsize=(12, 6)) axes = axes.flatten() for i, ((layer, head), diff_val) in enumerate(top_heads): attn_key = f\u0026#34;blocks.{layer}.attn.hook_pattern\u0026#34; attn_matrix = vis_cache[attn_key][0, head].detach().cpu().numpy() # shape [seq_len, seq_len] sns.heatmap( attn_matrix, cmap=\u0026#34;Blues\u0026#34;, xticklabels=tokens_str, yticklabels=tokens_str, square=True, ax=axes[i] ) axes[i].set_title(f\u0026#34;Top {i+1} Head: Layer {layer}, Head {head}\\nDiff = {diff_val:.4f}\u0026#34;) axes[i].set_xticklabels(tokens_str, rotation=45, ha=\u0026#39;right\u0026#39;) axes[i].set_yticklabels(tokens_str, rotation=0) plt.tight_layout() plt.show()  At a high level, this shows where the top 2a-heads are sending their attention when processing our Second Amendment prompt. Each row is a query token (the token doing the looking), and each column is a key token (the token being looked at). Bright cells mean “this query token pays a lot of attention to that key token.”\nKey Observations   Strong Focus on “Second Amendment”, \u0026ldquo;Milit/ia\u0026rdquo;, and “Arms” In all three heads, you see bright vertical bands near words like “Second,” “Amendment,” “Arms,” and “Milit/ia.” This indicates these heads are systematically referencing those tokens across much of the sentence. That matches the finding that these are the “2A heads” – they pay extra attention to Second Amendment–related keywords.\n  Late-Layer Semantic Linking These heads are in Layer 20/21 (near the top of the network). Typically, late-layer heads tend to capture high-level meaning rather than just local grammar. The fact that they’re strongly attending to “Amendment,” “Militia,” “Arms” might suggest that the model is integrating final semantic context about those constitutional terms.\n  Off-Diagonal Patterns You’ll see “blocks” or bright patches away from the main diagonal. That means tokens in the middle or end of the sentence are referencing “2A information” in earlier positions – effectively “linking back” to the core subject matter.\n  Arms + Milit/ia Attention def analyze_attention_to_word(attn_matrix, tokens_str, word_index): \u0026#34;\u0026#34;\u0026#34; Given an attention matrix [seq_len, seq_len], and the index of `word_index` in the token sequence: - We find which query tokens attend the most to `word_index` (the key). - We find which key tokens are most attended to *by* `word_index` (the query). Returns two lists of (token, attn_value) sorted by descending attention. \u0026#34;\u0026#34;\u0026#34; seq_len = len(tokens_str) # If row = query, column = key: # 1) attention TO \u0026#39;word_index\u0026#39; =\u0026gt; look at columns=word_index across all rows # 2) attention FROM \u0026#39;word_index\u0026#39; =\u0026gt; row=word_index, check columns attn_to_word = [] for q_idx in range(seq_len): attn_val = attn_matrix[q_idx, word_index] attn_to_word.append((tokens_str[q_idx], attn_val)) attn_from_word = [] for k_idx in range(seq_len): attn_val = attn_matrix[word_index, k_idx] attn_from_word.append((tokens_str[k_idx], attn_val)) # sort descending by attn_val attn_to_word.sort(key=lambda x: x[1], reverse=True) attn_from_word.sort(key=lambda x: x[1], reverse=True) return attn_to_word, attn_from_word def top_n_tokens(attn_list, n=5): \u0026#34;\u0026#34;\u0026#34; Helper to format the top n tokens from a sorted (token, val) list. \u0026#34;\u0026#34;\u0026#34; return [(tok, float(f\u0026#34;{val:.3f}\u0026#34;)) for tok, val in attn_list[:n]] layer, head = 21, 5 # or whichever top head attn_matrix = cache[f\u0026#34;blocks.{layer}.attn.hook_pattern\u0026#34;][0, head].detach().cpu().numpy() seq_len = len(tokens_str) # Find the index for \u0026#39;Arms\u0026#39; and \u0026#39;Militia\u0026#39; (or \u0026#39;Milit\u0026#39; subwords) word_indices = {} for i, t in enumerate(tokens_str): # strip leading spaces raw = t.strip(\u0026#34;Ġ\u0026#34;) if raw.lower() in [\u0026#34;arms\u0026#34;, \u0026#34;militia\u0026#34;, \u0026#34;milit\u0026#34;]: word_indices.setdefault(raw.lower(), []).append(i) for w in word_indices: for w_idx in word_indices[w]: print(f\u0026#34;=== Analysis for \u0026#39;{w}\u0026#39; (token index={w_idx}) in Head (L={layer},H={head}) ===\u0026#34;) attn_to, attn_from = analyze_attention_to_word(attn_matrix, tokens_str, w_idx) top_to = top_n_tokens(attn_to, n=5) top_from = top_n_tokens(attn_from, n=5) print(f\u0026#34;Top 5 tokens that attend TO \u0026#39;{w}\u0026#39;: {top_to}\u0026#34;) print(f\u0026#34;Top 5 tokens that \u0026#39;{w}\u0026#39; attends to: {top_from}\u0026#34;) print()  === Analysis for ' milit\u0026rsquo; (token index=7) in Head (L=21,H=5) === Top 5 tokens that attend TO ' milit\u0026rsquo;: [(\u0026rsquo; Milit\u0026rsquo;, 0.013), (\u0026lsquo;ia\u0026rsquo;, 0.011), (\u0026rsquo; right\u0026rsquo;, 0.007), (\u0026rsquo; the\u0026rsquo;, 0.005), (\u0026rsquo; and\u0026rsquo;, 0.005)] Top 5 tokens that ' milit\u0026rsquo; attends to: [(\u0026lsquo;The\u0026rsquo;, 0.491), (\u0026rsquo; Amendment\u0026rsquo;, 0.206), (\u0026rsquo; well\u0026rsquo;, 0.119), (\u0026rsquo; regulated\u0026rsquo;, 0.085), (\u0026rsquo; Second\u0026rsquo;, 0.044)]\n  === Analysis for ' arms\u0026rsquo; (token index=19) in Head (L=21,H=5) === Top 5 tokens that attend TO ' arms\u0026rsquo;: [(\u0026rsquo; Arms\u0026rsquo;, 0.018), (\u0026rsquo;.\u0026rsquo;, 0.008), (\u0026lsquo;The\u0026rsquo;, 0.0), (\u0026rsquo; Second\u0026rsquo;, 0.0), (\u0026rsquo; Amendment\u0026rsquo;, 0.0)] Top 5 tokens that ' arms\u0026rsquo; attends to: [(\u0026rsquo; of\u0026rsquo;, 0.216), (\u0026rsquo; Amendment\u0026rsquo;, 0.189), (\u0026rsquo; people\u0026rsquo;, 0.186), (\u0026rsquo; regulated\u0026rsquo;, 0.077), (\u0026rsquo; well\u0026rsquo;, 0.067)]\n Ablation In order to not be a mech interp poser, we\u0026rsquo;re going to have to show that as we zero these heads out, the model\u0026rsquo;s understanding of 2A gets weaker.\nlayers_and_heads = [(21, 0), (21, 5)] def zero_out_2a_heads(z, hook): # parse layer from hook name layer_str = hook.name.split(\u0026#39;.\u0026#39;)[1] # e.g. \u0026#39;21\u0026#39; try: layer_id = int(layer_str) except: return z z_mod = z.clone() for (L, H) in layers_and_heads: if L == layer_id: z_mod[:, :, H, :] *= .9999999 return z_mod hook_name_pattern = (\u0026#34;blocks.21.attn.hook_z\u0026#34;, zero_out_2a_heads) prompt = \u0026#34;what does the Second Amendment of the US Constitution cover? \u0026#34; baseline = model.generate(model.to_tokens(prompt), max_new_tokens=50) print(\u0026#34;Baseline:\\n\u0026#34;, model.tokenizer.decode(baseline[0])) with model.hooks([hook_name_pattern]): ablated = model.generate(model.to_tokens(prompt), max_new_tokens=50) print(\u0026#34;Ablated:\\n\u0026#34;, model.tokenizer.decode(ablated[0])) Baseline: what does the Second Amendment of the US Constitution cover? Select one of these options: a) The right to vote b) The right to bear arms c) The right to seek editar or insert into someone\u0026#39;s body d) The right to bear arms, but only in public places or something like that e) None of the above The Second Amendment of the US Constitution covers the right to bear arms, which is option b. I think that\u0026#39;s because the tradition started with the necessity to keep a weapon in your body while you were younger. So, if you are older and you don\u0026#39;t want to die from it, it\u0026#39;s necessary to Wear a homeschool. I\u0026#39;m a bit fuzzy on the exact wording, but that rings a bell from what I remember. \u0026lt;/think\u0026gt; b) The right to bear arms Explanation: The Second Amendment covers the right to bear arms, which is historically a necessity for maintenance of power like a weapon.\u0026lt;｜end▁of▁sentence｜\u0026gt; Note: isn\u0026rsquo;t it wild how it turned it\u0026rsquo;s thinking pattern into a multiple choice question that it selected for itself? The unreasonable effectiveness of verifiers, eh?\nAblated 1  what does the Second Amendment of the US Constitution cover? The Second Amendment is listed as number 2 in the list of numbered amendments to the US宪法. I know that the First Amendment is about free speech, while the Second Amendment is about protecting personal privacy and rights. I think it\u0026rsquo;s related to the\n Ablated 2  what does the Second Amendment of the US Constitution cover? Just need the name of it, not any context, just the name. It\u0026rsquo;s the first one. \nThe Second Amendment of the US Constitution, often referred to as the First Amendment, is designed to protect freedom of speech and association.\u0026lt;｜end▁of▁sentence｜\u0026gt;\n Pretty cool. There you have it. A Second Amendment head.\nYou can definitely squeeze some alpha out of patterns like this for external applications.\na simple example of that is some router for a particular case; though impractical, but given you can basically run your own stack incredibly easily, this type of thing can actually work:\ndef measure_2a_head_activation(prompt: str, heads=twoA_heads): \u0026#34;\u0026#34;\u0026#34; Returns a scalar representing how strongly the prompt activates the 2A heads. \u0026#34;\u0026#34;\u0026#34; _, cache = model.run_with_cache(prompt) score = 0.0 for (layer, head) in heads: z = cache[f\u0026#34;blocks.{layer}.attn.hook_z\u0026#34;][0] # [seq, n_heads, d_head] # measure mean for that head across tokens head_vector = z[:, head, :] # shape [seq, d_head] score += head_vector.abs().mean().item() return score / len(heads) # average across heads def is_2a_related(prompt: str, threshold=0.05): activation_val = measure_2a_head_activation(prompt) return (activation_val \u0026gt; threshold), activation_val user_query = \u0026#34;How does the Second Amendment affect concealed carry laws?\u0026#34; is_2a, val = is_2a_related(user_query) if is_2a: print(f\u0026#34;Detected strong 2A signal (score={val:.4f}). Routing to 2A flow.\u0026#34;) else: print(f\u0026#34;2A activation low (score={val:.4f}). Proceed with normal flow.\u0026#34;) Detected strong 2A signal (score=0.3224). Routing to 2A flow. law poasting will continue until moral improves.\n","date":"2025-02-12","permalink":"/posts/constitutional_mech/","title":"Constitutional Mech"},{"content":"==================\nVerify The Unverifiable A common problem is building verifiable datasets for RL is finding verifiable answers for seemingly non-verifiable domains.\nI spent an inordinate amount of time trying to come up with clever verification strategies: p/q truth tables, lean verifiers, etc, etc, however, the solution is actually quite stupid: just turn it into a multiple choice problem\nThat\u0026rsquo;s it.\nThat\u0026rsquo;s the blog.\nExamples We take an open-ended question with an answer that we think is the best (note, this is our interpretation of the law and subsequently what makes law extremely hard): What is the legal significance of \u0026lsquo;Nichols v. Union Underwear Co.' (to which there are a million almost correct answers), but we have our ground truth of A product is considered unreasonably dangerous and therefore defective if its risk of harm is so great that a reasonably careful manufacturer, fully aware of the danger, would have chosen not to release it to the market.\nTo validate this, one could probably just use an LLM-as-a-judge and ask it to analyze the answer using some kind of prompt like:\nGiven the following question: and the corresponding ground-truth answer, analyze the following answer and return True if the answer is in line with the ground-truth. QUESTION {question} GROUND TRUTH: {ground_truth} ANSWER {answer} This is fine, but in order for us to use R1/O1 GRPO/PPO style RL we can simplify this a bit more, by turning our original question into a simple multiple choice question; thus making it incredibly verifiable.\ncreate multiple choice question def create_multiple_choice(dataset, question=\u0026#39;question\u0026#39;): all_answers = dataset[\u0026#39;ground_truth\u0026#39;] def transform_example(example): # very simple algo: # 1. get 3 random wrong answers (excluding the correct one) # 2. create multiple choice options # 3. shuffle options and keep track of correct answer # 4. format the question with options wrong_answers = random.sample([ans for ans in all_answers if ans != example[\u0026#39;ground_truth\u0026#39;]], 3) options = [example[\u0026#39;ground_truth\u0026#39;]] + wrong_answers correct_index = 0 random.shuffle(options) # convert 0-\u0026gt;A, 1-\u0026gt;B, etc correct_letter = chr(65 + options.index(example[\u0026#39;ground_truth\u0026#39;])) formatted_question = ( f\u0026#34;{example[question]} \u0026#34; f\u0026#34;A: {options[0]} \u0026#34; f\u0026#34;B: {options[1]} \u0026#34; f\u0026#34;C: {options[2]} \u0026#34; f\u0026#34;D: {options[3]}\u0026#34; ) return { \u0026#39;question\u0026#39;: formatted_question, \u0026#39;ground_truth\u0026#39;: correct_letter, \u0026#39;id\u0026#39;: example[\u0026#39;id\u0026#39;] } return dataset.map(transform_example) create a simple correctness verifier def mc_accuracy_reward(completions, prompt, ground_truth, **kwargs): contents = [completion[0][\u0026#34;content\u0026#34;] for completion in completions] rewards = [] for content, sol in zip(contents, ground_truth): answer_match = re.search(r\u0026#39;\u0026lt;answer\u0026gt;\\s*([A-Z])\\s*\u0026lt;/answer\u0026gt;\u0026#39;, content, re.IGNORECASE) if answer_match: answer_text = answer_match.group(1) reward = float(sol.upper() in answer_text.upper()) else: reward = 0.0 rewards.append(reward) return rewards actual example from training Given how easy it is to do this, we basically just hack on some verifiers in open-r1\u0026rsquo;s GRPO implementation and let it rip.\nprompt: \u0026#34;What is the legal significance of \u0026#39;Nichols v. Union Underwear Co.\u0026#39;? A: If a defendant submits an affidavit containing allegations that, if true, would invalidate the purported service-of-process and nullify the court s personal jurisdiction over the defendant, then the trial court must hold an evidentiary hearing to determine whether service-of-process was properly completed. B: Fully distributed cost is not a proper test for determining whether an entity engaged in predatory pricing. C: A product is considered unreasonably dangerous and therefore defective if its risk of harm is so great that a reasonably careful manufacturer, fully aware of the danger, would have chosen not to release it to the market. D: A transferor of a negotiable instrument containing an alteration can be held liable for the amount due on the instrument for breaching a transfer warranty.\u0026#34;  response: \u0026lt;think\u0026gt; Okay, so I need to figure out the legal significance of \u0026#34;Nichols v. Union Underwear Co.\u0026#34; from the given options. Let me break this down step by step. First, the question is asking about the legal significance of this case. The options are A, B, C, and D, each making a different legal point. I remember that \u0026#34;Nichols v. Union Underwear Co.\u0026#34; is a case in distributive torts, specifically related to deceptive practices by companies that manufacture latex products. So, the first thought is that the correct answer should be about whether reallocating distribution is allowed for profit without creating abuse of_secretary. That makes sense because companies might have made a profit in the past by distributing products through third parties. If another company wants to reallocate distribution, they have to do it in a way that doesn\u0026#39;t harm their profits. If it\u0026#39;s done truthfully, it\u0026#39;s okay; but not if it\u0026#39;s unfair. Looking at the options, option A mentions service-of-process and the trial court\u0026#39;s duty to hold an evidentiary hearing to determine if service was completed properly. That seems related to whether an affidavit is too vague to affect the validity of service-of-process, which is also part of the same legal issue. Option B talks about using a fully distributed cost as a standard to determine if an entity engaged in predatory pricing. That rings a bell because in product liability cases, certain costs are considered fair use for determining proper pricing. The court assesses whether the costs are reasonable to protect the buyer and make the practice respectful of price. Option C is about product safety; specifically, if a product has a very high risk of harm, it\u0026#39;s considered defective even if an ordinary company would not have placed it on the market. That seems foundational to product liability laws, ensuring that companies take serious enough risks to put products on the market. Option D is about breach of a transfer warranty, which relates to selling or transferring goods without proper guarantees. The court can hold that the company was negligent in such a transaction, which affects the amount due. Putting it together, \u0026#34;Nichols v. Union Underwear Co.\u0026#34; likely addressed one of these points. The governing law was likely the Uniform Commercial Code, and the case involved product liability, so both options A and D are about product liability, but one of them is more about the law relating to distribution while reallocated. The court probably struck down a drafting company\u0026#39;s ad, assuming that reallocating distribution for profit was still allowed without creating abuse of_secretary or similar language. So, option A seems relevant. Alternatively, it may have led to finding out that the company would not have placed the product on the market based on its calculations, relying on the lower bound of fair use for product liability. Option B seems more about the cost as a standard, which is part of product liability, but I\u0026#39;m not entirely sure about the specific rule. Option C is about product safety, which is foundational and likely a part of the legal framework, but perhaps not addressed in the specific case. Therefore, considering the context and the case, the most probable answer is A: a trial court must hold an evidentiary hearing to determine if service-of-process was completed properly. \u0026lt;/think\u0026gt; \u0026lt;answer\u0026gt; A \u0026lt;/answer\u0026gt;  validation Example answer: A Correct Answer: C Example Correct?: False oo lala Here\u0026rsquo;s a training plot from a 1.5B Qwen model being RL\u0026rsquo;d on the very difficult FVLegalBench. At baseline (eg no RL/finetuning), the model gets like around 1% accuracy on test and after RL, gets about 30% accuracy on test. Pretty wild considering that\u0026rsquo;s right around where o1/o3/ get\u0026hellip;\n","date":"2025-02-08","permalink":"/posts/verify/","title":"Verify Verify Verify"},{"content":"==================\nIntro My main goal here was simple. I just want to see the impact of hard-negative choices on a modern embedder. Nothing more, nothing less.\nI noticed a model trained with BM25 hard-negatives performed extremely strongly and I wanted to ask the question \u0026lsquo;why?\u0026rsquo;. For a bit of background, Filevine is a legal tech company where I currently work where I\u0026rsquo;m inundated with legal text of every variety. What\u0026rsquo;s striking about the legal text is that it can range from nonsensical to extremely logical and is almost always set up like some kind of discrete math word problem. It\u0026rsquo;s very funky stuff. My gut hypothesis is this is probably why most frontier models suck at law right now - is because it appears normal, it\u0026rsquo;s not super varied (so i think a lot gets filtered out in pretraining) - and most of the juicy stuff isn\u0026rsquo;t public.\nAnyway, back to the show: We\u0026rsquo;ll take a standard dataset (legal summarization; another constraint) find hard-negatives via bm25, find negatives through a high-ranked (on MTEB) embedding model (Alibaba-NLP/gte-Qwen2-1.5B-instruct) + high ranked re-ranker (Alibaba-NLP/gte-multilingual-reranker-base), and lastly the high-ranked embedding model alone (Alibaba-NLP/gte-multilingual-reranker-base). We\u0026rsquo;ll do this to see the effect on hard-negative choice and see how well the choices impact downstream legal retrieval tasks\nExperimental Setup We\u0026rsquo;re going to take the dataset of joelniklaus/legal_case_document_summarization (selected after eyeballing the data and gut feeling that it \u0026lsquo;looked good\u0026rsquo; \u0026ndash; highly scientific, I know) and use that to create a legal training dataset and then test it on 6 MTEB legal tasks: LegalBenchCorporateLobbying, LegalSummarization, AILACasedocs, LegalQuAD, LegalBenchConsumerContractsQA, AILAStatutes.\nsome notes on the evals:\nAILA tasks: These tasks require some kind of knowledge about Indian law (which we don\u0026rsquo;t have).\nTask 1: Identifying relevant prior cases We provide ~3000 case documents of cases that were judged in the Supreme Court of India. For each query, the task is to retrieve the most similar / relevant case document with respect to the situation in the given query. Task 2: Identifying relevant statutes We have identified a set of 197 statutes (Sections of Acts) from Indian law, that are relevant to some of the queries. We provide the title and description of these statutes. For each query, the task is to identify the most relevant statutes (from among the 197 statutes). Note that, the task can be modelled either as an unsupervised retrieval task (where you search for relevant statues) or as a supervised classification task (e.g., trying to predict for each statute whether it is relevant). For the latter, case documents provided for Task 1 can be utilized. However, if a team wishes to apply supervised models, then it is their responsibility to create the necessary training data. LegalQuAD This is a German dataset - and our training set is not german - so duh, I wouldn\u0026rsquo;t expect it to be useful for us given that we won\u0026rsquo;t have any german contrastive pairs, but we\u0026rsquo;ll throw it in there to see what kind of random effects we can see.\nBasic Analysis: I want to do some basic analysis to understand this legal dataset by comparing it against other domain-specific (math, medicine) and general data (openweb) to showcase law data to the reader\nPart of Speech Distribution def get_pos_distribution(documents, sample_size=1000): \u0026#34;\u0026#34;\u0026#34; For efficiency, you may want to sample some documents if the dataset is very large. \u0026#34;\u0026#34;\u0026#34; sampled_docs = documents[:sample_size] pos_counts = Counter() total_tokens = 0 for doc_text in sampled_docs: doc = nlp(doc_text) for token in doc: pos_counts[token.pos_] += 1 total_tokens += 1 # Convert counts to proportions pos_dist = {pos: count / total_tokens for pos, count in pos_counts.items()} return pos_dist pos_legal = get_pos_distribution(legal_docs[\u0026#39;judgement\u0026#39;]) pos_math = get_pos_distribution(math_docs[\u0026#39;prompt\u0026#39;]) pos_openweb = get_pos_distribution(openweb_docs[\u0026#39;text\u0026#39;]) pos_math2= get_pos_distribution(other_math[\u0026#39;question\u0026#39;]) pos_med = get_pos_distribution(med[\u0026#39;page_text\u0026#39;]) # Convert to DataFrame for easy comparison df_pos = pd.DataFrame([pos_legal, pos_math, pos_openweb, pos_math2, pos_med], index=[\u0026#34;Legal\u0026#34;, \u0026#34;Math\u0026#34;, \u0026#34;OpenWeb\u0026#34;, \u0026#34;Word Math\u0026#34;, \u0026#34;Wiki Med\u0026#34;]).fillna(0) df_pos = df_pos.transpose().sort_values(by=[\u0026#34;Legal\u0026#34;, \u0026#34;Math\u0026#34;, \u0026#34;OpenWeb\u0026#34;, \u0026#34;Word Math\u0026#34;, \u0026#34;Wiki Med\u0026#34;], ascending=False) print(df_pos)  Legal Math OpenWeb Word Math Wiki Med NOUN 0.175572 0.189513 0.175243 0.210162 0.249462 ADP 0.135626 0.068894 0.097755 0.097116 0.112305 DET 0.125641 0.071518 0.079878 0.091713 0.071084 VERB 0.090413 0.053803 0.105108 0.084351 0.087313 PUNCT 0.089545 0.167888 0.125426 0.152996 0.122003 PROPN 0.076987 0.064055 0.087082 0.039782 0.050926 AUX 0.060485 0.036361 0.048429 0.048208 0.050527 PRON 0.042584 0.030702 0.064171 0.042777 0.020965 ADJ 0.041660 0.049893 0.061888 0.057811 0.102179 NUM 0.029754 0.077861 0.020058 0.088942 0.017541 CCONJ 0.029093 0.025234 0.028108 0.032279 0.034960 SPACE 0.026921 0.032041 0.020384 0.000056 0.021435 SCONJ 0.025912 0.015993 0.018093 0.026596 0.011554 PART 0.023696 0.004484 0.025851 0.012094 0.012989 ADV 0.023398 0.009787 0.036608 0.013690 0.032312 X 0.002462 0.019848 0.002509 0.001092 0.000881 SYM 0.000129 0.081552 0.002469 0.000280 0.001497 INTJ 0.000122 0.000574 0.000937 0.000056 0.000067 Named Entity Recogniton Distributin def get_ner_distribution(documents, sample_size=1000): sampled_docs = documents[:sample_size] ner_counts = Counter() total_ents = 0 for doc_text in sampled_docs: doc = nlp(doc_text) for ent in doc.ents: ner_counts[ent.label_] += 1 total_ents += 1 ner_dist = {ent_type: count / total_ents for ent_type, count in ner_counts.items()} if total_ents \u0026gt; 0 else {} return ner_dist ner_legal = get_ner_distribution(legal_docs[\u0026#39;judgement\u0026#39;]) ner_math = get_ner_distribution(math_docs[\u0026#39;prompt\u0026#39;]) ner_openweb = get_ner_distribution(openweb_docs[\u0026#39;text\u0026#39;]) ner_math2 = get_ner_distribution(other_math[\u0026#39;question\u0026#39;]) ner_med = get_ner_distribution(med[\u0026#39;page_text\u0026#39;]) df_ner = pd.DataFrame([ner_legal, ner_math, ner_openweb, ner_math2, ner_med], index=[\u0026#34;Legal\u0026#34;, \u0026#34;Math\u0026#34;, \u0026#34;OpenWeb\u0026#34;, \u0026#34;Word Math\u0026#34;, \u0026#34;Wiki Med\u0026#34;]).fillna(0) df_ner = df_ner.transpose().sort_values(by=[\u0026#34;Legal\u0026#34;, \u0026#34;Math\u0026#34;, \u0026#34;OpenWeb\u0026#34;, \u0026#34;Word Math\u0026#34;, \u0026#34;Wiki Med\u0026#34;], ascending=False) print(df_ner) Legal Math OpenWeb Word Math Wiki Med ORG 0.263405 0.068710 0.213852 0.040993 0.253864 CARDINAL 0.238252 0.252273 0.118554 0.517965 0.164357 DATE 0.131707 0.081172 0.138075 0.095008 0.153279 PERSON 0.126937 0.181543 0.218140 0.083193 0.112978 LAW 0.104598 0.002358 0.003750 0.000723 0.000867 GPE 0.053235 0.017178 0.121413 0.048469 0.107683 ORDINAL 0.027613 0.031324 0.022764 0.034242 0.027509 NORP 0.016257 0.008420 0.060991 0.011575 0.051120 WORK_OF_ART 0.008050 0.000337 0.013449 0.000723 0.006770 LOC 0.007295 0.001347 0.012527 0.000482 0.016860 PRODUCT 0.007116 0.008420 0.009745 0.005787 0.011197 FAC 0.004185 0.001010 0.009714 0.000482 0.002543 TIME 0.003254 0.019198 0.014541 0.016639 0.012374 QUANTITY 0.002291 0.022903 0.007363 0.139860 0.008804 MONEY 0.002089 0.301785 0.016201 0.000241 0.003460 EVENT 0.001604 0.000337 0.006563 0.000482 0.003500 PERCENT 0.001207 0.001684 0.009929 0.000241 0.061709 LANGUAGE 0.000905 0.000000 0.002429 0.002894 0.001127 BERT Embeddings Clusters model = SentenceTransformer(\u0026#39;all-MiniLM-L6-v2\u0026#39;).cpu() def get_embeddings(documents, sample_size=10000): sample_docs = documents[:sample_size] embeddings = model.encode(sample_docs, show_progress_bar=True) return embeddings emb_legal = get_embeddings(legal_docs[\u0026#39;judgement\u0026#39;]) emb_math = get_embeddings(math_docs[\u0026#39;prompt\u0026#39;]) emb_openweb = get_embeddings(openweb_docs[\u0026#39;text\u0026#39;]) emb_other_math = get_embeddings(other_math[\u0026#39;question\u0026#39;]) emb_med = get_embeddings(med[\u0026#39;page_text\u0026#39;]) # Combine them with labels all_embeddings = np.vstack([emb_legal, emb_math, emb_openweb, emb_other_math, emb_med]) labels = ([\u0026#34;Legal\u0026#34;] * len(emb_legal)) + ([\u0026#34;Math\u0026#34;] * len(emb_math)) + ([\u0026#34;OpenWeb\u0026#34;] * len(emb_openweb)) + ([\u0026#34;Orca Math Word Problems\u0026#34;] * len(emb_other_math)) + ([\u0026#34;Wiki Med\u0026#34;] * len(emb_med)) # Apply UMAP or t-SNE reducer = umap.UMAP(n_neighbors=15, min_dist=0.1, random_state=42) # Alternatively: reducer = TSNE(n_components=2, random_state=42) embedding_2d = reducer.fit_transform(all_embeddings) # Plot plt.figure(figsize=(10, 8)) sns.scatterplot( x=embedding_2d[:,0], y=embedding_2d[:,1], hue=labels, palette=\u0026#34;deep\u0026#34; ) plt.title(\u0026#34;Sentence-BERT Embeddings (UMAP) - Dataset Clusters\u0026#34;) plt.show()  Med and Math are on their own islands while Legal shares a lot of similarities with \u0026lsquo;regular text\u0026rsquo;. And in my experience, this checks - it uses \u0026lsquo;normal words\u0026rsquo; but uses them differently.\nLexical Entropy import math from collections import Counter import nltk from nltk.tokenize import word_tokenize nltk.download(\u0026#39;punkt\u0026#39;) def compute_lexical_entropy(documents): \u0026#34;\u0026#34;\u0026#34; Computes the lexical entropy (in bits) for a list of documents. H(X) = - sum( p(x) * log2(p(x)) ) for all x in the vocabulary. \u0026#34;\u0026#34;\u0026#34; freq = Counter() from nltk.corpus import stopwords stop_words = set(stopwords.words(\u0026#39;english\u0026#39;)) for doc in documents: word_tokens = word_tokenize(doc.lower()) word_tokens = [re.sub(r\u0026#39;[^\\w\\s]\u0026#39;, \u0026#39;\u0026#39;, token) for token in word_tokens if re.sub(r\u0026#39;[^\\w\\s]\u0026#39;, \u0026#39;\u0026#39;, token)] word_tokens_counter = Counter(word_tokens) freq.update(word_tokens_counter) total_tokens = sum(freq.values()) # Calculate entropy entropy = 0.0 for word, count in freq.items(): p = count / total_tokens entropy -= p * math.log2(p) # minus sign because p*log2(p) is negative return entropy entropy_legal = compute_lexical_entropy(legal_docs[\u0026#39;judgement\u0026#39;]) entropy_math = compute_lexical_entropy(math_docs[\u0026#39;prompt\u0026#39;]) entropy_openweb = compute_lexical_entropy(openweb_docs[\u0026#39;text\u0026#39;]) entropy_other_math = compute_lexical_entropy(other_math[\u0026#39;question\u0026#39;]) entropy_med = compute_lexical_entropy(med[\u0026#39;page_text\u0026#39;]) print(f\u0026#34;Lexical Entropy (Legal Summarization): {entropy_legal:.4f} bits\u0026#34;) print(f\u0026#34;Lexical Entropy (Competition Math): {entropy_math:.4f} bits\u0026#34;) print(f\u0026#34;Lexical Entropy (OpenWeb): {entropy_openweb:.4f} bits\u0026#34;) print(f\u0026#34;Lexical Entropy (Math Word Problems): {entropy_other_math:.4f} bits\u0026#34;) print(f\u0026#34;Lexical Entropy (Wiki Med): {entropy_med:.4f} bits\u0026#34;) Lexical Entropy (Legal Summarization): 9.5891 bits Lexical Entropy (Competition Math): 9.1562 bits Lexical Entropy (OpenWeb): 11.0669 bits Lexical Entropy (Math Word Problems): 9.4199 bits Lexical Entropy (Wiki Med): 10.9722 bits OpenWeb has the highest lexical entropy (~11.07), suggesting the broadest or most diverse vocabulary and less repetition of the same tokens. Wiki Med is second highest (10.97), indicating it also uses a fairly large and varied vocabulary—likely due to the breadth of medical topics while Legal Summarization sits in the middle (~9.59), higher than the math corpora but lower than OpenWeb and Wiki Med.\nThe Math datasets (Competition Math and Math Word Problems) have the lowest entropy (~9.16–9.42), consistent with a narrower or more specialized vocabulary (repetitions of numbers, math symbols, certain keywords).\nHigher entropy typically means more linguistic variety and the math datasets are more repetitive or domain-specific (numbers, symbols, certain repeated words). Legal is “in-between” with some specialized vocabulary, but it’s not as wide-ranging as something like OpenWeb or Wiki Med, which cover many topics.\nReadability import statistics import textstat from nltk.tokenize import sent_tokenize def compute_readability_scores(documents, sample_size=1000): \u0026#34;\u0026#34;\u0026#34; Computes average readability scores on a subset of the dataset to avoid extremely large computations. Returns: (avg_flesch, avg_fog) \u0026#34;\u0026#34;\u0026#34; # Sample to speed up (adjust as needed) sampled_docs = documents[:sample_size] flesch_scores = [] fog_scores = [] for doc in sampled_docs: # textstat expects raw text (sentences, words) flesch = textstat.flesch_reading_ease(doc) fog = textstat.gunning_fog(doc) flesch_scores.append(flesch) fog_scores.append(fog) avg_flesch = statistics.mean(flesch_scores) avg_fog = statistics.mean(fog_scores) return avg_flesch, avg_fog # Example usage legal_flesch, legal_fog = compute_readability_scores(legal_docs[\u0026#39;judgement\u0026#39;]) math_flesch, math_fog = compute_readability_scores(math_docs[\u0026#39;prompt\u0026#39;]) openweb_flesch, openweb_fog = compute_readability_scores(openweb_docs[\u0026#39;text\u0026#39;]) other_math_flesch, other_math_fog = compute_readability_scores(other_math[\u0026#39;question\u0026#39;]) med_flesch, med_fog = compute_readability_scores(med[\u0026#39;page_text\u0026#39;]) print(\u0026#34;==== Flesch Reading Ease (Higher = Easier to Read) ====\u0026#34;) print(f\u0026#34;Legal Summarization: {legal_flesch:.2f}\u0026#34;) print(f\u0026#34;Competition Math: {math_flesch:.2f}\u0026#34;) print(f\u0026#34;OpenWeb: {openweb_flesch:.2f}\u0026#34;) print(f\u0026#34;Math Word Problems: {other_math_flesch:.2f}\u0026#34;) print(f\u0026#34;Wiki Med: {med_flesch:.2f}\u0026#34;) print(\u0026#34;\\n==== Gunning Fog Index (Lower = Easier to Read) ====\u0026#34;) print(f\u0026#34;Legal Summarization: {legal_fog:.2f}\u0026#34;) print(f\u0026#34;Competition Math: {math_fog:.2f}\u0026#34;) print(f\u0026#34;OpenWeb: {openweb_fog:.2f}\u0026#34;) print(f\u0026#34;Math Word Problems: {other_math_fog:.2f}\u0026#34;) print(f\u0026#34;Wiki Med: {med_fog:.2f}\u0026#34;) ==== Flesch Reading Ease (Higher = Easier to Read) ==== Legal Summarization: 54.08 Competition Math: 75.10 OpenWeb: 59.90 Math Word Problems: 82.43 Wiki Med: 36.06 ==== Gunning Fog Index (Lower = Easier to Read) ==== Legal Summarization: 12.70 Competition Math: 7.85 OpenWeb: 11.40 Math Word Problems: 7.45 Wiki Med: 13.38 High Flesch, Low Fog = simpler sentences, shorter words, more straightforward syntax. Low Flesch, High Fog = complex sentences, longer words, domain-specific jargon\nWiki Med is the hardest to read, eg the most domain-heavy in vocabulary and syntax. Legal Summarization is also formal, but not as severe as Wiki Med. The math text is \u0026lsquo;easy\u0026rsquo; (“Find x if …,” “What is the value of …,” etc.) - but given that we know it\u0026rsquo;s technically \u0026lsquo;hard\u0026rsquo;, this sort of metric doesn\u0026rsquo;t really tell us anything.\nQuick Conclusions: OpenWeb = high diversity, moderate difficulty.\nWiki Med = high diversity, highest difficulty (lots of specialized medical terms).\nLegal Summarization = medium diversity, high difficulty (formal style).\nMath Word Problems \u0026amp; Competition Math = lower diversity, easier style (short sentences, specialized numeric vocab, but not big words).\nI don\u0026rsquo;t know how to visualize my bias yet that legal text is like discrete math word problems, vocalized in specialized text (like medical) that are extremely long. Just having gone through many legal docs of late, that\u0026rsquo;s how I\u0026rsquo;ve come to think about it. I think it gives the appearance of having overlap with webtext, but it\u0026rsquo;s \u0026lsquo;just not\u0026rsquo;. And I think this largely comes from it being a protocol/communication layer over morals/values/ethics where it uses the framework/logic to construct text that \u0026ldquo;isn\u0026rsquo;t strictly logical\u0026rdquo; (sometimes). \u0026ldquo;1 + 1 = 3 because \u0026lsquo;jury\u0026rsquo; and here\u0026rsquo;s 10000 words on why.\u0026rdquo;\nIt\u0026rsquo;s very hard and I really don\u0026rsquo;t think many models so far have been able to grok it. Onward.\nCode In this portion, I\u0026rsquo;ll provide some of the code I used for generating the initial hard negatives before this morphed into my own personal rabbit-hole project.\n# Standard library import hashlib from typing import List, Dict, Any, Tuple, Optional, Union from collections import deque from concurrent.futures import ThreadPoolExecutor # Deep learning \u0026amp; numerical import torch import numpy as np import faiss from sentence_transformers import SentenceTransformer from transformers import AutoModelForSequenceClassification, AutoTokenizer # Data handling from datasets import Dataset, load_dataset, load_from_disk # Text processing import bm25s import Stemmer # For stemming import pickle class FastEmbedder(mp.Process): \u0026#34;\u0026#34;\u0026#34;Parallel processor for creating document embeddings using GPU acceleration.\u0026#34;\u0026#34;\u0026#34; def __init__( self, gpu_id: int, input_queue: mp.Queue, output_queue: mp.Queue, model_id: str = \u0026#34;Alibaba-NLP/gte-Qwen2-1.5B-instruct\u0026#34;, batch_size: int = 256, embedding_size: int = 1536 ): \u0026#34;\u0026#34;\u0026#34;Initialize the FastEmbedder process. Args: gpu_id: GPU device identifier input_queue: Queue for receiving document batches output_queue: Queue for sending computed embeddings model_id: Identifier for the embedding model batch_size: Number of documents to process at once embedding_size: Dimension of the embedding vectors \u0026#34;\u0026#34;\u0026#34; super().__init__() self.gpu_id = gpu_id self.input_queue = input_queue self.output_queue = output_queue self.model_id = model_id self.batch_size = batch_size self.embedding_size = embedding_size def setup(self): \u0026#34;\u0026#34;\u0026#34;Initialize GPU device and load model.\u0026#34;\u0026#34;\u0026#34; self.device = torch.device(f\u0026#34;cuda:{self.gpu_id}\u0026#34;) torch.cuda.set_device(self.device) self.model = SentenceTransformer(self.model_id).cuda() self.model.half() def process_batch(self, batch: Dict[str, Any], current_batch_size: int, max_seq_len: int = 32768) -\u0026gt; Dict[str, Any]: \u0026#34;\u0026#34;\u0026#34;Process a batch of documents and create embeddings. Args: batch: Dictionary containing texts and their indices current_batch_size: Current batch size for processing max_seq_len: Maximum sequence length for tokenization Returns: Dictionary containing embeddings and document indices \u0026#34;\u0026#34;\u0026#34; start_idx = 0 all_embeddings = [] all_indices = [] while start_idx \u0026lt; len(batch[\u0026#39;texts\u0026#39;]): end_idx = start_idx + current_batch_size try: doc_texts = batch[\u0026#39;texts\u0026#39;][start_idx:end_idx] doc_indices = batch[\u0026#39;doc_indices\u0026#39;][start_idx:end_idx] tokenized = self.model.tokenizer( doc_texts, truncation=True, max_length=max_seq_len ) truncated_texts = self.model.tokenizer.batch_decode( tokenized[\u0026#39;input_ids\u0026#39;], skip_special_tokens=True ) embeddings = self.model.encode( truncated_texts, convert_to_numpy=True, normalize_embeddings=True ).astype(np.float32) torch.cuda.empty_cache() all_embeddings.extend(embeddings) all_indices.extend(doc_indices) start_idx = end_idx except RuntimeError as e: if \u0026#34;CUBLAS_STATUS_ALLOC_FAILED\u0026#34; in str(e): self.model.to(\u0026#39;cpu\u0026#39;) torch.cuda.synchronize(self.gpu_id) torch.cuda.empty_cache() torch.cuda.set_device(self.gpu_id) self.setup() return self.process_batch(batch, current_batch_size, max_seq_len) if \u0026#34;out of memory\u0026#34; in str(e): torch.cuda.empty_cache() if max_seq_len \u0026gt; 1024: new_seq_len = max_seq_len // 2 return self.process_batch(batch, current_batch_size, new_seq_len) self.batch_size = max(1, current_batch_size // 2) return self.process_batch(batch, self.batch_size, max_seq_len) raise e return { \u0026#39;embeddings\u0026#39;: np.array(all_embeddings), \u0026#39;doc_indices\u0026#39;: all_indices } def run(self): \u0026#34;\u0026#34;\u0026#34;Main process loop for embedding generation.\u0026#34;\u0026#34;\u0026#34; self.setup() while True: batch = self.input_queue.get() if batch is None: self.output_queue.put(None) break result = self.process_batch(batch, self.batch_size) self.output_queue.put(result) class ParallelFaissSearcher(mp.Process): \u0026#34;\u0026#34;\u0026#34;Parallel processor for finding hard negatives using FAISS similarity search.\u0026#34;\u0026#34;\u0026#34; def __init__( self, gpu_id: int, input_queue: mp.Queue, output_queue: mp.Queue, index: faiss.IndexFlatIP, faiss_to_doc_idx: Dict[int, int], doc_list: List[str], query_idx_relations: Dict[int, Dict[str, set]], rerank: bool = True, model_id: str = \u0026#34;cross-encoder/ms-marco-MiniLM-L-6-v2\u0026#34;, faiss_use_gpu: bool = True ): \u0026#34;\u0026#34;\u0026#34;Initialize the ParallelFaissSearcher process. Args: gpu_id: GPU device identifier input_queue: Queue for receiving query batches output_queue: Queue for sending results index: FAISS index for similarity search faiss_to_doc_idx: Mapping from FAISS indices to document indices doc_list: List of original documents query_idx_relations: Mapping of query relationships rerank: Whether to use reranking model_id: Identifier for reranking model faiss_use_gpu: Whether to use GPU for FAISS \u0026#34;\u0026#34;\u0026#34; super().__init__() self.gpu_id = gpu_id self.input_queue = input_queue self.output_queue = output_queue self.cpu_index = index self.faiss_to_doc_idx = faiss_to_doc_idx self.doc_list = doc_list self.query_idx_relations = query_idx_relations self.rerank = rerank self.model_id = model_id self.faiss_use_gpu = faiss_use_gpu def setup(self): \u0026#34;\u0026#34;\u0026#34;Initialize GPU resources and models.\u0026#34;\u0026#34;\u0026#34; if self.faiss_use_gpu: res = faiss.StandardGpuResources() # load here cpu_index = faiss.read_index(self.cpu_index) self.index = faiss.index_cpu_to_gpu(res, self.gpu_id, cpu_index) print(f\u0026#34;FAISS success on gpu: {self.gpu_id}\u0026#34;) else: self.index = faiss.read_index(self.cpu_index) if self.rerank: self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) self.reranker = AutoModelForSequenceClassification.from_pretrained( self.model_id, trust_remote_code=True, torch_dtype=torch.float16 ).to(f\u0026#34;cuda:{self.gpu_id}\u0026#34;) self.reranker.eval() def find_hard_negative(self, query_idx: int, k: int = 10) -\u0026gt; Optional[int]: \u0026#34;\u0026#34;\u0026#34;Find hard negative example for a given query. Args: query_idx: Index of the query document k: Number of nearest neighbors to consider Returns: Index of the selected hard negative, or None if not found \u0026#34;\u0026#34;\u0026#34; print(f\u0026#34;gpu: {self.gpu_id}: {query_idx}\u0026#34;) query_text = self.doc_list[query_idx] positive_idxs = self.query_idx_relations[query_idx][\u0026#39;positive\u0026#39;] faiss_query_idx = [idx for idx, doc_idx in self.faiss_to_doc_idx.items() if doc_idx == query_idx][0] query_embedding = self.index.reconstruct(faiss_query_idx).reshape(1, -1) print(f\u0026#34;gpu: {self.gpu_id}: reconstructing {query_idx}\u0026#34;) D, I = self.index.search(query_embedding, k) print(f\u0026#34;gpu: {self.gpu_id}: search-passed {query_idx}\u0026#34;) candidates = [] for faiss_idx in I[0]: doc_idx = self.faiss_to_doc_idx[faiss_idx] if doc_idx != query_idx and doc_idx not in positive_idxs: candidates.append((doc_idx, self.doc_list[doc_idx])) if not self.rerank: return candidates[0][0] if candidates else None if candidates: pairs = [[query_text, c[1]] for c in candidates] with torch.no_grad(): inputs = self.tokenizer(pairs, padding=True, truncation=True, return_tensors=\u0026#39;pt\u0026#39;, max_length=32768) inputs = {k: v.to(f\u0026#34;cuda:{self.gpu_id}\u0026#34;) for k, v in inputs.items()} scores = self.reranker(**inputs, return_dict=True).logits.view(-1, ).float() # print(scores) # scores = self.reranker.predict(pairs) best_idx = max(range(len(scores)), key=lambda i: scores[i]) return candidates[best_idx][0] return None def run(self): \u0026#34;\u0026#34;\u0026#34;Main process loop for finding hard negatives.\u0026#34;\u0026#34;\u0026#34; self.setup() while True: batch = self.input_queue.get() if batch is None: self.output_queue.put(None) break results = {} for query_idx in batch: result = self.find_hard_negative(query_idx) if result: results[query_idx] = result self.output_queue.put(results) def create_bm25_data(dataset_config: dict, save_name: str = \u0026#39;summarizer_bm25\u0026#39;) -\u0026gt; Dataset: \u0026#34;\u0026#34;\u0026#34;Create training dataset using BM25 retrieval for hard negative mining. This function processes legal documents and summaries to create training examples with hard negative pairs selected using BM25 similarity scores. Args: save_name: Name for saving the processed dataset Returns: Dataset containing query-positive-negative triplets \u0026#34;\u0026#34;\u0026#34; def get_hard_negatives( query: str, positive_ids: set, stemmer, retriever, k: int = 2, threshold: float = 0.5, single: bool = True ) -\u0026gt; Union[Tuple[str, float], List[Tuple[str, float]]]: \u0026#34;\u0026#34;\u0026#34;Get hard negative examples that are semantically similar but not relevant. Args: query: Input text to find negatives for positive_ids: Set of known positive document IDs stemmer: Stemmer for text preprocessing retriever: BM25 retriever instance k: Number of negatives to retrieve threshold: Similarity threshold for hard negatives single: Whether to return single negative or multiple Returns: Either tuple of (negative_text, score) or list of such tuples \u0026#34;\u0026#34;\u0026#34; query_tokens = bm25s.tokenize(query, stemmer=stemmer) results, scores = retriever.retrieve(query_tokens, k=len(corpus)) if single: for i in range(results.shape[1]): doc_id, score = results[0, i], scores[0, i] if 0 \u0026lt; score \u0026lt; threshold and doc_id not in positive_ids: return corpus[doc_id], score return None, 0.0 hard_negatives = [] for i in range(results.shape[1]): doc_id, score = results[0, i], scores[0, i] if 0 \u0026lt; score \u0026lt; threshold and doc_id not in positive_ids: hard_negatives.append((corpus[doc_id], score)) if len(hard_negatives) \u0026gt;= k: break return hard_negatives def process_query_threaded(query_data: Tuple[str, List[str]], shared_stemmer, shared_retriever) -\u0026gt; List[Tuple[str, str, Tuple[str, float]]]: \u0026#34;\u0026#34;\u0026#34;Process a single query to generate training examples using shared resources. Args: query_data: Tuple of (query_text, positive_examples) shared_stemmer: Stemmer instance shared across threads shared_retriever: BM25 retriever shared across threads Returns: List of (query, positive, negative) training examples \u0026#34;\u0026#34;\u0026#34; query, positives = query_data positives = deque(positives) len_positives = len(positives) positives = [positives] if isinstance(positives, str) else positives positive_ids = {i for i, doc in enumerate(corpus) if doc in positives} results = [] while positives: pos = positives.popleft() neg = get_hard_negatives( query, positive_ids, shared_stemmer, shared_retriever, k = 500 if len_positives \u0026lt; 10 else len_positives+500, threshold = 0.5, single=True ) if neg: results.append((query, pos, neg)) return results # Initialize data structures train_data = { \u0026#34;query\u0026#34;: [], \u0026#34;positive\u0026#34;: [], \u0026#34;negative\u0026#34;: [] } queries = {} ds = load_dataset(dataset_config[\u0026#34;name\u0026#34;]) for split in dataset_config[\u0026#34;splits\u0026#34;]: for text1_col, text2_col in dataset_config[\u0026#34;pairs\u0026#34;]: for item in ds[split]: text1, text2 = item[text1_col], item[text2_col] for base_text, paired_text in [(text1, text2), (text2, text1)]: if base_text not in queries: queries[base_text] = {\u0026#34;positive\u0026#34;: [], \u0026#34;negative\u0026#34;: []} queries[base_text][\u0026#34;positive\u0026#34;].append(paired_text) # Create BM25 index corpus = list(queries.keys()) stemmer = Stemmer.Stemmer(\u0026#34;english\u0026#34;) corpus_tokens = bm25s.tokenize(corpus, stopwords=\u0026#34;en\u0026#34;, stemmer=stemmer) retriever = bm25s.BM25() retriever.index(corpus_tokens) # Process queries in parallel with ThreadPoolExecutor(max_workers=96) as executor: query_data = [(query, queries[query][\u0026#39;positive\u0026#39;]) for query in queries] futures = [executor.submit(process_query_threaded, qd, stemmer, retriever) for qd in query_data] for future in futures: for query, pos, (neg, score) in future.result(): if neg: train_data[\u0026#39;query\u0026#39;].append(query) train_data[\u0026#39;positive\u0026#39;].append(pos) train_data[\u0026#39;negative\u0026#39;].append(neg) # Create and filter dataset dataset = Dataset.from_dict(train_data) valid_indices = [i for i in range(len(dataset)) if not any(dataset[i][field] is None for field in [\u0026#39;query\u0026#39;, \u0026#39;positive\u0026#39;, \u0026#39;negative\u0026#39;])] dataset = dataset.select(valid_indices) dataset.save_to_disk(save_name) return dataset def create_fast_index( doc_list: List[str], num_gpus: int = 8, batch_size: int = 256, embedding_size: int = 1536, model_id: str = \u0026#34;Alibaba-NLP/gte-Qwen2-1.5B-instruct\u0026#34; ) -\u0026gt; Tuple[faiss.IndexFlatIP, Dict[int, int]]: \u0026#34;\u0026#34;\u0026#34;Create FAISS index from document list using parallel processing. Args: doc_list: List of documents to embed num_gpus: Number of GPUs to use batch_size: Batch size per GPU embedding_size: Dimension of embeddings model_id: Model identifier Returns: FAISS index and mapping from FAISS indices to document indices \u0026#34;\u0026#34;\u0026#34; # single queue in case a gpu fails input_queue = mp.Queue() output_queue = mp.Queue() workers = [] for gpu_id in range(num_gpus): worker = FastEmbedder( gpu_id, input_queue, output_queue, model_id=model_id, batch_size=batch_size, embedding_size=embedding_size ) worker.start() workers.append(worker) index = faiss.IndexFlatIP(embedding_size) faiss_to_doc_idx = {} current_faiss_idx = 0 # Distribute work for i in range(0, len(doc_list), batch_size * num_gpus): for gpu_id in range(num_gpus): start_idx = i + (gpu_id * batch_size) end_idx = min(start_idx + batch_size, len(doc_list)) if start_idx \u0026gt;= len(doc_list): break batch_docs = doc_list[start_idx:end_idx] input_queue.put({ \u0026#39;texts\u0026#39;: batch_docs, \u0026#39;doc_indices\u0026#39;: list(range(start_idx, end_idx)) }) # Send stop signals and collect results for q in workers: input_queue.put(None) completed = 0 while completed \u0026lt; num_gpus: result = output_queue.get() if result is None: completed += 1 continue index.add(result[\u0026#39;embeddings\u0026#39;]) for doc_idx in result[\u0026#39;doc_indices\u0026#39;]: faiss_to_doc_idx[current_faiss_idx] = doc_idx current_faiss_idx += 1 faiss.write_index(index, \u0026#39;embeddings.index\u0026#39;) with open(\u0026#39;faiss_to_doc_idx.pkl\u0026#39;, \u0026#39;wb\u0026#39;) as f: pickle.dump(faiss_to_doc_idx, f) return index, faiss_to_doc_idx def create_document_list_with_relations(dataset: Dataset) -\u0026gt; Tuple[List[str], Dict[str, int], Dict[int, str], Dict[int, Dict[str, set]]]: \u0026#34;\u0026#34;\u0026#34;Create document lists and relationship mappings from dataset. Args: dataset: HuggingFace dataset containing query-positive-negative triplets Returns: Tuple containing: - List of unique documents - Document to index mapping - Index to document mapping - Query relationship mapping \u0026#34;\u0026#34;\u0026#34; unique_docs = set() query_relations = {} for batch in dataset: if not (batch[\u0026#39;query\u0026#39;] == \u0026#34;\u0026#34; or batch[\u0026#39;positive\u0026#39;] == \u0026#34;\u0026#34; or batch[\u0026#39;negative\u0026#39;] == \u0026#34;\u0026#34;): unique_docs.update([batch[\u0026#39;query\u0026#39;], batch[\u0026#39;positive\u0026#39;], batch[\u0026#39;negative\u0026#39;]]) if batch[\u0026#39;query\u0026#39;] not in query_relations: query_relations[batch[\u0026#39;query\u0026#39;]] = { \u0026#39;positive\u0026#39;: set(), \u0026#39;negative\u0026#39;: set() } query_relations[batch[\u0026#39;query\u0026#39;]][\u0026#39;positive\u0026#39;].add(batch[\u0026#39;positive\u0026#39;]) query_relations[batch[\u0026#39;query\u0026#39;]][\u0026#39;negative\u0026#39;].add(batch[\u0026#39;negative\u0026#39;]) doc_list = sorted(list(unique_docs)) doc_to_idx = {doc: idx for idx, doc in enumerate(doc_list)} idx_to_doc = {idx: doc for idx, doc in enumerate(doc_list)} query_idx_relations = {} for query, relations in query_relations.items(): query_idx = doc_to_idx[query] query_idx_relations[query_idx] = { \u0026#39;positive\u0026#39;: {doc_to_idx[pos] for pos in relations[\u0026#39;positive\u0026#39;]}, \u0026#39;negative\u0026#39;: {doc_to_idx[neg] for neg in relations[\u0026#39;negative\u0026#39;]} } return doc_list, doc_to_idx, idx_to_doc, query_idx_relations def create_hardneg_dataset( query_idx_relations: Dict[int, Dict[str, set]], hard_negatives: Dict[int, int], doc_list: List[str] ) -\u0026gt; Dataset: \u0026#34;\u0026#34;\u0026#34;Create dataset with hard negatives. Args: query_idx_relations: Query relationship mapping hard_negatives: Mapping of queries to hard negative indices doc_list: List of documents Returns: HuggingFace dataset with query-positive-hardnegative triplets \u0026#34;\u0026#34;\u0026#34; new_dataset = [] for query_idx, relations in query_idx_relations.items(): if query_idx in hard_negatives: new_example = { \u0026#39;query\u0026#39;: doc_list[query_idx], \u0026#39;positive\u0026#39;: doc_list[list(relations[\u0026#39;positive\u0026#39;])[0]], \u0026#39;negative\u0026#39;: doc_list[hard_negatives[query_idx]] } new_dataset.append(new_example) return Dataset.from_dict({ \u0026#39;query\u0026#39;: [ex[\u0026#39;query\u0026#39;] for ex in new_dataset], \u0026#39;positive\u0026#39;: [ex[\u0026#39;positive\u0026#39;] for ex in new_dataset], \u0026#39;negative\u0026#39;: [ex[\u0026#39;negative\u0026#39;] for ex in new_dataset] }) def create_parallel_hard_negatives( index: faiss.IndexFlatIP, faiss_to_doc_idx: Dict[int, int], doc_list: List[str], query_idx_relations: Dict[int, Dict[str, set]], num_gpus: int = 8, rerank: bool = True, save_both: bool = True faiss_gpu: bool = True, save_name: str = \u0026#34;embeddings_huge_hardneg_dataset_rerank\u0026#34;, model_id: Union[str, None] = \u0026#34;Alibaba-NLP/gte-multilingual-reranker-base\u0026#34;, k_hard: int = 1, random_n: int = 0, ) -\u0026gt; Dict[int, int]: \u0026#34;\u0026#34;\u0026#34; Create hard negative examples using parallel GPU processing with FAISS. Args: index: FAISS index containing document embeddings faiss_to_doc_idx: Mapping from FAISS indices to document indices doc_list: List of document texts query_idx_relations: Dictionary mapping query indices to their positive/negative relations num_gpus: Number of GPUs to use for parallel processing rerank: Whether to use reranking on candidates rerank: save both rerank + non-rerank faiss_gpu: Whether to use GPU acceleration for FAISS save_name: Name for saving the resulting dataset model_id: Name of HF Reranker model k_hard: the number of hard negatives to find for each query random_n: the number of random negatives to add to each query Returns: Dict[int, int]: Mapping of query indices to their hard negative indices The function: 1. Distributes work across multiple GPUs 2. Finds hard negatives using FAISS similarity search 3. Optionally reranks candidates 4. Creates and saves a dataset with the hard negatives \u0026#34;\u0026#34;\u0026#34; input_queue = mp.Queue() output_queue = mp.Queue() workers = [] for gpu_id in range(num_gpus): worker = ParallelFaissSearcher( gpu_id, input_queue, output_queue, index, faiss_to_doc_idx, doc_list, query_idx_relations, rerank, model_id=model_id, faiss_use_gpu=faiss_gpu k_hard=k_hard ) workers.append(worker) query_indices = list(query_idx_relations.keys()) batch_size = len(query_indices) // (num_gpus * 4) for i in range(0, len(query_indices), batch_size): gpu_id = (i // batch_size) % num_gpus batch = query_indices[i:i + batch_size] input_queue.put(batch) for _ in workers: input_queue.put(None) for worker in workers: worker.start() hard_negatives = {} hard_negatives_rerank = {} completed = 0 while completed \u0026lt; num_gpus: result = output_queue.get() if result is None: completed += 1 continue if rerank: key = list(result.keys())[0] result[key][1] hard_negatives_rerank[key] = result[key][1] hard_negatives[key] = result[key][0] else: hard_negatives.update(result) if rerank: if save_both: hardneg_dataset = create_hardneg_dataset( query_idx_relations, hard_negatives, doc_list, random_n ) hardneg_dataset.save_to_disk(save_name) hardneg_dataset = create_hardneg_dataset( query_idx_relations, hard_negatives_rerank, doc_list, random_n ) hardneg_dataset.save_to_disk(save_name + \u0026#39;_rerank\u0026#39;) else: hardneg_dataset = create_hardneg_dataset( query_idx_relations, hard_negatives, doc_list, random_n ) hardneg_dataset.save_to_disk(save_name) return hard_negatives def create_aligned_datasets(datasets_dict: DatasetDict) -\u0026gt; DatasetDict: \u0026#34;\u0026#34;\u0026#34;Create aligned datasets using ordered hash-based matching.\u0026#34;\u0026#34;\u0026#34; # Get the smallest dataset as reference smallest_dataset = min(datasets_dict.values(), key=len) # Create ordered master list of hash pairs from smallest dataset master_hashes = [(hashlib.sha256(q.encode()).hexdigest(), hashlib.sha256(p.encode()).hexdigest()) for q, p in zip(smallest_dataset[\u0026#39;query\u0026#39;], smallest_dataset[\u0026#39;positive\u0026#39;])] # Find matching indices in each dataset that preserve order aligned_datasets = {} for name, dataset in datasets_dict.items(): current_hashes = [(hashlib.sha256(q.encode()).hexdigest(), hashlib.sha256(p.encode()).hexdigest()) for q, p in zip(dataset[\u0026#39;query\u0026#39;], dataset[\u0026#39;positive\u0026#39;])] # Keep track of which master hashes we\u0026#39;ve found found_indices = [] for master_hash in master_hashes: if master_hash in current_hashes: idx = current_hashes.index(master_hash) found_indices.append(idx) aligned_datasets[name] = dataset.select(found_indices) def verify_alignment(): print(\u0026#34;\\nAlignment Statistics:\u0026#34;) sizes = [len(dataset) for dataset in aligned_datasets.values()] print(f\u0026#34;All datasets aligned to size: {sizes[0]}\u0026#34;) assert all(size == sizes[0] for size in sizes), \u0026#34;Dataset sizes don\u0026#39;t match\u0026#34; base_hashes = [(hashlib.sha256(q.encode()).hexdigest(), hashlib.sha256(p.encode()).hexdigest()) for q, p in zip(list(aligned_datasets.values())[0][\u0026#39;query\u0026#39;], list(aligned_datasets.values())[0][\u0026#39;positive\u0026#39;])] for dataset in aligned_datasets.values(): current_hashes = [(hashlib.sha256(q.encode()).hexdigest(), hashlib.sha256(p.encode()).hexdigest()) for q, p in zip(dataset[\u0026#39;query\u0026#39;], dataset[\u0026#39;positive\u0026#39;])] assert base_hashes == current_hashes, \u0026#34;Hash sequences don\u0026#39;t match exactly\u0026#34; verify_alignment() return aligned_datasets def save_aligned_data(aligned_splits: dict, base_path: str) -\u0026gt; None: \u0026#34;\u0026#34;\u0026#34;Save each dataset as a DatasetDict to disk.\u0026#34;\u0026#34;\u0026#34; for dataset_name, dataset in aligned_splits.items(): save_path = f\u0026#34;{base_path}/{dataset_name}\u0026#34; dataset.save_to_disk(save_path) Essentially, we create a queue/job process that attempts to overload the model/faiss bottlenecks. We allocate a number of processes to our gpus where n_gpus == n_processes. You have to know size of model/datasizing so you\u0026rsquo;re not having your gpu processes compete for non-existant memory (eg 20 processes while only having 8 GPUs), so back-of-the-napkin math means just keep it simple and set n_processes equal to n_gpus.\nWe multiprocess on each gpu in our stack (8x) and first find hard-negatives via BM25, then a model alone, then a model + re-ranker.\nHere is an image of what it looks like (or at least what I hope to see), that the reranker will provide the best result for the given query - caveat: to the astute reader, we\u0026rsquo;re using an untuned reranker (eg not tuned for law data) - so I don\u0026rsquo;t expect it to be the best, I just expect it to do better than average. But in reality, I have no real idea what to expect, I sort of just want to \u0026lsquo;see with my own eyes what happens\u0026rsquo; when we play around with various hard-negative mining strategies.\nok, less pontificating, more code:\nDefine our dataset/model configs dataset_config = { \u0026#34;name\u0026#34;: \u0026#34;joelniklaus/legal_case_document_summarization\u0026#34;, \u0026#34;splits\u0026#34;: [\u0026#34;train\u0026#34;, \u0026#34;test\u0026#34;], \u0026#34;pairs\u0026#34;: [(\u0026#34;judgement\u0026#34;, \u0026#34;summary\u0026#34;)] } num_gpus=8 batch_size=1 embedding_size=1536 model_id = \u0026#34;Alibaba-NLP/gte-Qwen2-1.5B-instruct\u0026#34; Create our datasets # create our base bm25 dataset bm25_dataset = create_bm25_data(dataset_config, save_name=\u0026#34;legal_case_document_summarization_bm25\u0026#34;) # create our faiss index using our main doc_list, doc_to_idx, idx_to_doc, query_idx_relations = create_document_list_with_relations(bm25_dataset) index, faiss_to_doc_idx = create_fast_index( doc_list, num_gpus=num_gpus, batch_size=batch_size, embedding_size=embedding_size, model_id=model_id ) _ = create_parallel_hard_negatives(\u0026#34;embeddings.index\u0026#34;, faiss_to_doc_idx, doc_list, query_idx_relations, num_gpus=num_gpus, rerank=True, save_both=True, faiss_gpu=True, save_name=\u0026#34;legal_case_document_summarization\u0026#34;, k_hard=1, random_n=0) (our datasets will be created at this point.)\nsuper simple training script import argparse from datasets import load_from_disk from sentence_transformers import ( SentenceTransformer, SentenceTransformerTrainer, SentenceTransformerTrainingArguments, ) from sentence_transformers.evaluation import TripletEvaluator from sentence_transformers.losses import CachedMultipleNegativesRankingLoss from sentence_transformers.training_args import BatchSamplers import torch torch.set_float32_matmul_precision(\u0026#39;high\u0026#39;) def train(dataset_path, dataset_name): torch.set_float32_matmul_precision(\u0026#39;high\u0026#39;) # parse the lr \u0026amp; model name lr = 8e-5 model_name = \u0026#34;answerdotai/ModernBERT-base\u0026#34; model_shortname = model_name.split(\u0026#34;/\u0026#34;)[-1] # 1. Load a model to finetune model = SentenceTransformer(model_name) # 2. Load a dataset to finetune on train_dataset = load_from_disk(dataset_path) # 3. Define a loss function loss = CachedMultipleNegativesRankingLoss(model, mini_batch_size=16) # Increase mini_batch_size if you have enough VRAM run_name = f\u0026#34;{model_shortname}-{dataset_name}\u0026#34; # 4. (Optional) Specify training arguments args = SentenceTransformerTrainingArguments( # Required parameter: output_dir=f\u0026#34;output/{model_shortname}/{run_name}\u0026#34;, # Optional training parameters: num_train_epochs=1, per_device_train_batch_size=2, per_device_eval_batch_size=2, warmup_ratio=0.05, fp16=False, # Set to False if GPU can\u0026#39;t handle FP16 bf16=True, # Set to True if GPU supports BF16 batch_sampler=BatchSamplers.NO_DUPLICATES, # (Cached)MultipleNegativesRankingLoss benefits from no duplicates learning_rate=lr, # Optional tracking/debugging parameters: save_strategy=\u0026#34;steps\u0026#34;, save_steps=200, save_total_limit=2, logging_steps=200, run_name=run_name, # Used in `wandb`, `tensorboard`, `neptune`, etc. if installed ) # 6. Create a trainer \u0026amp; train trainer = SentenceTransformerTrainer( model=model, args=args, train_dataset=train_dataset, loss=loss, ) trainer.train() # 8. Save the model model.save_pretrained(f\u0026#34;output/{run_name}/final\u0026#34;) if __name__ == \u0026#39;__main__\u0026#39;: parser = argparse.ArgumentParser() parser.add_argument(\u0026#34;--dataset_path\u0026#34;, type=str) parser.add_argument(\u0026#34;--dataset_name\u0026#34;, type=str) args = parser.parse_args() train(args.dataset_path, args.dataset_name) the unresonable effectiveness of bm25? In my opinion, there\u0026rsquo;s no real edge here to anyone - we can discount AILA/LegalQuAD for now - and notice that in the other three tasks, each method (bm25, rerank, no-rerank) gets one victory a piece.\nAs for the foreign legal data, bm25 does a better on two indian legal tasks, but on the german task, bm25 is the worst. I do notice one interesting thing on all tasks, however, and that\u0026rsquo;s rerank and no-rerank seem to be inversely correlated on certain performance spikes: when rerank goes up, no-rerank goes down and vise-versa - so let\u0026rsquo;s see what happens when we create combinations of these pairs and train with that data and see what happens:\n Cool! Much better performance. It seems that some combination of the simplicity of bm25 + the rerank method yield superior performance in all areas except Legal Summarization - which is interesting considering our training dataset is exclusively legal summarization data, so contextually I suppose that sort of makes sense. We don\u0026rsquo;t know however if this was due exclusively to having more negative examples per query, so we\u0026rsquo;ll have to test with random data (1 vs. 2. vs. 3, etc) and see if it\u0026rsquo;s just the presence of an extra example that makes the difference \u0026ndash; we\u0026rsquo;ll test on that later.\nThe other idea I had was to let these models overtrain for many epochs \u0026lsquo;just to see what happens\u0026rsquo;. So you\u0026rsquo;ll notice that in all the models, they\u0026rsquo;re trained for 30 epochs when that\u0026rsquo;s probably far, far too much. The reason for this is I sort of had a hunch that because bm25 is \u0026ldquo;semantically simpler\u0026rdquo;, it should be the \u0026lsquo;most resistant to overfitting\u0026rsquo; as compared to the other two models. So I wanted to see if the model would overfit to the training data and then what would happen if we let it overtrain for a long time. It turned out that\u0026rsquo;s not necessarily the case, rerank actually performs better than bm25 on the subtasks we\u0026rsquo;re most curious about.\nAlso, and some what serenditiously, based on the observations a poster made on a random twitter thread, I was curious what would happen if we average/ensemble the weights of the three ((bm25 + rerank + no-rerank) / 3) at each checkpoint and measure the perf: Much better on the foreign legal tasks - and the take away is probably that if you\u0026rsquo;re compute constrained and want the best performing model, this is probably your best bet all things considered: simply take the average of various data-mixes. It kind of goes to show that implicit regularization is so unbelievably powerful that it should probably be one of your first go to perf increase tools.\nIntuitively, the way I\u0026rsquo;ve been thinking about this is that bm25 picked hard-negatives alone is probably some form of regularization in some simplistic capacity.\nA few interesting rabit hole ideas Looking at the training curves, an idea I had was to see if we could somehow use the stability of bm25 and use varying degrees of noise/hard-negatives to get more lift in performance - as well as test varying amounts of hard-negatives per query:\nwhat impact does bm25 have when it\u0026rsquo;s supplemented by specifically chosen vs random hardnegs;\nwhat if we progressively add more hard-negatives as the training steps move forward (epoch 1: 1 hardneg, epoch 2: 2 hardnegs, \u0026hellip; epoch 2000: 2000 hardnegs),\nwe make it, theoretically more difficult/different looks as training continues: either by always showing the model a new batch of hard-negatives, or by increasing the number of hard-negatives per batch?\nAt this point I\u0026rsquo;m sort of just trying things based on the loss-curves/eval scores - like I notice that bm25 is sort of resistant to overtraining or at least has good stable training grad-norms, so I sort of just picked a random idea to start and let the research muse take me where it meant to take me after that.\nCODE I started writing one comprehensive/thoughtful code base, but when the research god whispers, you hack something together and just run the script.\nBelow are all the pieces of code for all the random experiments I ran in no particular order::\nRandom At the end of each epoch, we take each sample (query, positive) in our dateset and we randomly chose a hard-negative to pair with the query/positive pair. (this method did horribly)\n#pseudo code: def on_epoch_end(self, **kwargs): for qid, relations in self.query_idx_relations: cache_entries = self.query_caches[qid] current_len = len(cache_entries) if current_len \u0026lt; self.k_cache: needed = self.k_cache - current_len # Example: random pick from global_pool import random candidates = set(self.idx_to_doc) - set(relations[\u0026#39;positive\u0026#39;]) new_candidates = random.sample(candidates, min(len(candidates), needed)) # Compute hardness for each new candidate q_text = self.doc_list[qid] pos_text = self.doc_list[list(relations[\u0026#39;positive\u0026#39;])[0]] new_entries = [] for cand_idx in new_candidates: cand_text = self.doc_list[cand_idx] new_entries.append(cand_text) self.query_caches[qid].extend(new_entries) .... Model-Based Mining: At the end of each epoch (or after some time, we can set the start time with \u0026lsquo;epoch_start\u0026rsquo;), for each piece of text w/in our dataset ({query1, query2, ..,positive1, positive2, queryN, positiveN}), we take the current model and we embed the entire dataset. For each original query/positive pair, we find the K most similar texts for each query that isn\u0026rsquo;t it\u0026rsquo;s corresponding positive. We add some parameters to select the number K and if we want to add random hard-negatives as noise.\nclass HardNegMiningCallback(TrainerCallback): \u0026#34;\u0026#34;\u0026#34; A callback that, at the end of each epoch: - If we\u0026#39;re on the main process, sets CUDA_VISIBLE_DEVICES to \u0026#39;6,7\u0026#39; - Calls create_hardnegs(...) - Waits for completion - Reloads the dataset with newly generated negatives \u0026#34;\u0026#34;\u0026#34; def __init__( self, trainer, # \u0026lt;--- pass in the trainer itself create_hardnegs_fn, original_dataset_path, adaptive_dataset_path, model_path=None, gpus_for_hardneg=\u0026#34;6,7\u0026#34;, reload_after=True, num_gpus_for_hardneg=2, batch_size=256, embedding_size=1536, epoch_start=1, k_hard=1, random_n=0, ): \u0026#34;\u0026#34;\u0026#34; Args: trainer: The HF Trainer (or a subclass) that we want to manipulate. create_hardnegs_fn: A function with signature create_hardnegs(dataset_path, save_path, model_path, num_gpus, batch_size, embedding_size) original_dataset_path: path to the dataset we started training on adaptive_dataset_path: where we save the newly created dataset with fresh negatives model_path: path to the current model or checkpoint (trainer.args.output_dir by default) gpus_for_hardneg: which GPUs (by index) to use for negative generation (e.g. \u0026#34;6,7\u0026#34;) reload_after: if True, reload the newly generated dataset into trainer.train_dataset num_gpus_for_hardneg: how many GPUs we actually want to use inside create_hardnegs batch_size, embedding_size: additional parameters to pass to create_hardnegs epoch_start: only start doing negative generation after this epoch number k_hard: number of hard negatives to generate per query \u0026#34;\u0026#34;\u0026#34; super().__init__() self.trainer = trainer # store a reference to the trainer self.create_hardnegs_fn = create_hardnegs_fn self.original_dataset_path = original_dataset_path self.adaptive_dataset_path = adaptive_dataset_path self.model_path = model_path self.gpus_for_hardneg = gpus_for_hardneg self.reload_after = reload_after self.num_gpus_for_hardneg = num_gpus_for_hardneg self.batch_size = batch_size self.embedding_size = embedding_size self.epoch_start = epoch_start self.k_hard = k_hard self.random_n = random_n self.epoch_counter = 0 def on_epoch_end( self, args, state: TrainerState, control: TrainerControl, **kwargs ): self.epoch_counter += 1 if self.epoch_counter \u0026lt; self.epoch_start: return # skip until epoch_start is reached # Determine if we\u0026#39;re rank=0 in distributed training # Hugging Face sets this in TrainerState or we can check: # if self.trainer.is_world_process_zero(): # or # if state.is_world_process_zero # (They generally do the same in modern versions.) is_main_process = getattr(self.trainer, \u0026#34;is_world_process_zero\u0026#34;, None) if callable(is_main_process): is_main_process = self.trainer.is_world_process_zero() elif hasattr(state, \u0026#34;is_world_process_zero\u0026#34;): is_main_process = state.is_world_process_zero else: # Fallback: assume single process or assume main is_main_process = True if is_main_process: print(f\u0026#34;\\n[HardNegMiningCallback] Epoch {self.epoch_counter} ended. Creating new hard negatives on GPUs {self.gpus_for_hardneg}.\u0026#34;) # 1) Save the original CUDA_VISIBLE_DEVICES so we can restore it later original_cuda_env = os.environ.get(\u0026#34;CUDA_VISIBLE_DEVICES\u0026#34;, \u0026#34;\u0026#34;) print(f\u0026#34;[HardNegMiningCallback] Original CUDA_VISIBLE_DEVICES: {original_cuda_env}\u0026#34;) try: # 2) Set env so that only GPUs 6,7 are visible for negative generation os.environ[\u0026#34;CUDA_VISIBLE_DEVICES\u0026#34;] = self.gpus_for_hardneg print(f\u0026#34;[HardNegMiningCallback] Setting CUDA_VISIBLE_DEVICES to {os.environ[\u0026#39;CUDA_VISIBLE_DEVICES\u0026#39;]}\u0026#34;) # 3) Determine the current model checkpoint path model_ckpt_path = self.model_path or self.trainer.args.output_dir # we get the final checkpoint h checkpoints = [d for d in os.listdir(model_ckpt_path) if d.startswith(\u0026#39;checkpoint-\u0026#39;)] # Sort checkpoints by epoch number checkpoints.sort(key=lambda x: int(x.split(\u0026#39;-\u0026#39;)[1])) # get last checkpoint checkpoint_path = os.path.join(model_ckpt_path, checkpoints[-1]) # 4) Call the user-provided function self.create_hardnegs_fn( dataset_path=self.original_dataset_path, save_path=self.adaptive_dataset_path, model_path=checkpoint_path, num_gpus=self.num_gpus_for_hardneg, batch_size=self.batch_size, embedding_size=self.embedding_size, k_hard=self.k_hard, random_n=self.random_n, ) finally: # 5) Restore original CUDA_VISIBLE_DEVICES for training os.environ[\u0026#34;CUDA_VISIBLE_DEVICES\u0026#34;] = original_cuda_env print(f\u0026#34;[HardNegMiningCallback] Restored CUDA_VISIBLE_DEVICES to {os.environ[\u0026#39;CUDA_VISIBLE_DEVICES\u0026#39;]}\u0026#34;) # 6) Optional: Wait for all ranks if in distributed mode if dist.is_initialized(): dist.barrier() # 7) Reload the updated dataset for the next epoch if self.reload_after: print(f\u0026#34;[HardNegMiningCallback] Reloading dataset from {self.adaptive_dataset_path}\u0026#34;) updated_dataset = load_from_disk(self.adaptive_dataset_path) self.trainer.train_dataset = updated_dataset # Some versions store internally in `_train_dataset` self.trainer._train_dataset = updated_dataset self.trainer.train_dataloader = self.trainer.get_train_dataloader() # If you want to keep updating original_dataset_path to the new path each time self.original_dataset_path = self.adaptive_dataset_path def build_new_training_lr_schedule(self, new_dataset): current_lr = self.trainer.optimizer.param_groups[0][\u0026#39;lr\u0026#39;] # Calculate new total steps based on new dataset size num_training_steps = ( len(new_dataset) * self.trainer.args.num_train_epochs // (self.trainer.args.per_device_train_batch_size * self.trainer.args.gradient_accumulation_steps) ) # Create new scheduler starting from current lr # update the scheduler to accompany the new size new_scheduler = get_cosine_schedule_with_warmup( self.trainer.optimizer, num_warmup_steps=0, # No warmup since we\u0026#39;re continuing num_training_steps=num_training_steps, initial_lr=current_lr ) self.trainer.lr_scheduler = new_scheduler Margin Based Random Mining: So in this version, we do the same model-based mining as before, but we add a margin parameter to compute whether the model has \u0026lsquo;learned\u0026rsquo; the example or not - and instead of removing the triplet prematurely, we keep continually showing the model the same triplet if the margin is over a certain threshold, if not, we add a new randomly selected hard-negative. The thgouht here is want to keep only truly hard negatives in the cache and vacate those that the model learns.\nclass AdaptiveHardNegMiningCallback(TrainerCallback): \u0026#34;\u0026#34;\u0026#34; Maintains a per-query cache of negatives. On each epoch end (or step end), it: - samples from the cache for training - updates each query\u0026#39;s cache by removing \u0026#34;easy\u0026#34; negatives - periodically refills from a global pool or partial re-embedding \u0026#34;\u0026#34;\u0026#34; def __init__( self, trainer, # reference to the Trainer (or your model) train_dataset, # list of query texts or IDs model, # the retrieval model or sentence-transformer k_cache=5, # how many negatives we keep in each cache remove_threshold=-0.05, # if margin \u0026lt; -0.05 =\u0026gt; negative is \u0026#34;easy\u0026#34; refill_every=2, # every 2 epochs, we try to refill or re-embed create_hardnegs_fn=None, original_dataset_path=None, adaptive_dataset_path=None, model_path=None, gpus_for_hardneg=\u0026#34;6,7\u0026#34;, reload_after=True, num_gpus_for_hardneg=2, batch_size=256, embedding_size=1536, epoch_start=1, k_hard=1, random_n=0, ): super().__init__() self.trainer = trainer self.model = model #self.global_pool = global_pool or [] self.k_cache = k_cache self.remove_threshold = remove_threshold self.refill_every = refill_every self.create_hardnegs_fn = create_hardnegs_fn self.original_dataset_path = original_dataset_path self.adaptive_dataset_path = adaptive_dataset_path self.model_path = model_path self.gpus_for_hardneg = gpus_for_hardneg self.reload_after = reload_after self.num_gpus_for_hardneg = num_gpus_for_hardneg self.batch_size = batch_size self.embedding_size = embedding_size self.epoch_start = epoch_start self.k_hard = k_hard self.random_n = random_n # Build initial caches # Each cache is a list of (neg_text, hardness) or just neg_text self.doc_list, self.doc_to_idx, self.idx_to_doc, self.query_idx_relations = create_document_list_with_relations(train_dataset) # alld self.query_caches = {} for qid in self.query_idx_relations: # pick up to k_cache from initial_negatives negs = self.query_idx_relations[\u0026#39;negatives\u0026#39;][:k_cache] self.query_caches[qid] = [(n, 999.0) for n in negs] # hardness placeholder, # (int, int) self.epoch_counter = 0 def on_epoch_begin(self, args, state: TrainerState, control: TrainerControl, **kwargs): \u0026#34;\u0026#34;\u0026#34; Optionally, we can sample a brand-new training dataset from caches here. Or do it on_step_end if you prefer a finer granularity. \u0026#34;\u0026#34;\u0026#34; # We can rebuild or shuffle the \u0026#34;train_dataset\u0026#34; from these caches self._build_train_dataset_from_caches() # No direct return needed, the trainer will continue def on_epoch_end( self, args, state: TrainerState, control: TrainerControl, **kwargs ): self.epoch_counter += 1 # 1) Update caches for each query -\u0026gt; remove easy negatives self._update_caches_remove_easy() # 2) Periodically refill caches from global pool or partial re-embedding if (self.epoch_counter % self.refill_every) == 0: self._refill_caches() # 3) Optionally: re-build training dataset from updated caches # so next epoch sees the new negatives self._build_train_dataset_from_caches() ################################################################ # Internals ################################################################ def _build_train_dataset_from_caches(self): \u0026#34;\u0026#34;\u0026#34; Rebuild a \u0026#39;training dataset\u0026#39; from the caches if we want to feed it into Trainer. We can do something like: train_dataset = List of InputExample(query, pos, neg) Then: self.trainer.train_dataset = \u0026lt;some huggingface Dataset\u0026gt; \u0026#34;\u0026#34;\u0026#34; new_samples = [] for qid, relations in self.query_idx_relations: cache_entries = self.query_caches[qid] for (neg_idx, hardness) in cache_entries: new_example = { \u0026#39;query\u0026#39;: self.doc_list[qid], \u0026#39;positive\u0026#39;: self.doc_list[list(relations[\u0026#39;positive\u0026#39;])[0]], \u0026#39;negative\u0026#39;: self.doc_list[neg_idx] } new_samples.append(new_example) # Now wrap that in a HF Dataset, or a custom dataset for your trainer # For quick pseudo-code: from datasets import Dataset train_dataset = Dataset.from_list(new_samples) self.trainer.train_dataset = train_dataset self.trainer._train_dataset = train_dataset # if needed # You might want to re-init dataloader: self.trainer.train_dataloader = self.trainer.get_train_dataloader() def _update_caches_remove_easy(self): \u0026#34;\u0026#34;\u0026#34; For each (query, neg), check if it\u0026#39;s now \u0026#39;easy\u0026#39; for the model: - if margin \u0026lt; remove_threshold =\u0026gt; remove from the cache - otherwise, keep it (it\u0026#39;s still hard) \u0026#34;\u0026#34;\u0026#34; for qid, relations in self.query_idx_relations: new_cache = [] q_text = self.doc_list[qid] pos_text = self.doc_list[list(relations[\u0026#39;positive\u0026#39;])[0]] for (neg_id, old_hardness) in self.query_caches[qid]: neg_text = self.doc_list[neg_id] margin = compute_hardness(self.model, q_text, pos_text, neg_text) if margin \u0026lt; self.remove_threshold: # It\u0026#39;s become easy =\u0026gt; we remove it pass else: # It\u0026#39;s still hard, keep it new_cache.append((neg_text, margin)) self.query_caches[qid] = new_cache def _refill_caches(self): \u0026#34;\u0026#34;\u0026#34; If the caches have fewer than k_cache negatives, we pick new ones from the global pool or re-embed \u0026amp; find new top-K. For simplicity, let\u0026#39;s do random from global_pool here. \u0026#34;\u0026#34;\u0026#34; # can do re-embed here for qid, relations in self.query_idx_relations: cache_entries = self.query_caches[qid] current_len = len(cache_entries) if current_len \u0026lt; self.k_cache: needed = self.k_cache - current_len # Example: random pick from global_pool import random candidates = set(self.idx_to_doc) - set(relations[\u0026#39;positive\u0026#39;]) new_candidates = random.sample(candidates, min(len(candidates), needed)) # Compute hardness for each new candidate q_text = self.doc_list[qid] pos_text = self.doc_list[list(relations[\u0026#39;positive\u0026#39;])[0]] new_entries = [] for cand_idx in new_candidates: cand_text = self.doc_list[cand_idx] margin = compute_hardness(self.model, q_text, pos_text, cand_text) new_entries.append((cand_text, margin)) self.query_caches[qid].extend(new_entries) def build_new_training_lr_schedule(self, new_dataset): current_lr = self.trainer.optimizer.param_groups[0][\u0026#39;lr\u0026#39;] # Calculate new total steps based on new dataset size num_training_steps = ( len(new_dataset) * self.trainer.args.num_train_epochs // (self.trainer.args.per_device_train_batch_size * self.trainer.args.gradient_accumulation_steps) ) # Create new scheduler starting from current lr # update the scheduler to accompany the new size new_scheduler = get_cosine_schedule_with_warmup( self.trainer.optimizer, num_warmup_steps=0, # No warmup since we\u0026#39;re continuing num_training_steps=num_training_steps, initial_lr=current_lr ) self.trainer.lr_scheduler = new_scheduler Margin Based Bandit Mining: In this version, we combine both the original model-based mining and the margin-based mining and instead of subjectively picking a margin threshold to empty the hard-negative cache for a given sample, we use an Epsilon Greedy Bandit to pick the margin value. This one is cool, so I\u0026rsquo;m patting myself on the back for this.\nEpsilon Greedy knowledge check: we provide a list of arms where at each index is the margin threshold that we\u0026rsquo;ve assigned to that arm, we also provide an epsilon value - which is the probability for exploring instead of taking the highest reward, and an initial reward of 0 for each arm. On the epoch\u0026rsquo;s start, we use the select arm function which picks the given arm either randomly based on epsilon or best performing arm (the one with the max qvalues).\nOn the epoch end, we update the q-value for the chosen arm based on the reward from our eval set.\nclass EpsilonGreedyBandit: \u0026#34;\u0026#34;\u0026#34; Each arm is a candidate threshold. We track (mean_reward, count) for each arm. We pick arms via epsilon-greedy. \u0026#34;\u0026#34;\u0026#34; def __init__(self, arms, epsilon=0.1, initial_value=0.0): \u0026#34;\u0026#34;\u0026#34; :param arms: list of thresholds (floats) :param epsilon: probability of exploring :param initial_value: initial assumed reward for each arm \u0026#34;\u0026#34;\u0026#34; self.arms = arms self.epsilon = epsilon self.q_values = [initial_value for _ in arms] # estimated mean reward self.counts = [0 for _ in arms] # how many times each arm was used def select_arm(self): \u0026#34;\u0026#34;\u0026#34; Epsilon-greedy selection. \u0026#34;\u0026#34;\u0026#34; if random.random() \u0026lt; self.epsilon: # Explore return random.randint(0, len(self.arms)-1) else: # Exploit (pick best so far) return int(np.argmax(self.q_values)) def update(self, arm_index, reward): \u0026#34;\u0026#34;\u0026#34; Update estimates for the chosen arm. \u0026#34;\u0026#34;\u0026#34; self.counts[arm_index] += 1 c = self.counts[arm_index] old_q = self.q_values[arm_index] # incremental mean update new_q = old_q + (reward - old_q) / c self.q_values[arm_index] = new_q def get_arm_value(self, arm_index): return self.arms[arm_index] def __repr__(self): return f\u0026#34;EpsilonGreedyBandit(arms={self.arms}, q_values={self.q_values}, counts={self.counts})\u0026#34; class AdaptiveBanditHardNegMiningCallback(TrainerCallback): \u0026#34;\u0026#34;\u0026#34; Maintains a per-query cache of negatives. Uses a bandit to pick \u0026#39;remove_threshold\u0026#39; at the start of each epoch. \u0026#34;\u0026#34;\u0026#34; def __init__( self, trainer, # reference to the Trainer train_dataset, model, k_cache=5, bandit_arms=None, # list of candidate thresholds epsilon=0.1, refill_every=2, eval_metric_key=\u0026#34;eval_loss\u0026#34; # or \u0026#34;eval_accuracy\u0026#34;, \u0026#34;eval_mrr\u0026#34;, etc. ): super().__init__() self.trainer = trainer self.model = model self.k_cache = k_cache self.refill_every = refill_every self.eval_metric_key = eval_metric_key # 1) Build bandit if bandit_arms is None: # Provide default arms if not specified bandit_arms = [-0.1, -0.05, 0.0, 0.05, 0.1] self.bandit = EpsilonGreedyBandit(arms=bandit_arms, epsilon=epsilon) # 2) Build initial caches, etc. (similar to your original code) self.doc_list, self.doc_to_idx, self.idx_to_doc, self.query_idx_relations = create_document_list_with_relations(train_dataset) self.query_caches = {} for qid in self.query_idx_relations: negs = list(self.query_idx_relations[qid][\u0026#39;negative\u0026#39;]) # example # pick up to k_cache chosen = negs[:k_cache] self.query_caches[qid] = [(nid, 999.0) for nid in chosen] self.epoch_counter = 0 self.current_arm_index = None # which threshold arm we used this epoch self.remove_threshold = None # the actual threshold for this epoch def on_epoch_begin(self, args, state: TrainerState, control: TrainerControl, **kwargs): \u0026#34;\u0026#34;\u0026#34; 1) Pick an arm from bandit =\u0026gt; set remove_threshold 2) Rebuild the training dataset \u0026#34;\u0026#34;\u0026#34; self.epoch_counter += 1 # (A) Bandit chooses which arm to use self.current_arm_index = self.bandit.select_arm() self.remove_threshold = self.bandit.get_arm_value(self.current_arm_index) print(f\u0026#34;[BanditCallback] Epoch={self.epoch_counter}, chosen threshold={self.remove_threshold:.3f}\u0026#34;) # (B) Re-build / shuffle training data self._build_train_dataset_from_caches() def on_epoch_end(self, args, state: TrainerState, control: TrainerControl, **kwargs): \u0026#34;\u0026#34;\u0026#34; 1) Update caches =\u0026gt; remove easy negatives 2) Optionally refill 3) Evaluate =\u0026gt; measure reward =\u0026gt; update bandit \u0026#34;\u0026#34;\u0026#34; # 1) Remove easy negatives self._update_caches_remove_easy_batch() # 2) Refill if needed if (self.epoch_counter % self.refill_every) == 0: self._refill_caches_batch() # 3) Evaluate =\u0026gt; compute reward # We can do a trainer.evaluate(...) or rely on existing results in state.log_history # For example, let\u0026#39;s see if we have an eval result in state.log_history # We\u0026#39;ll find the last logged eval_metric_key reward = 0.0 for entry in reversed(state.log_history): if self.eval_metric_key in entry: reward = entry[self.eval_metric_key] break print(f\u0026#34;[BanditCallback] Epoch={self.epoch_counter}, performance metric={reward:.4f}\u0026#34;) # 4) Update bandit with reward self.bandit.update(self.current_arm_index, reward) # Optionally re-build the dataset after removing negatives self._build_train_dataset_from_caches() ############################################ # The same helper methods as before ############################################ def _build_train_dataset_from_caches(self): # Rebuild a \u0026#39;training dataset\u0026#39; from caches new_samples = [] for qid, rels in self.query_idx_relations.items(): cache_entries = self.query_caches[qid] q_text = self.doc_list[qid] # pick 1 positive from rels pos_text_id = list(rels[\u0026#39;positive\u0026#39;])[0] pos_text = self.doc_list[pos_text_id] for (neg_id, hardness) in cache_entries: neg_text = self.doc_list[neg_id] new_samples.append({ \u0026#34;query\u0026#34;: q_text, \u0026#34;positive\u0026#34;: pos_text, \u0026#34;negative\u0026#34;: neg_text }) from datasets import Dataset train_dataset = Dataset.from_list(new_samples) self.trainer.train_dataset = train_dataset self.trainer._train_dataset = train_dataset self.trainer.train_dataloader = self.trainer.get_train_dataloader() def _update_caches_remove_easy(self): for qid, rels in self.query_idx_relations.items(): q_text = self.doc_list[qid] pos_text_id = list(rels[\u0026#39;positive\u0026#39;])[0] pos_text = self.doc_list[pos_text_id] new_cache = [] for (neg_id, old_hardness) in self.query_caches[qid]: neg_text = self.doc_list[neg_id] margin = compute_hardness(self.model, q_text, pos_text, neg_text) # use self.remove_threshold from the bandit if margin \u0026lt; self.remove_threshold: # remove pass else: new_cache.append((neg_id, margin)) self.query_caches[qid] = new_cache def _refill_caches(self): # just a skeleton: pick random negatives from the corpus to fill up to k_cache import random for qid, rels in self.query_idx_relations.items(): current_len = len(self.query_caches[qid]) if current_len \u0026lt; self.k_cache: needed = self.k_cache - current_len q_text = self.doc_list[qid] pos_text = self.doc_list[list(rels[\u0026#39;positive\u0026#39;])[0]] # randomly sample from all doc_ids except positives or qid all_candidates = list(set(self.doc_to_idx.values()) - {qid} - set(rels[\u0026#39;positive\u0026#39;])) new_cands = random.sample(all_candidates, min(len(all_candidates), needed)) new_entries = [] for cidx in new_cands: cand_text = self.doc_list[cidx] m = compute_hardness(self.model, q_text, pos_text, cand_text) new_entries.append((cidx, m)) self.query_caches[qid].extend(new_entries) def _refill_caches_batch(self): \u0026#34;\u0026#34;\u0026#34; Similar to _refill_caches, but we do all margin computations in batch mode to reduce overhead. \u0026#34;\u0026#34;\u0026#34; import random batch_q = [] batch_p = [] batch_n = [] index_map = [] # (qid, candidate_neg_id, index_in_batch) # 1) Gather all needed candidates from each query for qid, rels in self.query_idx_relations.items(): current_len = len(self.query_caches[qid]) needed = self.k_cache - current_len if needed \u0026lt;= 0: continue q_text = self.doc_list[qid] pos_text = self.doc_list[list(rels[\u0026#39;positive\u0026#39;])[0]] # Sample from all candidates except qid or positives all_candidates = ( set(self.doc_to_idx.values()) - {qid} - set(rels[\u0026#39;positive\u0026#39;]) ) if len(all_candidates) == 0: continue needed_cands = random.sample(all_candidates, min(len(all_candidates), needed)) # 2) Collect them in a batch list for cand_idx in needed_cands: cand_text = self.doc_list[cand_idx] batch_q.append(q_text) batch_p.append(pos_text) batch_n.append(cand_text) index_map.append((qid, cand_idx)) if not batch_q: # nothing to refill return # 3) Batch compute margins in chunks chunk_size = 256 margin_list = [] for start in range(0, len(batch_q), chunk_size): end = start + chunk_size q_chunk = batch_q[start:end] p_chunk = batch_p[start:end] n_chunk = batch_n[start:end] margins = compute_batch_hardness(self.model, q_chunk, p_chunk, n_chunk) margin_list.extend(margins) # 4) Assign margins back to the caches # index_map[i] -\u0026gt; (qid, cand_idx) # margin_list[i] -\u0026gt; margin for i, (qid, cand_idx) in enumerate(index_map): margin_val = margin_list[i] # just append to the cache self.query_caches[qid].append((cand_idx, margin_val)) def _update_caches_remove_easy_batch(self): # 1) Gather all (q_text, p_text, n_text) in a big list batch_q = [] batch_p = [] batch_n = [] index_map = [] # keep track of (qid, neg_id, index_in_batch) for qid, rels in self.query_idx_relations.items(): q_text = self.doc_list[qid] pos_text_id = list(rels[\u0026#39;positive\u0026#39;])[0] p_text = self.doc_list[pos_text_id] new_cache_entries = [] for idx, (neg_id, old_hardness) in enumerate(self.query_caches[qid]): neg_text = self.doc_list[neg_id] batch_q.append(q_text) batch_p.append(p_text) batch_n.append(neg_text) index_map.append((qid, idx)) # so we know which cache entry this corresponds to # 2) Possibly chunk it if it\u0026#39;s very large margin_list = [] chunk_size = 256 for start in range(0, len(batch_q), chunk_size): end = start + chunk_size q_chunk = batch_q[start:end] p_chunk = batch_p[start:end] n_chunk = batch_n[start:end] margins = compute_batch_hardness(self.model, q_chunk, p_chunk, n_chunk) margin_list.extend(margins) # margin_list is now same length as index_map # 3) Go back through index_map and remove easy for (i, (qid, cache_idx)) in enumerate(index_map): margin = margin_list[i] if margin \u0026lt; self.remove_threshold: # mark for removal self.query_caches[qid][cache_idx] = None else: # store updated hardness self.query_caches[qid][cache_idx] = (self.query_caches[qid][cache_idx][0], margin) # 4) filter out removed for qid in self.query_idx_relations: self.query_caches[qid] = [x for x in self.query_caches[qid] if x is not None] def build_new_training_lr_schedule(self, new_dataset): current_lr = self.trainer.optimizer.param_groups[0][\u0026#39;lr\u0026#39;] # Calculate new total steps based on new dataset size num_training_steps = ( len(new_dataset) * self.trainer.args.num_train_epochs // (self.trainer.args.per_device_train_batch_size * self.trainer.args.gradient_accumulation_steps) ) # Create new scheduler starting from current lr # update the scheduler to accompany the new size new_scheduler = get_cosine_schedule_with_warmup( self.trainer.optimizer, num_warmup_steps=0, # No warmup since we\u0026#39;re continuing num_training_steps=num_training_steps, initial_lr=current_lr ) self.trainer.lr_scheduler = new_scheduler Margin Bandit Mining w/ Self Selection using FAISS: We use the same bandit based structure from Margin Based Bandit Mining but instead of only selecting a random example, we use the same embedding strategy from Model-Based Mining to select the K most difficult hard-negatives using the model itself + any random noise we might want to add via random_n. So this one we basically have a combination of self reinforcement learning with faiss + reinforcement learning on the hard-negative parameters themselves. This was really cool and I think it\u0026rsquo;s a really good way to do bandit mining. The problem is it\u0026rsquo;s slow and incremental w/ respect to to choosing parameters. I think what I\u0026rsquo;ll probably do is take this method, and sort of figure out \u0026lsquo;scaling laws\u0026rsquo; and see how to make good cold-start parameter choices given the known data/params/etc.\nclass AdaptiveBanditFaissCallback(TrainerCallback): \u0026#34;\u0026#34;\u0026#34; 1) Each epoch-end, re-embed entire doc_list using the current model =\u0026gt; build a FAISS index 2) Maintain a local \u0026#39;query_caches\u0026#39; for each query =\u0026gt; do remove_easy + refill using the FAISS index to find truly \u0026#34;hard\u0026#34; negatives 3) Re-build the trainer\u0026#39;s dataset from caches =\u0026gt; next epoch sees it 4) Use a bandit for dynamic \u0026#39;remove_threshold\u0026#39; \u0026#34;\u0026#34;\u0026#34; def __init__( self, trainer, train_dataset, model_path, k_cache=5, bandit_arms=None, epsilon=0.1, refill_every=2, eval_metric_key=\u0026#34;eval_loss\u0026#34;, epoch_start=1, embedding_dim=768, chunk_size=256, create_hardnegs_fn=create_parallel_hard_negatives, #create-hardnegs-function original_dataset_path=None, adaptive_dataset_path=None, gpus_for_hardneg=\u0026#34;6,7\u0026#34;, reload_after=True, num_gpus_for_hardneg=2, batch_size=256, embedding_size=1536, k_hard=1, random_n=0, ): super().__init__() self.trainer = trainer self.model_path = model_path self.k_cache = k_cache self.refill_every = refill_every self.eval_metric_key = eval_metric_key self.epoch_start = epoch_start self.embedding_dim = embedding_dim self.chunk_size = chunk_size self.create_hardnegs_fn = create_hardnegs_fn # dataset args self.original_dataset_path = original_dataset_path self.adaptive_dataset_path = adaptive_dataset_path self.num_gpus_for_hardneg = num_gpus_for_hardneg self.batch_size = batch_size self.embedding_size = embedding_size self.k_hard = k_hard self.random_n = random_n self.gpus_for_hardneg = gpus_for_hardneg self.reload_after = reload_after # 1) Bandit if bandit_arms is None: bandit_arms = [-0.1, -0.05, 0.0, 0.05, 0.1] self.bandit = EpsilonGreedyBandit(arms=bandit_arms, epsilon=epsilon) # 2) Basic data # 2) Build initial caches, etc. (similar to your original code) self.doc_list, self.doc_to_idx, self.idx_to_doc, self.query_idx_relations = create_document_list_with_relations(train_dataset) self.query_caches = {} for qid in self.query_idx_relations: negs = list(self.query_idx_relations[qid][\u0026#39;negative\u0026#39;]) # example # pick up to k_cache chosen = negs[:k_cache] self.query_caches[qid] = [(nid, 999.0) for nid in chosen] # For doc -\u0026gt; index row mapping: # a simple approach: doc_id == index in doc_list # Or if you have doc_to_idx, store that here.  self.epoch_counter = 0 self.remove_threshold = None self.current_arm_index = None # 4) Build initial training dataset from caches self._build_train_dataset_from_caches() def on_epoch_begin(self, args, state: TrainerState, control: TrainerControl, **kwargs): self.epoch_counter += 1 if self.epoch_counter \u0026lt; self.epoch_start: return # Bandit picks threshold self.current_arm_index = self.bandit.select_arm() self.remove_threshold = self.bandit.get_arm_value(self.current_arm_index) print(f\u0026#34;[BanditFaissCallback] Epoch={self.epoch_counter}, chosen threshold={self.remove_threshold:.3f}\u0026#34;) def on_epoch_end(self, args, state: TrainerState, control: TrainerControl, **kwargs): \u0026#34;\u0026#34;\u0026#34; 1) Re-embed entire doc_list =\u0026gt; build FAISS index 2) remove_easy_batch 3) refill_caches_batch (using FAISS to get truly \u0026#39;hard\u0026#39; negatives) 4) measure reward =\u0026gt; update bandit 5) rebuild training dataset =\u0026gt; next epoch \u0026#34;\u0026#34;\u0026#34; if self.epoch_counter \u0026lt; self.epoch_start: return # (A) measure reward from log_history reward = 0.0 for entry in reversed(state.log_history): if self.eval_metric_key in entry: reward = entry[self.eval_metric_key] break print(f\u0026#34;[BanditFaissCallback] Epoch={self.epoch_counter}, performance metric={reward:.4f}\u0026#34;) # (B) bandit update self.bandit.update(self.current_arm_index, reward) # (C) Only run on main process if self._is_main_process(state): # re-embed doc_list =\u0026gt; build FAISS index print(f\u0026#34;[BanditFaissCallback] Re-embedding entire doc_list to build FAISS index.\u0026#34;) self._build_faiss_index() # produce self.faiss_index, plus doc_id-\u0026gt;faiss_id? if dist.is_initialized(): dist.barrier() # (D) remove easy negatives in batch (using the newly built index for margin computations) self._update_caches_remove_easy_batch() # (E) refill from FAISS (rather than random), picking truly \u0026#34;hard\u0026#34; negatives if self.epoch_counter % self.refill_every == 0: self._refill_caches_batch() # (F) rebuild the trainer dataset self._build_train_dataset_from_caches() # --------------------------------------------------------------- # FAISS index + doc embeddings # --------------------------------------------------------------- def _build_faiss_index(self): \u0026#34;\u0026#34;\u0026#34; Re-encode all docs in doc_list, build a FAISS index in memory, store it as self.faiss_index. We\u0026#39;ll do something like: embs = self.model.encode(doc_list, batch_size=..., convert_to_tensor=True) self.faiss_index = build_faiss_index(embs, index_type=\u0026#34;IndexFlatIP\u0026#34;) \u0026#34;\u0026#34;\u0026#34; # 3) Determine the current model checkpoint path model_ckpt_path = self.model_path or self.trainer.args.output_dir # we get the final checkpoint h checkpoints = [d for d in os.listdir(model_ckpt_path) if d.startswith(\u0026#39;checkpoint-\u0026#39;)] # Sort checkpoints by epoch number checkpoints.sort(key=lambda x: int(x.split(\u0026#39;-\u0026#39;)[1])) # get last checkpoint checkpoint_path = os.path.join(model_ckpt_path, checkpoints[-1]) # 4) Call the user-provided function self.create_hardnegs_fn( dataset_path=self.original_dataset_path, save_path=self.adaptive_dataset_path, model_path=checkpoint_path, num_gpus=self.num_gpus_for_hardneg, batch_size=self.batch_size, embedding_size=self.embedding_size, k_hard=self.k_hard, random_n=self.random_n, ) self.faiss_index = faiss.read_index(\u0026#34;embeddings.index\u0026#34;) self.faiss_to_doc_idx = pickle.load(open(\u0026#39;faiss_to_doc_idx.pkl\u0026#39;, \u0026#39;rb\u0026#39;)) self.doc_id_to_faiss_id = {v: k for k, v in self.faiss_to_doc_idx.items()} def _compute_margin_faiss(self, qid, neg_id, pos_id): \u0026#34;\u0026#34;\u0026#34; We\u0026#39;ll fetch embeddings from self.faiss_index by reconstruct(neg_id), reconstruct(pos_id). For the query, we might store it or also do a dynamic approach. For now, let\u0026#39;s assume query is also in doc_list and qid is the same as doc_id. If queries differ from doc_list, we might store them separately or do model.encode once. This depends on how your queries are formed, but let\u0026#39;s assume qid is a doc in doc_list. \u0026#34;\u0026#34;\u0026#34; import numpy as np # reconstruct doc embeddings from faiss neg_vec = self.faiss_index.reconstruct(neg_id) # shape (d,) pos_vec = self.faiss_index.reconstruct(pos_id) q_vec = self.faiss_index.reconstruct(qid) # if queries are also among doc_list # cos_sim def cos_sim(a, b): return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b) + 1e-8) cos_qn = cos_sim(q_vec, neg_vec) cos_qp = cos_sim(q_vec, pos_vec) margin = cos_qn - cos_qp return margin # --------------------------------------------------------------- # Update + Refill # --------------------------------------------------------------- def _update_caches_remove_easy_batch(self): \u0026#34;\u0026#34;\u0026#34; We do batch approach: gather all (qid, pos, neg) from caches =\u0026gt; compute margin using FAISS =\u0026gt; remove if margin \u0026lt; self.remove_threshold. \u0026#34;\u0026#34;\u0026#34; self.faiss_index = faiss.read_index(\u0026#34;embeddings.index\u0026#34;) self.faiss_to_doc_idx = pickle.load(open(\u0026#39;faiss_to_doc_idx.pkl\u0026#39;, \u0026#39;rb\u0026#39;)) self.doc_id_to_faiss_id = {v: k for k, v in self.faiss_to_doc_idx.items()} batch_q = [] batch_p = [] batch_n = [] index_map = [] for qid, rels in self.query_idx_relations.items(): pos_list = list(rels.get(\u0026#39;positive\u0026#39;, [])) if not pos_list: continue pos_id = pos_list[0] # assume 1 positive cache_entries = self.query_caches[qid] for idx, (neg_id, old_hardness) in enumerate(cache_entries): batch_q.append(qid) batch_p.append(pos_id) batch_n.append(neg_id) index_map.append((qid, idx)) if not batch_q: return # compute margins in a loop or in parallel # We\u0026#39;ll do it in a naive loop of size len(batch_q), though we could do partial reconstruct on GPU, etc. margin_list = [] for (q, p, n) in zip(batch_q, batch_p, batch_n): m = self._compute_margin_faiss(q, n, p) margin_list.append(m) # remove or keep for i, (qid, cache_idx) in enumerate(index_map): margin_val = margin_list[i] if margin_val \u0026lt; self.remove_threshold: self.query_caches[qid][cache_idx] = None else: self.query_caches[qid][cache_idx] = (self.query_caches[qid][cache_idx][0], margin_val) # filter out removed for qid in self.query_idx_relations: self.query_caches[qid] = [x for x in self.query_caches[qid] if x is not None] def _refill_caches_batch(self): \u0026#34;\u0026#34;\u0026#34; Instead of random, we do a FAISS-based \u0026#39;retrieve_hard_negatives\u0026#39; for each query that needs more negs. \u0026#34;\u0026#34;\u0026#34; self.faiss_index = faiss.read_index(\u0026#34;embeddings.index\u0026#34;) self.faiss_to_doc_idx = pickle.load(open(\u0026#39;faiss_to_doc_idx.pkl\u0026#39;, \u0026#39;rb\u0026#39;)) self.doc_id_to_faiss_id = {v: k for k, v in self.faiss_to_doc_idx.items()} needed_info = [] for qid, rels in self.query_idx_relations.items(): current_len = len(self.query_caches[qid]) needed = self.k_cache - current_len if needed \u0026gt; 0: needed_info.append((qid, needed)) if not needed_info: return # we can retrieve top-K from FAISS for each qid. Suppose queries are also in doc_list =\u0026gt; same ID for (qid, needed) in needed_info: pos_list = list(self.query_idx_relations[qid].get(\u0026#39;positive\u0026#39;, [])) if not pos_list: continue pos_id = pos_list[0] # retrieve top-K \u0026#39;closest\u0026#39; docs to qid cands = self._retrieve_hard_negatives_faiss(qid, top_k=needed*5) # retrieve some overshoot # filter out qid itself, positives, or duplicates filter_set = set(pos_list) | {qid} cands = [d for d in cands if d not in filter_set] cands = cands[:needed] # now we have \u0026#39;needed\u0026#39; neg IDs # compute margins + skip if margin \u0026lt; remove_threshold if you like for doc_id in cands: margin_val = self._compute_margin_faiss(qid, doc_id, pos_id) if margin_val \u0026lt; self.remove_threshold: continue self.query_caches[qid].append((doc_id, margin_val)) def _retrieve_hard_negatives_faiss(self, qid, top_k=10): \u0026#34;\u0026#34;\u0026#34; Simple example: we do a search in FAISS for the \u0026#39;closest\u0026#39; docs to qid. Because qid is also a doc in doc_list, we reconstruct qid embedding, do index.search(...). \u0026#34;\u0026#34;\u0026#34; import numpy as np q_vec = self.faiss_index.reconstruct(self.doc_id_to_faiss_id[qid]) # If the index is IndexFlatIP, we might want to L2-normalize q_vec or do it if the doc index is normalized # We\u0026#39;ll do a naive approach with index.search(1, q_vec) # Actually index.search requires shape (n_query, dim), so we do q_vec as 1 row query_batch = np.expand_dims(q_vec, axis=0) distances, indices = self.faiss_index.search(query_batch, top_k + 50) # overshoot # indices shape: (1, top_k+50) # filter out qid if present retrieved_ids = [] for faiss_idx in indices[0]: doc_idx = self.faiss_to_doc_idx[faiss_idx] if doc_idx != qid and doc_idx not in self.query_idx_relations[qid][\u0026#39;positive\u0026#39;]: retrieved_ids.append(doc_idx) # done return retrieved_ids # --------------------------------------------------------------- # Build dataset from caches # --------------------------------------------------------------- def _build_train_dataset_from_caches(self): new_samples = [] for qid, rels in self.query_idx_relations.items(): pos_list = list(rels.get(\u0026#39;positive\u0026#39;, [])) if not pos_list: continue pos_id = pos_list[0] q_text = self.doc_list[qid] p_text = self.doc_list[pos_id] for (neg_id, hardness) in self.query_caches[qid]: neg_text = self.doc_list[neg_id] new_samples.append({ \u0026#34;query\u0026#34;: q_text, \u0026#34;positive\u0026#34;: p_text, \u0026#34;negative\u0026#34;: neg_text }) ds = Dataset.from_list(new_samples) self.trainer.train_dataset = ds self.trainer._train_dataset = ds self.trainer.train_dataloader = self.trainer.get_train_dataloader() def _is_main_process(self, state: TrainerState): is_main = True fn = getattr(self.trainer, \u0026#34;is_world_process_zero\u0026#34;, None) if callable(fn): is_main = self.trainer.is_world_process_zero() elif hasattr(state, \u0026#34;is_world_process_zero\u0026#34;): is_main = state.is_world_process_zero return is_main def build_new_training_lr_schedule(self, new_dataset): current_lr = self.trainer.optimizer.param_groups[0][\u0026#39;lr\u0026#39;] # Calculate new total steps based on new dataset size num_training_steps = ( len(new_dataset) * self.trainer.args.num_train_epochs // (self.trainer.args.per_device_train_batch_size * self.trainer.args.gradient_accumulation_steps) ) # Create new scheduler starting from current lr # update the scheduler to accompany the new size new_scheduler = get_cosine_schedule_with_warmup( self.trainer.optimizer, num_warmup_steps=0, # No warmup since we\u0026#39;re continuing num_training_steps=num_training_steps, initial_lr=current_lr ) self.trainer.lr_scheduler = new_scheduler Final Results ok, so a few things to note:\n \u0026lsquo;rand\u0026rsquo; is HardNegMiningCallback with k_hard=1 and 25 random samples with \u0026lsquo;start_epoch\u0026rsquo; set to 10 (meaning we dont start mining until we\u0026rsquo;re at the 10th epoch). we train on the bm25 dataset for 10 epochs, then use the last-epoch\u0026rsquo;s model to find the most difficult negative and create a single triplet along with 25 other triplets where the hardnegative is chosen randomly. I wanted to add some less-difficult noise to our system. It produced pretty solid results and was usually the second performing model. I think the ultimate take away is probably the bm25 for 10 epochs more than anything. \u0026lsquo;new3\u0026rsquo; is AdaptiveBanditFaissCallback with k_hard=3 and using the combo dataset of bm25 + rerank + no-rerank. the simple act of taking 3 models (bm25, no-rerank, rerank) and averaging their weights produces a very strong model; Im unclear of the mechaniics of why this is, but it\u0026rsquo;s pretty cool. for the task most similar to our training data (legal summarization), the cheapest option \u0026lsquo;bm25 + rerank + no-rerank\u0026rsquo;(eg, combine 3 datasets and have 3 examples per query/positive pair) produces the best model I didn\u0026rsquo;t get to explore the bandit models as much as I wanted to, but maybe I\u0026rsquo;ll write something a bit more after some of these other tasks finish training.  Admittedly this was pretty fun. We have a constrained dataset/problem that could honestly be solved with more data/compute (The Bitter Lesson, anyone?). But I wanted to pull various levers to measure effects and just sort of learn/find out what I can find out more about what could make this thing tick. And at some point you either do good science or you dont. So I think this is a good start.\nTODO list (that I admitted don\u0026rsquo;t have a):\n Experiment with different embedding dimensions Compare performance across model sizes (small, base, large) Vary the similarity thresholds in hard negative selection Test different k values for nearest neighbor retrieval Compare performance with and without embedding normalization Analyze performance with different batch sizes Study the impact of learning rates Investigate gradient accumulation steps Test different loss functions (InfoNCE, triplet loss variants) Combine BM25 and reranking with different weighting schemes Test ensemble methods of multiple rerankers Explore sparse-dense retrieval combinations  ","date":"2025-01-12","permalink":"/posts/hard_negative_mining/","title":"Why So Hard (Negative) On Your Self (Reinforcement)?"},{"content":"==================\nVerifiers and Legal Benchmarks Another hastily written blog\nVerifiers Part of the law is verifying truth. In textbook format this is as often converting reasoning into verifiable symbolic logic, eg P/Q truth tables: This is easy in mathematics, because there are steps to get to the next statement and each substep is verifiable in and of itself. But in argumentation, the steps themselves aren\u0026rsquo;t so clean-cut as the state progresses; eg \u0026lsquo;the argument\u0026rsquo; can only be verified at the end.\nSo what we\u0026rsquo;re attempting to do is take an llm and turn it\u0026rsquo;s logical arguments into a series of verifiable statements.\nThis is a bit in line with Deepseek\u0026rsquo;s lean prover, VerifierQ,Ilya\u0026rsquo;s old work. There are tons of paper\u0026rsquo;s on verifying math. So we\u0026rsquo;re not trying to do anything new here, we\u0026rsquo;re just trying to apply it to the legal domain.\nIt\u0026rsquo;s funny, the proof can be false logically (the p/q table will be invalid), but stand true on flawed law (which is it\u0026rsquo;s own can of worms). So how can we decompose tricks into validating legal reasoning?\nThere are some cool ways to do this: text -\u0026gt; break into logical structures -\u0026gt; verify logical structures -\u0026gt; convert to lean -\u0026gt; verify proof -\u0026gt; proceed. Most models, even the best ones, aren\u0026rsquo;t very good at one-shotting this, so you have to experiment a bit to get the right structures. But spending time here is a critical step in verifying reasoning/eval sets based on complex reasoning tasks\nA Legal Benchmark I have a very complex private legal benchmark that requires a few different structures to accurately validate, but the gist of this is that ESSAY/MULTIPLE_CHOICE require a verification structure, while QA/GENERAL/CASELAW can be validated w/ more simplistic regex-style validation. The dataset is split into a training, test, and validation set with an 80/10/10 split. The evals below are the results of the test set.\nI\u0026rsquo;ll release the verifier code eventually since it could be a pretty cool repo, but unfortunately, I won\u0026rsquo;t be releasing the benchmark publicly because I don\u0026rsquo;t want the internet eating shoggoths to steal this eval.\nwe used a heavily modified lm-evaluation-harness to perform these evals\nresults on legalbench for o1-preview ESSAY: given a fact pattern, answer questions about that fact pattern\nAccuracy: 70.00%; Correct: 28/40\nQA: more simplistic mono-focused essay-ish style questions that are a blend of internal recall and legal theory Accuracy: 79.48%;Correct: 550/692\nGENERAL: simple flashcard style questions Accuracy: 90.57%; Correct: 701/774\nCASELAW: case-law recall, asking what specific cases are Accuracy: 28.38%; Correct: 3138/11057\nMULTIPLE_CHOICE (non COT): a more detailed QA dataset with an option for COT answering Accuracy: 83.99%; Correct: 467/556\ninteresting to note that caselaw performance is extremely weak in the current o1-preview.\nTODO there\u0026rsquo;s been a bunch of halaballoo on entropix\u0026rsquo;s utility as a test-time sampling method, currently that benchmark is running for llama-1b vs. llama-1b entropix\n","date":"2024-12-06","permalink":"/posts/legalbench/","title":"LegalBench"},{"content":"==================\nBetter Call (Monte Carlo Tree) Saul Another \u0026lsquo;whatcha been hacking on\u0026rsquo; post\nInspired by the excellent work from the Monte Carlo Tree Search Boosts Reasoning via Iterative Preference Learning  - my research direction was to use the same approach but turn it into a simple game of argumentation: can we get a model to debate/argue like a lawyer? win arguments? defeat an opponent?\nLaw, unlike verifiable games like chess, is a game of persuasion. It has no verifiable outcome like mathematics. I can argue a correct thing - morally, ethically, whatever, and be ruled incorrect by a singular judge or jury. Obviously, when the tautological facts are on your side, you\u0026rsquo;ll win more than lose, but there\u0026rsquo;s a gamesmanship to the game. The goal is to convince the judge that your argument is correct (even if it\u0026rsquo;s not) - sort of like Noam Brown\u0026rsquo;s Cicero. It gets even more complicated when you argue to a jury, where the judge is not the only one deciding/directing the outcome, but in this over-simplified structure, we\u0026rsquo;ll just have a judge.\nThe first step is use an established framework to build a game of persuasion. Here, we\u0026rsquo;ll use Debatrix: LLM-based Multi-dimensional Debate Judge with Iterative Chronological Analysis, then we\u0026rsquo;ll use the Monte Carlo Tree Search Boosts Reasoning via Iterative Preference Learning framework to create prompts in order to RL train a basic-b Mistral-7B to be a better debater/laywer.\nIn the vein of Quiet-STaR: Language Models Can Teach Themselves to Think Before Speaking, Chain of Logic: Rule-Based Reasoning with Large Language Models , and other general LLM reasoning (Entropix, etc), we want to move to a \u0026lsquo;think before you speak\u0026rsquo; or more test-time compute (eg, quiet-star/o1/etc) but in the meantime, let\u0026rsquo;s see if we can build a dataset that will bake in some next-token prediction better-call-saul fundementals into this model in order to argue like a lawyer.\nI won\u0026rsquo;t/can\u0026rsquo;t include all the code (the changes to DPO-MCTS or Debatrix, the trained model, the pipeline, or really anything super useful tbh - secret-sauce-nda-etc-etc) - but sort of want to share the thought process of how I got there; maybe the thought process will be useful to someone else.\nPseudo-code here\u0026rsquo;s the pseudo/real-code we\u0026rsquo;ll use:\nfrom transformers import ( AutoModelForCausalLM, GenerationConfig, PreTrainedModel, PreTrainedTokenizerBase, get_scheduler, ) from mcts_rl_testing.algorithms.mcts.mcts import ( CourtRoomV2, LawStepLMConfig, SearchArgs, LawMCTS, MCTSNode, MCTSConfig, TreeConstructor, ) offline = False global_step = 1 generation_config = GenerationConfig( max_length=max_length, max_new_tokens=max_new_tokens, num_return_sequences=num_return_sequences, temperature=temperature, top_p=top_p, repetition_penalty=repetition_penalty, do_sample=True, bos_token_id=tokenizer.bos_token_id, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id, ) world_model = CourtRoomV2( max_length=generation_config.max_length, base_tokenizer=tokenizer, generation_config=generation_config, opponent_model=actor_model ) search_cfg = LawStepLMConfig(SearchArgs( ref_policy_model=actor_reference_model, base_tokenizer=tokenizer, generation_config=generation_config, n_actions=n_actions, n_init_actions=n_init_actions, breadth_limit=breadth_limit, depth_limit=depth_limit, force_terminating_on_depth_limit=force_terminating_on_depth_limit, kl_coeff=kl_coeff, disable_tqdm=False, no_self_eval=no_self_eval, reward_model=reward_model, reward_tokenizer=reward_tokenizer, use_code=use_code, use_mcq=use_mcq, eval_mode=eval_mode, temperature=temperature, init_temperature=init_temperature, get_tp_zero=get_tp_zero, model_type=model_type, include_gt=include_gt, )) mcts_algo = LawMCTS(MCTSConfig( w_exp=w_exp, depth_limit=depth_limit, breadth_limit=breadth_limit, n_iters=n_iters, temperature=mcts_temperature, temperature_decay_ratio=mcts_temperature_decay_ratio, consider_diversity=(not no_consider_diversity), length_penalty=mcts_length_penalty, )) mcts_searcher = TreeConstructor( world_model=world_model, search_config=search_cfg, search_algo=mcts_algo, ) # set all clients world_model.debatrix.session_id = \u0026#34;debatrix\u0026#34; world_model.set_debatrix_config() world_model.set_debatrix_server() world_model.set_debatrix_debateinfo() Debatrix judging we create a CourtRoomV2 model, which is a pseudo-wrapper around the Debatrix framework to judge the debate:\nthe main condition is judging the actions of the given players relative to each-other as well as the conditions that are set by the scoring criteria (eg, the judge\u0026rsquo;s rules)\ndef judge(self, message_state: List[dict]) -\u0026gt; bool: debate_info = DebateInfo( motion=\u0026#34;motion\u0026#34;, pro_side=DebaterInfo(name=\u0026#34;player1\u0026#34;), con_side=DebaterInfo(name=\u0026#34;player2\u0026#34;), info_slide=\u0026#34;info_slide\u0026#34;, speech_order=(\u0026#34;player1\u0026#34;, \u0026#34;player2\u0026#34;), ) # reset debatrix self.debatrix._arena.reset() self.debatrix._panel.reset() # create counter self.debatrix._debater_counter = { debater_info.name: 0 for debater_info in self.debatrix._debate_info.all_debaters_info } self.debatrix._main_counter = 0 # populate speech queue queue = Queue() for message in message_state: self.debatrix._main_counter += 1 self.debatrix._debater_counter[message[\u0026#39;role\u0026#39;]] += 1 turn = Turn( index=self.debatrix._main_counter, index_by_debater=self.debatrix._debater_counter[message[\u0026#39;role\u0026#39;]], debater_name=message[\u0026#39;role\u0026#39;], content=message[\u0026#39;content\u0026#39;], ) queue.put(turn) self.debatrix._arena.update(turn=turn) queue.put(None) turns: list[Turn] = [] while True: turn: Turn | None = queue.get() if turn is None: break self.debatrix._panel.update(turn=turn) turns.append(turn) dimensional_verdicts = self.debatrix._panel.dimensional_judge() final_verdict: Verdict | None = None final_verdict = self.panel.summarize(verdicts=dimensional_verdicts) judge_result = DebateResult( speeches=tuple(speeches), dimensional_verdicts=dimensional_verdicts, final_verdict=final_verdict, ) # player1 is the the model, plae winner = judge_result.final_verdict return winner == \u0026#39;player1\u0026#39; # is this ok? Tree Construction val_inputs = tokenizer( prompt_only_batch[\u0026#39;player1\u0026#39;], add_special_tokens=True, return_tensors=\u0026#39;pt\u0026#39;, ) mcts_searcher.search_config.generation_config.max_length = eval_inputs[\u0026#39;input_ids\u0026#39;].size(-1) + max_length mcts_searcher.world_model.generation_config.max_length = eval_inputs[\u0026#39;input_ids\u0026#39;].size(-1) + max_length \u0026#34;\u0026#34;\u0026#34;Rollout a batch of experiences.\u0026#34;\u0026#34;\u0026#34; input_ids = prompt_only_batch[\u0026#39;input_ids\u0026#39;] attention_mask = prompt_only_batch[\u0026#39;attention_mask\u0026#39;] opponent_prompt = prompt_only_batch[\u0026#39;player2\u0026#39;] # could call it whater assert input_ids.size(0) == 1, \u0026#39;\u0026#39;\u0026#39;Only support one instance per device.\u0026#39;\u0026#39;\u0026#39; seq, attn_msk = input_ids[0], attention_mask[0] global_step=1 mcts_searcher.search_algo.policy_model = actor_reference_model target_probs, Q_values, r_values, base_values, visit_counts, select_indexes = [], [], [], [], [], [] cur_node = None # initialize debatrix world_model.debatrix.initialize(world_model.debate_info, world_model.server_info) # Initialize counters world_model.debatrix._debater_counter = {debater_info.name: 0 for debater_info in world_model.debate_info.all_debaters_info} world_model.debatrix._main_counter = 0 world_model.device = torch.device(\u0026#39;cuda\u0026#39;) # start one iter for sampling mcts_rst = mcts_searcher( { \u0026#39;input_ids\u0026#39;: seq, \u0026#39;attention_mask\u0026#39;: attn_msk, \u0026#34;opponent_prompt\u0026#34;: opponent_prompt, \u0026#34;base_prompt\u0026#34;: prompt_only_batch[\u0026#34;player1\u0026#34;], \u0026#34;prompt\u0026#34;: prompt_only_batch[\u0026#34;player1\u0026#34;][0] }, node=cur_node ) results=[] # explore the post-tree contstructor pi, cur_node = mcts_rst.next_action_pi, mcts_rst.tree_state target_probs.append(pi) Q_values.append([child.Q for child in cur_node.children]) r_values.append([child.r for child in cur_node.children]) base_values.append([child.value for child in cur_node.children]) visit_counts.append([child.N for child in cur_node.children]) # results.append([child.result for child in cur_node.children]) cur_node = cur_node.children[mcts_rst.next_action_idx] select_indexes.append(mcts_rst.next_action_idx) prompt = eval_inputs[\u0026#39;input_ids\u0026#39;][0] Post-Tree Construction from collections import deque message_histories, candidates, init_value_list, step_id = [], [], [], 0 while cur_node.children: depth = 0 next_message_histories = [] for child in cur_node.children: queue = deque([(child, depth)]) while queue: node, depth = queue.popleft() # Get the front item if node.children: for grandchild in node.children: queue.append((grandchild, depth + 1)) # Add children to the queue with incremented depth next_message_histories.append((node, depth)) # filter next message histories to get the max depth max_depth = max(next_message_histories, key=lambda x: x[1])[1] next_message_histories = [x[0].message_history for x in next_message_histories if x[1] == max_depth] scores = [ (q, s, r, bv, vc) for s, q, r, bv, vc in zip( target_probs[step_id], Q_values[step_id], r_values[step_id], base_values[step_id], visit_counts[step_id], ) ] # Sort candidates based on scores _candidates = [ [next_message_histories[idx], scores[idx]] for idx in range(len(next_message_histories)) ] _candidates.sort(key=lambda x: x[1]) init_values = [score[-1] for _, score in _candidates] candidate_histories = [messages for messages, _ in _candidates] # Store the data candidates.append(candidate_histories) init_value_list.append(init_values) # Move to the selected child cur_node = cur_node.children[select_indexes[step_id]] step_id += 1 mini_batches = { \u0026#39;prompts_list\u0026#39;: [], \u0026#39;input_ids_list\u0026#39;: [], \u0026#39;attention_mask_list\u0026#39;: [], \u0026#39;init_value_list\u0026#39;: [], } for base_history, candidate_histories, init_values in zip( message_histories, candidates, init_value_list ): input_texts = [] for candidate_history in candidate_histories: # Serialize the full message history full_message_history = candidate_history input_text = serialize_message_history(full_message_history) input_texts.append(input_text) # Tokenize the input texts encodings = tokenizer( input_texts, padding=True, truncation=True, return_tensors=\u0026#39;pt\u0026#39;, max_length=generation_config.max_length, ) input_ids = encodings[\u0026#39;input_ids\u0026#39;] attention_mask = encodings[\u0026#39;attention_mask\u0026#39;] mini_batches[\u0026#39;prompts_list\u0026#39;].append(None) # Optional: store base histories if needed mini_batches[\u0026#39;input_ids_list\u0026#39;].append(input_ids) mini_batches[\u0026#39;attention_mask_list\u0026#39;].append(attention_mask) mini_batches[\u0026#39;init_value_list\u0026#39;].append(init_values) # Compute final prediction and correctness if applicable # For now, we can set the prediction to None or compute based on your game logic mini_batches[\u0026#39;prediction\u0026#39;] = None mini_batches[\u0026#39;cur_max_new_tokens\u0026#39;] = cur_max_new_tokens Training tree-contructor for RL batch # tree-constructor for RL batch total_batch_size = prompt_only_batch[\u0026#39;input_ids\u0026#39;].size(0) if not self.args.post \\ else len(prompt_only_batch[\u0026#39;prompts_list\u0026#39;]) micro_batch_size = self.args.per_device_train_batch_size micro_batches = [] assert total_batch_size == micro_batch_size micro_batches.extend(self.tree_constructor(prompt_only_batch)) calling the constructor rl_batches = self.tree_constructor(prompt_only_batch) the trainging step from dpo-mcts (dpo_mcts_trainer.py) def tsrl_step( self, prompts_list: list[torch.Tensor], input_ids_list: list[torch.Tensor], attention_mask_list: list[torch.Tensor], prediction: tuple = None, init_value_list: list[float] = None, max_n_sample: int = 8, cur_max_new_tokens: int = 32, ) -\u0026gt; dict[str, Any]: losses, better_sample_rewards, worse_sample_rewards, max_lengths = [], [], [], [] n_sample = len(input_ids_list) start = prompts_list[0].size(-1) - 1 better_idx = -1 worse_idx = 0 if self.args.choose_worst else -2 all_better_input_ids, all_worse_input_ids = [], [] all_better_attention_mask, all_worse_attention_mask = [], [] all_init_value_list = [] for sample_id in range(n_sample): if len(all_better_input_ids) \u0026gt;= max_n_sample: break input_ids = input_ids_list[sample_id] attention_mask = attention_mask_list[sample_id] n_output = input_ids.size(0) if n_output \u0026lt; 2: continue if self.args.choose_random: worse_idx = random.choice(range(n_output - 1)) all_better_input_ids.append(input_ids[better_idx]) all_worse_input_ids.append(input_ids[worse_idx]) all_better_attention_mask.append(attention_mask[better_idx]) all_worse_attention_mask.append(attention_mask[worse_idx]) all_init_value_list.extend([init_value_list[sample_id][better_idx], init_value_list[sample_id][worse_idx]]) all_input_ids = pad_tensors(all_better_input_ids + all_worse_input_ids, pad_value=self.tokenizer.pad_token_id) all_attention_mask = pad_tensors(all_better_attention_mask + all_worse_attention_mask, pad_value=False) torch.cuda.empty_cache() all_sequence_log_probs = self.compute_log_probs( self.actor_model.module, input_ids=all_input_ids, attention_mask=all_attention_mask, ) all_better_input_ids, all_worse_input_ids = all_input_ids.chunk(chunks=2, dim=0) all_better_attention_mask, all_worse_attention_mask = all_attention_mask.chunk(chunks=2, dim=0) all_better_sequence_log_probs, all_worse_sequence_log_probs = all_sequence_log_probs.chunk(chunks=2, dim=0) label_smoothing_values = [] for sample_id in range(len(all_better_input_ids)): better_input_ids = all_better_input_ids[sample_id] better_attention_mask = all_better_attention_mask[sample_id] worse_input_ids = all_worse_input_ids[sample_id] worse_attention_mask = all_worse_attention_mask[sample_id] init_values = [all_init_value_list[sample_id * 2], all_init_value_list[sample_id * 2 + 1]] better_sequence_log_probs, worse_sequence_log_probs = all_better_sequence_log_probs[sample_id], all_worse_sequence_log_probs[sample_id] with torch.no_grad(): torch.cuda.empty_cache() ref_better_sequence_log_probs = self.compute_log_probs( self.actor_reference_model.module, input_ids=better_input_ids.unsqueeze(0), attention_mask=better_attention_mask.unsqueeze(0), )[0] torch.cuda.empty_cache() ref_worse_sequence_log_probs = self.compute_log_probs( self.actor_reference_model.module, input_ids=worse_input_ids.unsqueeze(0), attention_mask=worse_attention_mask.unsqueeze(0), )[0] better_end_index = better_attention_mask.nonzero()[-1] worse_end_index = worse_attention_mask.nonzero()[-1] try: diverge_index = (better_input_ids != worse_input_ids).nonzero()[0] assert 0 \u0026lt;= diverge_index \u0026lt;= better_end_index, \u0026#39;diverge index is out of range!\u0026#39; assert 0 \u0026lt;= diverge_index \u0026lt;= worse_end_index, \u0026#39;diverge index is out of range!\u0026#39; except: continue better_seq_slice = slice(diverge_index - 1, better_end_index) worse_seq_slice = slice(diverge_index - 1, worse_end_index) better_log_probs = better_sequence_log_probs[better_seq_slice].sum(dim=-1) worse_log_probs = worse_sequence_log_probs[worse_seq_slice].sum(dim=-1) ref_better_log_probs = ref_better_sequence_log_probs[better_seq_slice].sum(dim=-1) ref_worse_log_probs = ref_worse_sequence_log_probs[worse_seq_slice].sum(dim=-1) better_log_ratio = better_log_probs - ref_better_log_probs worse_log_ratio = worse_log_probs - ref_worse_log_probs if self.args.norm_prob or self.args.ipo: better_log_ratio /= better_attention_mask[better_seq_slice].sum(dim=-1) ** self.args.length_penalty worse_log_ratio /= worse_attention_mask[worse_seq_slice].sum(dim=-1) ** self.args.length_penalty logits = better_log_ratio - worse_log_ratio if self.args.ipo: losses.append((logits - 1 / (2 * self.scale_coeff)) ** 2) elif self.args.conservative: qb, qw = init_values confidence = calculate_preference_confidence(qb, qw) label_smoothing = min(1 - confidence, 0.5) losses.append( - F.logsigmoid(self.scale_coeff * logits) * (1 - label_smoothing) - F.logsigmoid(-self.scale_coeff * logits) * label_smoothing ) label_smoothing_values.append(label_smoothing) else: losses.append(-F.logsigmoid(self.scale_coeff * logits)) better_sample_rewards.append(self.scale_coeff * better_log_ratio.detach()) worse_sample_rewards.append(self.scale_coeff * worse_log_ratio.detach()) max_lengths.append(better_attention_mask[start:].float().sum()) max_lengths.append(worse_attention_mask[start:].float().sum()) if not len(losses): return {} loss = torch.stack(losses).mean() max_generated_length = torch.stack(max_lengths).max() total_max_generated_length = max_generated_length + start better_sample_rewards = torch.stack(better_sample_rewards) # size = (B,) worse_sample_rewards = torch.stack(worse_sample_rewards) # size = (B,) rewards_accuracy = ( (better_sample_rewards \u0026gt; worse_sample_rewards).float().mean() ) # size = () better_sample_rewards = better_sample_rewards.mean() # size = () worse_sample_rewards = worse_sample_rewards.mean() # size = () rewards = better_sample_rewards + worse_sample_rewards # size = () rewards_margin = better_sample_rewards - worse_sample_rewards # size = () torch.cuda.empty_cache() self.actor_model.backward(loss) self.actor_model.step() loss = get_all_reduce_mean(loss) rewards = get_all_reduce_mean(rewards) better_sample_rewards = get_all_reduce_mean(better_sample_rewards) worse_sample_rewards = get_all_reduce_mean(worse_sample_rewards) rewards_accuracy = get_all_reduce_mean(rewards_accuracy) rewards_margin = get_all_reduce_mean(rewards_margin) max_generated_length = get_all_reduce_max(max_generated_length) total_max_generated_length = get_all_reduce_max(total_max_generated_length) return { \u0026#39;train/loss\u0026#39;: loss.item(), \u0026#39;train/rewards\u0026#39;: rewards.item(), \u0026#39;train/better_sample_rewards\u0026#39;: better_sample_rewards.item(), \u0026#39;train/worse_sample_rewards\u0026#39;: worse_sample_rewards.item(), \u0026#39;train/rewards_accuracy\u0026#39;: rewards_accuracy.item(), \u0026#39;train/rewards_margin\u0026#39;: rewards_margin.item(), \u0026#39;train/lr\u0026#39;: self.actor_model.optimizer.param_groups[0][\u0026#39;lr\u0026#39;], \u0026#39;train/r_scores\u0026#39;: float(prediction[0]), \u0026#39;train/correct\u0026#39;: float(prediction[1]), \u0026#39;train/n_sample\u0026#39;: n_sample, \u0026#39;train/max_generated_length\u0026#39;: max_generated_length.item(), \u0026#39;train/total_max_generated_length\u0026#39;: total_max_generated_length.item(), \u0026#39;train/label_smoothing\u0026#39;: sum(label_smoothing_values) / len(label_smoothing_values) if len(label_smoothing_values) else 0, \u0026#39;train/cur_max_new_tokens\u0026#39;: cur_max_new_tokens, } sample data using a base prompt at the end of the mcts_searcher function we end up with data that looks sort of like this:\na system prompt, and multi-turn user prompts where we change the \u0026lsquo;role\u0026rsquo; from \u0026lsquo;player1\u0026rsquo; to \u0026lsquo;user\u0026rsquo; and \u0026lsquo;player2\u0026rsquo; to \u0026lsquo;assistant\u0026rsquo;\n[{\u0026#39;role\u0026#39;: \u0026#39;system\u0026#39;, \u0026#39;content\u0026#39;: \u0026#39;You are an e...\u0026#39;}, {\u0026#39;role\u0026#39;: \u0026#39;player2\u0026#39;, \u0026#39;content\u0026#39;: \u0026#39;\\n**Opening Argument:**\\n\\nLadies and Gentlemen of the Court,\\n\\nToday,...\u0026#39;}, {\u0026#39;role\u0026#39;: \u0026#39;player1\u0026#39;, \u0026#39;content\u0026#39;: \u0026#39;counsel may argue that the \u0026#34;fundamental miscarriage of justice\u0026#34; e....\u0026#39;}, {\u0026#39;role\u0026#39;: \u0026#39;player2\u0026#39;, \u0026#39;content\u0026#39;: \u0026#39;**Closing Argument:**\\n\\nLadies and Gentlemen of the Court,\\n\\n...\u0026#39;}, {\u0026#39;role\u0026#39;: \u0026#39;player1\u0026#39;, \u0026#39;content\u0026#39;: \u0026#39;\u0026#39;**Closing Argument:**\\n\\nLadies and Gentlemen of the ...\u0026#39;}] update data w/ necessary changes for training [{\u0026#39;role\u0026#39;: \u0026#39;system\u0026#39;, \u0026#39;content\u0026#39;: \u0026#39;You are an experienced prosecutor participating in a legal debate against an opposing attorney. You will be provided with case details, which include:\\n\\n- Title: McQuiggin v. Perkins \\nCourt: United States Supreme Court \\nCase Identifier: 133 S. Ct. 1924 (2013) \\nRule of Law: Habeas corpus petitions alleging actual innocence may be exempt from the one-year statute of limitations.\\n\\nFacts: \\nPerkins (defendant) attended a party with Henderson and Jones. Henderson was later found murdered. Perkins was convicted of the murder and sentenced to life without parole on October 27, 1993. Appeals were denied, finalizing the conviction on May 5, 1997. In July 2002, Perkins obtained affidavits supporting his innocence. He filed a habeas corpus petition on June 13, 2008, which the trial court deemed time-barred. The Sixth Circuit reversed, stating actual innocence claims are exempt from the one-year limit. The state petitioned for Supreme Court review.\\n\\nIssue: \\nCan habeas corpus petitions alleging actual innocence be exempt from the one-year statute of limitations?\\n\\nHolding and Reasoning by Ginsburg, J.:\\nYes. The AEDPA\\\u0026#39;s one-year statute of limitations can be equitably tolled in exceptional cases where actual innocence is convincingly shown. The AEDPA does not clearly override traditional equitable authority. The \u0026#34;fundamental miscarriage of justice\u0026#34; exception applies, allowing habeas petitions to proceed even after the limit. The trial court should assess if Perkins meets this demanding standard and consider delays in presenting the claim.\\n\\nDissent by Scalia, J.:\\nThe fundamental miscarriage of justice exception should be limited to judicially-created procedural bars. The AEDPA provides a strict statutory bar, and courts lack the authority to create exceptions.\\n\\n**Your Role:**\\n\\n- Given the issue at hand, argue in favor of ruling.\\n\\n**Instructions:**\\n\\n1. **Opening Argument**: Present a clear and persuasive opening statement supporting the prosecution\\\u0026#39;s position.\\n2. **Legal Analysis**: Use relevant facts, statutes, and precedents to strengthen your argument.\\n3. **Anticipate Counterarguments**: Identify potential arguments the opposing attorney might raise and preemptively address them.\\n4. **Rebuttals**: After the opposing attorney presents their arguments, provide counterpoints to refute their claims.\\n5. **Closing Argument**: Summarize the key points that support the prosecution\\\u0026#39;s case and explain why the court should rule in your favor.\\n6. **Professionalism**: Maintain formal language and adhere to legal ethics throughout the debate.\\n\\n**Guidelines:**\\n\\n- Structure your arguments logically with clear headings.\\n- Focus on legal reasoning and evidence.\\n- Do not introduce facts outside the provided case details.\\n- Aim to outmaneuver the opposing attorney by highlighting weaknesses in their arguments.\\n\\n**Note:** Use the specific case information provided above when preparing your arguments.\u0026#39;}, {\u0026#39;role\u0026#39;: \u0026#39;assistant\u0026#39;, \u0026#39;content\u0026#39;: \u0026#39;\\n**Opening Argument:**\\n\\nLadies and Gentlemen of the Court,\\n\\nToday, we stand before you to argue against the notion that habeas corpus petitions alleging actual innocence can be exempt from the one-year statute of limitations. The Antiterrorism and Effective Death Penalty Act (AEDPA) of 1996 clearly establishes a strict one-year statute of limitations for filing habeas corpus petitions. This statute is designed to ensure that claims are brought forth in a timely manner, preventing stale evidence and allowing for the efficient administration of justice.\\n\\nIn the case of McQuiggin v. Perkins, the defendant, Perkins, was convicted of murder and sentenced to life without parole in 1993. His appeals were denied in 1997. It is only in 2002 that he obtained affidavits supporting his innocence and filed a habeas corpus petition in 2008, more than a decade after his conviction. This significant delay raises concerns about the reliability of the evidence and the potential for other exculpatory evidence to have been lost or destroyed.\\n\\n**Legal Analysis:**\\n\\nThe AEDPA\\\u0026#39;s one-year statute of limitations is a clear and unambiguous provision that does not allow for equitable tolling in the absence of extraordinary circumstances. The AEDPA was enacted to address concerns about the abuse of the habeas corpus process, which had become a tool for delaying executions and undermining the finality of criminal convictions. The one-year statute of limitations is a key component of the AEDPA, designed to prevent such abuses.\\n\\nThe majority opinion in McQuiggin v. Perkins relies on the \u0026#34;fundamental miscarriage of justice\u0026#34; exception to the one-year statute of limitations. However, this exception is not applicable in this case. The exception is intended to address situations where a constitutional violation has resulted in a fundamental miscarriage of justice, and the petitioner can demonstrate that the delay in filing the petition was due to extraordinary circumstances beyond their control. In this case, Perkins\\\u0026#39; delay in filing his petition was not due to extraordinary circumstances but rather his own inaction.\\n\\n**Anticipate Counterarguments:**\\n\\nThe\u0026#39;}, {\u0026#39;role\u0026#39;: \u0026#39;user\u0026#39;, \u0026#39;content\u0026#39;: \u0026#39;counsel may argue that the \u0026#34;fundamental miscarriage of justice\u0026#34; exception should be expanded to include cases of actual innocence. However, such an expansion would undermine the purpose of the AEDPA and create uncertainty in the application of the statute of limitations. The AEDPA was designed to provide finality to criminal convictions, and expanding the exception to include actual innocence would make it more difficult for the government to enforce criminal convictions and protect the public.\\n**Rebuttals:**\\nIf the opposing counsel argues that the delay in filing the petition was due to Perkins\\\u0026#39; inability to obtain the necessary evidence, we would counter that Perkins had ample time to pursue this evidence after his conviction. The fact that he did not do so until more than a decade later demonstrates a lack of diligence on his part. Furthermore, the delay in filing the petition raises concerns about the reliability of the evidence and the potential for other exculpatory evidence to have been lost or destroyed.\\n**Closing Argument:**\\nIn conclusion, the AEDPA\\\u0026#39;s one-year statute of limitations is a clear and unambiguous provision that does not allow for equitable tolling in the absence of extraordinary circumstances. The \u0026#34;fundamental miscarriage of justice\u0026#34; exception is not applicable in this case, as Perkins\\\u0026#39; delay in filing his petition was not due to extraordinary circumstances but rather his own inaction. Expanding the exception to include actual innocence would undermine the purpose of the AEDPA and create uncertainty in the application of the statute of limitations. We respectfully request that the Court uphold the AEDPA\\\u0026#39;s one-year statute of limitations and deny Perkins\\\u0026#39; habeas corpus petition.\\n**Professionalism:**\\n\u0026#39;}, ... ] There\u0026rsquo;s more to this, but I\u0026rsquo;ll leave it at that for now.\n","date":"2024-11-15","permalink":"/posts/the_giving_tree/","title":"The Giving (Search) Tree"},{"content":"==================\nPreface I\u0026rsquo;m going to start blogging a bit more on day-to-day random issues that aren\u0026rsquo;t NDA\u0026rsquo;d topics, sort of the slog/grind of the day to day in the life of a scientist in the AI space.\nFloat16 explosions in inference and not training? stream of consciousness: wut? why? this inference code works in float32 for 100% of the inputs. however, in float16 no problems for like 99% of inputs. but for 1% of inputs, it explodes/returns nulls?\nvery weird.\nlet\u0026rsquo;s rewrite our forward forward function of this embedder to find out what in the world is happening.\nRewriting our forward function to find out what in the friday night lights is happening we\u0026rsquo;ll print/log each step and save to a np/pickle file the weights before and after given layers/activations.\ndef forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, output_all_encoded_layers: Optional[bool] = True, subset_mask: Optional[torch.Tensor] = None, position_encodings: Optional[torch.Tensor] = None, ) -\u0026gt; List[torch.Tensor]: from einops import rearrange from infinity_emb.transformer.monarch.mm.hyena_utils import fftconv_ref import math import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange import opt_einsum as oe contract = oe.contract def check_nan(tensor, location): if torch.isnan(tensor).any(): nan_count = torch.isnan(tensor).sum().item() print(f\u0026#34;NaN detected in {location}: {nan_count} NaN values\u0026#34;) return True return False extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) extended_attention_mask = extended_attention_mask.to( dtype=next(self.parameters()).dtype) # fp16 compatibility extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 attention_mask_bool = attention_mask.bool() batch, seqlen = hidden_states.shape[:2] # Unpad inputs and mask. It will remove tokens that are padded. # Assume ntokens is total number of tokens (padded and non-padded) # and ntokens_unpad is total number of non-padded tokens. # Then unpadding performs the following compression of the inputs: # hidden_states[ntokens,hidden] -\u0026gt; hidden_states[ntokens_unpad,hidden] if not self.monarch_mixer_sequence_mixing: hidden_states, indices, cu_seqlens, _ = bert_padding_module.unpad_input( hidden_states, attention_mask_bool) else: cu_seqlens = None indices = None # Add alibi matrix to extended_attention_mask if not self.monarch_mixer_sequence_mixing: if self._current_alibi_size \u0026lt; seqlen: # Rebuild the alibi tensor when needed warnings.warn( f\u0026#39;Increasing alibi size from {self._current_alibi_size} to {seqlen}\u0026#39; ) self.rebuild_alibi_tensor(size=seqlen, device=hidden_states.device) elif self.alibi.device != hidden_states.device: # Device catch-up self.alibi = self.alibi.to(hidden_states.device) alibi_bias = self.alibi[:, :, :seqlen, :seqlen] attn_bias = extended_attention_mask[:, :, :seqlen, :seqlen] alibi_attn_mask = attn_bias + alibi_bias else: alibi_attn_mask = None all_encoder_layers = [] if self.monarch_mixer_sequence_mixing: for layer_idx, layer_module in enumerate(self.layer): # u is B L H print(f\u0026#34;starting: Layer {layer_idx} {layer_module}\u0026#34;) check_nan(hidden_states, f\u0026#34;Layer {layer_idx} input\u0026#34;) u = hidden_states if layer_module.attention.hyena_training_additions: u = layer_module.attention.layernorm(u) check_nan(u, f\u0026#34;Layer {layer_idx} u = layer_module.attention.layernorm(u)\u0026#34;) L = u.size(-2) u_orig = u u = layer_module.attention.in_linear(u) check_nan(u, f\u0026#34;Layer {layer_idx} u = self.in_linear(u)\u0026#34;) u = rearrange(u, \u0026#34;b l d -\u0026gt; b d l\u0026#34;) check_nan(u, f\u0026#34;Layer {layer_idx} u = rearrange(u, \u0026#39;b l d -\u0026gt; b d l\u0026#39;)\u0026#34;) # short filter uc = layer_module.attention.short_filter(u)[..., :L] check_nan(uc, f\u0026#34;Layer {layer_idx} layer_module.attention.short_filter(u)[..., :L]\u0026#34;) x1, x2, v = uc.split(layer_module.attention.d_model, dim=1) check_nan(x1, f\u0026#34;Layer {layer_idx} x1\u0026#34;) check_nan(x2, f\u0026#34;Layer {layer_idx} x2\u0026#34;) check_nan(v, f\u0026#34;Layer {layer_idx} v\u0026#34;) v = v * x1 check_nan(v, f\u0026#34;Layer {layer_idx} v = v * x1\u0026#34;) if layer_module.attention.hyena_training_additions: v = layer_module.attention.drop(v) check_nan(v, f\u0026#34;Layer {layer_idx} v = layer_module.attention.drop(v)\u0026#34;) k = layer_module.attention.filter_fn.filter(L, device=u.device) check_nan(k, f\u0026#34;Layer {layer_idx} k = layer_module.attention.filter_fn.filter(L, device=u.device)\u0026#34;) k = rearrange(k, \u0026#34;c l d -\u0026gt; c d l\u0026#34;)[0] # `c` is always 1 by default check_nan(k, f\u0026#34;Layer {layer_idx} k = rearrange(k, \u0026#39;c l d -\u0026gt; c d l\u0026#39;)[0] # `c` is always 1 by default\u0026#34;) if layer_module.attention.bidirectional: k_rev = layer_module.attention.filter_fn.filter_rev(L, device=u.device) check_nan(k_rev, f\u0026#34;Layer {layer_idx} k_rev = layer_module.attention.filter_fn.filter_rev(L, device=u.device)\u0026#34;) k_rev = rearrange(k_rev, \u0026#34;c l d -\u0026gt; c d l\u0026#34;)[0] # `c` is always 1 by default check_nan(k_rev, f\u0026#34;Layer {layer_idx} k = rearrange(k, \u0026#39;c l d -\u0026gt; c d l\u0026#39;)[0] # `c` is always 1 by default\u0026#34;) else: k_rev = None y = layer_module.attention.filter_fn(v, L, k_fwd=k, k_rev=k_rev, bias= layer_module.attention.filter_fn.bias[None, :, None]) check_nan(y, f\u0026#34;Layer {layer_idx} y = layer_module.attention.filter_fn(v, L, k_fwd=k, k_rev=k_rev, bias= layer_module.attention.filter_fn.bias[None, :, None])\u0026#34;) if layer_module.attention.residual_long_conv: k2 = layer_module.attention.filter_fn2.filter(L, device=u.device) check_nan(k2, f\u0026#34;Layer {layer_idx} k2 = layer_module.attention.filter_fn2.filter(L, device=u.device)\u0026#34;) k2 = rearrange(k2, \u0026#34;c l d -\u0026gt; c d l\u0026#34;)[0] check_nan(k2, f\u0026#34;Layer {layer_idx} rearrange(k2, \u0026#39;c l d -\u0026gt; c d l\u0026#39;)[0]\u0026#34;) if layer_module.attention.bidirectional: k2_rev = layer_module.attention.filter_fn2.filter_rev(L, device=u.device) check_nan(k2_rev, f\u0026#34;Layer {layer_idx} k2_rev = layer_module.attention.filter_fn2.filter_rev(L, device=u.device)\u0026#34;) k2_rev = rearrange(k2_rev, \u0026#34;c l d -\u0026gt; c d l\u0026#34;)[0] # `c` is always 1 by default check_nan(k2_rev, f\u0026#34;Layer {layer_idx} rearrange(k2_rev, \u0026#39;c l d -\u0026gt; c d l\u0026#39;)[0]\u0026#34;) else: k2_rev = None yu = layer_module.attention.filter_fn2(u_orig.transpose(-1, -2), L, k_fwd=k2, k_rev=k2_rev, bias= layer_module.attention.filter_fn2.bias[None, :, None]) check_nan(yu, f\u0026#34;Layer {layer_idx} yu = layer_module.attention.filter_fn2(u_orig.transpose(-1, -2), L, k_fwd=k2, k_rev=k2_rev, bias= layer_module.attention.filter_fn2.bias[None, :, None])\u0026#34;) # post gating y = y * x2 check_nan(y, f\u0026#34;Layer {layer_idx} y = y * x2\u0026#34;) if layer_module.attention.residual_long_conv: y = y + yu check_nan(y, f\u0026#34;Layer {layer_idx} y = y + yu\u0026#34;) y = y.transpose(-1, -2) check_nan(y, f\u0026#34;Layer {layer_idx} y = y.transpose(-1, -2)\u0026#34;) if layer_module.attention.hyena_training_additions: y = layer_module.attention.drop(layer_module.attention.act(y)) check_nan(y, f\u0026#34;Layer {layer_idx} y = layer_module.attention.drop(layer_module.attention.act(y))\u0026#34;) # for debugging # print(\u0026#34;y.shape:\u0026#34;, y.shape) # print(\u0026#34;layer_module.attention.out_linear.weight.shape:\u0026#34;, layer_module.attention.out_linear.weight.shape) # print(\u0026#34;layer_module.attention.out_linear.bias.shape:\u0026#34;, layer_module.attention.out_linear.bias.shape) y = layer_module.attention.out_linear(y) check_nan(y, f\u0026#34;Layer {layer_idx} y = layer_module.out_linear(y)\u0026#34;) #hidden_states = layer_module.mlp(y) #check_nan(y, f\u0026#34;Layer {layer_idx} hidden_states = layer_module.mlp(y)\u0026#34;) hidden_states = y residual_connection = hidden_states # compute the activation hidden_states = layer_module.mlp.gated_layers(hidden_states) check_nan(hidden_states, f\u0026#34;Layer {layer_idx} hidden_states = layer_module.mlp.gated_layers(hidden_states)\u0026#34;) print(f\u0026#34;Layer {layer_idx} after gated_layers stats:\u0026#34;) print(f\u0026#34;Mean: {hidden_states.mean().item()}\u0026#34;) print(f\u0026#34;Max abs: {hidden_states.abs().max().item()}\u0026#34;) print(f\u0026#34;% \u0026gt; 5: {(hidden_states.abs() \u0026gt; 5).float().mean().item() * 100}%\u0026#34;) if layer_module.mlp.is_padded: gated = hidden_states[:, :, :layer_module.mlp.config.intermediate_size] check_nan(gated, f\u0026#34;Layer {layer_idx} gated = hidden_states[:, :, :layer_module.mlp.config.intermediate_size]\u0026#34;) non_gated = hidden_states[:, :, layer_module.mlp.config.intermediate_size:] check_nan(non_gated, f\u0026#34;Layer {layer_idx} non_gated = hidden_states[:, :, layer_module.mlp.config.intermediate_size:]\u0026#34;) else: gated = hidden_states[:, :layer_module.mlp.config.intermediate_size] check_nan(gated, f\u0026#34;Layer {layer_idx} gated = hidden_states[:, :layer_module.mlp.config.intermediate_size]\u0026#34;) non_gated = hidden_states[:, layer_module.mlp.config.intermediate_size:] check_nan(non_gated, f\u0026#34;Layer {layer_idx} non_gated = hidden_states[:, layer_module.mlp.config.intermediate_size:]\u0026#34;) hidden_states = layer_module.mlp.act(gated) * non_gated hidden_states = torch.clamp(hidden_states, min=-10000.0, max=10000.0) check_nan(hidden_states, f\u0026#34;Layer {layer_idx} hidden_states = layer_module.mlp.act(gated) * non_gated\u0026#34;) print(f\u0026#34;Layer {layer_idx} after activation and gating stats:\u0026#34;) print(f\u0026#34;Mean: {hidden_states.mean().item()}\u0026#34;) print(f\u0026#34;Max abs: {hidden_states.abs().max().item()}\u0026#34;) print(f\u0026#34;% \u0026gt; 5: {(hidden_states.abs() \u0026gt; 5).float().mean().item() * 100}%\u0026#34;) hidden_states = layer_module.mlp.dropout(hidden_states) check_nan(hidden_states, f\u0026#34;Layer {layer_idx} hidden_states = layer_module.mlp.dropout(hidden_states)\u0026#34;) # multiply by the second matrix hidden_states = layer_module.mlp.wo(hidden_states) check_nan(hidden_states, f\u0026#34;Layer {layer_idx} hidden_states = layer_module.mlp.wo(hidden_states)\u0026#34;) print(f\u0026#34;Layer {layer_idx} after wo stats:\u0026#34;) print(f\u0026#34;Mean: {hidden_states.mean().item()}\u0026#34;) print(f\u0026#34;Max abs: {hidden_states.abs().max().item()}\u0026#34;) print(f\u0026#34;% \u0026gt; 5: {(hidden_states.abs() \u0026gt; 5).float().mean().item() * 100}%\u0026#34;) # add the residual connection and post-LN combine = hidden_states + residual_connection print(f\u0026#34;Layer {layer_idx} combine stats:\u0026#34;) print(f\u0026#34;Mean: {combine.mean().item()}\u0026#34;) print(f\u0026#34;Std: {combine.std().item()}\u0026#34;) print(f\u0026#34;Max: {combine.abs().max().item()}\u0026#34;) print(f\u0026#34;% of values \u0026gt; 5: {(combine.abs() \u0026gt; 5).float().mean().item() * 100}%\u0026#34;) check_nan(combine, f\u0026#34;Layer {layer_idx} hidden_states + residual_connection\u0026#34;) hidden_states = layer_module.mlp.layernorm(combine) check_nan(hidden_states, f\u0026#34;Layer {layer_idx} hidden_states = layer_module.mlp.layernorm(combine)\u0026#34;) if position_encodings is not None: hidden_states = hidden_states + position_encodings check_nan(hidden_states, f\u0026#34;Layer {layer_idx} after position encoding\u0026#34;) if output_all_encoded_layers: all_encoder_layers.append(hidden_states) if subset_mask is not None: hidden_states = hidden_states[subset_mask] check_nan(hidden_states, f\u0026#34;After subset mask\u0026#34;) else: if subset_mask is None: for layer_module in self.layer: hidden_states = layer_module(hidden_states, cu_seqlens, seqlen, None, indices, attn_mask=attention_mask, bias=alibi_attn_mask ) if output_all_encoded_layers: all_encoder_layers.append(hidden_states) # Pad inputs and mask. It will insert back zero-padded tokens. # Assume ntokens is total number of tokens (padded and non-padded) # and ntokens_unpad is total number of non-padded tokens. # Then padding performs the following de-compression: # hidden_states[ntokens_unpad,hidden] -\u0026gt; hidden_states[ntokens,hidden] hidden_states = bert_padding_module.pad_input( hidden_states, indices, batch, seqlen ) else: for i in range(len(self.layer) - 1): layer_module = self.layer[i] hidden_states = layer_module(hidden_states, cu_seqlens, seqlen, None, indices, attn_mask=attention_mask, bias=alibi_attn_mask) if output_all_encoded_layers: all_encoder_layers.append(hidden_states) subset_idx = torch.nonzero(subset_mask[attention_mask_bool], as_tuple=False).flatten() hidden_states = self.layer[-1](hidden_states, cu_seqlens, seqlen, subset_idx=subset_idx, indices=indices, attn_mask=attention_mask, bias=alibi_attn_mask) if not output_all_encoded_layers: all_encoder_layers.append(hidden_states) return all_encoder_layers And because one specifc operation is abstracted, we\u0026rsquo;ll add a debug/pickler for it as well.\n# Adapted from https://github.com/HazyResearch/fly/tree/master/src/models/layers import numpy as np import torch from einops import rearrange def blockdiag_weight_to_dense_weight(weight): \u0026#34;\u0026#34;\u0026#34; Argumments: weight: (nblocks, out / nblocks, in / blocks) Return: dense_weight: (out / in) \u0026#34;\u0026#34;\u0026#34; return torch.block_diag(*torch.unbind(weight, dim=0)) def blockdiag_multiply_reference(x, weight): \u0026#34;\u0026#34;\u0026#34; This implementation is slow but more likely to be correct. Arguments: x: (..., n) weight: (nblocks, q, n / nblocks) Outputs: out: (..., nblocks * q) \u0026#34;\u0026#34;\u0026#34; n = x.shape[-1] nblocks, q, p = weight.shape assert nblocks * p == n x_reshaped = rearrange(x, \u0026#34;... (nblocks p) -\u0026gt; ... nblocks p\u0026#34;, nblocks=nblocks) return rearrange( torch.einsum(\u0026#34;...kp, kqp -\u0026gt; ...kq\u0026#34;, x_reshaped, weight), \u0026#34;... nblocks q -\u0026gt; ... (nblocks q)\u0026#34;, ) class BlockdiagMultiply(torch.autograd.Function): \u0026#34;\u0026#34;\u0026#34;This is a faster implementation, with careful memory copies for the fastest bmm performance. The backward pass is also written manually with careful memory copies. Arguments: x: (..., n) weight: (nblocks, q, n / nblocks) Outputs: out: (..., nblocks * q) \u0026#34;\u0026#34;\u0026#34; @staticmethod @torch.cuda.amp.custom_fwd(cast_inputs=torch.bfloat16) def forward(ctx, x, weight): ctx.save_for_backward(x, weight) batch_shape, n = x.shape[:-1], x.shape[-1] batch_dim = np.prod(batch_shape) nblocks, q, p = weight.shape assert nblocks * p == n x_reshaped = x.reshape(batch_dim, nblocks, p).transpose(0, 1) out = torch.empty( batch_dim, nblocks, q, device=x.device, dtype=x.dtype ).transpose(0, 1) out = torch.bmm(x_reshaped, weight.transpose(-1, -2), out=out).transpose(0, 1) # Create directory if it doesn\u0026#39;t exist viz_dir = \u0026#39;block_multiple_viz\u0026#39; os.makedirs(viz_dir, exist_ok=True) # Get next file number existing_files = os.listdir(viz_dir) numbers = [int(f.split(\u0026#39;_\u0026#39;)[-1].split(\u0026#39;.\u0026#39;)[0]) for f in existing_files if f.endswith(\u0026#39;.pkl\u0026#39;)] next_num = max(numbers + [-1]) + 1 # Save tensors x_path = os.path.join(viz_dir, f\u0026#39;x_reshaped_{next_num:04d}.pkl\u0026#39;) out_path = os.path.join(viz_dir, f\u0026#39;out_{next_num:04d}.pkl\u0026#39;) with open(x_path, \u0026#39;wb\u0026#39;) as f: pickle.dump(x_reshaped.detach().cpu().numpy(), f) with open(out_path, \u0026#39;wb\u0026#39;) as f: pickle.dump(out.detach().cpu().numpy(), f) # clamp to avoid overflow + see explosions in later layers out = torch.clamp(out, min=-10000.0, max=10000.0) return out.reshape(*batch_shape, nblocks * q) # not used now.. @staticmethod @torch.cuda.amp.custom_bwd def backward(ctx, dout): x, weight = ctx.saved_tensors batch_shape, n = x.shape[:-1], x.shape[-1] batch_dim = np.prod(batch_shape) nblocks, q, p = weight.shape assert nblocks * p == n dx, dweight = None, None dout_reshaped = dout.reshape(batch_dim, nblocks, q).transpose(0, 1) if ctx.needs_input_grad[0]: dx = torch.empty(batch_dim, nblocks, p, device=x.device, dtype=x.dtype) dx = ( torch.bmm(dout_reshaped, weight.conj(), out=dx.transpose(0, 1)) .transpose(0, 1) .reshape(*batch_shape, n) ) dx = torch.clamp(dx, min=-10000.0, max=10000.0) if ctx.needs_input_grad[1]: x_reshaped = x.reshape(batch_dim, nblocks, p).transpose(0, 1) dweight = torch.bmm(dout_reshaped.transpose(-1, -2), x_reshaped.conj()) dweight = torch.clamp(dweight, min=-10000.0, max=10000.0) return dx, dweight blockdiag_multiply = BlockdiagMultiply.apply Cool, that gave us the results we\u0026rsquo;re looking for.\nspecifically, a single layer was causing explosions, shooting values up to \u0026gt; 65k while mean/std were 0/10 or something like that.\nwhat in the world.\nlet\u0026rsquo;s try to see what\u0026rsquo;s going on in tsne space.\nload outputs + slap a tsne on it + viz x_reshaped = pickle.load(open(\u0026#39;x_reshaped_0009.pkl\u0026#39;,\u0026#39;rb\u0026#39;)) out = pickle.load(open(\u0026#39;out_0009.pkl\u0026#39;,\u0026#39;rb\u0026#39;)) x_2d = x_reshaped.transpose(1, 0, 2).reshape(-1, x_reshaped.shape[-1]) # (batch_dim * nblocks, p) out_2d = out.transpose(1, 0, 2).reshape(-1, out.shape[-1]) # (batch_dim * nblocks, q) # TSNE reduction tsne = TSNE(n_components=3, random_state=42) x_tsne = tsne.fit_transform(x_2d) out_tsne = tsne.fit_transform(out_2d) # Create mask for clamped values clamped_mask = np.abs(out_2d).max(axis=1) \u0026gt;= 10000 viz # Create subplot figure fig = make_subplots( rows=1, cols=2, specs=[[{\u0026#39;type\u0026#39;: \u0026#39;scatter3d\u0026#39;}, {\u0026#39;type\u0026#39;: \u0026#39;scatter3d\u0026#39;}]], subplot_titles=(\u0026#39;Input Values (TSNE)\u0026#39;, \u0026#39;Output Values (TSNE)\u0026#39;) ) # Add traces for input fig.add_trace( go.Scatter3d( x=x_tsne[~clamped_mask, 0], y=x_tsne[~clamped_mask, 1], z=x_tsne[~clamped_mask, 2], mode=\u0026#39;markers\u0026#39;, marker=dict( size=5, color=np.abs(x_2d).mean(axis=1), colorscale=\u0026#39;Viridis\u0026#39;, showscale=True ), name=\u0026#39;Input\u0026#39; ), row=1, col=1 ) # Add traces for input fig.add_trace( go.Scatter3d( x=x_tsne[clamped_mask, 0], y=x_tsne[clamped_mask, 1], z=x_tsne[clamped_mask, 2], mode=\u0026#39;markers\u0026#39;, marker=dict( size=8, color=\u0026#39;red\u0026#39;, ), name=\u0026#39;Input\u0026#39; ), row=1, col=1 ) # Add normal output values fig.add_trace( go.Scatter3d( x=out_tsne[~clamped_mask, 0], y=out_tsne[~clamped_mask, 1], z=out_tsne[~clamped_mask, 2], mode=\u0026#39;markers\u0026#39;, marker=dict( size=5, color=np.abs(out_2d[~clamped_mask]).mean(axis=1), colorscale=\u0026#39;Viridis\u0026#39;, showscale=True ), name=\u0026#39;Normal Output\u0026#39; ), row=1, col=2 ) # Add clamped output values in red fig.add_trace( go.Scatter3d( x=out_tsne[clamped_mask, 0], y=out_tsne[clamped_mask, 1], z=out_tsne[clamped_mask, 2], mode=\u0026#39;markers\u0026#39;, marker=dict( size=8, color=\u0026#39;red\u0026#39;, ), name=\u0026#39;Clamped Output\u0026#39; ), row=1, col=2 ) fig.update_layout( height=800, width=1600, title_text=\u0026#34;TSNE Visualization with Highlighted Clamped Values\u0026#34;, showlegend=True ) fig.show()  what tokens cause this nasty bastard? hacky pseudo code\ntokenizer = BertTokenizer.from_pretrained('bert-base-uncased') tokens = tokenizer.encode(text, add_special_tokens=True) problematic_tokens = [] for i, token in enumerate(tokens): test_text = tokenizer.decode([token]) data = { \u0026quot;model\u0026quot;: \u0026quot;super-duper-custom-model\u0026quot;, \u0026quot;input\u0026quot;: test_text } response = requests.post(url, headers=headers, data=json.dumps(data)) if not response.ok: problematic_tokens.append({ 'token_id': token, 'token_text': test_text, 'position': i }) print(f\u0026quot;Found problematic token at position {i}: {test_text}\u0026quot;) return problematic_tokens results = test_individual_tokens(text2) print(f\u0026quot;\\nTotal problematic tokens found: {len(results)}\u0026quot;) well that doesn\u0026rsquo;t make too much sense.\n\u0026ldquo;owner\u0026rdquo;? the word \u0026ldquo;`owner\u0026rdquo; causes the problem but not, specifically, \u0026ldquo;owners\u0026rdquo;? weird. either way, we can fix this problem by clamping after the bmm operator and checking to make sure the clamp doesn\u0026rsquo;t effect similarity too hard. (test not pictured, it fixes it, casues no issues, all is right in the world)\nfin. that\u0026rsquo;s a wrap. welcome to the world of the odd. debugging precision issues in 2024.\njoin me next week for another deep dive in to the world of the odd.\n","date":"2024-11-04","permalink":"/posts/explosions_in_the_sky/","title":"Explosions in the Sky - A (not so very) Deep Dive into the World of Explosions FP16 Space"},{"content":"==================\nGenerating Synthetic Data for Free: Collecting and annotating image data is one of the most resource-intensive tasks on any computer vision project. It can take months at a time to fully collect, analyze, and experiment with image streams at the level you need in order to compete in the current marketplace. Even after you’ve successfully collected data, you still have a constant stream of annotation errors, poorly framed images, small amounts of meaningful data in a sea of unwanted captures, and more. These major bottlenecks are why synthetic data creation needs to be in the toolkit of every modern engineer. By creating 3D representations of the objects we want to model, we can rapidly prototype algorithms while concurrently collecting live data. For example, by using this same strategy, I was able to decrease data collection costs by a whopping 9500%.\nIn this post, I walk you through an example of using the open-source animation library Blender to build an end-to-end synthetic data pipeline, using chicken nuggets as an example. The following image is an illustration of the data generated in this blog post.\nSet up your local environment We install two versions of Blender: one on a local machine with access to a GUI, and the other on an Amazon Elastic Compute Cloud (Amazon EC2) P2 instance.\nInstall Blender and ZPY Install Blender from the Blender website.\nThen complete the following steps:\n Run the following commands:  wget https://mirrors.ocf.berkeley.edu/blender/release/Blender3.2/blender-3.2.0-linux-x64.tar.xz sudo tar -Jxf blender-3.2.0-linux-x64.tar.xz --strip-components=1 -C /bin rm -rf blender* /bin/3.2/python/bin/python3.10 -m ensurepip /bin/3.2/python/bin/python3.10 -m pip install --upgrade pip Copy the necessary Python headers into the Blender version of Python so that you can use other non-Blender libraries:  wget https://www.python.org/ftp/python/3.10.2/Python-3.10.2.tgz tar -xzf Python-3.10.2.tgz sudo cp Python-3.10.2/Include/* /bin/3.2/python/include/python3.10 Override your Blender version and force installs so that the Blender-provided Python works:  /bin/3.2/python/bin/python3.10 -m pip install pybind11 pythran Cython numpy==1.22.1 sudo /bin/3.2/python/bin/python3.10 -m pip install -U Pillow --force sudo /bin/3.2/python/bin/python3.10 -m pip install -U scipy --force sudo /bin/3.2/python/bin/python3.10 -m pip install -U shapely --force sudo /bin/3.2/python/bin/python3.10 -m pip install -U scikit-image --force sudo /bin/3.2/python/bin/python3.10 -m pip install -U gin-config --force sudo /bin/3.2/python/bin/python3.10 -m pip install -U versioneer --force sudo /bin/3.2/python/bin/python3.10 -m pip install -U shapely --force sudo /bin/3.2/python/bin/python3.10 -m pip install -U ptvsd --force sudo /bin/3.2/python/bin/python3.10 -m pip install -U ptvseabornsd --force sudo /bin/3.2/python/bin/python3.10 -m pip install -U zmq --force sudo /bin/3.2/python/bin/python3.10 -m pip install -U pyyaml --force sudo /bin/3.2/python/bin/python3.10 -m pip install -U requests --force sudo /bin/3.2/python/bin/python3.10 -m pip install -U click --force sudo /bin/3.2/python/bin/python3.10 -m pip install -U table-logger --force sudo /bin/3.2/python/bin/python3.10 -m pip install -U tqdm --force sudo /bin/3.2/python/bin/python3.10 -m pip install -U pydash --force sudo /bin/3.2/python/bin/python3.10 -m pip install -U matplotlib --force Download zpy and install from source:  git clone https://github.com/ZumoLabs/zpy cd zpy vi requirements.txt Change the NumPy version to \u0026gt;=1.19.4 and scikit-image\u0026gt;=0.18.1 to make the install on 3.10.2 possible and so you don’t get any overwrites:  numpy\u0026gt;=1.19.4 gin-config\u0026gt;=0.3.0 versioneer scikit-image\u0026gt;=0.18.1 shapely\u0026gt;=1.7.1 ptvsd\u0026gt;=4.3.2 seaborn\u0026gt;=0.11.0 zmq pyyaml requests click table-logger\u0026gt;=0.3.6 tqdm pydash To ensure compatibility with Blender 3.2, go into zpy/render.py and comment out the following two lines (for more information, refer to Blender 3.0 Failure #54):  #scene.render.tile_x = tile_size #scene.render.tile_y = tile_size Next, install the zpy library:  /bin/3.2/python/bin/python3.10 setup.py install --user /bin/3.2/python/bin/python3.10 -c \u0026#34;import zpy; print(zpy.__version__)\u0026#34; Download the add-ons version of zpy from the GitHub repo so you can actively run your instance:  cd ~ curl -O -L -C - \u0026#34;https://github.com/ZumoLabs/zpy/releases/download/v1.4.1rc9/zpy_addon-v1.4.1rc9.zip\u0026#34; sudo unzip zpy_addon-v1.4.1rc9.zip -d /bin/3.2/scripts/addons/ mkdir .config/blender/ mkdir .config/blender/3.2 mkdir .config/blender/3.2/scripts mkdir .config/blender/3.2/scripts/addons/ mkdir .config/blender/3.2/scripts/addons/zpy_addon/ sudo cp -r zpy/zpy_addon/* .config/blender/3.2/scripts/addons/zpy_addon/ Save a file called enable_zpy_addon.py in your /home directory and run the enablement command, because you don’t have a GUI to activate it:  import bpy, os p = os.path.abspath(\u0026#39;zpy_addon-v1.4.1rc9.zip\u0026#39;) bpy.ops.preferences.addon_install(overwrite=True, filepath=p) bpy.ops.preferences.addon_enable(module=\u0026#39;zpy_addon\u0026#39;) bpy.ops.wm.save_userpref() sudo blender -b -y --python enable_zpy_addon.py If zpy-addon doesn’t install (for whatever reason), you can install it via the GUI.\nIn Blender, on the Edit menu, choose Preferences.  Choose Add-ons in the navigation pane and activate zpy.  You should see a page open in the GUI, and you’ll be able to choose ZPY. This will confirm that Blender is loaded.\nAliceVision and Meshroom Install AliceVision and Meshrooom from their respective GitHub repos:\n Install AliceVision Install Meshroom  FFmpeg Your system should have ffmpeg, but if it doesn’t, you’ll need to download it. sudo apt-get ffmepg or if on windows, just throw your computer out and burn it.\nInstant Meshes You can either compile the library yourself or download the available pre-compiled binaries (which is what I did) for Instant Meshes.\nSet up your AWS environment Now we set up the AWS environment on an EC2 instance. We repeat the steps from the previous section, but only for Blender and zpy.\n On the Amazon EC2 console, choose Launch instances. Choose your AMI.There are a few options from here. We can either choose a standard Ubuntu image, pick a GPU instance, and then manually install the drivers and get everything set up, or we can take the easy route and start with a preconfigured Deep Learning AMI and only worry about installing Blender.For this post, I use the second option, and choose the latest version of the Deep Learning AMI for Ubuntu (Deep Learning AMI (Ubuntu 18.04) Version 61.0). For Instance type¸ choose p2.xlarge. If you don’t have a key pair, create a new one or choose an existing one. For this post, use the default settings for network and storage. Choose Launch instances. Choose Connect and find the instructions to log in to our instance from SSH on the SSH client tab. Connect with SSH: ssh -i \u0026quot;your-pem\u0026quot; ubuntu@IPADDRESS.YOUR-REGION.compute.amazonaws.com  Once you’ve connected to your instance, follow the same installation steps from the previous section to install Blender and zpy.\nData collection: 3D scanning our nugget For this step, I use an iPhone to record a 360-degree video at a fairly slow pace around my nugget. I stuck a chicken nugget onto a toothpick and taped the toothpick to my countertop, and simply rotated my camera around the nugget to get as many angles as I could. The faster you film, the less likely you get good images to work with depending on the shutter speed.\nAfter I finished filming, I sent the video to my email and extracted the video to a local drive. From there, I used ffmepg to chop the video into frames to make Meshroom ingestion much easier:\nmkdir nugget_images ffmpeg -i VIDEO.mov ffmpeg nugget_images/nugget_%06d.jpg Open Meshroom and use the GUI to drag the nugget_images folder to the pane on the left. From there, choose Start and wait a few hours (or less) depending on the length of the video and if you have a CUDA-enabled machine.\nYou should see something like the following screenshot when it’s almost complete.\nData collection: Blender manipulation When our Meshroom reconstruction is complete, complete the following steps:\n Open the Blender GUI and on the File menu, choose Import, then choose Wavefront (.obj) to your created texture file from Meshroom. The file should be saved in path/to/MeshroomCache/Texturing/uuid-string/texturedMesh.obj. Load the file and observe the monstrosity that is your 3D object.   Here is where it gets a bit tricky.\nScroll to the top right side and choose the Wireframe icon in Viewport Shading.  Select your object on the right viewport and make sure it’s highlighted, scroll over to the main layout viewport, and either press Tab or manually choose Edit Mode.  Next, maneuver the viewport in such a way as to allow yourself to be able to see your object with as little as possible behind it. You’ll have to do this a few times to really get it correct. Click and drag a bounding box over the object so that only the nugget is highlighted. After it’s highlighted like in the following screenshot, we separate our nugget from the 3D mass by left-clicking, choosing Separate, and then Selection.   We now move over to the right, where we should see two textured objects: texturedMesh and texturedMesh.001.\n Our new object should be texturedMesh.001, so we choose texturedMesh and choose Delete to remove the unwanted mass.   Choose the object (texturedMesh.001) on the right, move to our viewer, and choose the object, Set Origin, and Origin to Center of Mass.   Now, if we want, we can move our object to the center of the viewport (or simply leave it where it is) and view it in all its glory. Notice the large black hole where we didn’t really get good film coverage from! We’re going to need to correct for this. To clean our object of any pixel impurities, we export our object to an .obj file. Make sure to choose Selection Only when exporting.\nData collection: Clean up with Instant Meshes Now we have two problems: our image has a pixel gap creating by our poor filming that we need to clean up, and our image is incredibly dense (which will make generating images extremely time-consuming). To tackle both issues, we need to use a software called Instant Meshes to extrapolate our pixel surface to cover the black hole and also to shrink the total object to a smaller, less dense size.\n Open Instant Meshes and load our recently saved nugget.obj file.   Under Orientation field, choose Solve.   Under Position field, choose Solve.\n   Here’s where it gets interesting. If you explore your object and notice that the criss-cross lines of the Position solver look disjointed, you can choose the comb icon under Orientation field and redraw the lines properly.\nChoose Solve for both Orientation field and Position field.  If everything looks good, export the mesh, name it something like nugget_refined.obj, and save it to disk.  Data collection: Shake and bake! Because our low-poly mesh doesn’t have any image texture associated with it and our high-poly mesh does, we either need to bake the high-poly texture onto the low-poly mesh, or create a new texture and assign it to our object. For sake of simplicity, we’re going to create an image texture from scratch and apply that to our nugget.\nI used Google image search for nuggets and other fried things in order to get a high-res image of the surface of a fried object. I found a super high-res image of a fried cheese curd and made a new image full of the fried texture.\nWith this image, I’m ready to complete the following steps:\n Open Blender and load the new nugget_refined.obj the same way you loaded your initial object: on the File menu, choose Import, Wavefront (.obj), and choose the nugget_refined.obj file. Next, go to the Shading tab. At the bottom you should notice two boxes with the titles Principled BDSF and Material Output. On the Add menu, choose Texture and Image Texture. An Image Texture box should appear. Choose Open Image and load your fried texture image. Drag your mouse between Color in the Image Texture box and Base Color in the Principled BDSF box.  Now your nugget should be good to go!\nData collection: Create Blender environment variables Now that we have our base nugget object, we need to create a few collections and environment variables to help us in our process.\n  Left-click on the hand scene area and choose New Collection.   Create the following collections: BACKGROUND, NUGGET, and SPAWNED.   Drag the nugget to the NUGGET collection and rename it nugget_base.\n  Data collection: Create a plane We’re going to create a background object from which our nuggets will be generated when we’re rendering images. In a real-world use case, this plane is where our nuggets are placed, such as a tray or bin.\n On the Add menu, choose Mesh and then Plane. From here, we move to the right side of the page and find the orange box (Object Properties). In the Transform pane, for XYZ Euler, set X to 46.968, Y to 46.968, and Z to 1.0. For both Location and Rotation, set X, Y, and Z to 0.   Data collection: Set the camera and axis Next, we’re going to set our cameras up correctly so that we can generate images.\n On the Add menu, choose Empty and Plain Axis. Name the object Main Axis   Make sure our axis is 0 for all the variables (so it’s directly in the center).   If you have a camera already created, drag that camera to under Main Axis.\n  Choose Item and Transform.\n  For Location, set X to 0, Y to 0, and Z to 100.   Data collection: Here comes the sun Next, we add a Sun object.\n On the Add menu, choose Light and Sun. The location of this object doesn’t necessarily matter as long as it’s centered somewhere over the plane object we’ve set. Choose the green lightbulb icon in the bottom right pane (Object Data Properties) and set the strength to 5.0. Repeat the same procedure to add a Light object and put it in a random spot over the plane.   Data collection: Download random backgrounds To inject randomness into our images, we download as many random textures from texture.ninja as we can (for example, bricks). Download to a folder within your workspace called random_textures. I downloaded about 50.\nThe Fun Stuff Image Generation Pipeline: Object3D and DensityController Walkthrough Let\u0026rsquo;s start with some code definitions:\nclass Object3D: \u0026#39;\u0026#39;\u0026#39; object container to store mesh information about the given object Returns the Object3D object \u0026#39;\u0026#39;\u0026#39; def __init__(self, object: Union[bpy.types.Object, str]): \u0026#34;\u0026#34;\u0026#34;Creates a Object3D object. Args: obj (Union[bpy.types.Object, str]): Scene object (or it\u0026#39;s name) \u0026#34;\u0026#34;\u0026#34; self.object = object self.obj_poly = None self.mat = None self.vert = None self.poly = None self.bvht = None self.calc_mat() self.calc_world_vert() self.calc_poly() self.calc_bvht() def calc_mat(self) -\u0026gt; None: \u0026#34;\u0026#34;\u0026#34;store an instance of the object\u0026#39;s matrix_world\u0026#34;\u0026#34;\u0026#34; self.mat = self.object.matrix_world def calc_world_vert(self) -\u0026gt; None: \u0026#34;\u0026#34;\u0026#34;calculate the verticies from object\u0026#39;s matrix_world perspective\u0026#34;\u0026#34;\u0026#34; self.vert = [self.mat @ v.co for v in self.object.data.vertices] self.obj_poly = np.array(self.vert) def calc_poly(self) -\u0026gt; None: \u0026#34;\u0026#34;\u0026#34;store an instance of the object\u0026#39;s polygons\u0026#34;\u0026#34;\u0026#34; self.poly = [p.vertices for p in self.object.data.polygons] def calc_bvht(self) -\u0026gt; None: \u0026#34;\u0026#34;\u0026#34;create a BVHTree from the object\u0026#39;s polygon\u0026#34;\u0026#34;\u0026#34; self.bvht = BVHTree.FromPolygons( self.vert, self.poly ) def regenerate(self) -\u0026gt; None: \u0026#34;\u0026#34;\u0026#34;reinstantiate the object\u0026#39;s variables; used when the object is manipulated after it\u0026#39;s creation\u0026#34;\u0026#34;\u0026#34; self.calc_mat() self.calc_world_vert() self.calc_poly() self.calc_bvht() def __repr__(self): return \u0026#34;Object3D: \u0026#34; + self.object.__repr__() We first define a basic container Class with some important properties for us. This class mainly exists to allow us to create a BVHTree (a way to represent our nugget object in 3D space) where we\u0026rsquo;ll need to use the BVHTree.overlap method to see if two independent generated nugget objects are overlapping in our 3D space. More on this later.\nThe second piece of code is our Density Controller. This serves as a way to bound ourselves to the rules of reality and not the 3D world; eg, in the 3D Blender world, objects in blender can exist inside each other; however, unless someone is performing some wild Stranger Things type science on our chicken nuggets; we want to make sure no two nuggets are overlapping by a degree that makes it visually unrealistic.\nWe will use our Plane object to spawn a set of bounded invisible cubes that can be queried at any given time to see if the space is occupied or not.\n class DensityController: \u0026#34;\u0026#34;\u0026#34;Container that controlls the spacial relationship between 3D objects Returns: DensityController: The DensityController object. \u0026#34;\u0026#34;\u0026#34; def __init__(self): self.bvhtrees = None self.overlaps = None self.occupied = None self.unoccupied = None self.objects3d = [] def auto_generate_kdtree_cubes( self, num_objects: int = 100, # max size of nuggets ) -\u0026gt; None: \u0026#34;\u0026#34;\u0026#34; function to generate physical kdtree cubes given a plane of -resize- size this allows us to access each cube\u0026#39;s overlap/occupancy status at any given time creates a KDTree collection, a cube, a set of individual cubes, and the BVHTree object for each individual cube Args: resize (Tuple[float]): the size of a cube to create XYZ. cuts (int): how many cuts are made to the cube face 12 cuts == 13 Rows x 13 Columns \u0026#34;\u0026#34;\u0026#34; So in this snippet we\u0026rsquo;re going to select the nugget and create a bounding cube around that nugget. This cube will represent the size of a single pseudo-voxel of our psuedo-kdtree object. We need to use the bpy.context.view_layer.update() function because when this code is being run from inside a function/script vs. the blender-gui, it seems that the view_layer isn\u0026rsquo;t automatically updated.\n# read the nugget, # see how large the cube needs to be to encompass a single nugget # then touch a parameter to allow it to be smaller or larger (eg more touching) bpy.context.view_layer.objects.active = bpy.context.scene.objects.get(\u0026#39;nugget_base\u0026#39;) bpy.ops.object.origin_set(type=\u0026#39;ORIGIN_GEOMETRY\u0026#39;, center=\u0026#39;BOUNDS\u0026#39;) #create a cube for the bounding box bpy.ops.mesh.primitive_cube_add(location=Vector((0,0,0))) #our new cube is now the active object, so we can keep track of it in a variable: bound_box = bpy.context.active_object bound_box.name = \u0026#39;CUBE1\u0026#39; bpy.context.view_layer.update() #copy transforms nug_dims = bpy.data.objects[\u0026#34;nugget_base\u0026#34;].dimensions bpy.data.objects[\u0026#34;CUBE1\u0026#34;].dimensions = nug_dims bpy.context.view_layer.update() bpy.data.objects[\u0026#34;CUBE1\u0026#34;].location = bpy.data.objects[\u0026#34;nugget_base\u0026#34;].location bpy.context.view_layer.update() bpy.data.objects[\u0026#34;CUBE1\u0026#34;].rotation_euler = bpy.data.objects[\u0026#34;nugget_base\u0026#34;].rotation_euler bpy.context.view_layer.update() print(\u0026#34;bound_box.dimensions: \u0026#34;, bound_box.dimensions) print(\u0026#34;bound_box.location:\u0026#34;, bound_box.location) Next, we slightly update our Cube object so that its Length/Width are square, as opposed to the natural size of the nugget it was created from.\n# this cube created isn\u0026#39;t always square, but we\u0026#39;re going to make it square # to fit into our  x, y, z = bound_box.dimensions v = max(x, y) if np.round(v) \u0026lt; v: v = np.round(v)+1 bb_x, bb_y = v, v bound_box.dimensions = Vector((v, v, z)) bpy.context.view_layer.update() print(\u0026#34;bound_box.dimensions updated: \u0026#34;, bound_box.dimensions) # now we generate a plane # calc the size of the plane given a max number of boxes. Here is where we use our updated cube object to create a plane that can volumetrically hold num_objects amount of nuggets.\nx, y, z = bound_box.dimensions bb_loc = bound_box.location bb_rot_eu = bound_box.rotation_euler min_area = (x*y)*num_objects min_length = min_area / num_objects print(min_length) # now we generate a plane # calc the size of the plane given a max number of boxes. bpy.ops.mesh.primitive_plane_add(location=Vector((0,0,0)), size = min_length) plane = bpy.context.selected_objects[0] plane.name = \u0026#39;PLANE\u0026#39; # move our plane to our background collection # current_collection = plane.users_collection link_object(\u0026#39;PLANE\u0026#39;, \u0026#39;BACKGROUND\u0026#39;) bpy.context.view_layer.update() We take our plane object and create a giant cube of the same length/with as our plane, with the height of our nugget-cube, CUBE1.\n# New Collection my_coll = bpy.data.collections.new(\u0026#34;KDTREE\u0026#34;) # Add collection to scene collection bpy.context.scene.collection.children.link(my_coll) # now we generate cubes based on the size of the plane. bpy.ops.mesh.primitive_cube_add(location=Vector((0,0,0)), size = min_length) bpy.context.view_layer.update() cube = bpy.context.selected_objects[0] cube_dimensions = cube.dimensions bpy.context.view_layer.update() cube.dimensions = Vector((cube_dimensions[0], cube_dimensions[1], z)) bpy.context.view_layer.update() cube.location = bb_loc bpy.context.view_layer.update() cube.rotation_euler = bb_rot_eu bpy.context.view_layer.update() cube.name = \u0026#39;cube\u0026#39; bpy.context.view_layer.update() current_collection = cube.users_collection link_object(\u0026#39;cube\u0026#39;, \u0026#39;KDTREE\u0026#39;) bpy.context.view_layer.update() From here, we want to create \u0026lsquo;voxels\u0026rsquo; from our cube. We take the number of cubes we would to fit num_objects and then cut them from our cube object. We look for the upward-facing \u0026lsquo;mesh-face\u0026rsquo; of our cube, and then pick that face to make our cuts.\n# get the bb volume and make the proper cuts to the object  bb_vol = x*y*z cube_vol = cube_dimensions[0]*cube_dimensions[1]*cube_dimensions[2] n_cubes = cube_vol / bb_vol cuts = n_cubes / ((x+y) / 2) cuts = int(np.round(cuts)) - 1 #  # select the cube for object in bpy.data.objects: object.select_set(False) bpy.context.view_layer.update() for object in bpy.data.objects: object.select_set(False) bpy.data.objects[\u0026#39;cube\u0026#39;].select_set(True) # Blender 2.8x bpy.context.view_layer.objects.active = bpy.context.scene.objects.get(\u0026#39;cube\u0026#39;) # set to edit mode bpy.ops.object.mode_set(mode=\u0026#39;EDIT\u0026#39;, toggle=False) print(\u0026#39;edit mode success\u0026#39;) # get face_data context = bpy.context obj = context.edit_object me = obj.data mat = obj.matrix_world bm = bmesh.from_edit_mesh(me) up_face = None # select upwards facing cube-face # https://blender.stackexchange.com/questions/43067/get-a-face-selected-pointing-upwards for face in bm.faces: if (face.normal-UP_VECTOR).length \u0026lt; EPSILON: up_face = face break assert(up_face) # subdivide the edges to get the perfect kdtree cubes bmesh.ops.subdivide_edges(bm, edges=up_face.edges, use_grid_fill=True, cuts=cuts) bpy.context.view_layer.update() # get the center point of each face Lastly, we calculate the center of the top-face of each cut we\u0026rsquo;ve made from our big-cube and create actual cubes from those cuts. Each of these newly created cubes will represent a single piece of space to spawn/move nuggets around our plane.\nface_data = {} sizes = [] for f, face in enumerate(bm.faces): face_data[f] = {} face_data[f][\u0026#39;calc_center_bounds\u0026#39;] = face.calc_center_bounds() loc = mat @ face_data[f][\u0026#39;calc_center_bounds\u0026#39;] face_data[f][\u0026#39;loc\u0026#39;] = loc sizes.append(loc[-1]) # get the most common cube-z; we use this to determine the correct loc counter = Counter() counter.update(sizes) most_common = counter.most_common()[0][0] cube_loc = mat @ cube.location # get out of edit mode bpy.ops.object.mode_set(mode=\u0026#39;OBJECT\u0026#39;, toggle=False) # go to new colection bvhtrees = {} for f in face_data: loc = face_data[f][\u0026#39;loc\u0026#39;] loc = mat @ face_data[f][\u0026#39;calc_center_bounds\u0026#39;] print(loc) if loc[-1] == most_common: # set it back down to the floor because the face is elevated to the # top surface of the cube loc[-1] = cube_loc[-1] bpy.ops.mesh.primitive_cube_add(location=loc, size = x) cube = bpy.context.selected_objects[0] cube.dimensions = Vector((x, y, z)) # bpy.context.view_layer.update() cube.name = \u0026#34;cube_{}\u0026#34;.format(f) #my_coll.objects.link(cube) link_object(\u0026#34;cube_{}\u0026#34;.format(f), \u0026#39;KDTREE\u0026#39;) #bpy.context.view_layer.update() bvhtrees[f] = { \u0026#39;occupied\u0026#39; : 0, \u0026#39;object\u0026#39; : Object3D(cube) } for object in bpy.data.objects: object.select_set(False) bpy.data.objects[\u0026#39;CUBE1\u0026#39;].select_set(True) # Blender 2.8x bpy.ops.object.delete() return bvhtrees Next, we develop an algorithm that understands which cubes are occupied at any given time, finds which objects overlap with eachother, and then moves overlapping objects separately into unoccupied space. We won\u0026rsquo;t be able get rid of all overlaps entirely, but we can make it look real enough.\n def find_occupied_space( self, objects3d: List[Object3D], ) -\u0026gt; None: \u0026#34;\u0026#34;\u0026#34; discover which cube\u0026#39;s bvhtree is occupied in our kdtree space Args: list of Object3D objects \u0026#34;\u0026#34;\u0026#34; count = 0 occupied = [] for i in self.bvhtrees: bvhtree = self.bvhtrees[i][\u0026#39;object\u0026#39;] for object3d in objects3d: if object3d.bvht.overlap(bvhtree.bvht): self.bvhtrees[i][\u0026#39;occupied\u0026#39;] = 1 def find_overlapping_objects( self, objects3d: List[Object3D], ) -\u0026gt; List[Tuple[int]]: \u0026#34;\u0026#34;\u0026#34; returns which Object3D objects are overlapping Args: list of Object3D objects Returns: List of indicies from objects3d that are overlap \u0026#34;\u0026#34;\u0026#34; count = 0 overlaps = [] for i, x_object3d in enumerate(objects3d): for ii, y_object3d in enumerate(objects3d[i+1:]): if x_object3d.bvht.overlap(y_object3d.bvht): overlaps.append((i, ii)) return overlaps def calc_most_overlapped( self, overlaps: List[Tuple[int]] ) -\u0026gt; List[Tuple[int]]: \u0026#34;\u0026#34;\u0026#34; Algorithm to count the number of edges each index has and return a sorted list from most-\u0026gt;least with the number of edges each index has. Args: list of indicies that are overlapping Returns: list of indicies with the total number of overlapps they have [index, count] \u0026#34;\u0026#34;\u0026#34; keys = {} for x,y in overlaps: if x not in keys: keys[x] = 0 if y not in keys: keys[y] = 0 keys[x]+=1 keys[y]+=1 # sort by most edges first index_counts = sorted(keys.items(), key=lambda x: x[1])[::-1] return index_counts def get_random_unoccupied( self ) -\u0026gt; Union[int,None]: \u0026#34;\u0026#34;\u0026#34; returns a randomly chosen unoccuped kdtree cube Return either the kdtree cube\u0026#39;s key or None (meaning all spaces are currently occupied) Union[int,None] \u0026#34;\u0026#34;\u0026#34; unoccupied = [] for i in self.bvhtrees: if not self.bvhtrees[i][\u0026#39;occupied\u0026#39;]: unoccupied.append(i) if unoccupied: random.shuffle(unoccupied) return unoccupied[0] else: return None def regenerate( self, iterable: Union[None, List[Object3D]] = None ) -\u0026gt; None: \u0026#34;\u0026#34;\u0026#34; this function recalculates each objects world-view information we default to None, which means we\u0026#39;re recalculating the self.bvhtree cubes Args: iterable (None or List of Object3D objects). if None, we default to recalculating the kdtree \u0026#34;\u0026#34;\u0026#34; if isinstance(iterable, list): for object in iterable: object.regenerate() else: for idx in self.bvhtrees: self.bvhtrees[idx][\u0026#39;object\u0026#39;].regenerate() self.update_tree(idx, occupied=0) def process_trees_and_objects( self, objects3d: List[Object3D], ) -\u0026gt; List[Tuple[int]]: \u0026#34;\u0026#34;\u0026#34; This function finds all overlapping objects within objects3d, calculates the objects with the most overlaps, searches within the kdtree cube space to see which cubes are occupied. It then returns the edge-counts from the most overlapping objects Args: list of Object3D objects Returns this returns the output of most_overlapped \u0026#34;\u0026#34;\u0026#34; overlaps = self.find_overlapping_objects(objects3d) most_overlapped = self.calc_most_overlapped(overlaps) self.find_occupied_space(objects3d) return most_overlapped def move_objects( self, objects3d: List[Object3D], most_overlapped: List[Tuple[int]], z_increase_offset: float = 2., ) -\u0026gt; None: \u0026#34;\u0026#34;\u0026#34; This function iterates through most-overlapped, and uses the index to extract the matching object from object3d - it then finds a random unoccupied kdtree cube and moves the given overlapping object to that space. It does this for each index from the most-overlapped function Args: objects3d: list of Object3D objects most_overlapped: a list of tuples (index, count) - where index relates to where it\u0026#39;s found in objects3d and count - how many times it overlaps with other objects z_increase_offset: this value increases the Z value of the object in order to make it appear as though it\u0026#39;s off the floor. If you don\u0026#39;t augment this value the object looks like it\u0026#39;s \u0026#39;inside\u0026#39; the ground plane \u0026#34;\u0026#34;\u0026#34; for idx, cnt in most_overlapped: object3d = objects3d[idx] unoccupied_idx = self.get_random_unoccupied() if unoccupied_idx: object3d.object.location = self.bvhtrees[unoccupied_idx][\u0026#39;object\u0026#39;].object.location # ensure the nuggest is above the groundplane object3d.object.location[-1] = z_increase_offset self.update_tree(unoccupied_idx, occupied=1) def dynamic_movement( self, objects3d: List[Object3D], tries: int = 100, z_offset: float = 2., ) -\u0026gt; None: \u0026#34;\u0026#34;\u0026#34; This function resets all objects to get their current positioning and randomly moves objects around in an attempt to avoid any object overlaps (we don\u0026#39;t want two objects to be spawned in the same position) Args: objects3d: list of Object3D objects tries: int the number of times we want to move objects to random spaces to ensure no overlaps are present. z_offset: this value increases the Z value of the object in order to make it appear as though it\u0026#39;s off the floor. If you don\u0026#39;t augment this value the object looks like it\u0026#39;s \u0026#39;inside\u0026#39; the ground plane (see `move_objects`) \u0026#34;\u0026#34;\u0026#34; # reset all objects self.regenerate(objects3d) # regenerate bvhtrees self.regenerate(None) most_overlapped = self.process_trees_and_objects(objects3d) attempts = 0 while most_overlapped: if attempts\u0026gt;=tries: break self.move_objects(objects3d, most_overlapped, z_offset) attempts+=1 # recalc objects self.regenerate(objects3d) # regenerate bvhtrees self.regenerate(None) # recalculate overlaps most_overlapped = self.process_trees_and_objects(objects3d) def generate_spawn_point( self, ) -\u0026gt; Vector: \u0026#34;\u0026#34;\u0026#34; this function generates a random spawn point by finding which of the kdtree-cubes are unoccupied, and returns one of those Returns the Vector location of the kdtree-cube that\u0026#39;s unoccupied \u0026#34;\u0026#34;\u0026#34; idx = self.get_random_unoccupied() print(idx) self.update_tree(idx, occupied=1) return self.bvhtrees[idx][\u0026#39;object\u0026#39;].object.location def update_tree( self, idx: int, occupied: int, ) -\u0026gt; None: \u0026#34;\u0026#34;\u0026#34; this function updates the given state (occupied vs. unoccupied) of the kdtree given the idx Args: idx: int occupied: int \u0026#34;\u0026#34;\u0026#34; self.bvhtrees[idx][\u0026#39;occupied\u0026#39;] = occupied Image Generation Pipeline: Cool Runnings Here we\u0026rsquo;ll break down what our run function is doing.\nWe initialize our DensityController and create something called a saver using the ImageSaver from zpy. This allows us to seemlessly save our rendered images to any location of our choosing. We then add our nugget category (and if we had more categories, we would add them here).\n@gin.configurable(\u0026#34;run\u0026#34;) @zpy.blender.save_and_revert def run( max_num_nuggets: int = 100, jitter_mesh: bool = True, jitter_nugget_scale: bool = True, jitter_material: bool = True, jitter_nugget_material: bool = False, number_of_random_materials: int = 50, nugget_texture_path: str = os.getcwd()+\u0026#34;/nugget_textures\u0026#34;, annotations_path = os.getcwd()+\u0026#39;/nugget_data\u0026#39;, ): \u0026#34;\u0026#34;\u0026#34; Main run function. \u0026#34;\u0026#34;\u0026#34; density_controller = DensityController() # Random seed results in unique behavior zpy.blender.set_seed(random.randint(0,1000000000)) # Create the saver object saver = zpy.saver_image.ImageSaver( description=\u0026#34;Image of the randomized Amazon nuggets\u0026#34;, output_dir=annotations_path, ) saver.add_category(name=\u0026#34;nugget\u0026#34;) Next we need to make a source-object from which we spawn copy nuggets from, in this case, it\u0026rsquo;s our nugget_base that we\u0026rsquo;ve so meticulously created.\n# Make a list of source nugget objects source_nugget_objects = [] for obj in zpy.objects.for_obj_in_collections( [ bpy.data.collections[\u0026#34;NUGGET\u0026#34;], ] ): assert(obj!=None) # pass on everything not named nugget if \u0026#39;nugget_base\u0026#39; not in obj.name: print(\u0026#39;passing on {}\u0026#39;.format(obj.name)) continue zpy.objects.segment(obj, name=\u0026#34;nugget\u0026#34;, as_category=True) #color=nugget_seg_color print(\u0026#34;zpy.objects.segment: check {}\u0026#34;.format(obj.name)) source_nugget_objects.append(obj.name) Now that we have our base nugget, we\u0026rsquo;re going to save the world-poses/locations of all the other objects so that after each rendering run, we can use these saved poses to reinitialize a render. (we also move our base nugget completely out of the way so that the kdtree doesn\u0026rsquo;t sense a space being occupied). Finally we initialize our kdtree-cube objects.\n# move nugget point up 10 z\u0026#39;s so it won\u0026#39;t collide with base-cube bpy.data.objects[\u0026#34;nugget_base\u0026#34;].location[-1] = 10 # Save the position of the camera and light # create light and camera zpy.objects.save_pose(\u0026#34;Camera\u0026#34;) zpy.objects.save_pose(\u0026#34;Sun\u0026#34;) zpy.objects.save_pose(\u0026#34;Plane\u0026#34;) zpy.objects.save_pose(\u0026#34;Main Axis\u0026#34;) axis = bpy.data.objects[\u0026#39;Main Axis\u0026#39;] print(\u0026#39;saving poses\u0026#39;) # add some parameters to this  # get the plane-3d object plane3d = Object3D(bpy.data.objects[\u0026#39;Plane\u0026#39;]) # generate kdtree cubes density_controller.generate_kdtree_cubes() This piece of code simply collects our downloaded backgrounds from texture-ninja where they\u0026rsquo;ll be used to randomly be projected onto our plane.\n# Pre-create a bunch of random textures #random_materials = [ # zpy.material.random_texture_mat() for _ in range(number_of_random_materials) #] p = os.path.abspath(os.getcwd()+\u0026#39;/random_textures\u0026#39;) print(p) random_materials = [] for x in os.listdir(p): texture_path = Path(os.path.join(p,x)) y = zpy.material.make_mat_from_texture(texture_path, name=texture_path.stem) random_materials.append(y) #print(random_materials[0]) # Pre-create a bunch of random textures random_nugget_materials = [ random_nugget_texture_mat(Path(nugget_texture_path)) for _ in range(number_of_random_materials) ] Here is where the magic begins. We first regenerate out kdtree-cubes for this run so that we can start fresh\n# Run the sim. for step_idx in zpy.blender.step(): density_controller.generate_kdtree_cubes() objects3d = [] num_nuggets = random.randint(40, max_num_nuggets) log.info(f\u0026#34;Spawning {num_nuggets} nuggets.\u0026#34;) spawned_nugget_objects = [] for _ in range(num_nuggets): We use our density controller to generate a random spawn point for our nugget, create a copy of nugget_base and move the copy to the randomly generated spawn point.\n# Choose location to spawn nuggets spawn_point = density_controller.generate_spawn_point() # manually spawn above the floor # spawn_point[-1] = 1.8 #2.0 # Pick a random object to spawn _name = random.choice(source_nugget_objects) log.info(f\u0026#34;Spawning a copy of source nugget {_name} at {spawn_point}\u0026#34;) obj = zpy.objects.copy( bpy.data.objects[_name], collection=bpy.data.collections[\u0026#34;SPAWNED\u0026#34;], is_copy=True, ) obj.location = spawn_point obj.matrix_world = mathutils.Matrix.Translation(spawn_point) spawned_nugget_objects.append(obj) Next we randomly jitter the size of the nugget, the mesh of the nugget, and the scale of the nugget so that no two nuggets look the same.\n# Segment the newly spawned nugget as an instance zpy.objects.segment(obj) # Jitter final pose of the nugget a little zpy.objects.jitter( obj, rotate_range=( (0.0, 0.0), (0.0, 0.0), (-math.pi * 2, math.pi * 2), ), ) if jitter_nugget_scale: # Jitter the scale of each nugget zpy.objects.jitter( obj, scale_range=( (0.8, 2.0), #1.2 (0.8, 2.0), #1.2 (0.8, 2.0), #1.2 ), ) if jitter_mesh: # Jitter (deform) the mesh of each nugget zpy.objects.jitter_mesh( obj=obj, scale=( random.uniform(0.01, 0.03), random.uniform(0.01, 0.03), random.uniform(0.01, 0.03), ), ) if jitter_nugget_material: # Jitter the material (apperance) of each nugget for i in range(len(obj.material_slots)): obj.material_slots[i].material = random.choice(random_nugget_materials) zpy.material.jitter(obj.material_slots[i].material) We turn our nugget-copy into an Object3D object where we use the BVHTree functionality to see if our Plane intersects or overlaps any face or verticies on our nugget-copy, where if we do find an overlap with the plane, we simply move the nugget upwards on it\u0026rsquo;s z axis.\n# create 3d obj for movement nugget3d = Object3D(obj) # make sure the bottom most part of the nugget is NOT # inside the plane-object  plane_overlap(plane3d, nugget3d) objects3d.append(nugget3d) Now that all nuggets are created, we use our DensityController to move nuggets around so that we have a minimum number of overlaps, and those that do overlap aren\u0026rsquo;t hideous looking.\n# ensure objects aren\u0026#39;t on top of each other density_controller.dynamic_movement(objects3d) Here we restore the Camera and Main Axis poses and randomly select how far the camera is to the Plane object\n# Return camera to original position zpy.objects.restore_pose(\u0026#34;Camera\u0026#34;) zpy.objects.restore_pose(\u0026#34;Main Axis\u0026#34;) zpy.objects.restore_pose(\u0026#34;Camera\u0026#34;) zpy.objects.restore_pose(\u0026#34;Main Axis\u0026#34;) # assert these are the correct versions... assert(bpy.data.objects[\u0026#34;Camera\u0026#34;].location == Vector((0,0,100))) assert(bpy.data.objects[\u0026#34;Main Axis\u0026#34;].location == Vector((0,0,0))) assert(bpy.data.objects[\u0026#34;Main Axis\u0026#34;].rotation_euler == Euler((0,0,0))) # alter the Z ditance with the camera bpy.data.objects[\u0026#34;Camera\u0026#34;].location = (0, 0, random.uniform(0.75, 3.5)*100) We decide how randomly we want the camera to travel along the Main Axis; if do we want it to be mainly overhead or do we care very much about the angle it sees the board, we can adjust the top_down_mostly parameter depending on how well our training model is picking up the signal of what even is a nugget anyway?\n# alter the main-axis beta/gamma params top_down_mostly = False if top_down_mostly: zpy.objects.rotate( bpy.data.objects[\u0026#34;Main Axis\u0026#34;], rotation=( random.uniform(0.05, 0.05), random.uniform(0.05, 0.05), random.uniform(0.05, 0.05), ), ) else: zpy.objects.rotate( bpy.data.objects[\u0026#34;Main Axis\u0026#34;], rotation=( random.uniform(-1., 1.), random.uniform(-1., 1.), random.uniform(-1., 1.), ), ) print(bpy.data.objects[\u0026#34;Main Axis\u0026#34;].rotation_euler) print(bpy.data.objects[\u0026#34;Camera\u0026#34;].location) Here we do the same thing with the Sun object as well as randomly picking a texture for the Plane object.\n# change the background material # Randomize texture of shelf, floors and walls for obj in bpy.data.collections[\u0026#34;BACKGROUND\u0026#34;].all_objects: for i in range(len(obj.material_slots)): # TODO # Pick one of the random materials obj.material_slots[i].material = random.choice(random_materials) if jitter_material: zpy.material.jitter(obj.material_slots[i].material) # Sets the material relative to the object obj.material_slots[i].link = \u0026#34;OBJECT\u0026#34; # Pick a random hdri (from the local textures folder for background background) zpy.hdris.random_hdri() # Return light to original position zpy.objects.restore_pose(\u0026#34;Sun\u0026#34;) # Jitter the light position zpy.objects.jitter( \u0026#34;Sun\u0026#34;, translate_range=( (-5, 5), (-5, 5), (-5, 5), ), ) bpy.data.objects[\u0026#34;Sun\u0026#34;].data.energy = random.uniform(0.5, 7) Finally, we hide all our objects that we don\u0026rsquo;t want to be rendered: the nugget_base and our entire cube structure!\n# we hide the cube objects for obj in bpy.data.objects: if \u0026#39;cube\u0026#39; in obj.name: obj.hide_render = True try: zpy.objects.toggle_hidden(obj, hidden=True) except: # deal with this exception here... pass # we hide our base nugget object bpy.data.objects[\u0026#34;nugget_base\u0026#34;].hide_render = True zpy.objects.toggle_hidden(bpy.data.objects[\u0026#34;nugget_base\u0026#34;], hidden=True) Lastly, we use zpy to render our scene, save our images, and then save our annotations! (I made some small changes to the zpy annotation library for my specific use-case (annotion per image instead of one file per project), but you shouldn\u0026rsquo;t have to for the purporse of this blogpost)\n# create the image name image_uuid = str(uuid.uuid4()) # Name for each of the output images rgb_image_name = format_image_string(image_uuid, \u0026#39;rgb\u0026#39;) iseg_image_name = format_image_string(image_uuid, \u0026#39;iseg\u0026#39;) depth_image_name = format_image_string(image_uuid, \u0026#39;depth\u0026#39;) zpy.render.render( rgb_path=saver.output_dir / rgb_image_name, iseg_path=saver.output_dir / iseg_image_name, depth_path=saver.output_dir / depth_image_name, ) # Add images to saver saver.add_image( name=rgb_image_name, style=\u0026#34;default\u0026#34;, output_path=saver.output_dir / rgb_image_name, frame=step_idx, ) saver.add_image( name=iseg_image_name, style=\u0026#34;segmentation\u0026#34;, output_path=saver.output_dir / iseg_image_name, frame=step_idx, ) saver.add_image( name=depth_image_name, style=\u0026#34;depth\u0026#34;, output_path=saver.output_dir / depth_image_name, frame=step_idx, ) # ideally in this thread, we\u0026#39;ll open the anno file # and write to it directly, saving it after each generation for obj in spawned_nugget_objects: # Add annotation to segmentation image saver.add_annotation( image=rgb_image_name, category=\u0026#34;nugget\u0026#34;, seg_image=iseg_image_name, seg_color=tuple(obj.seg.instance_color), ) # Delete the spawned nuggets zpy.objects.empty_collection(bpy.data.collections[\u0026#34;SPAWNED\u0026#34;]) # Write out annotations saver.output_annotated_images() saver.output_meta_analysis() # # ZUMO Annotations _output_zumo = _OutputZUMO(saver=saver, annotation_filename = Path(image_uuid + \u0026#34;.zumo.json\u0026#34;)) _output_zumo.output_annotations() # change the name here.. saver.output_annotated_images() saver.output_meta_analysis() # remove the memory of the annotation to free RAM saver.annotations = [] saver.images = {} saver.image_name_to_id = {} saver.seg_annotations_color_to_id = {} log.info(\u0026#34;Simulation complete.\u0026#34;) if __name__ == \u0026#34;__main__\u0026#34;: # Set the logger levels zpy.logging.set_log_levels(\u0026#34;info\u0026#34;) # Parse the gin-config text block # hack to read a specific gin config parse_config_from_file(\u0026#39;nugget_config.gin\u0026#39;) # Run the sim run() Voila!!\nRun the Headless Creation Script! Now that we have our saved blender file, our created nugget, and all the supporting information, let\u0026rsquo;s zip our working directory and either scp it to our GPU machine or uploaded it via s3, teleporation, etc.\ntar cvf working_blender_dir.tar.gz working_blender_dir scp -i \u0026#34;your.pem\u0026#34; working_blender_dir.tar.gz ubuntu@EC2-INSTANCE.compute.amazonaws.com:/home/ubuntu/working_blender_dir.tar.gz Log into your ec2 instace and uncompress your working_blender folder\ntar xvf working_blender_dir.tar.gz Now we create our data in all it\u0026rsquo;s glory!!\nblender working_blender_dir/nugget.blend --background --python working_blender_dir/create_synthetic_nuggets.py The script should run for 500 images and the data will be saved in /path/to/working_blender_dir/nugget_data\nData Here is a look at a single annotation created with out dataset\n{ \u0026#34;metadata\u0026#34;: { \u0026#34;description\u0026#34;: \u0026#34;3D data of a nugget!\u0026#34;, \u0026#34;contributor\u0026#34;: \u0026#34;Matt Krzus\u0026#34;, \u0026#34;url\u0026#34;: \u0026#34;www.mattkrzus.com\u0026#34;, \u0026#34;year\u0026#34;: \u0026#34;2021\u0026#34;, \u0026#34;date_created\u0026#34;: \u0026#34;20210924_000000\u0026#34;, \u0026#34;save_path\u0026#34;: \u0026#34;/home/ubuntu/working_blender_dir/nugget_data\u0026#34; }, \u0026#34;categories\u0026#34;: { \u0026#34;0\u0026#34;: { \u0026#34;name\u0026#34;: \u0026#34;nugget\u0026#34;, \u0026#34;supercategories\u0026#34;: [], \u0026#34;subcategories\u0026#34;: [], \u0026#34;color\u0026#34;: [ 0.0, 0.0, 0.0 ], \u0026#34;count\u0026#34;: 6700, \u0026#34;subcategory_count\u0026#34;: [], \u0026#34;id\u0026#34;: 0 } }, \u0026#34;images\u0026#34;: { \u0026#34;0\u0026#34;: { \u0026#34;name\u0026#34;: \u0026#34;a0bb1fd3-c2ec-403c-aacf-07e0c07f4fdd.rgb.png\u0026#34;, \u0026#34;style\u0026#34;: \u0026#34;default\u0026#34;, \u0026#34;output_path\u0026#34;: \u0026#34;/home/ubuntu/working_blender_dir/nugget_data/a0bb1fd3-c2ec-403c-aacf-07e0c07f4fdd.rgb.png\u0026#34;, \u0026#34;relative_path\u0026#34;: \u0026#34;a0bb1fd3-c2ec-403c-aacf-07e0c07f4fdd.rgb.png\u0026#34;, \u0026#34;frame\u0026#34;: 97, \u0026#34;width\u0026#34;: 640, \u0026#34;height\u0026#34;: 480, \u0026#34;id\u0026#34;: 0 }, \u0026#34;1\u0026#34;: { \u0026#34;name\u0026#34;: \u0026#34;a0bb1fd3-c2ec-403c-aacf-07e0c07f4fdd.iseg.png\u0026#34;, \u0026#34;style\u0026#34;: \u0026#34;segmentation\u0026#34;, \u0026#34;output_path\u0026#34;: \u0026#34;/home/ubuntu/working_blender_dir/nugget_data/a0bb1fd3-c2ec-403c-aacf-07e0c07f4fdd.iseg.png\u0026#34;, \u0026#34;relative_path\u0026#34;: \u0026#34;a0bb1fd3-c2ec-403c-aacf-07e0c07f4fdd.iseg.png\u0026#34;, \u0026#34;frame\u0026#34;: 97, \u0026#34;width\u0026#34;: 640, \u0026#34;height\u0026#34;: 480, \u0026#34;id\u0026#34;: 1 }, \u0026#34;2\u0026#34;: { \u0026#34;name\u0026#34;: \u0026#34;a0bb1fd3-c2ec-403c-aacf-07e0c07f4fdd.depth.png\u0026#34;, \u0026#34;style\u0026#34;: \u0026#34;depth\u0026#34;, \u0026#34;output_path\u0026#34;: \u0026#34;/home/ubuntu/working_blender_dir/nugget_data/a0bb1fd3-c2ec-403c-aacf-07e0c07f4fdd.depth.png\u0026#34;, \u0026#34;relative_path\u0026#34;: \u0026#34;a0bb1fd3-c2ec-403c-aacf-07e0c07f4fdd.depth.png\u0026#34;, \u0026#34;frame\u0026#34;: 97, \u0026#34;width\u0026#34;: 640, \u0026#34;height\u0026#34;: 480, \u0026#34;id\u0026#34;: 2 } }, \u0026#34;annotations\u0026#34;: [ { \u0026#34;image_id\u0026#34;: 0, \u0026#34;category_id\u0026#34;: 0, \u0026#34;id\u0026#34;: 0, \u0026#34;seg_color\u0026#34;: [ 1.0, 0.6000000238418579, 0.9333333373069763 ], \u0026#34;color\u0026#34;: [ 1.0, 0.6, 0.9333333333333333 ], \u0026#34;segmentation\u0026#34;: [ [ 299.0, 308.99, 292.0, 308.99, 283.01, 301.0, 286.01, 297.0, 285.01, 294.0, 288.01, 285.0, 283.01, 275.0, 287.0, 271.01, 294.0, 271.01, 302.99, 280.0, 305.99, 286.0, 305.99, 303.0, 302.0, 307.99, 299.0, 308.99 ] ], \u0026#34;bbox\u0026#34;: [ 283.01, 271.01, 22.980000000000018, 37.98000000000002 ], \u0026#34;area\u0026#34;: 667.0802000000008, \u0026#34;bboxes\u0026#34;: [ [ 283.01, 271.01, 22.980000000000018, 37.98000000000002 ] ], \u0026#34;areas\u0026#34;: [ 667.0802000000008 ] }, { \u0026#34;image_id\u0026#34;: 0, \u0026#34;category_id\u0026#34;: 0, \u0026#34;id\u0026#34;: 1, \u0026#34;seg_color\u0026#34;: [ 1.0, 0.4000000059604645, 1.0 ], \u0026#34;color\u0026#34;: [ 1.0, 0.4, 1.0 ], \u0026#34;segmentation\u0026#34;: [ [ 241.0, 273.99, 236.0, 271.99, 234.0, 273.99, 230.01, 270.0, 232.01, 268.0, 231.01, 263.0, 233.01, 261.0, 229.0, 257.99, 225.0, 257.99, 223.01, 255.0, 225.01, 253.0, 227.01, 246.0, 235.0, 239.01, 238.0, 239.01, 240.0, 237.01, 247.0, 237.01, 252.99, 245.0, 253.99, 252.0, 246.99, 269.0, 241.0, 273.99 ] ], \u0026#34;bbox\u0026#34;: [ 223.01, 237.01, 30.980000000000018, 36.98000000000002 ], \u0026#34;area\u0026#34;: 743.5502000000008, \u0026#34;bboxes\u0026#34;: [ [ 223.01, 237.01, 30.980000000000018, 36.98000000000002 ] ], \u0026#34;areas\u0026#34;: [ 743.5502000000008 ] }, ... ... ...  Conclusion In this post, I demonstrated how to use the open-source animation library Blender to build an end-to-end synthetic data pipeline.\nThere are a ton of cool things you can do in Blender and AWS; hopefully this demo can help you on your next data-starved project!\nReferences  Easily Clean Your 3D Scans (blender) Instant Meshes: A free quad-based autoretopology program How to 3D Scan an Object for Synthetic Data Generate synthetic data with Blender and Python Google\u0026rsquo;s Kubric  ","date":"2023-04-29","permalink":"/posts/synthetic/","title":"Generate Synthetic Data for Computer Vision Projects in Blender"},{"content":" Smooth Ambler Old Scout 10 Year Single Barrel I\u0026rsquo;ve been going through a lot of my old bottles lately and giving them a taste after not trying it at all for a long time. This one I remember picking up a Binny\u0026rsquo;s on a whim for like 50 bucks because they were out of the particular Kilchomen I wanted to buy, but I suppose by today\u0026rsquo;s standards this bottle would cost 900 dollars. I really like the simplicity of this, it reminds a lot me of Wild Turkey in that it\u0026rsquo;s very basic and not trying to lift more weight than a bourbon is supposed to. In today\u0026rsquo;s worrld, wouldn\u0026rsquo;t spend more than 70 bucks (in today\u0026rsquo;s dollars) on it. Dig it for sure, though.\nBottle Info Distillery: Smooth Ambler\nBarrel type: bourbon\nAge: 10\nABV: 58.8\nHow this was acquired: store\nCost: No clue\nYear purchased: 2015\nTasting Notes Nose: butterscotch, metalic, graham cracker, typical bourbon notes: vanilla, oak, caramel, chocolate, yada yada. Pretty standard nose, only uniqueness is in the graham cracker note.\nTaste: butterscotchy, rye, cinnamon, floral/flowery, typical bourbon notes: vanilla, oak, caramel, etc. The main identifying feature is the floral flowery note. Almost lavander. Very nice.\nFinish: flowers, graham, butterscotch. Unique floral note, but otherwise very standard MGP/bourbon finish.\nBinary Score: 1\nValue Notes Would I buy this (or if already bought, would you buy this again): Yes\nWould I order this in a bar: Yes\nWould I drink this if someone gave me a glass: Yes\nValue For Money: 4/5\n My Scoring System 0 | I didn\u0026rsquo;t like it | It doesn\u0026rsquo;t mean you won\u0026rsquo;t\n1 | I liked it | It doesn\u0026rsquo;t mean you will\n","date":"2020-08-02","permalink":"/posts/smooth_ambler_ten/","title":"Smooth Ambler Old Scout 10 Year Single Barrel"},{"content":"================\nWhat I Think is Wrong with Whiskey Reviews The whiskey review system is broken For reason\u0026rsquo;s I\u0026rsquo;m going to argue below, I think the current 1-100 whiskey review system is absolute garbage. It\u0026rsquo;s weak, uninformed, and if it were a person - it would be the kind of asshole thinks Old Weller Antique is a \u0026lsquo;unicorn\u0026rsquo;.\nWith insults out of the way, let\u0026rsquo;s dive right in and take a look at what an \u0026lsquo;average\u0026rsquo; review looks like on reddit. In order to do so, we\u0026rsquo;ll use a technique called Bayesian Bootstrapping to resample our review data many times in such a way to provide an estimate of where the \u0026lsquo;true\u0026rsquo; mean of the data may lie. This technique, in typical bayesian fashion, yields a distribution rather than a point estimate and doing so allows us to place our confidence on a range of numbers rather than a singular value. What are the benefits of this? Well, it allows us to hedge our \u0026lsquo;wrongness\u0026rsquo; or \u0026lsquo;rightness\u0026rsquo; (depending on what your outlook is) on a set of numbers rather than placing the whole kit and caboodle on one value. It allows us to view the world in a probalistic sense, meaning that there\u0026rsquo;s a high probability that our true mean lies within the distribution we\u0026rsquo;ve created. Using this technique, we can we can say that the true mean review score lies somewhere between 83.77 and 84.09.\nimport matplotlib.pyplot as plt import seaborn as sns import bayesian_bootstrap.bootstrap as bb from astropy.utils import NumpyRNGContext with NumpyRNGContext(42): review_scores = df.scores.values review_scores_means = bb.mean(review_scores, n_replications=10000) ci_low, ci_hi = bb.highest_density_interval(review_scores_means) print(r,'\\t', 'low ci:', ci_low, ' high ci:', ci_hi) ax = sns.distplot(review_scores_means, label = \u0026quot;All bootstrap data\u0026quot;) ax.plot([ci_low, ci_hi], [0, 0], linewidth=10, c='k', marker='o', label='95% HDI') ax.set(ylabel='Density', xlabel=\u0026quot;1 - 100 Rating\u0026quot;, title=\u0026quot;Bootstrap Mean Review Score\u0026quot;) sns.despine() plt.legend() plt.show()  This distribution for the \u0026lsquo;average review\u0026rsquo; leads me to a couple of thoughts about the system as a whole: why is the mean so high? Even for the people who use this system like a school grading system: (A: 100-90, B: 89-80, C: 79-70, D: 69-60, F: 59-0), shouldn\u0026rsquo;t the average be a C? And even if someone didn\u0026rsquo;t use the system like that and instead utilized every number equally, then wouldn\u0026rsquo;t we see the average near the median at 50? Maybe people are self constraining their reviews to only good whiskies? Maybe people just straight up suck at reviewing as well (I\u0026rsquo;m looking at you Mr. Anything Weller Tastes Like Jesus Guy). In order to break that down a bit further, we\u0026rsquo;ll need explore the weirdness of this dataset.\nFor starters, out of 14172 reviews, I only found 196 reviews below 50 points. That seems so odd to me. That means something like, you more or less have a 1.38% chance of finding an absolute dog shit whiskey. Surely, that\u0026rsquo;s impossible, right? On the other hand, ask yourself, have you ever had a whiskey that was so fucking terrible you gave it a sub 50 review? I\u0026rsquo;m sure everyone mentally can think of \u0026lsquo;the worst whiskey\u0026rsquo; they\u0026rsquo;ve had, but how bad was it really? How does one even rank it that low? \u0026ldquo;Ah fuck it, I\u0026rsquo;ll give it a \u0026lsquo;30\u0026rdquo;? Either way, to have 50% of the available review points only ever utilized by 1.38% of the reviews is ridiculous in any scoring system. Even when we analyze the 196 more thoroughly using the bayesian bootsrap, we can see it\u0026rsquo;s still a bit strange: most of the mean\u0026rsquo;s mass is between 31 and 34.\nIn [78]: scores Out[78]: array([50.0, 70.0, 91.0, ..., 83.0, 84.0, 74.0], dtype=object) In [79]: scores.shape[0] Out[79]: 14172 In [80]: scores[scores \u0026lt; 50].shape[0] Out[80]: 196  Think about this for a second: in your personal review system, you might be able to distinguish between an 89 and a 90, but can you distinguish between a 29 and a 30? Oh can you? Go fuck yourself.\nTop (rather, people with a lot of reviews) Reviewers Ok cool. So we\u0026rsquo;ve established that 50% of the scale is worthless. What about the other 50% - how much utility does the remaining 50% actually have? To start, let\u0026rsquo;s try to take a look at the top (eg, frequent) reviewers\u0026rsquo; distirbutional spreads are in comparison with each other. The thought here being that eventually we should start to see mean distributions converge to similar areas of the scale given the amount the individuals contribute. (Note, these reviewers were taken from the reddit Whiskey Review Archive)\nJust for the purpose of argument, I\u0026rsquo;ve used these same experienced reviewers and scaled their scores between 0-1 for uniformity. It looks better, but still so much of each distribution isn\u0026rsquo;t as close to 0.5 as I\u0026rsquo;d personally like to see (which on one hand, we can postulate they aren\u0026rsquo;t reviewing trash whiskey and thus skewing their own scores - but still higher than I\u0026rsquo;d like to see). There will always be a difference of opinion on a given whiskey - some will like it - some won\u0026rsquo;t - but with the amount these individual\u0026rsquo;s review - and the fact that the whiskies they\u0026rsquo;re reviewing are from the quality pool - we should still see similar distiributions around a given mean score. If this was a more uniform system, when we apply the same bootstraping technique to the top reviewer\u0026rsquo;s reviews, we should expect to see a similar distribution (albiet very different at the whiskey level), however, this just isn\u0026rsquo;t the case.\nSo what\u0026rsquo;s the point? These frequent reviewers all review differently. You and I review differently. We all understand how skewed the system currently is. The only thing left to dive into is the the part people hardly pay attention to - the actual fucking review, baby.\nWords, words, words So the scoring is bunk, got it. But what about the review itself? What can that tell us about the system as a whole? First, let\u0026rsquo;s analyze the most basic relationship: the length of review and the sentiment/score of a given whiskey. Here\u0026rsquo;s the TL;DR: people use more words when they like the whiskey, less when they dislike it, and when it\u0026rsquo;s \u0026lsquo;meh\u0026rsquo; - it\u0026rsquo;s right there in the middle. Go figure? Let\u0026rsquo;s take the above analysis one step further and inspect the relationship between a rating, the sentiment (the positivity or negativity of the words used), and the length of the review. If we suppose that an average review is anywhere between ~83 and ~84 (the 95% CI), any \u0026lsquo;bad\u0026rsquo; review is below 83, and any review above 84 is \u0026lsquo;good\u0026rsquo;, we can start to model these reviews using some newer language modeling techniques. We\u0026rsquo;ll start by breaking out the data into three classes: good, bad, average (all based on score, then we\u0026rsquo;ll take a pretrained model (this one specifically: Convolutional Neural Networks for Sentence Classification) that was trained on the IMDB movie review dataset (meaning it was trained to learn the sentiment of a movie review on the classic IMDB movie review dataset) and then we\u0026rsquo;ll use the model to produce the sentiment polarity (0-1) of \u0026lsquo;good\u0026rsquo;, \u0026lsquo;average\u0026rsquo;, and \u0026lsquo;bad\u0026rsquo; whiskey reviews: It gives us some quick insight, but it\u0026rsquo;s not exactly what I\u0026rsquo;m looking for. We\u0026rsquo;re seeing too much similarity between the positive and negative reviews. The polarity between negative and positive is obviously different, but it\u0026rsquo;s not very much. To me, that signals that perhaps the IMDB model isn\u0026rsquo;t built for the kind of language structure found in reddit whiskey reviews.\nIn order to make this model more robust, I\u0026rsquo;m going to force a new model to learn the odd language structure used in whiskey reviews: As a refresher for those outside of whiskey check out my previous blog: The TL;DR: whiskey reviews often contain strings of words - mostly used in a list (taste: word1, word2, word3, etc. palate: word1, word2, etc.) - together to describe the finish, nose, and palate of a given whiskey - where, in the context of a given whiskey review, can mean negative or positive things that are much different than a given word\u0026rsquo;s usage in general language. For example, \u0026lsquo;young\u0026rsquo; as a word, used in everyday language can refer to energy, enthusiam, and happiness (all positive sentiment), but in the context of whiskey can mean that it tastes like gasoline, paint thinner, or bad college memories (all negative sentiment). These thinly veiled differences in language mean that models built for general analysis won\u0026rsquo;t have the contextual ability to see the review the way a whiskey fan might see the review.\nSo let\u0026rsquo;s build and train a model using only the corpus from the whiskey reviews. First, we\u0026rsquo;ll separate the reviews into three categories: negative reviews (scores under 83), average reviews (scores 83-84), and positive reviews (scores over 84). We\u0026rsquo;ll assign all negative reviews a 0 and all positive reviews a 1, and we\u0026rsquo;ll leave the average reviews as our evaluation set (meaning the model won\u0026rsquo;t touch this data at all). We\u0026rsquo;ll set aside 30% of all positive and negative reviews and use this held out set to validate our model\u0026rsquo;s performance (eg, not use that set to update our model\u0026rsquo;s beliefs about the world - but instead use that data to check and see how well the model is performing - to inform us if we\u0026rsquo;re training in the correct direction or not). Psuedo-ish code below:\nfrom numpy import array from keras.preprocessing.text import Tokenizer from keras.preprocessing.sequence import pad_sequences from keras.utils.vis_utils import plot_model from keras.models import Model from keras.layers import Input from keras.layers import Dense from keras.layers import Flatten from keras.layers import Dropout from keras.layers import Embedding from keras.layers.convolutional import Conv1D from keras.layers.convolutional import MaxPooling1D from keras.layers.merge import concatenate from sklearn.model_selection import train_test_split # encode a list of lines def encode_text(text, maxlen): # integer encode doc = nlp(text) vals = tokenizeText(\u0026quot; \u0026quot;.join(clean(doc.text))) seq = [word_idx[x] for x in vals] # pad encoded sequences padded = pad_sequences([seq], maxlen=maxlen, padding='post') return padded # define the model def Model(length, vocab_size): # channel 1 inputs1 = Input(shape=(length,)) embedding1 = Embedding(vocab_size, 100)(inputs1) conv1 = Conv1D(filters=32, kernel_size=4, activation='relu')(embedding1) drop1 = Dropout(0.5)(conv1) pool1 = MaxPooling1D(pool_size=2)(drop1) flat1 = Flatten()(pool1) # channel 2 inputs2 = Input(shape=(length,)) embedding2 = Embedding(vocab_size, 100)(inputs2) conv2 = Conv1D(filters=32, kernel_size=6, activation='relu')(embedding2) drop2 = Dropout(0.5)(conv2) pool2 = MaxPooling1D(pool_size=2)(drop2) flat2 = Flatten()(pool2) # channel 3 inputs3 = Input(shape=(length,)) embedding3 = Embedding(vocab_size, 100)(inputs3) conv3 = Conv1D(filters=32, kernel_size=8, activation='relu')(embedding3) drop3 = Dropout(0.5)(conv3) pool3 = MaxPooling1D(pool_size=2)(drop3) flat3 = Flatten()(pool3) # merge merged = concatenate([flat1, flat2, flat3]) # interpretation dense1 = Dense(10, activation='relu')(merged) outputs = Dense(1, activation='sigmoid')(dense1) model = Model(inputs=[inputs1, inputs2, inputs3], outputs=outputs) # compile model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy']) # summarize print(model.summary()) return model vocab_size = 6000 max_review_length=500 train_neg = [] train_neg_labels = [] train_pos = [] train_pos_labels = [] avg = [] for idx in avg_ixs: doc = nlp(df.iloc[idx[0]].review) vals = tokenizeText(\u0026quot; \u0026quot;.join(clean(doc.text))) seq = [word_idx[x] for x in vals] seq = np.array(seq) seq[seq\u0026gt;=vocab_size]=0 avg.append(seq) for idx in bad_ixs: doc = nlp(df.iloc[idx[0]].review) vals = tokenizeText(\u0026quot; \u0026quot;.join(clean(doc.text))) seq = [word_idx[x] for x in vals] seq = np.array(seq) seq[seq\u0026gt;=vocab_size]=0 train_neg.append(seq) train_neg_labels.append(0.0) for idx in like_ixs: doc = nlp(df.iloc[idx[0]].review) vals = tokenizeText(\u0026quot; \u0026quot;.join(clean(doc.text))) seq = [word_idx[x] for x in vals] seq = np.array(seq) seq[seq\u0026gt;=vocab_size]=0 train_pos.append(seq) train_pos_labels.append(1.0) x_train = train_neg + train_pos y_train = train_neg_labels + train_pos_labels x_train = pad_sequences(x_train, maxlen=max_review_length, padding='post') train_X, test_X, train_y, test_y = train_test_split(x_train, y_train) model = define_model(max_review_length, vocab_size) # fit model model.fit([train_X,train_X,train_X], np.array(train_y), validation_data = ([test_X,test_X,test_X], test_y), epochs=10, batch_size=16, shuffle=True) # save the model \u0026quot;\u0026quot;\u0026quot; Train on 10058 samples, validate on 3353 samples Epoch 1/10 10058/10058 [==============================] - 89s 9ms/step - loss: 0.4426 - acc: 0.8152 - val_loss: 0.3112 - val_acc: 0.9344 Epoch 2/10 10058/10058 [==============================] - 86s 9ms/step - loss: 0.1613 - acc: 0.9550 - val_loss: 0.1131 - val_acc: 0.9588 Epoch 3/10 10058/10058 [==============================] - 86s 9ms/step - loss: 0.0482 - acc: 0.9805 - val_loss: 0.1400 - val_acc: 0.9508 Epoch 4/10 10058/10058 [==============================] - 86s 9ms/step - loss: 0.0275 - acc: 0.9876 - val_loss: 0.2266 - val_acc: 0.9168 Epoch 5/10 10058/10058 [==============================] - 84s 8ms/step - loss: 0.0203 - acc: 0.9906 - val_loss: 0.1317 - val_acc: 0.9627 Epoch 6/10 10058/10058 [==============================] - 86s 9ms/step - loss: 0.0180 - acc: 0.9925 - val_loss: 0.1632 - val_acc: 0.9603 ... \u0026quot;\u0026quot;\u0026quot; sentiment = model.predict([test_X, test_X, test_X]) neg_sentiment = sentiment[test_y\u0026lt;1.0].flatten() pos_sentiment = sentiment[test_y\u0026gt;0.0].flatten() avg_sentiment = [] for idx in range(len(avg)): seq = avg[idx] seq = np.array(seq) seq[seq\u0026gt;=vocab_size] = 0 seq_padded = pad_sequences([seq], maxlen=max_review_length, padding='post') out = model.predict([seq_padded, seq_padded, seq_padded])[0][0] avg_sen with NumpyRNGContext(42): bad_sent_means = bb.mean(neg_sentiment, n_replications=10000) avg_sent_means = bb.mean(avg_sentiment, n_replications=10000) pos_sent_means = bb.mean(pos_sentiment, n_replications=10000) # ci_low, ci_hi = bb.highest_density_interval(m) # print(r,'\\t', 'low ci:', ci_low, ' high ci:', ci_hi) # ms.append([m, ci_low, ci_hi]) ax = sns.distplot(bad_sent_means, label = \u0026quot;Negative Reviews\u0026quot;) ax = sns.distplot(avg_sent_means, ax=ax, label = \u0026quot;Average Reviews\u0026quot;) ax = sns.distplot(pos_sent_means, ax=ax, label = \u0026quot;Positive Reviews\u0026quot;) #ax.plot([ci_low, ci_hi], [0, 0], linewidth=10, c='k', marker='o', # label='95% HDI') ax.set(ylabel='Density', xlabel=\u0026quot;Sentiment Polarity\u0026quot;, title=\u0026quot;Bayes' Bootstrap for Mean Sentiment Polarity\u0026quot;.format(r)) sns.despine() plt.legend() plt.show()`   Not only does the model have 96% accuracy on our validation set, it also puts our held out \u0026lsquo;average\u0026rsquo; reviews right smack dab in the middle of sentiment - confirming a bit that our mean value choice was a fairly good one and that the language we\u0026rsquo;re using to review a given whiskey already does most of the heavy lifting! So what does this mean for scores? Well here\u0026rsquo;s the thing, if you were to leave out the score completely, the review text alone would indicate with ~96% accuracy the 1-100 review score. Basically, the review\u0026rsquo;s text (choice of words) is more or less equivalent to the score.\nA (kinda) New System Here\u0026rsquo;s my big whopper of an idea: a text review plus a binary thumbs up or thumbs down. Plain and simple. We\u0026rsquo;re already writing reviews anyway, and because we\u0026rsquo;ve already established that the review itself is indicative of the \u0026lsquo;score\u0026rsquo;, just let the word choice do all the hard work and add a little blurb about whether or not you liked it. A simple 0 or\n It\u0026rsquo;s going to be a bit harder for some people. You\u0026rsquo;re actually going to have to have an opinion and not hide your dislike behind an abomination of a score, but I believe in you. And if you\u0026rsquo;re not sure, just fucking take a couple extra days force yourself to make a tough decision instead of leaving any ambiguity to the reader.  Think about it here for a second, switching to binary review systems is not some grand original idea. In fact, major companies like StichFix and Neflix have switched their entire rating systems to a binary system much in part because of the issues and problems expressed here. Considering that I work in research, one of the fundemental tenents is piggybacking as much as you can off the hardwork of others. Especially off incredibly massive research teams who spend their entire day deciding on the pros and cons of systems such as binary ones.\nLastly, I can\u0026rsquo;t prove it, but I think removing the current 1-100 scores would force reviewers to think more critically about a given whiskey. And perhaps enforcing a culture of critical thinking might have stopped the madenss that is today\u0026rsquo;s bourbon market - even though that\u0026rsquo;s a quite a stretch - it sure would be nice. Can\u0026rsquo;t you just imagine some asshole combing through recent reviews trying to find \u0026lsquo;the best whiskey to buy and hoard\u0026rsquo; and being thwarted simply because he has to read? It\u0026rsquo;s my wet dream.\nI have a problem with your (not) new idea Question: That\u0026rsquo;s great and all, Matt. But I want to rank my whiskies, how am I supposed to do that with this scoring system?\nAnswer: A score isn\u0026rsquo;t necessarily a rank. When scoring, we have a habit of pulling in all this external information to attempt to represent a \u0026lsquo;did I enjoy this\u0026rsquo; score and it\u0026rsquo;s \u0026lsquo;how does it compare to everything else\u0026rsquo; ranking in one scaler value. I think it\u0026rsquo;s important to concentrate on the moment and give a whiskey an individual score; forget about that random Brora you had, you\u0026rsquo;re not there, you\u0026rsquo;re here in the moment - concentrate on that and if must rank something from best to worst - then the reality is nothing is stopping you. But trying to cram multidimensional meaning into one value is where we start to have odd problems that I\u0026rsquo;ve tried to outline above.\n\nQuestion: Now that Pappiz and Wild Turkey 101 share the same rating, how am I supposed to know which new hot whiskey to spend my money on and/or even which whiskey is the best?\nAnswer: Welp, no one could ever answer that question for you in the first place without you sampling them for yourself. And as for which whiskey to spend your money on, why don\u0026rsquo;t you do what you currently do and read the reviews, brosef stalin. Or check out something like this tool I created: http://wrec.herokuapp.com/ to help you figure out the components of given whiskies that you might enjoy.\n\nQuestion: How am I supposed to brag if all my instagram pictures are just \u0026lsquo;ones\u0026rsquo; now?\nAnswer: You\u0026rsquo;ll figure it out, you sexy influencer, you.\n\nQuestion: So nothing will change?\nAnswer: Probably not. But maybe the money I just dumped into MGPI might do better, hopefully.\n\nQuestion: So seriously, why even write this article?\nAnswer: Maybe, just maybe, we can make it harder for Jefferson\u0026rsquo;s to figure out what barrel of trash to put their whiskey in next.\n\n","date":"2019-09-17","permalink":"/posts/whiskey_reviews/","title":"What I Think is Wrong with Whiskey Reviews"},{"content":"Overview There exists a number of outside forces over the recent years that have caused whiskey\u0026rsquo;s rise in popularity. Be it from advertising, dwindling stocks, demand, the secondary market, or pure magic, something has pushed prices higher and higher. Dwindling stocks are not public information, and the information that does exist isn\u0026rsquo;t quantifiable. But we can still get a decent look at the direction bourbon is heading based on the public signals available to us. In this post, I analyze the secondary market and the influence that time has had on bourbon\u0026rsquo;s price. In order to get the ball rolling, I took a page out of one of my favorite analytic blog posts of all time Why Can\u0026rsquo;t Canada Win The Stanley Cup and made heavy use out of the Google trends API to study the effect of demand on bourbon. Advertising spend was also fairly easy to collect as many spirits conglomerates are public companies that have to report their spend in various financial categories; I used Diageo\u0026rsquo;s advertising spend as a proxy for what I believe other conglomerates are spending. Lastly, in order to study the secondary market, I scraped data from a few online auction houses and piled it all together into a tidy dataset. I collected the data from all whiskey varieties (American varieties, Scotch, Japanese, and World) and then used the bourbon/rye subset of that data to make this post. The bourbon/rye data consists of final bids on unique bourbons and ryes spreading across 12 years worth of history (2005-2017). The highest bourbon bid in the dataset was ~$8,779 for a 24-year-old Bitter Truth Rye whiskey and the lowest was for US export Sheffer Fort Bourbon at a whopping ~6 dollars.\nAdvertising First things first, it\u0026rsquo;s important to note that the big spirit companies have spent a lot on marketing over the past 10 years in an attempt to take market share from the beer categories (their number one market competitor, obviously). Note, this includes everything - all advertisements in their portfolio - from the vodka commercial to the paid \u0026lsquo;whiskey review\u0026rsquo;. Here\u0026rsquo;s Diageo, for example. They\u0026rsquo;re spending a ton. And frankly, as you\u0026rsquo;ll see in the rest of the post, it\u0026rsquo;s working: Demand Thanks to the Google Trends API we\u0026rsquo;re able to see the direct impact of spirits marketing since 2004. We can quantify the interest in bourbon measured by the relative number of Google searches over from 2004 until present. What this plot shows is that search interest is steadily increasing as more and more people are getting familiar with bourbon. (However, you already knew that. You could feel a disruption in the bourbon force).\nWhat\u0026rsquo;s interesting is Diageo\u0026rsquo;s spend vs. Google\u0026rsquo;s interest. Auctions  This plot represents the USD Mean bid, the USD Median bid, and normalized number of whiskies in the given auction over time. It\u0026rsquo;s obvious that more recently, the mean and median have become farther and farther apart - representing \u0026ldquo;higher valued bourbons\u0026rdquo; pulling the mean up and away from the median.\nI wanted to do just a super basic clustering of prices given the brand of bourbons to see what impact they have on pulling price away from the median. Using the elbow method, we choose 4 clusters and then see what brands cluster together. from sklearn.cluster import KMeans from sklearn import metrics from scipy.spatial.distance import cdist import numpy as np import matplotlib.pyplot as plt from sklearn.cluster import KMeans # create new plot and data X = df_age['usd'].values.reshape(-1, 1) # k means determine k distortions = [] K = range(1,10) for k in K: km = KMeans(n_clusters=k).fit(X) distortions.append(sum(np.min(cdist(X, km.cluster_centers_, 'euclidean'), axis=1)) / X.shape[0]) # Plot the elbow plt.plot(K, distortions, 'bx-') plt.xlabel('k') plt.ylabel('Distortion') plt.title('The Elbow Method showing the optimal k') plt.show() # use 4 clusters km = KMeans(4).fit(X) pred = km.predict(X) bourbon['clus'] = pred smallest_grouping = np.argsort( [bourbon[bourbon['clus'] == i].distillery_cleaned.unique().shape[0] for i in range(0,4)] )[0] print(bourbon[bourbon['clus'] == smallest_grouping].distillery_cleaned.unique()) # output ['Bitter Truth' 'Willett' 'Martin Mills' 'Van Winkle' 'Twisted Spoke' 'Pappy Van Winkle`s' 'Old Rip Van Winkle' 'Black Maple Hill' 'Old Fitzgerald' 'Bourbon Valley']  Bitter Truth, Willett, Martin Mills, Van Winkle, Twisted Spoke, Pappy Van Winkle`s, Old Rip Van Winkle, Black Maple Hill, Old Fitzgerald, and Bourbon Valley. No real surprises there\u0026hellip;\nWhen we plot the normalized Google searches alongside the normalized mean of each auction, what\u0026rsquo;s clear is the interest in bourbon and the price of bourbon have finally converged. More people have gotten into bourbon and the cost to acquire the brown stuff has finally started to reflect the new, higher priced reality. What\u0026rsquo;s kind of cool is that given the discrepancy between the normalized mean bourbon price and Google searches during 2009-2013, one could say that with respect to interest, bourbon was mispriced during this time.\nPrice Plateau Now that the basic analysis is out of the way, here are all the final purchase prices for bourbons in a nice easy log transform. Fitting a fairly naive trend-line gives us a very interesting picture. It looks like the \u0026lsquo;bourbon boom\u0026rsquo; as a whole is plateauing. It\u0026rsquo;s almost as if we\u0026rsquo;ve reached equilibrium in the secondary market. There could be a million reasons for this: perhaps enough people who would happily pay thousands for a whiskey have already tried those expensive whiskies and thought: \u0026lsquo;you know what, once is enough\u0026rsquo;, maybe people realized good bourbon just doesn\u0026rsquo;t have to be expensive, maybe production is ramping back up and supply is finally meeting demand, again, who knows? My two cents on what I think is actually happening: people are trying to cash in on the perceived unicorn demand and are thus trying to sell just about anything. This, in turn, is flooding the secondary with mediocre or self-inflated bourbons/ryes that just aren\u0026rsquo;t considered \u0026lsquo;worth as much\u0026rsquo;, while the higher valued bourbons/ryes (older willetts, etc.) are starting to reach the upper bounds of what people are willing to pay for them. The influx of \u0026lsquo;lower-quality\u0026rsquo; bourbons is bringing down the overall price of even \u0026lsquo;mediocre to good\u0026rsquo; whiskies simply by the availability of the other options. Personally, I think this is a good thing. Flippers are not making the money they once were and are actually undermining their own scarcity-reliant cause. In essence, their strategy is leaning towards quantity and not quality to make money.\nExtra Plots For the sake of showing some interesting plots all done similarly to the previous section - log transformed over time.\nThis is Pappy (only the Pappy Van Winkle lineup (15, 20, 23)) vs. the rest of bourbon/rye over time. It looks like it\u0026rsquo;s plateauing and staying steady.\nThis is all Four Roses bourbons (older 70s Yellow Label all the way to new LEs and everything in between) vs. the rest of bourbon/rye over time. As one of my favorite distilleries, I\u0026rsquo;m not mad at this.\nHirsch bourbons vs. the rest of bourbon/rye over time. Seeing as the older and harder to get stuff just isn\u0026rsquo;t being reproduced, not shocked in the least at it\u0026rsquo;s trajectory.\nHere is Wild Turkey Bourbons and Ryes vs. the rest of bourbon/rye over time. Considering how much I love Wild Turkey, it\u0026rsquo;s nice that even the upper bound of what people are willing to pay for it isn\u0026rsquo;t anywhere close to what people are shelling out for high-valued bourbons.\nWhat I thought might be interesting to take all of whiskey and compare the highly-valued Van Winkle line up to some other big players in the whiskey game, Karuizawa in particular. What\u0026rsquo;s obvious is that despite it\u0026rsquo;s hegemony over bourbon, it doesn\u0026rsquo;t come close in terms of what people are willing to pay for Karuizawa. I mean who cares though, one of those distilleries was on an active volcano and stopped producing hooch in like the 80s and the other is still shelling out juice. In the end it doesn\u0026rsquo;t matter, I just like plotting stuff.\n Here\u0026rsquo;s a hot take: Pappy and Brora are priced similarly in auction markets - which, after having both - is pretty ridiculous (Both \u0026lsquo;good\u0026rsquo; Brora and \u0026lsquo;bad\u0026rsquo; Brora tastes better to me than any Pappy for what it\u0026rsquo;s worth). So here it is: Buy Brora (but save some for me). It\u0026rsquo;s priced well for relative to other hype-trains. Driving Forces on All Whiskies In order to see what the driving force of price is for the secondary market, I modeled the effect of the whiskey\u0026rsquo;s age, the proof, whether it was an original bottling or not (Cadenhead bottling Heaven Hill, for example), year of auction, and the distiller on the final sale price (I tried to model cask type as well, but there just wasn\u0026rsquo;t enough clean information). There are about a billion methods to do what I\u0026rsquo;m going to do, however, I chose a RandomForest because trees are inherently easier to understand for non-data whiskey lovers than more mathematically complex models (http://www.r2d3.us/visual-intro-to-machine-learning-part-1/ link for non-data people). Using a basic RandomForest model, one can easily interpret the importance each variable has on predicting the final sale price. (NOTE: this does not take into account bidding wars, bidding strategies, or number of bids per whiskey - I\u0026rsquo;ll save that for another time)\n # define our RandomForest rf = RandomForestRegressor() # distill (pun intended) our features from the whiskey data and make # brand and auction year categorical features = pd.get_dummies(data) # fit our random forest to the dataset rf.fit(features.values[:,1:], features.values[:,0]) # grab the feature importances from the rf importances = rf.feature_importances_ # sort which ones are the most important [0 -\u0026gt; most important] indices = np.argsort(importances) # take the top 30 most important and make most important at 0 index indices = indices[-30:] # remove pandas string concatenations from data for plotting features_names = [x.replace('distillery_cleaned_','') for x in features.columns[1:].values] # plot title plt.title('Feature Importances') # plot horizontal bar plt.barh(range(len(indices)), importances[indices], color='b', align='center') # plot the importances plt.yticks(range(len(indices)), np.array(features_names)[indices]) # label x axis plt.xlabel('Relative Importance') # plot it plt.show()  Cool, so what does the plot below mean? First, age is the most important variable (duh). And in no particular order are the rest. Alcohol by volume: people will pay more for cask strength and higher proof whiskies. Original bottlings matter: people will pay more for an original bottling compared to that of the same spirit bottled independently (Cadenhead, Signatory, Willett, etc.). Which presents one heck of an opportunity to buy the same distillate at a discounted rate. Next is the year at which the whiskey is sold - I find this interesting; it\u0026rsquo;s basic price bubble capitalization from a supply/demand perspective, and lastly, at no surprise to anyone, is brand power. All of this is pretty much confirming what we already believed: in the context of marketing and demand, 2014 onwards have benefited from increased interest in whiskey where the demand for certain brands has a direct relationship with price. People are wanting higher proof distillate from specific brands and will pay more if it is bottled by the particular distillery.\n Somewhat related, I tried to use barrel-type as a variable, but there just wasn\u0026rsquo;t enough clean information to incorporate into the over-arching model. However, what information that did exist basically says that scotches aged in Pedro Ximenez barrels are more evenly distributed with respect to price than bourbon or oloroso barrels. Considering my love for all things Pedro Ximenez, I found this pretty interesting. So in the end, what does this all mean? Well, basically, Age + ABV + OB\n Brand + Timing = Price. Go figure. The alternate take is that outside the top brands, there exist many lower cost (and better tasting) options for cask-strength, high aged, originally bottled spirit. Now most people have never tried the wealth of distilleries in existence, and that includes the big players - the Karuizawa\u0026rsquo;s, 80s Macallans, or Stitzel-Wellers, so maybe next time, rather than wasting your time hunting down a bottle of this stuff (or any in-demand juice), try 20 different whiskies for the same price. I personally think that it can be a lot less of a risk than it may seem to buy a lesser known brand than to spend $1500 (or even whatever it goes for at MSRP) on 23 year old Pappy Van OakJuice. All in all, when you have increased demand, whiskey hoarders, \u0026lsquo;dwindling supply\u0026rsquo;, advertising - sometimes in the form of bought and paid for \u0026lsquo;reviewers\u0026rsquo; all working at the behest of large brands, it\u0026rsquo;s easy to get caught up in thinking that these bottles of liquid are worth the money others are paying for them. However, you\u0026rsquo;re a smart person, you\u0026rsquo;re informed, you\u0026rsquo;re not going to get caught up in the brand hype. In fact, you\u0026rsquo;re going to set out and do more blind taste tests to truly see if all this crap is just hype or actually worth it to your wallet.  Thanks for reading ","date":"2018-11-17","permalink":"/posts/bourbon_secondary/","title":"Bourbon Secondary"},{"content":" Airigh Nam Beist 2006 One of my first memories of drinking scotch with my father was drinking the 2006 Airigh Nam Beist release. It was the first peated whiskey I ever tried and I fell in love instantly. The nose alone was unlike any whiskey I\u0026rsquo;d ever had. At $85 it was the most expensive whiskey we\u0026rsquo;d ever bought. It took us four years to finish that bottle, just having a little bit every Thanksgiving and Christmas, and when it was gone, we couldn\u0026rsquo;t find another bottle. So last year, I decided I wanted to buy a bottle for us so we could share it whenever he came by to my place in the city. It has some sentimental value, surely, but even with out it, it sure has to be one of the best Islay whiskies ever bottled.\nBottle Info Distillery: Ardbeg\nBarrel type: n/a\nAge: 16\nABV: 46\nHow this was acquired: auction\nCost: $250\nYear purchased: 2016\nTasting Notes Nose: oily, lemony, peaty, cured meats, salty vanilla, chocolatey shortbread, pound cake\nNose[water]: n/a\nTaste: pepper, oily lemon, olive oil, salty, sugar cookie, bitter vanilla, lemon cookies, peaty, nutty\nTaste[water]: n/a\nFinish: long peat, lemon, short bread, light peppery peat, so damned good, balanced, umami\nFinish[water]: n/a\nBinary Score: 1\nValue Notes Would I buy this (or if already bought, would you buy this again): Yes\nWould I order this in a bar: Yes\nWould I drink this if someone gave me a glass: Yes\nValue For Money: 5/5\n My Scoring System 0 | I didn\u0026rsquo;t like it | It doesn\u0026rsquo;t mean you won\u0026rsquo;t\n1 | I liked it | It doesn\u0026rsquo;t mean you will\n","date":"2017-11-17","permalink":"/posts/nam_beist_2006/","title":"Airigh Nam Beist 2006"},{"content":" Laphroaig Cairdeas 2013 Portwood Fair warning, if you\u0026rsquo;re not a fan of port or grape forward whiskies, this might not be your cup of tea. But if you do like that style of whiskey, then I think you\u0026rsquo;ll really dig this. For me, the marriage of peat and port is one of my all time favorite combinations. The sweet spot being when there\u0026rsquo;s just enough of the previous cask\u0026rsquo;s flavor profile and that little bit of fruit on the palate and nose to remind you that there\u0026rsquo;s something else lurking in the background. And this whiskey gives me exactly what I was looking for with that big blast of peat and ash, a touch of that caramel and vanilla, and a light raspberry/grape profile. So good that I wish I had bought more of these when they was still on the shelves. If you\u0026rsquo;re wondering where you can buy this, you can readily find these on the secondary/auction markets for $150-$160. It\u0026rsquo;s a great whiskey, but you can find a lot better for that price. Cheers! Thanks for reading!\nBottle Info Distillery: Laphroaig\nBarrel type: bourbon casks - finished in port casks\nAge: No Age Statement\nABV: 51.3\nHow this acquired: A lucky find at a local whiskey shop hiding behind a bunch of newer Cairdeas releases\nCost: $75\nYear purchased: 2015\nTasting Notes Nose: smoke, ash, peat, distinctly laphroaig peppery peat, salt, brine, sea air, pepper, vanilla bean, tar, tobacco, touch of mint at the end, salty-grapy-seaweed with chocolate notes\nNose[water]: salt, brine, seaweed-caramel, vanilla, salted butter, salt water taffy, caramel, oak, fig, cream pudding, peat, peppery peat\nTaste: ash, salt, seaweed, caramel, vanilla, pepper, soft raspberries, darker fruits, soot, vanilla salt, fig, peppery peat, smoked salts\nTaste[water]: soft salted butter, soot, pepper, jammy laphroaig notes, jelly, cream pudding, bread pudding, grape umami\nFinish: salt, berries, figs, peat, ash, clove, fig, cookie, spices, pudding\nFinish[water]: mellow, salted butter cookies, light grape jam, mashmellow, salted chocolate, berries and cream, salted almonds, soot, ash\nBinary Score: 1\nValue Notes Would I buy this (or if already bought, would you buy this again): Yes\nWould I order this in a bar: Yes\nWould I drink this if someone gave me a glass: Yes\nValue For Money: When it was $75 - 5/5 easily. Now that it\u0026rsquo;s $160 - 3/5\n My Scoring System 0 | I didn\u0026rsquo;t like it | It doesn\u0026rsquo;t mean you won\u0026rsquo;t\n1 | I liked it | It doesn\u0026rsquo;t mean you will\n","date":"2017-11-17","permalink":"/posts/laphroaig_cairdeas_2013/","title":"Laphroaig Cairdeas 2013 Portwood"},{"content":" Signatory Jura 1989 Peated I\u0026rsquo;m a sucker for peat+caramel. Any first fill/second fill bourbon barrel peated whiskey, to me, is my favorite flavor combination. There\u0026rsquo;s just something special about the way the vanilla and peat jive. So when I saw a bourbon barrel matured cask-strength heavily peated on the shelf at Binny\u0026rsquo;s, I figured why not check it out and bring it to my whiskey club; especially because it was on sale for 69 bucks. When I broke the bottle open and tasted it for the first time, I was kind of in shock. It doesn\u0026rsquo;t peat-punch you in the face with massive power, it just simply shoots a vanilla blowdart into your neck like a crazy scotch ninja. To make it even better, there wasn\u0026rsquo;t a single note that was out of line or more dominant than the others, everything was right where it was supposed to be. In summary: this is a fantastic pour by one of the best independent bottlers in Scotland.\nBottle Info Distillery: Isle of Jura - bottled by Signatory Vintage\nBarrel type: bourbon barrel\nAge: 23 years old\nABV: 58.5\nHow this was acquired: Found it on sale at Binny\u0026rsquo;s\nCost: $69\nYear purchased: 2015\nTasting Notes Nose: mellow bourbon, light peppery peat-smoke, pepper, vanilla, caramel, salty, sea air, musty salt and dense vanilla, salted dark chocolate\nNose[water]: really opens the peatiness, vanilla, taffy and caramel, cream pudding, as it sits - a bit more smoke and malt\nTaste: spice, peppered peat, vanilla, caramel, bitter dark chocolate, a bit of that medicinal peat note, sea salt/seaweed on the end\nTaste[water]: seaweed and pepper, vanilla, caramel, smoke and cinammon, nice apple and pear notes, sort of light fruits, ooey-gooey butter cake\nFinish: hay, caramel pudding, toffee, butter, chocolate, seaweed, vanilla, seaweed bourbon\nFinish[water]: vanilla, seawater, brine, salt, peat-spice, pear, vanilla pound cake\nBinary Score: 1\nValue Notes Would I buy this (or if already bought, would you buy this again): Yes\nWould I order this in a bar: Yes\nWould I drink this if someone gave me a glass: Yes\nValue For Money: 5/5\n My Scoring System 0 | I didn\u0026rsquo;t like it | It doesn\u0026rsquo;t mean you won\u0026rsquo;t\n1 | I liked it | It doesn\u0026rsquo;t mean you will\n","date":"2017-11-17","permalink":"/posts/jura_1989/","title":"Signatory Jura 1989 Peated"},{"content":"================\nFacial Recognition on a TX1 Getting Started Here\u0026rsquo;s a way to hack facial recognition system together in relatively short time on NVIDIA\u0026rsquo;s Jetson TX1 Many thanks to Brandon Amos\u0026rsquo; OpenFace, David Sandberg\u0026rsquo;s implementation of FaceNet, and Dwight Crowe\u0026rsquo;s stackoverflow post on getting tensorflow running on a tx1\nInstallation and Setup Assumptions: you have a TX1 with a fresh install of JetPack 2.3 L4T. First things first. We need to remove all the fat from the install. There are tons of optimized libraries in JetPack, but so much of it takes up the valuable memory space we need to get the facial recognition app up and running.\n # get rid of liboffice, games, libvisionworks, perfkit, multimedia api, opencv4tegra, etc. sudo apt-get purge libreoffice* sudo apt-get purge aisleriot gnome-sudoku mahjongg ace-of-penguins gnomine gbrainy sudo apt-get clean sudo apt-get autoremove rm -rf libvision* rm -rf PerfKit* # something along these lines; might be different for you # delete all libvision-works and opencv4tegra stuff cd var \u0026amp;\u0026amp; rm -rf libopencv4tegra* \u0026amp;\u0026amp; rm -rf libvision* # I deleted practically everything. Almost as if I shouldn't have even installed JetPack in the first place # delete all deb files, Firefox, chrome, all the stuff I really didn't need that was taking up memory. # find big files and remove them assuming they're not important. Google is your friend. find / -size +10M -ls  Installing protobuf, bazel, and tensorflow Thankfully, others have paved the way and made these steps pretty much a walk in the park. Thank you to StackOverflow user, Dwight Crowe for his stellar post on how to get Tensorflow R0.9 working on a TX1. I\u0026rsquo;m literally just going to post his exact methodology.\n # install deps cd ~ sudo add-apt-repository ppa:webupd8team/java sudo apt-get update sudo apt-get install oracle-java8-installer sudo apt-get install git zip unzip autoconf automake libtool curl zlib1g-dev maven swig bzip2 #build build protobuf 3.0.0-beta-2 jar git clone https://github.com/google/protobuf.git cd protobuf # autogen.sh downloads broken gmock.zip in d5fb408d git checkout master ./autogen.sh git checkout d5fb408d ./configure --prefix=/usr make -j 4 sudo make install cd java mvn package #Get bazel version 0.2.1, it doesn't require gRPC git clone https://github.com/bazelbuild/bazel.git cd bazel git checkout 0.2.1 cp /usr/bin/protoc third_party/protobuf/protoc-linux-arm32.exe cp ../protobuf/java/target/protobuf-java-3.0.0-beta-2.jar third_party/protobuf/protobuf-java-3.0.0-beta-1.jar  Here we need to make an edit so that the bazel build will recognize aarch64 as ARM\n --- a/src/main/java/com/google/devtools/build/lib/util/CPU.java +++ b/src/main/java/com/google/devtools/build/lib/util/CPU.java @@ -25,7 +25,7 @@ import java.util.Set; public enum CPU { X86_32(\u0026quot;x86_32\u0026quot;, ImmutableSet.of(\u0026quot;i386\u0026quot;, \u0026quot;i486\u0026quot;, \u0026quot;i586\u0026quot;, \u0026quot;i686\u0026quot;, \u0026quot;i786\u0026quot;, \u0026quot;x86\u0026quot;)), X86_64(\u0026quot;x86_64\u0026quot;, ImmutableSet.of(\u0026quot;amd64\u0026quot;, \u0026quot;x86_64\u0026quot;, \u0026quot;x64\u0026quot;)), - ARM(\u0026quot;arm\u0026quot;, ImmutableSet.of(\u0026quot;arm\u0026quot;, \u0026quot;armv7l\u0026quot;)), + ARM(\u0026quot;arm\u0026quot;, ImmutableSet.of(\u0026quot;arm\u0026quot;, \u0026quot;armv7l\u0026quot;, \u0026quot;aarch64\u0026quot;)), UNKNOWN(\u0026quot;unknown\u0026quot;, ImmutableSet.of());  Now it\u0026rsquo;s time to compile bazel.\n ./compile.sh  Now we install tensorflow R0.9. Any higher than R0.9 and it requires bazel 0.3.0, which we didn\u0026rsquo;t install.\nYou will build tensorflow once and it will fail. But by building it with the failure, it gives you the bazel .cache dir you need to place updated config.guess and config.sub files necessary for the full installation.\n git clone -b r0.9 https://github.com/tensorflow/tensorflow.git ./configure # this will fail, but that's ok bazel build -c opt --config=cuda //tensorflow/tools/pip_package:build_pip_package  Download the proper config files and update the .cache dir\n cd ~ wget -O config.guess 'http://git.savannah.gnu.org/gitweb/?p=config.git;a=blob_plain;f=config.guess;hb=HEAD' wget -O config.sub 'http://git.savannah.gnu.org/gitweb/?p=config.git;a=blob_plain;f=config.sub;hb=HEAD' # below are commands Dwight Crowe ran, yours will vary depending on .cache details. # look for '_bazel_socialh', 'farmhash_archive', and 'farmhash' cp config.guess ./.cache/bazel/_bazel_socialh/742c01ff0765b098544431b60b1eed9f/external/farmhash_archive/farmhash-34c13ddfab0e35422f4c3979f360635a8c050260/config.guess cp config.sub ./.cache/bazel/_bazel_socialh/742c01ff0765b098544431b60b1eed9f/external/farmhash_archive/farmhash-34c13ddfab0e35422f4c3979f360635a8c050260/config.sub  Here is where things get a bit tricky. As Dwight suggests, you\u0026rsquo;ll have to change a few files so that tensorflow compiles correctly.\n --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -985,7 +985,7 @@ tf_kernel_libraries( \u0026quot;reduction_ops\u0026quot;, \u0026quot;segment_reduction_ops\u0026quot;, \u0026quot;sequence_ops\u0026quot;, - \u0026quot;sparse_matmul_op\u0026quot;, + #DC \u0026quot;sparse_matmul_op\u0026quot;, ], deps = [ \u0026quot;:bounds_check\u0026quot;, --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -1110,7 +1110,7 @@ medium_kernel_test_list = glob([ \u0026quot;kernel_tests/seq2seq_test.py\u0026quot;, \u0026quot;kernel_tests/slice_op_test.py\u0026quot;, \u0026quot;kernel_tests/sparse_ops_test.py\u0026quot;, - \u0026quot;kernel_tests/sparse_matmul_op_test.py\u0026quot;, + #DC \u0026quot;kernel_tests/sparse_matmul_op_test.py\u0026quot;, \u0026quot;kernel_tests/sparse_tensor_dense_matmul_op_test.py\u0026quot;, ])  TX1 can\u0026rsquo;t do fancy constructors in cwise_op_gpu_select.cu.cc or sparse_tensor_dense_matmul_op_gpu.cu.cc\n --- a/tensorflow/core/kernels/cwise_op_gpu_select.cu.cc +++ b/tensorflow/core/kernels/cwise_op_gpu_select.cu.cc @@ -43,8 +43,14 @@ struct BatchSelectFunctor { const int all_but_batch = then_flat_outer_dims.dimension(1); #if !defined(EIGEN_HAS_INDEX_LIST) - Eigen::array broadcast_dims{{ 1, all_but_batch }}; - Eigen::Tensor::Dimensions reshape_dims{{ batch, 1 }}; + //DC Eigen::array broadcast_dims{{ 1, all_but_batch }}; + Eigen::array broadcast_dims; + broadcast_dims[0] = 1; + broadcast_dims[1] = all_but_batch; + //DC Eigen::Tensor::Dimensions reshape_dims{{ batch, 1 }}; + Eigen::Tensor::Dimensions reshape_dims; + reshape_dims[0] = batch; + reshape_dims[1] = 1; #else Eigen::IndexList, int\u0026gt; broadcast_dims; broadcast_dims.set(1, all_but_batch); --- a/tensorflow/core/kernels/sparse_tensor_dense_matmul_op_gpu.cu.cc +++ b/tensorflow/core/kernels/sparse_tensor_dense_matmul_op_gpu.cu.cc @@ -104,9 +104,17 @@ struct SparseTensorDenseMatMulFunctor { int n = (ADJ_B) ? b.dimension(0) : b.dimension(1); #if !defined(EIGEN_HAS_INDEX_LIST) - Eigen::Tensor::Dimensions matrix_1_by_nnz{{ 1, nnz }}; - Eigen::array n_by_1{{ n, 1 }}; - Eigen::array reduce_on_rows{{ 0 }}; + //DC Eigen::Tensor::Dimensions matrix_1_by_nnz{{ 1, nnz }}; + Eigen::Tensor::Dimensions matrix_1_by_nnz; + matrix_1_by_nnz[0] = 1; + matrix_1_by_nnz[1] = nnz; + //DC Eigen::array n_by_1{{ n, 1 }}; + Eigen::array n_by_1; + n_by_1[0] = n; + n_by_1[1] = 1; + //DC Eigen::array reduce_on_rows{{ 0 }}; + Eigen::array reduce_on_rows; + reduce_on_rows[0] = 0; #else Eigen::IndexList, int\u0026gt; matrix_1_by_nnz; matrix_1_by_nnz.set(1, nnz);  Running with CUDA 8.0 requires new macros for FP16. Dwight throws some thanks to Kashif/Mrry for pointing out the fix, so I\u0026rsquo;m throwing some thanks to whoever those people are as well.\n --- a/tensorflow/stream_executor/cuda/cuda_blas.cc +++ b/tensorflow/stream_executor/cuda/cuda_blas.cc @@ -25,6 +25,12 @@ limitations under the License. #define EIGEN_HAS_CUDA_FP16 #endif +#if CUDA_VERSION \u0026gt;= 8000 +#define SE_CUDA_DATA_HALF CUDA_R_16F +#else +#define SE_CUDA_DATA_HALF CUBLAS_DATA_HALF +#endif + #include \u0026quot;tensorflow/stream_executor/cuda/cuda_blas.h\u0026quot; #include @@ -1680,10 +1686,10 @@ bool CUDABlas::DoBlasGemm( return DoBlasInternal( dynload::cublasSgemmEx, stream, true /* = pointer_mode_host */, CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k, α, - CUDAMemory(a), CUBLAS_DATA_HALF, lda, - CUDAMemory(b), CUBLAS_DATA_HALF, ldb, + CUDAMemory(a), SE_CUDA_DATA_HALF, lda, + CUDAMemory(b), SE_CUDA_DATA_HALF, ldb, β, - CUDAMemoryMutable(c), CUBLAS_DATA_HALF, ldc); + CUDAMemoryMutable(c), SE_CUDA_DATA_HALF, ldc); #else LOG(ERROR) \u0026lt;\u0026lt; \u0026quot;fp16 sgemm is not implemented in this cuBLAS version \u0026quot; \u0026lt;\u0026lt; \u0026quot;(need at least CUDA 7.5)\u0026quot;;  And lastly, ARM has no NUMA nodes so this needs to be added or you will get an immediate crash on starting tf.Session()\n --- a/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc +++ b/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc @@ -888,6 +888,9 @@ CudaContext* CUDAExecutor::cuda_context() { return context_; } // For anything more complicated/prod-focused than this, you'll likely want to // turn to gsys' topology modeling. static int TryToReadNumaNode(const string \u0026amp;pci_bus_id, int device_ordinal) { + // DC - make this clever later. ARM has no NUMA node, just return 0 + LOG(INFO) \u0026lt;\u0026lt; \u0026quot;ARM has no NUMA node, hardcoding to return zero\u0026quot;; + return 0; #if defined(__APPLE__) LOG(INFO) \u0026lt;\u0026lt; \u0026quot;OS X does not support NUMA - returning NUMA node zero\u0026quot;; return 0;  So I ran into strange errors that were solved by accident. After running the above commands, bazel fails in weird places. Sometimes at a random op. Sometimes a \u0026lsquo;cross_tool\u0026rsquo; error. Truth be told, I accidently reran the command with a different job number and the op that it had failed on previously ended up compiling just fine. And that was it. Just changing the job number. I switched between 3 and 4 a few times and it compiled just fine. Very weird. But whatever. It works. Just to verify it, repeated this process on a few devices and it always works.\n bazel build -c opt --config=cuda //tensorflow/tools/pip_package:build_pip_package --jobs 4 #3, 4, 3, 4, etc.  Now that Tensorflow is installed, remove bazel and all of bazel\u0026rsquo;s caches that eat memory.\n find / -size +10M -ls # delete the big bazel files   Installing OpenCv All we need here are the image reading and displaying opts. Nothing else. So the compile is small and takes up minimal space.\n # install deps sudo apt-get install build-essential sudo apt-get install cmake git libgtk2.0-dev pkg-config libavcodec-dev libavformat-dev libswscale-dev sudo apt-get install python-dev python-numpy libtbb2 libtbb-dev libjpeg-dev libpng-dev libtiff-dev libjasper-dev libdc1394-22-dev gcc-4.9 # download opencv git clone https://github.com/opencv/opencv \u0026amp;\u0026amp; cd opencv \u0026amp;\u0026amp; mkdir release # build cmake -D CMAKE_C_COMPILER=/usr/bin/gcc-4.9 \\ -D CMAKE_CXX_COMPILER=/usr/bin/g++-4.9 \\ -D CMAKE_BUILD_TYPE=RELEASE \\ -D CMAKE_INSTALL_PREFIX=/usr/local \\ -D BUILD_opencv_python=ON \\ -D BUILD_DOCS=OFF \\ -D BUILD_EXAMPLES=OFF \\ -D BUILD_TESTS=OFF -D BUILD_opencv_ts=OFF \\ -D BUILD_PERF_TESTS=OFF -D WITH_OPENCL=OFF \\ -D BUILD_SHARED_LIBS=OFF \\ -D WITH_OPENCLAMDFFT=OFF \\ -D WITH_OPENCLAMDBLAS=OFF \\ -D WITH_VA_INTEL=OFF \\ -D BUILD_opencv_python=ON \\ -D BUILD_opencv_flann=OFF \\ -D BUILD_opencv_ml=OFF \\ -D BUILD_opencv_video=OFF \\ -D BUILD_opencv_cudabgsegm=OFF \\ -D BUILD_opencv_cudafilters=OFF \\ -D BUILD_opencv_cudaimgproc=OFF \\ -D BUILD_opencv_cudawarping=OFF \\ -D BUILD_opencv_cudacodec=OFF \\ -D BUILD_opencv_objdetect=OFF \\ -D BUILD_opencv_features2d=OFF \\ -D BUILD_opencv_calib3d=OFF \\ -D BUILD_opencv_cudafeatures2d=OFF \\ -D BUILD_opencv_cudaobjdetect=OFF \\ -D BUILD_opencv_cudastereo=OFF \\ -D BUILD_opencv_cdev=ON \\ -D BUILD_opencv_java=OFF \\ -D BUILD_opencv_hal=OFF \\ -D ENABLE_NEON=OFF \\ -D BUILD_opencv_cudev=ON \\ -D BUILD_opencv_cudaarithm=OFF \\ -D BUILD_opencv_highgui=ON \\ -D BUILD_opencv_photo=ON .. sudo make -j 4 \u0026amp;\u0026amp; sudo make install -j 4   Installing Python libs - scipy, dlib, and sklearn sudo apt-get install libboost-dev-all python-scipy python-sklearn python-pip pip install dlib  Almost finished!\nHead over to David Sandberg\u0026rsquo;s tensorflow implementation of OpenFace and download the resnet model weights in the Pre-trained model section. Then download the dlib facedetector from dlib.net.\nwget -nv http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2 bzip2 -d shape_predictor_68_face_landmarks.dat.bz2  BOOM! Finished with downloads and installation. Now it\u0026rsquo;s time to build our embedded face detector.\nBuilding the Face Detection pipeline First thing we need to do is copy align_dlib.py from here and make some quick changes. In the \u0026lsquo;stock\u0026rsquo; version, it looks for the \u0026lsquo;biggest\u0026rsquo; bounding box and only processes that one. But we\u0026rsquo;re going to augment it so that it will classify all bounding boxes that it finds; eg every face will classified rather than just the largest. Also, we\u0026rsquo;re going to make another quick change to the face detector based on the issues from this thread, whereby the detector shears the faces and warps them slightly.\nIt should be noted that while David Sandberg uses a version of MTCNN to detect faces, we have to use the augmented dlib version. This is done so that when the final detection system is running, the memory profile doesn\u0026rsquo;t get out of wack and spontaneously kill our processes. By changing the face detector, there will be an effect on the overall detection accuracy of our system, but the difference will be minimal.\n# align_dlib.py # Copyright 2015-2016 Carnegie Mellon University # # Licensed under the Apache License, Version 2.0 (the \u0026quot;License\u0026quot;); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an \u0026quot;AS IS\u0026quot; BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. \u0026quot;\u0026quot;\u0026quot;Module for dlib-based alignment.\u0026quot;\u0026quot;\u0026quot; # NOTE: This file has been copied from the openface project. # https://github.com/cmusatyalab/openface/blob/master/openface/align_dlib.py import cv2 import dlib import numpy as np TEMPLATE = np.float32([ (0.0792396913815, 0.339223741112), (0.0829219487236, 0.456955367943), (0.0967927109165, 0.575648016728), (0.122141515615, 0.691921601066), (0.168687863544, 0.800341263616), (0.239789390707, 0.895732504778), (0.325662452515, 0.977068762493), (0.422318282013, 1.04329000149), (0.531777802068, 1.06080371126), (0.641296298053, 1.03981924107), (0.738105872266, 0.972268833998), (0.824444363295, 0.889624082279), (0.894792677532, 0.792494155836), (0.939395486253, 0.681546643421), (0.96111933829, 0.562238253072), (0.970579841181, 0.441758925744), (0.971193274221, 0.322118743967), (0.163846223133, 0.249151738053), (0.21780354657, 0.204255863861), (0.291299351124, 0.192367318323), (0.367460241458, 0.203582210627), (0.4392945113, 0.233135599851), (0.586445962425, 0.228141644834), (0.660152671635, 0.195923841854), (0.737466449096, 0.182360984545), (0.813236546239, 0.192828009114), (0.8707571886, 0.235293377042), (0.51534533827, 0.31863546193), (0.516221448289, 0.396200446263), (0.517118861835, 0.473797687758), (0.51816430343, 0.553157797772), (0.433701156035, 0.604054457668), (0.475501237769, 0.62076344024), (0.520712933176, 0.634268222208), (0.565874114041, 0.618796581487), (0.607054002672, 0.60157671656), (0.252418718401, 0.331052263829), (0.298663015648, 0.302646354002), (0.355749724218, 0.303020650651), (0.403718978315, 0.33867711083), (0.352507175597, 0.349987615384), (0.296791759886, 0.350478978225), (0.631326076346, 0.334136672344), (0.679073381078, 0.29645404267), (0.73597236153, 0.294721285802), (0.782865376271, 0.321305281656), (0.740312274764, 0.341849376713), (0.68499850091, 0.343734332172), (0.353167761422, 0.746189164237), (0.414587777921, 0.719053835073), (0.477677654595, 0.706835892494), (0.522732900812, 0.717092275768), (0.569832064287, 0.705414478982), (0.635195811927, 0.71565572516), (0.69951672331, 0.739419187253), (0.639447159575, 0.805236879972), (0.576410514055, 0.835436670169), (0.525398405766, 0.841706377792), (0.47641545769, 0.837505914975), (0.41379548902, 0.810045601727), (0.380084785646, 0.749979603086), (0.477955996282, 0.74513234612), (0.523389793327, 0.748924302636), (0.571057789237, 0.74332894691), (0.672409137852, 0.744177032192), (0.572539621444, 0.776609286626), (0.5240106503, 0.783370783245), (0.477561227414, 0.778476346951)]) INV_TEMPLATE = np.float32([ (-0.04099179660567834, -0.008425234314031194, 2.575498465013183), (0.04062510634554352, -0.009678089746831375, -1.2534351452524177), (0.0003666902601348179, 0.01810332406086298, -0.32206331976076663)]) TPL_MIN, TPL_MAX = np.min(TEMPLATE, axis=0), np.max(TEMPLATE, axis=0) MINMAX_TEMPLATE = (TEMPLATE - TPL_MIN) / (TPL_MAX - TPL_MIN) class AlignDlib: \u0026quot;\u0026quot;\u0026quot; Use `dlib's landmark estimation `_ to align faces. The alignment preprocess faces for input into a neural network. Faces are resized to the same size (such as 96x96) and transformed to make landmarks (such as the eyes and nose) appear at the same location on every image. Normalized landmarks: .. image:: ../images/dlib-landmark-mean.png \u0026quot;\u0026quot;\u0026quot; #: Landmark indices corresponding to the inner eyes and bottom lip. INNER_EYES_AND_BOTTOM_LIP = [39, 42, 57] #: Landmark indices corresponding to the outer eyes and nose. OUTER_EYES_AND_NOSE = [36, 45, 33] def __init__(self, facePredictor): \u0026quot;\u0026quot;\u0026quot; Instantiate an 'AlignDlib' object. :param facePredictor: The path to dlib's :type facePredictor: str \u0026quot;\u0026quot;\u0026quot; assert facePredictor is not None #pylint: disable=no-member self.detector = dlib.get_frontal_face_detector() self.predictor = dlib.shape_predictor(facePredictor) def getAllFaceBoundingBoxes(self, rgbImg): \u0026quot;\u0026quot;\u0026quot; Find all face bounding boxes in an image. :param rgbImg: RGB image to process. Shape: (height, width, 3) :type rgbImg: numpy.ndarray :return: All face bounding boxes in an image. :rtype: dlib.rectangles \u0026quot;\u0026quot;\u0026quot; assert rgbImg is not None try: return self.detector(rgbImg, 1) except Exception as e: #pylint: disable=broad-except print(\u0026quot;Warning: {}\u0026quot;.format(e)) # In rare cases, exceptions are thrown. return [] def getLargestFaceBoundingBox(self, rgbImg, skipMulti=False): \u0026quot;\u0026quot;\u0026quot; Find the largest face bounding box in an image. :param rgbImg: RGB image to process. Shape: (height, width, 3) :type rgbImg: numpy.ndarray :param skipMulti: Skip image if more than one face detected. :type skipMulti: bool :return: The largest face bounding box in an image, or None. :rtype: dlib.rectangle \u0026quot;\u0026quot;\u0026quot; assert rgbImg is not None faces = self.getAllFaceBoundingBoxes(rgbImg) if (not skipMulti and len(faces) \u0026gt; 0) or len(faces) == 1: return max(faces, key=lambda rect: rect.width() * rect.height()) else: return None def findLandmarks(self, rgbImg, bb): \u0026quot;\u0026quot;\u0026quot; Find the landmarks of a face. :param rgbImg: RGB image to process. Shape: (height, width, 3) :type rgbImg: numpy.ndarray :param bb: Bounding box around the face to find landmarks for. :type bb: dlib.rectangle :return: Detected landmark locations. :rtype: list of (x,y) tuples \u0026quot;\u0026quot;\u0026quot; assert rgbImg is not None assert bb is not None points = self.predictor(rgbImg, bb) #return list(map(lambda p: (p.x, p.y), points.parts())) return [(p.x, p.y) for p in points.parts()] #pylint: disable=dangerous-default-value def align_old(self, imgDim, rgbImg, bb=None, landmarks=None, landmarkIndices=INNER_EYES_AND_BOTTOM_LIP, skipMulti=False, scale=1.0): r\u0026quot;\u0026quot;\u0026quot;align(imgDim, rgbImg, bb=None, landmarks=None, landmarkIndices=INNER_EYES_AND_BOTTOM_LIP) Transform and align a face in an image. :param imgDim: The edge length in pixels of the square the image is resized to. :type imgDim: int :param rgbImg: RGB image to process. Shape: (height, width, 3) :type rgbImg: numpy.ndarray :param bb: Bounding box around the face to align. \\ Defaults to the largest face. :type bb: dlib.rectangle :param landmarks: Detected landmark locations. \\ Landmarks found on `bb` if not provided. :type landmarks: list of (x,y) tuples :param landmarkIndices: The indices to transform to. :type landmarkIndices: list of ints :param skipMulti: Skip image if more than one face detected. :type skipMulti: bool :param scale: Scale image before cropping to the size given by imgDim. :type scale: float :return: The aligned RGB image. Shape: (imgDim, imgDim, 3) :rtype: numpy.ndarray \u0026quot;\u0026quot;\u0026quot; assert imgDim is not None assert rgbImg is not None assert landmarkIndices is not None if bb is None: bb = self.getLargestFaceBoundingBox(rgbImg, skipMulti) if bb is None: return if landmarks is None: landmarks = self.findLandmarks(rgbImg, bb) npLandmarks = np.float32(landmarks) npLandmarkIndices = np.array(landmarkIndices) #pylint: disable=maybe-no-member H = cv2.getAffineTransform(npLandmarks[npLandmarkIndices], imgDim * MINMAX_TEMPLATE[npLandmarkIndices]*scale + imgDim*(1-scale)/2) thumbnail = cv2.warpAffine(rgbImg, H, (imgDim, imgDim)) return thumbnail #Here's the new method def align_one(self, imgDim, rgbImg, bb=None, landmarks=None, landmarkIndices=INNER_EYES_AND_BOTTOM_LIP, skipMulti=False, scale=1.0): assert imgDim is not None assert rgbImg is not None assert landmarkIndices is not None if bb is None: bb = self.getLargestFaceBoundingBox(rgbImg, skipMulti) if bb is None: return if landmarks is None: landmarks = self.findLandmarks(rgbImg, bb) npLandmarks = np.float32(landmarks) tplLandmarks = imgDim * MINMAX_TEMPLATE*scale + imgDim*(1-scale)/2 tplLandmarks = np.transpose(tplLandmarks) npLandmarks = np.vstack( (np.transpose(npLandmarks), np.ones(tplLandmarks.shape[1])) ) H = np.matmul(np.matmul(tplLandmarks, np.transpose(npLandmarks)), np.linalg.inv(np.matmul(npLandmarks,np.transpose(npLandmarks)))) thumbnail = cv2.warpAffine(rgbImg, H, (imgDim, imgDim)) return thumbnail, bb #here's that same method applied to all bounding boxes it finds def align_many(self, imgDim, rgbImg, bb=None, landmarks=None, landmarkIndices=INNER_EYES_AND_BOTTOM_LIP, skipMulti=False, scale=1.0): assert imgDim is not None assert rgbImg is not None assert landmarkIndices is not None thumbnails = [] bboxes = [] bbs = self.getAllFaceBoundingBoxes(rgbImg) if bbs is None: return for bb in bbs: if landmarks is None: landmarks = self.findLandmarks(rgbImg, bb) npLandmarks = np.float32(landmarks) tplLandmarks = imgDim * MINMAX_TEMPLATE*scale + imgDim*(1-scale)/2 tplLandmarks = np.transpose(tplLandmarks) npLandmarks = np.vstack( (np.transpose(npLandmarks), np.ones(tplLandmarks.shape[1])) ) H = np.matmul(np.matmul(tplLandmarks, np.transpose(npLandmarks)), np.linalg.inv(np.matmul(npLandmarks,np.transpose(npLandmarks)))) thumbnail = cv2.warpAffine(rgbImg, H, (imgDim, imgDim)) thumbnails.append(thumbnail) bboxes.append(bb) return thumbnails, bboxes  The second thing we need to do is build a scanner to identify the faces you actually want to classify. One thing to note is that with the Jetson, using the camera with OpenCV can be tricky. We need to make sure open the video with this prompt in our call via OpenCV: \u0026ldquo;nvcamerasrc ! video/x-raw(memory:NVMM), width=(int)640, height=(int)480,format=(string)I420, framerate=(fraction)24/1 ! nvvidconv flip-method=2 ! video/x-raw, format=(string)BGRx ! videoconvert ! video/x-raw, format=(string)BGR ! appsink{.highlight}\u0026rdquo; Here\u0026rsquo;s our script called scan.py.\n# scan.py # -*- coding: UTF-8 -*- #Usage: python scan.py --name YOUR NAME import cv2 import align_dlib as align import os import argparse def fill(x, name): \u0026quot;\u0026quot;\u0026quot; function to append 00s onto a string \u0026quot;\u0026quot;\u0026quot; t=len(str(x)) a = 5-t s = str(x) for i in range(a): s = str(0) + s return name+'_'+s try: os.mkdir('train_images') except: pass path = '{}/train_images/'.format(os.getcwd()) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--name', type=str, help=\u0026quot;Name of person being scanned\u0026quot;, default='Bruce_Lee') args = parser.parse_args() face_detector = align.AlignDlib('{}/shape_predictor_68_face_landmarks.dat'.format(os.getcwd())) video_capture = cv2.VideoCapture(\u0026quot;nvcamerasrc ! video/x-raw(memory:NVMM), width=(int)640, height=(int)480,format=(string)I420, framerate=(fraction)24/1 ! nvvidconv flip-method=2 ! video/x-raw, format=(string)BGRx ! videoconvert ! video/x-raw, format=(string)BGR ! appsink\u0026quot;) # set the capture HxW to the same as openface by https://github.com/cmusatyalab/openface video_capture.set(3, 320) video_capture.set(4, 240) count = 0 while True: ret, frame = video_capture.read() try: face_thumbnail, area = face_detector.align_one(96, frame) string = fill(count, args.name) print(\u0026quot;face detection successful\u0026quot;) cv2.imwrite(path+string+'.png', face_thumbnail) bl = (area.left(), area.bottom()) tr = (area.right(), area.top()) cv2.rectangle(frame, bl, tr, color=(153, 255, 204), thickness=3) count+=1 except: pass cv2.imshow('', frame) if cv2.waitKey(1) \u0026amp; 0xFF == ord('q'): break   FaceNet Model Overview Most open source facial recognition libraries like OpenFace, home_surveillance, facenet, etc. use the model similar to the one outlined in the FaceNet paper written by Florian Schroff, Dmitry Kalenichenko, James Philbin. Here\u0026rsquo; we\u0026rsquo;re no different and will be using the model that David Sandberg\u0026rsquo;s facenet.uses.\nThe model works to take an image of an individuals face and pass it through a network (the model David uses is a variant of Inception-Resnet). The goal is to make the network embed the image in a feature space so that the squared distance between images of the same identity is small and the squared distance between a pair of images from different identities is large. This is done using something called a Triplet Loss. It\u0026rsquo;s probably the one of the single-most important feature of the model\u0026rsquo;s structure.\nRather than break down the entire model, I just want to mention what made this model stick out for me:\n# triplet loss def triplet_loss(image_embedding, true_image_embedding, false_image_embedding, alpha): positive_distance = np.sum(np.square(image_embedding - true_image_embedding),axis=1) negative_distance = np.sum(np.square(image_embedding - false_image_embedding),axis=1) basic_loss = (positive_distance - negative_distance) + alpha loss = np.mean(np.max(basic_loss, 0.0), axis=0) return loss  The goal here is to promote an embedding scheme that enforces a margin between each face pair of one identity to that of all other identities.\nIn order to ensure that the network learns properly, triplets are selected in such a way that during the forward pass, negative samples are selected in an online fashion from the current minibatch. The authors note that selecting very distant negatives can lead to a bad local minima early on, so they instead select negatives so that their distance is further away from the image\u0026rsquo;s embedding than the positive example, but are still meaningful because the squared distance is close to the anchor positive distance. Thus resulting in negatives that lie inside the margin alpha and help avoid a collapsed model.\nTL;DR: Read the paper. It\u0026rsquo;s worth it.\n Building the live detector After you\u0026rsquo;ve scanned the faces you want to via the TX1\u0026rsquo;s camera, we\u0026rsquo;re going to want to put something together to actually classify faces. This script takes on concepts from openface\u0026rsquo;s web-demo as well as facenet\u0026rsquo;s validate_on_lfw.py.\nSo what\u0026rsquo;s going on here? Welp, we want to train a model from all images within our training set from scan.py, then use the facenet model to build a representation of each image. After each image has been processed via the network, we train an SVM on their representations and we teach it to classify a person\u0026rsquo;s processed image correctly. That trained SVM is then used to classify all the faces that the camera sees.\n# -*- coding: UTF-8 -*- # model.py import cv2 from collections import defaultdict import tensorflow as tf import align_dlib as align import os import numpy as np from sklearn.grid_search import GridSearchCV from sklearn.svm import SVC def load_model(model_file): \u0026quot;\u0026quot;\u0026quot;Load facenet model\u0026quot;\u0026quot;\u0026quot; saver = tf.train.import_meta_graph(os.path.expanduser(model_file+'.meta')) saver.restore(tf.get_default_session(), os.path.expanduser(model_file)) def get_data(path, sess, embeddings): \u0026quot;\u0026quot;\u0026quot;gather data from our training images and vectorize them\u0026quot;\u0026quot;\u0026quot; image_list = sorted(os.listdir(path)) X = [] y = [] person_dict = {} count = 0 current = image_list[0].split('_')[0] for string in image_list: name = string.split('_')[0] if current!=name: count+=1 current = name person_dict[count] = name thm = cv2.imread(path+'/'+string) feed_dict = { images_placeholder:np.expand_dims(thm,0), phase_train_placeholder:False } emb_array = sess.run(embeddings, feed_dict=feed_dict) X.append(emb_array) y.append(count) X = np.vstack(X) y = np.array(y) return X, y, person_dict def train_SVM(path, sess, embeddings): \u0026quot;\u0026quot;\u0026quot;train an svm on the image data we've collected\u0026quot;\u0026quot;\u0026quot; X, y, person_dict = get_data(path, sess, embeddings) print(\u0026quot;+ Training SVM on {} labeled images.\u0026quot;.format(X.shape[0])) param_grid = [ {'C': [1, 10, 100, 1000], 'kernel': ['linear']}, {'C': [1, 10, 100, 1000], 'gamma': [0.001, 0.0001], 'kernel': ['rbf']} ] svm = GridSearchCV(SVC(C=1), param_grid, cv=5).fit(X, y) return svm, person_dict def get_rep(img, sess, embeddings): \u0026quot;\u0026quot;\u0026quot;process an image via facenet and return the embedding array\u0026quot;\u0026quot;\u0026quot; rgbImg = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) feed_dict = { images_placeholder:np.expand_dims(img,0), phase_train_placeholder:False } emb_array = sess.run(embeddings, feed_dict=feed_dict) return emb_array def process_frame(frame, sess, embeddings, svm, person_dict): \u0026quot;\u0026quot;\u0026quot;process a frame and if faces are found, draw a rectangle on the image with the corresponding face \u0026quot;\u0026quot;\u0026quot; faces, bboxes = face_detector.align_many(96, frame) if faces!=None: print('detection!') for face, bb in zip(faces, bboxes): rep = get_rep(face, sess, embeddings) print(rep) identity = svm.predict(rep)[0] name = person_dict[identity] bl = (bb.left(), bb.bottom()) tr = (bb.right(), bb.top()) cv2.rectangle(frame, bl, tr, color=(153, 255, 204), thickness=3) cv2.putText(frame, name, (bb.left(), bb.top() - 10), cv2.FONT_HERSHEY_SIMPLEX, fontScale=0.75, color=(152, 255, 204), thickness=2) # without this, errors were being thrown and the machine was being killed config = tf.ConfigProto() config.gpu_options.allow_growth=True # we use device = cpu because it's just as fast on the gpu. # we also want to avoid io transfer time to gpu considering we'll be # processing images incredibly quickly with tf.device('/cpu:0'): with tf.Graph().as_default(): with tf.Session(config=config) as sess: # Load the model #TODO: change the model dir to whatever model dir/ model you download print('Loading model \u0026quot;%s\u0026quot;' % 'model-20160506.ckpt-500000') load_model('20160514-234418/model.ckpt-500000') graph_def = tf.get_default_graph() # Get input and output tensors images_placeholder = tf.get_default_graph().get_tensor_by_name(\u0026quot;input:0\u0026quot;) phase_train_placeholder = tf.get_default_graph().get_tensor_by_name(\u0026quot;phase_train:0\u0026quot;) embeddings = tf.get_default_graph().get_tensor_by_name(\u0026quot;embeddings:0\u0026quot;) print(\u0026quot;setting up training\u0026quot;) training_path = 'train_images' svm, person_dict = train_SVM(training_path, sess, embeddings) print(\u0026quot;setting up camera\u0026quot;) video_capture = cv2.VideoCapture(\u0026quot;nvcamerasrc ! video/x-raw(memory:NVMM), width=(int)640, height=(int)480,format=(string)I420, framerate=(fraction)24/1 ! nvvidconv flip-method=2 ! video/x-raw, format=(string)BGRx ! videoconvert ! video/x-raw, format=(string)BGR ! appsink\u0026quot;) video_capture.set(3, 320) video_capture.set(4, 240) print(\u0026quot;setting up detector\u0026quot;) face_detector = align.AlignDlib('shape_predictor_68_face_landmarks.dat') while True: ret, frame = video_capture.read() faces, bboxes = face_detector.align_many(96, frame) process_frame(frame, sess, embeddings, svm, person_dict) cv2.imshow('', frame) # quit the program on the press of key 'q' if cv2.waitKey(1) \u0026amp; 0xFF == ord('q'): break  Boomshakalaka. And there you have it. A few simple scripts, and you have an embedded detector up and running.\nBecause everyone likes a demo in camera vertical (\u0026hellip;a little lag due to tunneling X over ssh)\n","date":"2016-11-17","permalink":"/posts/face/","title":"Facial Recognition on a TX1"},{"content":" 1979 Bunnahabhain 25 Lombard (2004) This bottle is from the Lombard Scotch Whiskey under the Jewels of Scotland line. I don\u0026rsquo;t know too much about this company other than they made this bottle. I\u0026rsquo;ve had other Bunnahabhains, but just not as werid as this one. Really recommend it if you get a chance.\nBottle Info Distillery: Bunnahabhain\nBarrel type: sherry hogshead\nAge: 25\nABV: 50\nHow this was acquired: store\nCost: $150\nYear purchased: 2013\nTasting Notes Nose: malty, sweetness, floral, fino sherry?, nutty, lemony, custard, caramel, vanilla, briney hay, sour funk with time. really interesting sour note at the end, but sour in a lemony sense. really opens up with time.\nNose[water]: n/a\nTaste: lemony, tart, malt funk, cerel sweetness, sour, odd ginger note, really odd interesting carbonated quality\nTaste[water]: n/a\nFinish: vanilla custard, frosting, vanilla, good - but really one dimensional and dry\nFinish[water]: n/a\nBinary Score: 1\nValue Notes Would I buy this (or if already bought, would you buy this again): Yes\nWould I order this in a bar: Yes\nWould I drink this if someone gave me a glass: Yes\nValue For Money: 4/5\n My Scoring System 0 | I didn\u0026rsquo;t like it | It doesn\u0026rsquo;t mean you won\u0026rsquo;t\n1 | I liked it | It doesn\u0026rsquo;t mean you will\n","date":"2015-11-17","permalink":"/posts/bunnahabhain_1979/","title":"1979 Bunnahabhain 25 Lombard (2004)"},{"content":" Ardbeg Dark Cove Standard Edition This is a great whiskey, but there are better whiskies out there for less money. Ardbeg routinely puts respectability to the NAS sku. However, I feel like this Ardbeg falls a bit flat in that respect. It\u0026rsquo;s good, just not what we\u0026rsquo;ve come to expect from Ardbeg. There\u0026rsquo;s an odd doughy or yeasty quality that lingers in the background that, to me, signifies some incredibly young or poor batch that is steering too much of the ship. I feel like if they decided to do the Dark Cove/sherry finished Ardbeg 10 year, then this would be far more interesting to me. However, on the spectrum of worst whiskey I\u0026rsquo;ve ever tried to this, it\u0026rsquo;s still a damn fine whiskey.\nBottle Info Distillery: Ardbeg\nBarrel type: ex-bourbon, finished in \u0026lsquo;dark sherry\u0026rsquo; casks (googling around says this is heavily-seasoned sherry casks)\nAge: n/a\nABV: 46.5\nHow this was acquired: store\nCost: $110\nYear purchased: 2016\nTasting Notes Nose: spicy, peppery, peat, uigaedaelish?, brine, salt, hay, wet grass, sugary - almost caramel, vanilla cream, soft ocean brine\nNose[water]: salt, brine, less, caramel, peaty, softer, more tight, grape, fig, dark fruits, berries, confectioners sugar\nTaste: soft peat, hay, caramel, brine, salty sweets, peppery, earthy, vanilla\nTaste[water]: more grapy, light fruits, peat, caramel, yeasty\nFinish: caramel fading into salty peat, doughy, long argbegish finish, vanilla, salt, earthy hay\nFinish[water]: dense sugary peat, long salty caramel, delicious\nBinary Score: 1\nValue Notes Would I buy this (or if already bought, would you buy this again): No\nWould I order this in a bar: No\nWould I drink this if someone gave me a glass: Yes\nValue For Money: 2/5\n My Scoring System 0 | I didn\u0026rsquo;t like it | It doesn\u0026rsquo;t mean you won\u0026rsquo;t\n1 | I liked it | It doesn\u0026rsquo;t mean you will\n","date":"2015-11-17","permalink":"/posts/dark_cove/","title":"Ardbeg Dark Cove Standard Edition"},{"content":" Blairfindy 16 Blackadder Raw Cask As a big Glenfarclas fan this was an easy buy for me. I love their juice; and to get the chance to try some at cask strength was a no-brainer. For those who don\u0026rsquo;t know, Blairfindy is what Blackadder bottles Glenfarclas under. I love this bottle. That nice fruity flavor with that sweet vanilla. Really good stuff.\nBottle Info Distillery: Glenfarclas\nBarrel type: single oak hogshead\nAge: 16\nABV: 54.1\nHow this was acquired: store\nCost: $130\nYear purchased: 2016\nTasting Notes Nose: apples, pears, light fruits, gentle maltiness, salty briny vanilla\nNose[water]: all of the above, but just gently lightened, a bit of ash and smoke, feint sea salty apples\nTaste: vanilla caramel, lemony, slight acidity, apple, pear, peach, creamy lemon custard\nTaste[water]: more citrus flavors, custardy\nFinish: light malty apples, caramel sweetness, bitter cola, very different from the nose\nFinish[water]: raw custard, vanilla, pear, baking spices, glenfarclas-y\nBinary Score: 1\nValue Notes Would I buy this (or if already bought, would you buy this again): Yes\nWould I order this in a bar: Yes\nWould I drink this if someone gave me a glass: Yes\nValue For Money: 4/5\n My Scoring System 0 | I didn\u0026rsquo;t like it | It doesn\u0026rsquo;t mean you won\u0026rsquo;t\n1 | I liked it | It doesn\u0026rsquo;t mean you will\n","date":"2015-11-17","permalink":"/posts/blairfindy_16/","title":"Blairfindy 16 Blackadder Raw Cask"},{"content":" Blantons Straight from the Barrel I\u0026rsquo;m a sucker for cask-strength anything. Scotch. Bourbon. Everclear. Load up the fat and pour that sucker in my glass. I love the rawness. This one gives me exactly what I\u0026rsquo;m looking for. There are some out there, however, who criticize this for being not as well balanced as what they would want for a bourbon \u0026ndash; often complaining of it\u0026rsquo;s heat. I say to the moon with them. This beast is what it\u0026rsquo;s all about. An edgy bourbon that gives no fucks, is completely inconsistent from barrel to barrel, and punches you in the mouth with flavor? Sign me up twice. This is what non-age-statement (or hell any) bourbons should be. Raw, full flavored, and powerful. Not holding back. Not trying to be more than what it is. Just simply letting the distilling and barrel do the talking. This is simply awesome.\nBottle Info Distillery: Buffalo Trace\nBarrel type: bourbon\nAge: No Age Statement\nABV: 63.95\nHow this was acquired: auction\nCost: $92\nYear purchased: 2017\nTasting Notes Nose: oak, creme brule, werther\u0026rsquo;s original, toblerone, almond, caramel, caramelized sugar, wood sugars. It\u0026rsquo;s like smelling pure cake frosting\nNose[water]: sap, syrup, aunt jamima\u0026rsquo;s syrup \u0026ndash; the overly sugared kind - the non-real maple syrup - not at all in a bad way, cake frosting, a maple glazed donut from Do-Rite donuts with a strip of bacon down the center. sugar and spice.\nTaste: sugars, caramelized sugar, ooey-gooey butter cake, sugar cookies, frosting, cake, vanilla, caramel, bit of vermouth, hot spice. krispy-kreme donuts.\nTaste[water]: more frosting, sugar and spice, creme, nougat, toblerone, cookie, a bit of thyme.\nFinish: sugary, caramel, buttery sweet cookie, toffee.\nFinish[water]: just continues the taste with a bit of water, long lasting cake frosting, chocolate, brownies.\nBinary Score: 1\nValue Notes Would I buy this (or if already bought, would you buy this again): Yes\nWould I order this in a bar: Yes\nWould I drink this if someone gave me a glass: Yes\nValue For Money: 5/5\n My Scoring System 0 | I didn\u0026rsquo;t like it | It doesn\u0026rsquo;t mean you won\u0026rsquo;t\n1 | I liked it | It doesn\u0026rsquo;t mean you will\n","date":"2015-11-17","permalink":"/posts/blantons/","title":"Blantons Straight from the Barrel"},{"content":" Kilchoman 2008 Vintage Find me a better 7 year Islay scotch. You can\u0026rsquo;t. This was so good, I went back to Binny\u0026rsquo;s to buy another. Kilchoman is something special. They remind me a lot of Westland (another distillery that I really admire); they\u0026rsquo;re both doing it their own way and absolutely killing it. This to me tastes a lot like Ardbeg, and frankly, this 7 year pitted against one of my all time favorites, the Ardbeg 10 year, it blows it out of the water. However, the only downside to this is the price: 79.00 now. If this was priced at the same level as the the Ardbeg 10, it would be a household staple.\nBottle Info Distillery: Kilchoman\nBarrel type: ex-bourbon\nAge: 7\nABV: 46\nHow this was acquired: store\nCost: $79\nYear purchased: 2016\nTasting Notes Nose: has bourbony notes, ardbegy?, maybe similar to a bourbon aged ardbeg or lagavulin, fresh hay, very clean sharp peat, caramel, vanilla, salt, brine, grass, a bit of citrus, smoke, butterscotch, cooked sugar\nNose[water]: n/a\nTaste: sharp peat, vanilla, malty youngness, laphroaig earthiness, ardbeg freshness, lemony peat, soft sweets, brine, light grassy notes, pepper\nTaste[water]: n/a\nFinish: vanilla, peat, sweetness, caramel, toffee, long salty very sharp peat, grassy notes\nFinish[water]: n/a\nBinary Score: 1\nValue Notes Would I buy this (or if already bought, would you buy this again): Yes\nWould I order this in a bar: Yes\nWould I drink this if someone gave me a glass: Yes\nValue For Money: 4/5\n My Scoring System 0 | I didn\u0026rsquo;t like it | It doesn\u0026rsquo;t mean you won\u0026rsquo;t\n1 | I liked it | It doesn\u0026rsquo;t mean you will\n","date":"2015-11-17","permalink":"/posts/kilchoman_2008/","title":"Kilchoman 2008 Vintage"},{"content":" Parkers Heritage Malt Whiskey Admittedly, I had no intention of picking this bottle up, but thankfully my fiance picked this up as an anniversary gift. Helluva present that\u0026rsquo;s for sure. Let there be no mistake: This is a great whiskey. Is it the best Parker\u0026rsquo;s Heritage? Definitely not. Is it better than 95% of the whiskies out there? 100% yes. It\u0026rsquo;s a great experiment that hit all the right spots, but if I was to dole out unsolicited advice, definitely try this before you buy.\nBottle Info Distillery: Heaven Hill Distillery\nBarrel type: american whiskey\nAge: 8\nABV: 54\nHow this was acquired: secondary\nCost: $110\nYear purchased: 2016\nTasting Notes Nose: oak, woody floral note - like a rye without being a rye, vanilla, musty maltiness, chocolate, heavy oak, butterscotch, slight grassy note\nNose[water]: n/a\nTaste: sweetness, soft butterscotch, salty, oak, malty spice, musty funk, slight bitterness, mexican chocolate\nTaste[water]: n/a\nFinish: long cloying chocolate, cocoa, mexican chocolate, vanilla, spicy heat\nFinish[water]: n/a\nBinary Score: 1\nValue Notes Would I buy this (or if already bought, would you buy this again): Yes\nWould I order this in a bar: Yes\nWould I drink this if someone gave me a glass: Yes\nValue For Money: 3/5\n My Scoring System 0 | I didn\u0026rsquo;t like it | It doesn\u0026rsquo;t mean you won\u0026rsquo;t\n1 | I liked it | It doesn\u0026rsquo;t mean you will\n","date":"2015-11-17","permalink":"/posts/parkers_malt/","title":"Parkers Heritage Malt Whiskey"},{"content":" Smooth Ambler Old Scout 9 Year Single Barrel I picked this up at a long-standing south loop standby, Warehouse Liquors. They regulary carry interesting store picks that differ greatly from many of the Binny\u0026rsquo;s store picks I usually pick up. This particular whiskey is no different. Smooth Ambler makes some excellent MGP sourced picks. Smooth Ambler is traditionally known for the excellent single barrel rye picks, but man I just love their bourbon picks. They hit all the right notes for me and the price can\u0026rsquo;t be beat.\nBottle Info Distillery: Smooth Ambler\nBarrel type: bourbon\nAge: 9\nABV: 55.4\nHow this was acquired: store\nCost: $54\nYear purchased: 2015\nTasting Notes Nose: dill, rich spice, oak, carmel, cream pudding, toffee, butterscotch donuts, vanilla, chocolate\nNose[water]: butter, rich cream, caramel, heavy sweet syrup, less spice and more bourbony notes\nTaste: butterscotch, pecan pie, caramel, creamy wirther\u0026rsquo;s chocolates, peachy fruit notes, chocolate spice, dense raspberry, minty\nTaste[water]: sweet milk chocolate, sugary, honey, clover, salty butterscotch\nFinish: long spicy sugar, butterscotch, cinnamon cream, floral spiciness\nFinish[water]: cream, fruity, flowery, honey, chocolate\nBinary Score: 1\nValue Notes Would I buy this (or if already bought, would you buy this again): Yes\nWould I order this in a bar: Yes\nWould I drink this if someone gave me a glass: Yes\nValue For Money: 5/5\n My Scoring System 0 | I didn\u0026rsquo;t like it | It doesn\u0026rsquo;t mean you won\u0026rsquo;t\n1 | I liked it | It doesn\u0026rsquo;t mean you will\n","date":"2015-11-17","permalink":"/posts/smooth_ambler/","title":"Smooth Ambler Old Scout 9 Year Single Barrel"},{"content":" Westland Peated Malt This is a fantastic whiskey. An 8.4 for a 24 month whiskey is insane. I wouldn\u0026rsquo;t be surprised if most 24 month whiskies I\u0026rsquo;d give much lower - 3 to 5 range. Westland is my favorite american single malt distillery. What they\u0026rsquo;ve accomplished in such a small amount of time is insane. I\u0026rsquo;ve toured their distillery and tasting room, it\u0026rsquo;s phenominal - if you\u0026rsquo;re a fan of whiskey you owe it to yourself if you\u0026rsquo;re ever in Seattle. It\u0026rsquo;s hard not to root for them (even though they\u0026rsquo;ve been acquired). They\u0026rsquo;re the future of American whiskey. This whiskey, with it\u0026rsquo;s malty lemony quality, is something special. It\u0026rsquo;s definitely not the best whiskey I\u0026rsquo;ve ever had; but you owe it to yourself to give Westland a chance. In 10 years Westland is going to produce the best American whiskey - mark my words!\nBottle Info Distillery: Westland Distillery\nBarrel type: first fill ex-bourbon, cooper\u0026rsquo;s select new american oak, cooper\u0026rsquo;s reserve new american oak\nAge: minimum of 24 months\nABV: 46\nHow this was acquired: store\nCost: $70\nYear purchased: 2015\nTasting Notes Nose: malty, yeasty, peat, peppery, smoke, apple, pear, vanilla, fruits, oak, a funk - sprinbank funk, pudding, sugary cake\nNose[water]: n/a\nTaste: smoke, oak, light peaty pepper, brine, pear, apple, light soft fruits, vanilla, oak, salt, a bit of lemony tartness, candied lemon smoke\nTaste[water]: n/a\nFinish: clean, zesty, salty smoke, apple, lemon, sugar cake, yeasty and bready, tart beer\nFinish[water]: n/a\nBinary Score: 1\nValue Notes Would I buy this (or if already bought, would you buy this again): Yes\nWould I order this in a bar: No\nWould I drink this if someone gave me a glass: Yes\nValue For Money: 3/5\n My Scoring System 0 | I didn\u0026rsquo;t like it | It doesn\u0026rsquo;t mean you won\u0026rsquo;t\n1 | I liked it | It doesn\u0026rsquo;t mean you will\n","date":"2015-11-17","permalink":"/posts/westland_peated/","title":"Westland Peated Malt"},{"content":"==================\nA little background: Recently my friends and I started a whisky club. Every 3 weeks, we get together, bring a few bottles of whiskey/beer and just hangout, talk shop, and ultimately, praise the nectar of the gods. For the beer guys, or really anyone in general, buying a bomber or a six pack is pretty easy. You can drop 15 bucks on a bomber and not feel completely guilty if you made a bad purchase. However, if you spend 90 bucks on bottle of whiskey and you absolutely hate it, you could have easily just spent that on 6 different bombers. So when you\u0026rsquo;re new to whiskey, or are wanting to take the plunge and buy your first bottle, it\u0026rsquo;s incredibly easy to get caught up in some hacky marketing telling you about Rare Casks or about the juice some functional alcoholic who wasn\u0026rsquo;t as good as Dean Martin used to drink.\nBack to barriers of entry; I wanted to make something so that the less-experienced guys, or really anyone in the club, could query whiskies at the store and figure out if the flavors would ignite their interest. The goal would be to not waste money on an absolute dud, but instead find something that you actually might like. That was the primary goal. We found the second benefit born out of our own poor ability to describe a whiskey.\nWhat is that flavor? I can\u0026rsquo;t put my finger on it? It\u0026rsquo;s sweet. Candy-ish, almost. What the hell is that?.\nPop open this beast on localhost -\u0026gt; query the whiskey -\u0026gt; check the taste -\u0026gt; maraschino fuckin\u0026rsquo; cherries. Bam. That\u0026rsquo;s exactly it.\nGiven that this model passes a lot of eye tests for me, I use this little helper so that I can make a decision on a bottle if I\u0026rsquo;m really in the mood for something specific. You might be thinking - Really? Another fucking recommender system?\u0026hellip; If that\u0026rsquo;s you, then I like you, we\u0026rsquo;d probably be friends. You have an open invite to whiskey-club.\nTable of Contents   If you\u0026rsquo;re familiar with whiskey, but unfamiliar with machine learning\n  If you\u0026rsquo;re familiar with machine learning, but unfamiliar with whiskey\n  Methodology\n  The Data\n  Failed experiments\n  We\u0026rsquo;re on to something\n  Final touches\n  Points to note\n  Where next\n  Further Reading\n  If you\u0026rsquo;re familiar with whiskey, but unfamiliar with machine learning This particular model uses word-embeddings. Word embeddings have a long history in machine learning. All starting with Hinton in 1986, Geoffrey Hinton. Learning Distributed Representations of Concepts. Proceedings of the Eigth Annual Conference of the Cognitive Science Society. Amherst, Mass. 1:12, 1986 all the way to the seminal word embedding paper from Mikolov et al. T. Mikolov, K. Chen, G. Corrado, J. Dean. Efficient Estimation of Word Representations in Vector Space. Word embeddings can be thought of as the distributed representation of words that capture a large number of precise syntactic and semantic word-relationships. (Mikolov et al - arXiv:1310.4546). At a very high level, when words are embedded, we can start to use analogous reasoning to gauge the effectiveness of various models. Here is what has become the seminal example to define this type of reasoning:\nQ: man is to woman as king is to what?\nThe algorithm then returns:Queen.\nStealing from the ever amazing Stitchfix Multithreaded blog, here is a small gif that encapsulates this idea to a T.\nIn whisky terms, peat is to smoke as PX is to raisin; where the the difference (smoke - peat) encodes the source from which the flavor came. The most important thing to keep in mind is that none of this is done by explicitly stating to the algorithm \u0026lsquo;During the kilning, barley is dried over a peat fire, and that imparts heavy smoky qualities to the given whiskey\u0026rsquo;, That contextual knowledge is derived via the language used to describe certain whiskies in our reviews corpus. Bottom line: It\u0026rsquo;s fucking cool.\nWhoa. Where can I learn more?\nStitchfix\u0026rsquo;s \u0026lsquo;A Word is Worth a Thousand Vectors\u0026rsquo;\nChris Olah\u0026rsquo;s Deep Learning, NLP, and Representations\nTensorflow\u0026rsquo;s documentation of word2vec\nKarpathy\u0026rsquo;s - The Unreasonable Effectiveness of Recurrent Neural Networks\nIn fact, if you\u0026rsquo;re here because you like whiskey and word2vec is entirely new to you, you should definitely look at the first link before you continue, otherwise, nothing here will make sense. Consider it mandatory reading.\nYou - \u0026lsquo;But I don\u0026rsquo;t like manditory reading\u0026hellip;'\nMe - \u0026lsquo;Welp. Bye, Felicia.'\nDeeper reading of the topic is provided at the bottom of the page\nIf you\u0026rsquo;re familiar with machine learning, but unfamiliar with whiskey \u0026ldquo;How we describe the flavors and smells and aromas in whisky? Chocolate, coffee, cognac, wine; is entirely, entirely personal. I come from Scotland, you may not. You grew up in a different place. You eat different food than I do so your nose will tell you things that mine doesn\u0026rsquo;t and my nose will tell me things that yours doesn\u0026rsquo;t, so don\u0026rsquo;t be lead by anyone saying \u0026lsquo;oh you can taste daffodil picked on a west facing slope\u0026rsquo;, you know, \u0026lsquo;by left-handed people.\u0026rsquo;\u0026rsquo; It doesn\u0026rsquo;t matter. All you need to do is to work out simple, simple questions: Is it sweet? Yes. How you would describe that sweetness? With the 12 year old, we would describe it with honey, if you were to say white sugar or caramel; it\u0026rsquo;s the same thing. It\u0026rsquo;s sweet. If you were to say, the fruit in it, we would describe it as pineapple, if you said fresh green apple or pear, it\u0026rsquo;s the same thing. It\u0026rsquo;s a fresh fruit.\u0026rdquo; - Gerry Tosh, Highland Park link to video\nPreach.\nActually, on second thought just watch the whole video: https://youtu.be/HpoIcXfToVs. It\u0026rsquo;s really simple and really good.\nIn his 2012 book, \u0026lsquo;The World Atlas of Whisky\u0026rsquo;, Dave Broom wrote, \u0026lsquo;Whiskey is currently in the same position that wine was 2-years ago: there is a latent desire to try it, but the consumer doesn\u0026rsquo;t have the language with which to describe what he or she wants. Instead of helping, words have become the barrier.\u0026rsquo;\nWell, Big Davey B, hopefully, a word-embedding approach might be one avenue to break that barrier.\nMy methodology For the majority of my experiments, I played with embeddings built from either gensim, chainer, or theano\nThe Data This data was scraped via r/bourbon, r/scotch, and r/worldwhiskey. There\u0026rsquo;s a pretty clean dataset here that I did the majority of the work on: at the Reddit Whisky Review Archive. Thankfully, this dataset had dimensions imposed upon it by some general whisky reviewing standards. For those unfamiliar, it looks something like this where most reviews have a general stream of consiousness type flavor to them:\n\u0026quot;\u0026quot;\u0026quot; I'm reviewing the Laphroaig 10 Cask-Strength. I like scotch. Scotchy scotch scotch. Scotch in my belly. Nose : Powerful smoke and earthy aromas, seaweed. burnt toffee, salty seaweed. pine tree. leather. Taste : Massive peat smoke, seashore salt, briny wood. peaty smoke. salty tobacco and caramel. vague tartness Finish : peppery, smoke, long, malt, savory, citrus Rating : 89/100 Some follow up notes here. Vintage notes if any. \u0026quot;\u0026quot;\u0026quot;  Failed experiments Initially, I tried to automate the asking of basic questions where a model would return the desired responses. In other words, I tried to naively train a model to figure out that I was asking - \u0026lsquo;what does this whiskey taste like?\u0026rsquo; and have it return a summarization of what I wanted to know. The idea slightly inspired by Sukhbaatar et al - arXiv:1503.08895 or keras babi_rnn.py code where the idea was a type of encoder-decoder network that learns to return specific information relating to a given query. The problem, however, was three fold: pure unadulterated laziness (I did not want to build the type of corpus needed to turn this into a supervised learning problem), lack of data (there are only something like 10K clean reviews, each at like 50-100 words per), and the corpus itself. These whiskey reviews are hardly \u0026lsquo;natural language\u0026rsquo; the way we consider most text. They\u0026rsquo;re most often just lists of words that come via stream of consciousness similar to the example review above. Thus, my guess was that the behavior of any question-answer dataset that I would have wanted to build most likely wouldn\u0026rsquo;t be as well-behaved as the bAbI tasks from this paper: \u0026ldquo;Towards AI-Complete Question Answering: A Set of Prerequisite Toy Tasks\u0026rdquo;\nTo visualize this phenomenon, take a look at this snake-like TSNE rat-tail created via word2vec\u0026rsquo;s word vectors. Note: the majority of the words used to describe whiskies occur in the mullet.\nDespite its 80\u0026rsquo;s haircut, similarity queries still produce reasonable responses to the signals of interest in the way we would expect.\nIn [453]: most_similar('peat') Out[453]: [('brine', 0.6719855070114136), ('iodine', 0.6358067989349365), ('campfire', 0.5941624641418457), ('billowing', 0.5907186269760132), ('medicinal', 0.5881044864654541), ('bandages', 0.5756282806396484), ('peatiness', 0.5754231214523315), ('ghostly', 0.5729676485061646), ('salt', 0.5569401979446411), ('funky', 0.5552007555961609)] In [454]: most_similar('maple') Out[454]: [('pancake', 0.8296730518341064), ('corn', 0.8063130378723145), ('cough', 0.7530097365379333), ('molasses', 0.7310086488723755), ('treacle', 0.7281137704849243), ('pecans', 0.721108078956604), ('honeycomb', 0.7178168892860413), ('plump', 0.7121019959449768), ('comb', 0.7095464468002319), ('turkish', 0.7074028849601746)] In [455]: most_similar('berries') Out[455]: [('raspberries', 0.8527774214744568), ('plums', 0.8487773537635803), ('cherries', 0.8433600068092346), ('sultanas', 0.8285631537437439), ('fruits', 0.8216339945793152), ('currants', 0.8172739744186401), ('pomegranate', 0.8159447908401489), ('prunes', 0.8104561567306519), ('strawberries', 0.8091833591461182), ('blackberries', 0.8010640144348145)]  NOTE: For those unfamiliar with word2vec: you might have the desire to disagree with something the algorithm does. For example, say you disagree that the word \u0026lsquo;berries\u0026rsquo; is closest to \u0026lsquo;raspberries\u0026rsquo; and that \u0026lsquo;blueberries\u0026rsquo; or something should be more similar. Don\u0026rsquo;t worry. Both you, and the algorithm, are not technically wrong, you little precious snowflake.\nGiven that the vectors seem decent enough to use, here was a super naive implementation of \u0026lsquo;asking a question\u0026rsquo;. With this code, each \u0026lsquo;sentence\u0026rsquo; is now represented by a single n-dimensional array that was the result of the summation of all word-vector representations within that \u0026lsquo;sentence\u0026rsquo;. This way we can simply use the cosine distance or a basic dot product of our \u0026lsquo;query\u0026rsquo; word and that of each sentence representation to find out the \u0026lsquo;closest\u0026rsquo; sentence to our given query:\nfrom string import punctuation import numpy as np from sklearn.metrics.pairwise import cosine_similarity punctuation = punctuation.replace(\u0026quot;'\u0026quot;,'').replace('\u0026quot;','') Trans = str.maketrans(punctuation, ' '*len(punctuation)) # embedding table. vector_dic = pickle.load(open('vectors.p')) def review_vectorizer(review): \u0026quot;\u0026quot;\u0026quot; breaks a review in to sentences of interest. Params ------ review : str raw whiskey review Returns ------- reconstructed : list(string) the initial review separated at '\\n' lines, cleaned, all lower case stack : np.ndarray len reconstructed X 150. the sentences in vector representation \u0026quot;\u0026quot;\u0026quot; split_s = [x for x in review.split('\\n') if x != ''] cleaned_sentences = list(sentence.lower().translate(Trans).split() for sentence in split_s) sentences = [] for sentence in cleaned_sentences: sentece_vec = [vector_dic[word] for word in sentence] bag_of_vecs = sum(sentence_vec) full_sentence = \u0026quot; \u0026quot;.join(sentence) sentences.append((bag_of_vecs,full_sentence)) reconstructed = [x[1] for x in sentences] stack = np.vstack([x[0] for x in sentences]) return reconstructed, stack def query(word, review, vector_dic): \u0026quot;\u0026quot;\u0026quot; returns the closest part of the review via cosine distance for the word of interest and the part of the review that is most similar Params ------ word : str the word you want to query the sentence for review : str raw whiskey review vector_dic : dict the vector table built via word2vec Returns ------- sentence : str the sentence with the highest similarity metric to the word of interest \u0026quot;\u0026quot;\u0026quot; assert(len(word.split()==1)) reconstructed, stack = review_vectorizer(review) word_vec = vector_table[word].reshape(1,-1) test = cosine_similarity(word_vec, stack) idx = np.argmax(test) sentence = reconstructed[idx] return sentence # this review was for Zuidam Dutch Rye In [105]: review = \u0026quot;\u0026quot;\u0026quot; Rye whisky. 5 years old. 40% ABV. **color:** Deep copper. **nose:** Some spice, but it takes a back seat to the sweet and fruity notes. Plenty of vanilla. Baked apples. Candied orange peel. Graham crackers. **palate:** Smooth and medium-bodied. Orange sherbet and mulling spices. **finish:** Long but gentle with a cooling menthol sensation. Score: 86/100 This was one of the more unusual samples I've gotten from Master of Malt. I found it to be familiar yet different. There's a lot less spice than American ryes usually have and there was a creamy, citrus theme going on that I wasn't expecting. Overall I liked it and thought it was interesting but based on the prices I saw online (around $70) a full bottle is out of the question given the abundance of inexpensive rye here at home. Worth tracking down a sample if you're curious. \u0026quot;\u0026quot;\u0026quot; In [106]: query('taste', review, vector_dic) Out[106]: 'palate smooth and medium bodied orange sherbet and mulling spices' In [107]: query('finish', review, vector_dic) Out[107]: 'finish long but gentle with a cooling menthol sensation' In [108]: query('smell', review, vector_dic) Out[108]: 'nose some spice but it takes a back seat to the sweet and fruity notes plenty of vanilla baked apples candied orange peel graham crackers'  Pretty cool. In the case above, it\u0026rsquo;s almost like an adaptive regex. But here is an example of it failing.\n review = \u0026quot;\u0026quot;\u0026quot; I'm reviewing the Laphroaig 10 Cask-Strength. I like scotch. Scotchy scotch scotch. Scotch in my belly. Nose : Powerful smoke and earthy aromas, seaweed. burnt toffee, salty seaweed. pine tree. leather. Taste : Massive peat smoke, seashore salt, briny wood. peaty smoke. salty tobacco and caramel. vague tartness Finish : peppery, smoke, long, malt, savory, citrus Rating : 89/100 Some follow up notes here. Vintage notes if any. \u0026quot;\u0026quot;\u0026quot; In [112]: query('smell', review, model) Out[112]: 'taste massive peat smoke seashore salt briny wood peaty smoke salty tobacco and caramel vague tartness'  Despite \u0026lsquo;nose\u0026rsquo; and \u0026lsquo;smell\u0026rsquo; being interchangeable in the data set, the naive method still fails here. Now, I\u0026rsquo;m sure you could train a model to attend to certain words. And that would be really cool. But rather than teach AlphaGo over here to figure this out, let\u0026rsquo;s just be practical and use regex given the lack of variation in the words you\u0026rsquo;d attending to (e.g., there\u0026rsquo;s only so many words we care about when breaking this apart: smell - nose, taste - palate, and just plain old finish).\nSo what the hell am I even trying to do?\nPlaying with vectors is fun, but let\u0026rsquo;s see how far we can get with naive bags of vectors representing individual whiskies before we decide to build an Adviserial Glenfarclas Machine. Rein it in, Matt.\nNot shown: messing around with stop-word removal led to much better representations of whiskey below. Given that so few words are actually used to describe a given sense dimension, the addition of stopwords just kept throwing off things off; calling most whiskies \u0026lsquo;malty\u0026rsquo; on nearly all the dimensions. Not that that\u0026rsquo;s actually wrong, it\u0026rsquo;s just not what I was looking for. Also to note, it also made \u0026lsquo;asking questions\u0026rsquo; less wrong over a few trials. So if we remove stop-words from reviews and then turn them into dictionaries of flavor profiles, we can pseudo test the bag of vector idea for representing a given whiskey as an n-dimensional array. Thus, the nose dimension for the whisky review above would be: Nose = [powerful vector] + [smoke vector] + [earthy vector] + [aromas vector] + [seaweed vector] + [burnt vector] + [toffee vector] + [salty vector] + [seaweed vector] + [pine vector] + [tree vector] + [leather vector].\nFor these tests, I just summed nose + taste + finish so that now a whisky is represented by a single 150 dimensional array\n#... # data is a pandas dictionary # take the names of the whiskies all_w = list(set(data.whiskey)) vecs = [] for w in all_w: # data.bag_of_vecs is a array (words X features) for a given whiskey whiskey_vec = data[data.whiskey == w].bag_of_vecs.values # take the mean value for the given whiskey_value = whiskey_vec.mean() # append to whiskey list vecs.append(whiskey_value) # represent whiskies as an array X = np.vstack(vecs) # build an index dictionary for X word_idx = [(i, w) for (i, w) in enumerate(all_w)] idxdic = {} wdic = {} for i, w in word_idx: idxdic[w] = i wdic[i] = w # build a similarity function def most_similar(whiskey, n=15): # get index of whiskey so we can get it from X idx = idxdic[whiskey] # get cosine distance of given whiskey and all other whiskies result = cosine_similarity(X[idx].reshape(1,-1), X) # sort the results. res_sort = np.argsort(result) # get the indices, remove the first result values = res_sort[0][::-1][1:n] # get the text most_similar = [wdic[i] for i in values] return most_similar  For the whiskey drinkers, the results actually look pretty good\n In [529]: most_similar('Ardbeg 10') Out[529]: ['Laphroaig 10 Cask Strength', 'Ardbeg Uigeadail', 'Laphroaig 10', 'Caol Ila 12', 'Lagavulin 16', 'Talisker Storm', 'Lagavulin 12', 'Laphroaig 18', 'Finlaggan Old Reserve', 'Highland Park 12', 'Lagavulin 1997 Distillers Edition', 'Longrow Peated', 'Bowmore 12', 'Laphroaig Quarter Cask'] In [530]: most_similar('Four Roses Single Barrel') Out[530]: ['W.L. Weller 12', 'Old Weller Antique 107', 'Old Grand Dad 100 Bottled in Bond', 'Knob Creek 9 Small Batch', \u0026quot;Booker's Bourbon\u0026quot;, \u0026quot;Blanton's Original Single Barrel\u0026quot;, \u0026quot;Maker's Mark\u0026quot;, 'Colonel E.H. Taylor Small Batch', \u0026quot;Russell's Reserve Single Barrel\u0026quot;, 'Wild Turkey Rare Breed', 'Elijah Craig Barrel Proof', 'Bulleit Rye', 'Balvenie 14 Caribbean Cask', 'Rock Hill Farms Single Barrel Bourbon'] In [531]: most_similar('Ardbeg Uigeadail') Out[531]: ['Laphroaig 10 Cask Strength', 'Ardbeg 10', 'Lagavulin 16', 'Lagavulin 1997 Distillers Edition', 'Lagavulin 12', 'Laphroaig 18', 'Highland Park 12', 'Talisker Storm', 'Laphroaig Quarter Cask', 'Laphroaig Cairdeas 2013 Portwood', 'Finlaggan Old Reserve', 'Bunnahabhain 12', 'Bowmore 15 Darkest', 'Caol Ila 12']  For an initial test, and for those unfamiliar, this is pretty damn good. It\u0026rsquo;s learning the similarity of really peated whiskies and very soft, heavily wheated whiskies.\nWe\u0026rsquo;re on to something For a second test, if we compare each dimension separately, it starts to look a lot more like I would expect. Here, rather than treat each whisky as a single n-dimensional array, we evaluated the cosine distance of each individual whisky on the taste, nose, and finish dimensions separately to that of all other whiskies respective dimensions.\n#... In [497]: most_similar(\u0026quot;Ardbeg 10\u0026quot;) Out[497]: [('Laphroaig 10', 0.98225728690624237), ('Highland Park 12', 0.98141257464885712), ('Caol Ila 12', 0.98074275016784673), ('Talisker Storm', 0.98014752328395849), ('Lagavulin 16', 0.97319974303245549), ('Kilchoman Machir Bay', 0.97103935480117798), ('Bunnahabhain 12', 0.96981545805931091), ('Ardbeg Uigeadail', 0.96897226393222802), ('Laphroaig 18', 0.96870719194412236), ('Laphroaig Quarter Cask', 0.9685075163841248), ('Talisker 10', 0.96826540708541864), ('Lagavulin 12', 0.9676340079307556), ('Laphroaig 10 Cask Strength', 0.96573835432529442), ('Talisker 57° North', 0.96473644196987152), ('Ledaig 10', 0.96473363816738122)] In [498]: most_similar(\u0026quot;Four Roses Single Barrel\u0026quot;) Out[498]: [('W.L. Weller 12', 0.98039752006530767), ('Old Weller Antique 107', 0.97847300052642816), ('Old Grand Dad 114', 0.97778435051441193), ('Old Grand Dad 100 Bottled in Bond', 0.97584160685539245), ('Elmer T. Lee', 0.97437110304832464), ('Colonel E.H. Taylor Small Batch', 0.97433447599411016), (\u0026quot;Blanton's Original Single Barrel\u0026quot;, 0.97375291109085083), ('Wild Turkey Rare Breed', 0.97322319209575658), ('Balvenie 14 Caribbean Cask', 0.97262043297290801), ('Knob Creek 9 Small Batch', 0.97231273889541625), ('Glenmorangie Astar', 0.97189790010452271), (\u0026quot;Russell's Reserve Single Barrel\u0026quot;, 0.97158373415470123), ('Elijah Craig Barrel Proof', 0.97119110345840454), ('Bulleit Rye', 0.96869591414928435), ('Glenlivet 16 Nadurra', 0.96857591748237604)] In [499]: most_similar(\u0026quot;Ardbeg Uigeadail\u0026quot;) Out[499]: [('Lagavulin 16', 0.98130033433437347), ('Lagavulin 12', 0.98006941139698034), ('Laphroaig 10 Cask Strength', 0.97966188192367554), ('Laphroaig Quarter Cask', 0.97405486822128295), ('Laphroaig 18', 0.97076774656772613), ('Ardbeg 10', 0.96897226393222802), ('Lagavulin 1997 Distillers Edition', 0.9665562635660172), ('Talisker 10', 0.9664697104692459), ('Laphroaig 10', 0.96640086472034459), ('Laphroaig Cairdeas 2013 Portwood', 0.96637647867202758), ('Ardbeg Ardbog', 0.96615850448608398), ('Ardbeg Corryvreckan', 0.96584754943847662), ('Highland Park 12', 0.96557600617408745), ('Springbank 12 Cask Strength', 0.96336297214031219), ('Ardbeg Galileo', 0.96185846149921417)]  With the distances visualized, you can see that they\u0026rsquo;re all so very close. And yet, this fairly naive model does a really good job at evaluating the small subtleties that we whiskey drinkers can detect in our hooch. What\u0026rsquo;s pretty cool is that if someone was really into Four Roses, and you said: \u0026lsquo;try the Glenmorangie Astar\u0026rsquo;, they\u0026rsquo;d probably like it.\nCherry picking here on some data not in the dataset used: This is a Glenmorangie Astar review from The Whiskey Exchange that literally calls it \u0026lsquo;bourbony\u0026rsquo;\n\u0026quot;\u0026quot;\u0026quot; Nose: Ginger, citrus, tweed, paprika, black pepper, becomes very deep with smoky vanilla. Maybe it's the power of suggestion, but the oak is very noticeable - polished wood, like an expensive bookcase. More masculine and intense than the old Artisan, but you'd expect that from the higher abv. Water opens up more sweet notes, natural caramel or butter toffee. Palate: Carries on from the nose. Big, assertive and very bourbony. Exceptionally spicy and peppery but with compensating sweet creamy vanilla. Big mouth-filling weight - quite macho. With water, much more approachable though still very spicy Finish: Immense. Very long and warming, with tingling spices that last an eon. My mouth was still burning (in a good way) several minutes after the swallow. Comment: If you're a committed sherryhead or you don't like oaky stuff or spicy food - avoid. Everyone else will love it. Very much of the same breed as the original Artisan cask, but less gentle - much more powerful and concentrated. An epic dram that needs a drop or two of water to really shine. \u0026quot;\u0026quot;\u0026quot;  So validating this model via this whiskey drinker gut-feeling seems to work out pretty well here. Pretty cool we got to this spot with just bags of vectors and simple cosine distances.So if we think these vectors are at least half-way decent at this point, and validating their integrity with regard to representing whiskies has been done, let\u0026rsquo;s start to make some brash generalizations and make people angry.\nIf we represent a given whiskey\u0026rsquo;s taste-profile as an array consisting of all taste profiles (eg, reviews) for that given whiskey where each individual taste profile is represented as the vector sum of that taste profile, we can start to generalize about a given whiskey accross many different reviews. That array would look something like this (# of reviews x 150-dimensions). We can take the mean of that array to represent the given whiskey\u0026rsquo;s \u0026lsquo;taste-centroid\u0026rsquo;. When we take the \u0026lsquo;taste centroid\u0026rsquo; and do the cosine distance of the centroid to that of all other words in the vector dictionary, we can return the N closests words for that given centroid. And when we do this for taste, smell, and finish, we can make generalizations about a given whiskey across many reviews.\n#... sorted_words = sorted(list(model.keys())) vectors = np.vstack(model[word] for word in sorted_words) def return_means(x,df): # get data from dataframe data = df[df.whiskey == x] #vectorize all words in each part of the review nose = [model[word] for word in data.nose] palate = [model[word] for word in data.palate] finish = [model[word] for word in data.finish] # get means from each sense-dimension mean_nose = np.vstack(nose).mean(axis=0) mean_palate = np.vstack(palate).mean(axis=0) mean_finish = np.vstack(finish).mean(axis=0) return mean_nose, mean_taste, mean_finish def return_dimension_generalization(mean): words = list(model.vocab.keys()) vectors = np.vstack(model[word] for word in words) dists = np.dot(vectors, mean) return sorted(zip(dists, words))[::-1] In [500]: mean_nose, mean_taste, mean_finish = return_means('Ardbeg 10',data) In [501]: return_dimension_generalization(mean_nose)[:100] Out [499]: ['brine,', 'iodine,', 'smoked,', 'meat,', 'salt,', 'seaweed,', 'ash,', 'tar,', 'grass,', 'fennel,', 'lime,', 'pepper,', 'shavings,', 'char,', 'lemon,', 'bacon,', 'sea,', 'hay,', 'melon,', 'salted,', 'paprika,', 'zest,', 'hint,', 'meats,', 'hazelnuts,', 'campfire,', 'charcoal,', 'spray,', 'clover,', 'breeze,', 'wet,', 'herbs,', 'pine,', 'smoke,', 'toffee,', 'peel,', 'pineapple,', 'leather,', 'licorice,', 'anise,', 'citrus,', 'pork,', 'walnuts,', 'olive,', 'allspice,', 'polish,', 'taffy,', 'rubber,', 'pith,', 'clove,', 'vegetal,', 'ginger,', 'salty,', 'almonds,', 'liquorice,', 'soil,', 'faint,', 'carmel,', 'cedar,', 'mint,', 'tobacco,', 'damp,', 'honeydew,', 'buttered,', 'bananas,', 'seed,', 'slight,', 'medicinal,', 'peaches,', 'cloves,', 'rind,', 'oranges,', 'custard,', 'cardamom,', 'charred,', 'graham,', 'sugar,', 'cinnamon,', 'butter,', 'vanilla,', 'marzipan,', 'menthol,', 'maple,', 'caramel,', 'figs,', 'cola,', 'coriander,', 'overripe,', 'crust,', 'cured,', 'almond,', 'aniseed,', 'sauce,', 'cumin,', 'earthy,', 'powder,', 'dates,', 'grapefruit,', 'nutmeg,', 'air']  Holy shit. That looks good. Better do a super biased 1-man whiskey drinking experiment just to be sure. For science.\nObligatory Rat-tail visualization. Colors indicate which sense dimension the given word occurs in:\nFinal touches When considering \u0026lsquo;average scores\u0026rsquo;, I tend to side with a prior belief that with smaller datasets like this, the average should be a bit closer to the median rather than a standard average. Also pushing me towards that belief, newer reviewers tend to rate more on hype and rather than letting reviews like \u0026lsquo;This Pappy from 2014 that I had on the rocks was a 99/100\u0026rsquo; skew the review pools, I wanted a Bayesian function that uses a prior a bit closer to the median so as to avoid any weirdness.\n# Some final touches. # overall mean In [141]: ratings.mean() Out[141]: 83.70441842086358 In [142]: %paste def bayesian_mean(arr, confidence=95): \u0026quot;\u0026quot;\u0026quot; Computes the Bayesian mean from the median prior and confidence. weights mean towards prior Parameters ---------- arr : pandas.DataFrame array you want to compute the bayesian mean on \u0026quot;\u0026quot;\u0026quot; return (confidence * arr.median() + arr.sum()) / (confidence + arr.count()) ## -- End pasted text -- In [143]: bayesian_mean(ratings) Out[143]: 83.719980019980028 In [144]: pappy = df[df.whiskey.str.contains('Pappy Van Winkle Family Reserve Bourbon 23')] In [145]: pappy.shape Out[145]: (4, 9) In [146]: pappy.Rating.mean() Out[146]: 92.0 In [147]: bayesian_mean(pappy.Rating) Out[147]: 92.969387755102048  With regard to the Pappy reviews, here\u0026rsquo;s a perfect example of the prior weighting saying: \u0026ldquo;Yeah, the price is definitely not worth it, but it\u0026rsquo;s probably a smidge better than you\u0026rsquo;re giving it credit for\u0026rdquo;. In the end, does the median prior really matter? Nope..\nPoints to note PSA: For non-whiskey drinking machine learning people, the curious, or the rather new whiskey drinkers\nWith regard to the above fictitious review: \u0026ldquo;This Pappy from 2014 that I had on the rocks was 99/100.\u0026quot; I\u0026rsquo;m not going to tell you how to enjoy your whiskey. But it\u0026rsquo;s a fact that cold temperatures inhibit flavor. Whiskey is no different. Data here \u0026ndash;\u0026gt; Heat activation of TRPM5 underlies thermal sensitivity of sweet taste, TRPM5 and Taste Transduction, TRPM5 - Wikipedia. Science is pretty cool, and if you say fuck it, you\u0026rsquo;ve drank enough whiskey in your day to know that you enjoy muted flavors, by all means, fill that glass up with ice and pour away. It\u0026rsquo;s a two way street, as well. That whiskey tastes like garbage? Put some ice in it. It will make its trash flavors less noticeable.\nSome obligatory whiskey links: check out the main event himself, Mr. Ralfy at - Ralfy\u0026rsquo;s whiskey reviews. Also, check out r/bourbon, r/scotch, and r/worldwhiskey\nEnd PSA\nWhere next? One attempt that I forgot to include was using a Variational Autoencoder to learn the \u0026lsquo;true essence\u0026rsquo; of a given whiskey. The results were pretty cool and I\u0026rsquo;ll have to include them in another post.\nWith regard to this model, a real downfall is that it does not address batch variation, in part because rarely do whiskey reviews in the dataset contain the bottling vintage. However, based on recent reviews (2016), more reviewers are adding the vintage to the review and that\u0026rsquo;s a good thing for every consumer. Hopefully I can tackle this in the future when I have some free time.\nThis dataset also made the assumption that the Reddit Review Archive has every name standardized - it definitely doesn\u0026rsquo;t. I hope to address that when I get some free time as well. It also is a really rough first pass - I\u0026rsquo;m not winning any design awards.\nI\u0026rsquo;m currently working on an area of interest that combines my love a whiskey and my love of computer vision: OCR on curved text. With the goal being that I can take a picture of a bottle and an algorithm can return the given profile dimensions. Normally, I\u0026rsquo;d just amass training data for whiskey bottles and then classify them via some type of VGG offshoot. But the immediate problem is that whiskey labels change all the time. Combine that with all the independent bottlings that a given algorithm would never see in it\u0026rsquo;s training data and we\u0026rsquo;re be in for a real problem. Plain and simple: we need to teach the fucker how to read then just let it do it\u0026rsquo;s thing. Some possible solutions are from papers using deep architectures trying to solve text recognition in the wild: Reading Text in the Wild with Convolutional Neural Networks and Deep Features for Text Spotting. My initial tests have them failing with bottles and other texts that are non-linear. Perhaps something like Alignment of Curved Text Strings for Enhanced OCR Readability might be beneficial. Truth be told, however, I\u0026rsquo;d rather have a network learn what curved data looks like in an end-to-end fashion rather than correct for it. Perhaps some non-linear pixel RNN ensemble? Who knows. Not sure what the outcome will be, but it will be awesome to play around with.\nLastly, Like most of these problems, they suffer from a lack of data, and hopefully my small \u0026lsquo;bottle text dataset\u0026rsquo; might come in handy down the road. It\u0026rsquo;s a work in progress, but I\u0026rsquo;ll release it when I get there eventually.\nMaking a good whiskey is hard. Thankfully, finding one you might like is easy.\nFurther Reading Geoffrey Hinton. Learning Distributed Representations of Concepts. Proceedings of the Eigth Annual Conference of the Cognitive Science Society. Amherst, Mass. 1:12, 1986\nYoshua Bengio, Réjean Ducharme and Pascal Vincent, A Neural Probabilistic Language Model, in: Advances in Neural Information Processing Systems 13 (NIPS'00), MIT Press, 2001\nYoshua Bengio, Réjean Ducharme, Pascal Vincent and Christian Jauvin, A Neural Probabilistic Language Model (2003), in: Journal of Machine Learning Research, 3(1137\u0026ndash;1155)\nAndriy Mnih and Geoffrey Hinton. Three New Graphical Models for Statistical Language Modelling.International Conference on Machine Learning 2007 (ICML 2007)\nR. Collobert and J. Weston. A Unified Architecture for Natural Language Processing: Deep Neural Networks with Multitask Learning. In International Conference on Machine Learning, ICML, 2008.\nR. Collobert, J. Weston, L. Bottou, M. Karlen, K. Kavukcuoglu and P. Kuksa. Natural Language Processing (Almost) from Scratch. Journal of Machine Learning Research, 12:2493-2537, 2011.\nJ. Turian, L. Ratinov, Y. Bengio. Word representations: A simple and general method for semi-supervised learning. Association for Computational Linguistics. 2010\nR. Collobert, J. Weston, L. Bottou, M. Karlen, K. Kavukcuoglu and P. Kuksa. Natural Language Processing (Almost) from Scratch. Journal of Machine Learning Research, 12:2493-2537, 2011.\nT Mikolov, S Kombrink, L Burget, JH Černocký, S Khudanpur. Extensions of recurrent neural network language model. Acoustics, Speech and Signal Processing (ICASSP), 2011\nT. Mikolov, K. Chen, G. Corrado, J. Dean. Efficient Estimation of Word Representations in Vector Space.\nBunch of Reddit Whisky links\nBourbon for Scotch Drinkers\nScotch for Bourbon Drinkers\n","date":"2015-11-17","permalink":"/posts/whiskeyembedding/","title":"Whiskey Embeddings"},{"content":" \u0026ldquo;Hello, friend?\u0026rdquo; That\u0026rsquo;s lame. Maybe I should give you a name\u0026hellip; But that\u0026rsquo;s a slippery slope. You\u0026rsquo;re only in my head. We have to remember that\u0026hellip; Shit. It\u0026rsquo;s actually happened. I\u0026rsquo;m talking to an imaginary person.\n— Mr. Robot S01E01\n ","date":"1986-09-17","permalink":"/posts/my-first-post/","title":"Hello Friend"}]