Source code for nd2py.search.ndformer.ndformer_model_flash_ansr
# Copyright (c) 2024-present, Yumeow. Licensed under the MIT License.
"""
FLASH-ANSR variant of NDFormer.
Based on the paper describing FLASH-ANSR architecture:
- Pre-norm Transformer (norm_first=True)
- Set Transformer encoder with induction points
- FlashAttention support (via torch.nn.MultiheadAttention with backend selection)
"""
from __future__ import annotations
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional
from .ndformer_model import NDFormerModel
from .ndformer_config import NDFormerModelConfig
from .ndformer_tokenizer import NDFormerTokenizer
[docs]
class SetTransformerEncoder(nn.Module):
"""
Set Transformer Encoder with induction points.
Drop-in replacement for nn.TransformerEncoder with identical forward signature.
Induction points are used internally but output shape matches input shape.
Architecture (Lee et al., 2019):
1. Induction points attend to input data (cross-attention)
2. Self-attention among induction points
3. Induction points attend back to original positions (output projection)
Args:
encoder_layer: Not used (kept for API compatibility)
num_layers: Number of transformer layers
norm: Final normalization layer
d_model: Embedding dimension
n_induction_points: Number of learnable induction points
"""
[docs]
def __init__(
self,
encoder_layer=None,
num_layers=2,
norm=None,
enable_nested_tensor=True,
mask_check=True,
d_model: int = None,
n_induction_points: int = 128,
n_head: int = 8,
):
super().__init__()
self.n_induction_points = n_induction_points
self.induction_points = nn.Parameter(torch.randn(1, n_induction_points, d_model))
self.transformer = nn.TransformerEncoder(
encoder_layer=encoder_layer,
num_layers=num_layers,
norm=norm,
enable_nested_tensor=enable_nested_tensor,
mask_check=mask_check
)
self.attention = nn.MultiheadAttention(
embed_dim=d_model,
num_heads=n_head,
batch_first=True,
)
[docs]
def forward(
self,
src: torch.Tensor,
mask=None,
src_key_padding_mask: Optional[torch.Tensor] = None,
is_causal=None,
) -> torch.Tensor:
"""
Args:
src: Input tensor (batch, seq_len, d_model) - GNN encoded nodes
mask: Not used (kept for API compatibility)
src_key_padding_mask: Padding mask (batch, seq_len), True for padding
is_causal: Not used (kept for API compatibility)
Returns:
Output tensor with same shape as input (batch, seq_len, d_model)
"""
ind_points = self.induction_points.expand(src.shape[0], -1, -1) # (batch, I, d_model)
combined_src = torch.cat([ind_points, src], dim=1) # (batch, I+seq_len, d_model)
if src_key_padding_mask is not None:
# Pad induction points as False (valid), keep original padding
ind_pad = torch.zeros(
*ind_points.shape[:2],
device=src_key_padding_mask.device,
dtype=src_key_padding_mask.dtype
)
combined_src_key_padding_mask = torch.cat([ind_pad, src_key_padding_mask], dim=1)
else:
combined_src_key_padding_mask = None
out = self.transformer(combined_src, src_key_padding_mask=combined_src_key_padding_mask)
ind_out = out[:, :self.n_induction_points, :] # (batch, I, d_model)
attn_output, _ = self.attention(query=src, key=ind_out, value=ind_out) # (batch, seq_len, d_model)
attn_output[src_key_padding_mask, :] = 0.0
return attn_output
[docs]
@NDFormerModel.register_model('flash_ansr')
class FlashANSRNDFormer(NDFormerModel):
"""
FLASH-ANSR: Transformer-based symbolic regression with Set Transformer encoder
and pre-norm architecture.
Key features:
- Pre-norm Transformer (norm_first=True)
- Set Transformer encoder with learnable induction points
- LayerNorm for normalization
Reuses NDFormerModel.encode_graph() and NDFormerModel.decode_sequence().
"""
[docs]
def __init__(self, config: NDFormerModelConfig, tokenizer: NDFormerTokenizer):
super().__init__(config, tokenizer)
# Replace encoder with Set Transformer variant (induction points)
# Decoder uses standard nn.TransformerDecoder with norm_first=True
self.n_induction_points = getattr(config, 'n_induction_points', 128)
self.transformer_encoder = SetTransformerEncoder(
nn.TransformerEncoderLayer(
d_model=self.config.d_emb,
nhead=self.config.n_head,
dim_feedforward=self.config.d_ff,
dropout=self.config.dropout,
batch_first=True,
# norm_first=True,
),
num_layers=self.config.n_transformer_encoder_layers,
# norm=nn.LayerNorm(self.config.d_emb),
d_model=self.config.d_emb,
n_induction_points=self.n_induction_points,
)
self.transformer_decoder = nn.TransformerDecoder(
nn.TransformerDecoderLayer(
d_model=self.config.d_emb,
nhead=self.config.n_head,
dim_feedforward=self.config.d_ff,
dropout=self.config.dropout,
batch_first=True,
norm_first=True,
),
num_layers=self.config.n_transformer_decoder_layers,
norm=nn.LayerNorm(self.config.d_emb),
)