Skip to content

Projection System - Reference Implementation

Overview

This document provides a complete reference implementation of the projection formalism developed in PROJECTION_FORMALISM.md. This serves as both documentation and a specification for the actual implementation in LangCalc.


1. Core Abstractions

1.1 Projection Base Class

from abc import ABC, abstractmethod
from typing import List, Set, Tuple, Optional
import numpy as np

class Projection(ABC):
    """
    Abstract base class for context projections.

    A projection transforms a query context before matching against the corpus.
    Mathematically: π: Σ* × 2^(Σ*) → Σ*
    """

    @abstractmethod
    def project(self, context: List[int], corpus: List[int]) -> List[int]:
        """
        Project context onto corpus.

        Args:
            context: Query context (sequence of token IDs)
            corpus: Corpus (sequence of token IDs)

        Returns:
            Transformed context
        """
        pass

    def project_multi(self, context: List[int], corpus: List[int]) -> Set[Tuple[int, ...]]:
        """
        Multi-valued projection (returns multiple candidate contexts).

        Default implementation returns singleton set. Override for projections
        that generate multiple candidates (e.g., synonym expansion).

        Args:
            context: Query context
            corpus: Corpus

        Returns:
            Set of transformed contexts (as tuples for hashability)
        """
        return {tuple(self.project(context, corpus))}

    # Composition operators

    def __rshift__(self, other: 'Projection') -> 'Projection':
        """
        Sequential composition: self >> other

        Applies self first, then other.
        Mathematically: (π₁ >> π₂)(x, C) = π₂(π₁(x, C), C)
        """
        return SequentialProjection(self, other)

    def __or__(self, other: 'Projection') -> 'Projection':
        """
        Parallel composition (union): self | other

        Returns multiple projected contexts.
        Mathematically: (π₁ | π₂)(x, C) = {π₁(x, C), π₂(x, C)}
        """
        return ParallelProjection(self, other)

    def __matmul__(self, weight: float) -> 'Projection':
        """
        Weighted projection: projection @ weight

        For use in stochastic composition.
        """
        return WeightedProjection(self, weight)

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}()"

1.2 Augmentation Base Class

class Augmentation(ABC):
    """
    Abstract base class for corpus augmentations.

    An augmentation expands the corpus by adding transformed variants.
    Mathematically: α: 2^(Σ*) → 2^(Σ*)
    """

    @abstractmethod
    def augment(self, corpus: List[int]) -> List[int]:
        """
        Augment corpus with transformed variants.

        Args:
            corpus: Original corpus

        Returns:
            Augmented corpus (original + variants)
        """
        pass

    def __add__(self, other: 'Augmentation') -> 'Augmentation':
        """
        Compose augmentations: self + other

        Applies both augmentations to the corpus.
        """
        return ComposedAugmentation(self, other)

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}()"

2. Composition Implementations

2.1 Sequential Projection

class SequentialProjection(Projection):
    """
    Sequential composition of projections.

    (π₁ >> π₂)(x, C) = π₂(π₁(x, C), C)
    """

    def __init__(self, first: Projection, second: Projection):
        self.first = first
        self.second = second

    def project(self, context: List[int], corpus: List[int]) -> List[int]:
        intermediate = self.first.project(context, corpus)
        return self.second.project(intermediate, corpus)

    def project_multi(self, context: List[int], corpus: List[int]) -> Set[Tuple[int, ...]]:
        # Apply first projection (may be multi-valued)
        intermediate_set = self.first.project_multi(context, corpus)

        # Apply second projection to each result
        result = set()
        for intermediate in intermediate_set:
            result.update(self.second.project_multi(list(intermediate), corpus))

        return result

    def __repr__(self) -> str:
        return f"({self.first} >> {self.second})"

2.2 Parallel Projection

class ParallelProjection(Projection):
    """
    Parallel composition (union) of projections.

    (π₁ | π₂)(x, C) = {π₁(x, C), π₂(x, C)}
    """

    def __init__(self, *projections: Projection):
        self.projections = projections

    def project(self, context: List[int], corpus: List[int]) -> List[int]:
        # For single-valued interface, return first projection
        # (This is somewhat arbitrary for parallel composition)
        return self.projections[0].project(context, corpus)

    def project_multi(self, context: List[int], corpus: List[int]) -> Set[Tuple[int, ...]]:
        # Union of all projections
        result = set()
        for proj in self.projections:
            result.update(proj.project_multi(context, corpus))
        return result

    def __repr__(self) -> str:
        return " | ".join(str(p) for p in self.projections)

2.3 Weighted Projection

class WeightedProjection:
    """
    Weighted projection for stochastic composition.

    Not a Projection itself, but used in mixture models.
    """

    def __init__(self, projection: Projection, weight: float):
        self.projection = projection
        self.weight = weight

    def __repr__(self) -> str:
        return f"{self.weight} * {self.projection}"

2.4 Composed Augmentation

class ComposedAugmentation(Augmentation):
    """
    Composition of multiple augmentations.

    (α₁ + α₂)(C) applies both augmentations.
    """

    def __init__(self, *augmentations: Augmentation):
        self.augmentations = augmentations

    def augment(self, corpus: List[int]) -> List[int]:
        result = corpus
        for aug in self.augmentations:
            result = aug.augment(result)
        return result

    def __repr__(self) -> str:
        return " + ".join(str(a) for a in self.augmentations)

3. Basic Projections

3.1 Identity Projection

class IdentityProjection(Projection):
    """
    Identity projection: π(x, C) = x

    No transformation.
    """

    def project(self, context: List[int], corpus: List[int]) -> List[int]:
        return context

3.2 Recency Projection

class RecencyProjection(Projection):
    """
    Recency projection: truncate to most recent k tokens.

    π_rec(x, C) = x[-k:] if |x| > k else x
    """

    def __init__(self, max_length: int):
        """
        Args:
            max_length: Maximum context length to keep
        """
        self.max_length = max_length

    def project(self, context: List[int], corpus: List[int]) -> List[int]:
        if len(context) <= self.max_length:
            return context
        return context[-self.max_length:]

    def __repr__(self) -> str:
        return f"RecencyProjection(max_length={self.max_length})"

3.3 Truncation Projection

class TruncationProjection(Projection):
    """
    Truncation projection: keep first k tokens.

    Useful for testing or limiting context scope.
    """

    def __init__(self, max_length: int):
        self.max_length = max_length

    def project(self, context: List[int], corpus: List[int]) -> List[int]:
        return context[:self.max_length]

    def __repr__(self) -> str:
        return f"TruncationProjection(max_length={self.max_length})"

4. Normalization Projections

4.1 Lowercase Projection

class LowercaseProjection(Projection):
    """
    Lowercase projection: convert context to lowercase.

    π_lower(x, C) = lowercase(x)

    Note: If corpus is augmented with lowercase variant,
    this projection can be skipped (projection-augmentation duality).
    """

    def project(self, context: List[int], corpus: List[int]) -> List[int]:
        try:
            text = bytes(context).decode('utf-8')
            lower_text = text.lower()
            return list(lower_text.encode('utf-8'))
        except (UnicodeDecodeError, UnicodeEncodeError):
            # If not valid UTF-8, return unchanged
            return context

    def __repr__(self) -> str:
        return "LowercaseProjection()"

4.2 Uppercase Projection

class UppercaseProjection(Projection):
    """Uppercase projection: convert context to uppercase."""

    def project(self, context: List[int], corpus: List[int]) -> List[int]:
        try:
            text = bytes(context).decode('utf-8')
            return list(text.upper().encode('utf-8'))
        except (UnicodeDecodeError, UnicodeEncodeError):
            return context

4.3 Whitespace Normalization Projection

import re

class WhitespaceProjection(Projection):
    """
    Whitespace normalization: collapse consecutive whitespace to single space.

    π_ws(x, C) = normalize_whitespace(x)
    """

    def project(self, context: List[int], corpus: List[int]) -> List[int]:
        try:
            text = bytes(context).decode('utf-8')
            normalized = re.sub(r'\s+', ' ', text)
            return list(normalized.encode('utf-8'))
        except (UnicodeDecodeError, UnicodeEncodeError):
            return context

    def __repr__(self) -> str:
        return "WhitespaceProjection()"

4.4 Unicode Normalization Projection

import unicodedata

class UnicodeNormalizationProjection(Projection):
    """
    Unicode normalization projection.

    π_unicode(x, C) = normalize(x, form)
    """

    def __init__(self, form: str = 'NFC'):
        """
        Args:
            form: Unicode normalization form ('NFC', 'NFD', 'NFKC', 'NFKD')
        """
        if form not in ('NFC', 'NFD', 'NFKC', 'NFKD'):
            raise ValueError(f"Invalid normalization form: {form}")
        self.form = form

    def project(self, context: List[int], corpus: List[int]) -> List[int]:
        try:
            text = bytes(context).decode('utf-8')
            normalized = unicodedata.normalize(self.form, text)
            return list(normalized.encode('utf-8'))
        except (UnicodeDecodeError, UnicodeEncodeError):
            return context

    def __repr__(self) -> str:
        return f"UnicodeNormalizationProjection(form='{self.form}')"

5. Advanced Projections

5.1 Edit Distance Projection

class EditDistanceProjection(Projection):
    """
    Edit distance projection: find most similar suffix in corpus.

    π_edit(x, C) = argmin_{s ∈ Suffixes(C)} {edit(x, s) : edit(x, s) ≤ d}

    WARNING: This is expensive (O(|x| * |C|)). Use sparingly.
    """

    def __init__(self, max_distance: int = 2, suffix_length: Optional[int] = None):
        """
        Args:
            max_distance: Maximum edit distance to consider
            suffix_length: Only check suffixes of this length (for efficiency)
        """
        self.max_distance = max_distance
        self.suffix_length = suffix_length

    def project(self, context: List[int], corpus: List[int]) -> List[int]:
        if not context:
            return context

        # Limit search to suffixes of specific length if specified
        search_length = self.suffix_length or len(context)

        # Find best matching suffix (simplified implementation)
        best_suffix = None
        best_distance = float('inf')

        # Search through corpus for matching suffixes
        # (In practice, would use suffix array for efficiency)
        for i in range(len(corpus)):
            suffix = corpus[max(0, i - search_length):i]
            if not suffix:
                continue

            distance = self._edit_distance(context, suffix)
            if distance <= self.max_distance and distance < best_distance:
                best_distance = distance
                best_suffix = suffix

        return best_suffix if best_suffix is not None else context

    def _edit_distance(self, s1: List[int], s2: List[int]) -> int:
        """Compute Levenshtein distance between two sequences."""
        if len(s1) < len(s2):
            return self._edit_distance(s2, s1)

        if not s2:
            return len(s1)

        previous_row = range(len(s2) + 1)
        for i, c1 in enumerate(s1):
            current_row = [i + 1]
            for j, c2 in enumerate(s2):
                # Cost of insertions, deletions, or substitutions
                insertions = previous_row[j + 1] + 1
                deletions = current_row[j] + 1
                substitutions = previous_row[j] + (c1 != c2)
                current_row.append(min(insertions, deletions, substitutions))
            previous_row = current_row

        return previous_row[-1]

    def __repr__(self) -> str:
        return f"EditDistanceProjection(max_distance={self.max_distance})"

5.2 Longest Suffix Projection

class LongestSuffixProjection(Projection):
    """
    Longest suffix projection: find longest matching suffix.

    π_lms(x, C) = LMS(x, C)

    Uses suffix array for efficient lookup.
    """

    def __init__(self, min_length: int = 1):
        """
        Args:
            min_length: Minimum suffix length to consider
        """
        self.min_length = min_length
        self._suffix_array = None

    def project(self, context: List[int], corpus: List[int]) -> List[int]:
        # Build suffix array if not cached
        # (In practice, would build once and reuse)
        if self._suffix_array is None:
            from infinigram import Infinigram
            self._infinigram = Infinigram(corpus=corpus)

        # Find longest matching suffix
        # This is a simplified version - actual implementation would
        # use suffix array binary search
        for length in range(len(context), self.min_length - 1, -1):
            suffix = context[-length:]
            # Check if this suffix exists in corpus
            # (Would use suffix array lookup in practice)
            if self._exists_in_corpus(suffix, corpus):
                return suffix

        return context[-self.min_length:] if len(context) >= self.min_length else context

    def _exists_in_corpus(self, pattern: List[int], corpus: List[int]) -> bool:
        """Check if pattern exists in corpus (naive implementation)."""
        if not pattern:
            return False
        pattern_tuple = tuple(pattern)
        for i in range(len(corpus) - len(pattern) + 1):
            if tuple(corpus[i:i + len(pattern)]) == pattern_tuple:
                return True
        return False

    def __repr__(self) -> str:
        return f"LongestSuffixProjection(min_length={self.min_length})"

6. Basic Augmentations

6.1 Lowercase Augmentation

class LowercaseAugmentation(Augmentation):
    """
    Lowercase augmentation: α_lower(C) = C ∪ {lowercase(C)}

    Doubles corpus size.
    """

    def augment(self, corpus: List[int]) -> List[int]:
        try:
            text = bytes(corpus).decode('utf-8')
            lower_text = text.lower()
            lower_bytes = list(lower_text.encode('utf-8'))
            # Return original + lowercase
            return corpus + lower_bytes
        except UnicodeDecodeError:
            return corpus

    def __repr__(self) -> str:
        return "LowercaseAugmentation()"

6.2 Case Augmentation

class CaseAugmentation(Augmentation):
    """
    Full case augmentation: α_case(C) = C ∪ {lower, upper, title}

    Quadruples corpus size.
    """

    def augment(self, corpus: List[int]) -> List[int]:
        try:
            text = bytes(corpus).decode('utf-8')
            variants = [
                text,
                text.lower(),
                text.upper(),
                text.title(),
            ]
            return [byte for variant in variants
                    for byte in variant.encode('utf-8')]
        except UnicodeDecodeError:
            return corpus

    def __repr__(self) -> str:
        return "CaseAugmentation()"

6.3 Whitespace Augmentation

class WhitespaceAugmentation(Augmentation):
    """
    Whitespace augmentation: α_ws(C) = C ∪ {normalize_ws(C)}

    Doubles corpus size.
    """

    def augment(self, corpus: List[int]) -> List[int]:
        try:
            text = bytes(corpus).decode('utf-8')
            normalized = re.sub(r'\s+', ' ', text)
            return corpus + list(normalized.encode('utf-8'))
        except UnicodeDecodeError:
            return corpus

    def __repr__(self) -> str:
        return "WhitespaceAugmentation()"

7. Model Integration

7.1 Projected Language Model

class ProjectedModel(LanguageModel):
    """
    Language model with projection applied to context.

    M^π(x, a) = M(π(x, C), a)
    """

    def __init__(self, base_model: LanguageModel, projection: Projection, corpus: List[int]):
        """
        Args:
            base_model: Underlying language model
            projection: Projection to apply to context
            corpus: Corpus (needed for projection)
        """
        self.base_model = base_model
        self.projection = projection
        self.corpus = corpus

    def logprobs(self, tokens: List[int], context: Optional[List[int]] = None) -> np.ndarray:
        if context is None:
            context = []

        # Apply projection to context
        projected_context = self.projection.project(context, self.corpus)

        # Query base model with projected context
        return self.base_model.logprobs(tokens, projected_context)

    def sample(self, context: Optional[List[int]] = None,
               temperature: float = 1.0, max_tokens: int = 100) -> List[int]:
        if context is None:
            context = []

        projected_context = self.projection.project(context, self.corpus)
        return self.base_model.sample(projected_context, temperature, max_tokens)

    def score(self, sequence: List[int]) -> float:
        # For scoring, apply projection to increasingly long prefixes
        # This is one possible interpretation
        return self.base_model.score(sequence)

    def __repr__(self) -> str:
        return f"ProjectedModel({self.base_model} @ {self.projection})"

7.2 Multi-Projection Model

class MultiProjectionModel(LanguageModel):
    """
    Model that tries multiple projections and combines results.

    M^{π_i, w_i}(x, a) = Σ_i w_i M(π_i(x, C), a)
    """

    def __init__(self, base_model: LanguageModel,
                 weighted_projections: List[Tuple[Projection, float]],
                 corpus: List[int]):
        """
        Args:
            base_model: Underlying language model
            weighted_projections: List of (projection, weight) pairs
            corpus: Corpus
        """
        self.base_model = base_model
        self.weighted_projections = weighted_projections
        self.corpus = corpus

        # Normalize weights
        total_weight = sum(w for _, w in weighted_projections)
        self.weighted_projections = [
            (proj, w / total_weight)
            for proj, w in weighted_projections
        ]

    def logprobs(self, tokens: List[int], context: Optional[List[int]] = None) -> np.ndarray:
        if context is None:
            context = []

        # Weighted mixture of projections
        result = np.zeros(len(tokens))
        for projection, weight in self.weighted_projections:
            projected_context = projection.project(context, self.corpus)
            logprobs = self.base_model.logprobs(tokens, projected_context)
            result += weight * np.exp(logprobs)  # Convert to probs, mix, convert back

        return np.log(result + 1e-10)  # Back to log space

    def sample(self, context: Optional[List[int]] = None,
               temperature: float = 1.0, max_tokens: int = 100) -> List[int]:
        if context is None:
            context = []

        # Randomly choose projection based on weights
        import random
        rand = random.random()
        cumsum = 0
        for projection, weight in self.weighted_projections:
            cumsum += weight
            if rand < cumsum:
                projected_context = projection.project(context, self.corpus)
                return self.base_model.sample(projected_context, temperature, max_tokens)

        # Fallback to last projection
        projected_context = self.weighted_projections[-1][0].project(context, self.corpus)
        return self.base_model.sample(projected_context, temperature, max_tokens)

    def score(self, sequence: List[int]) -> float:
        return self.base_model.score(sequence)

    def __repr__(self) -> str:
        proj_str = ", ".join(f"{w}*{p}" for p, w in self.weighted_projections)
        return f"MultiProjectionModel({self.base_model} @ [{proj_str}])"

8. Usage Examples

8.1 Simple Case-Insensitive Model

from langcalc.models import InfinigramModel
from langcalc.projections import LowercaseProjection

# Approach 1: Query-time projection
corpus = list("Hello World".encode('utf-8'))
projection = LowercaseProjection()
model = ProjectedModel(
    InfinigramModel(corpus),
    projection=projection,
    corpus=corpus
)

# Approach 2: Training-time augmentation (more efficient)
from langcalc.augmentations import LowercaseAugmentation

augmented_corpus = LowercaseAugmentation().augment(corpus)
model = InfinigramModel(augmented_corpus)

8.2 Composed Projections

# Normalize whitespace, then lowercase, then truncate to 10 tokens
projection = (
    WhitespaceProjection() >>
    LowercaseProjection() >>
    RecencyProjection(max_length=10)
)

model = ProjectedModel(InfinigramModel(corpus), projection, corpus)

8.3 Multi-Projection Model

# Try multiple projections with different weights
projections = [
    (IdentityProjection(), 0.5),          # Original context
    (LowercaseProjection(), 0.3),         # Lowercase
    (RecencyProjection(5), 0.2),          # Recent tokens only
]

model = MultiProjectionModel(InfinigramModel(corpus), projections, corpus)

8.4 Standard Normalization

# Common preprocessing pipeline
projection = (
    WhitespaceProjection() >>
    LowercaseProjection() >>
    UnicodeNormalizationProjection('NFC')
)

model = ProjectedModel(InfinigramModel(corpus), projection, corpus)

9. Testing

9.1 Projection Tests

def test_identity_projection():
    proj = IdentityProjection()
    context = [1, 2, 3]
    corpus = [4, 5, 6]
    assert proj.project(context, corpus) == context

def test_recency_projection():
    proj = RecencyProjection(max_length=3)
    context = [1, 2, 3, 4, 5]
    corpus = []
    assert proj.project(context, corpus) == [3, 4, 5]

def test_lowercase_projection():
    proj = LowercaseProjection()
    context = list("Hello".encode('utf-8'))
    corpus = []
    result = bytes(proj.project(context, corpus)).decode('utf-8')
    assert result == "hello"

def test_sequential_composition():
    proj = WhitespaceProjection() >> LowercaseProjection()
    context = list("Hello  World".encode('utf-8'))
    corpus = []
    result = bytes(proj.project(context, corpus)).decode('utf-8')
    assert result == "hello world"

9.2 Augmentation Tests

def test_lowercase_augmentation():
    aug = LowercaseAugmentation()
    corpus = list("Hello".encode('utf-8'))
    result = aug.augment(corpus)

    text = bytes(result).decode('utf-8')
    assert "Hello" in text
    assert "hello" in text
    assert len(result) == 2 * len(corpus)

def test_augmentation_composition():
    aug = LowercaseAugmentation() + WhitespaceAugmentation()
    corpus = list("Hello  World".encode('utf-8'))
    result = aug.augment(corpus)

    # Should contain: original, lowercase, normalized whitespace, and combinations
    text = bytes(result).decode('utf-8')
    assert "Hello  World" in text
    assert "hello  world" in text

10. Conclusion

This reference implementation provides:

  1. Core abstractions - Projection and Augmentation base classes
  2. Composition operators - Sequential (>>), parallel (|), weighted (@)
  3. Basic projections - Identity, recency, truncation, normalization
  4. Advanced projections - Edit distance, longest suffix
  5. Augmentations - Case, whitespace, Unicode normalization
  6. Model integration - ProjectedModel and MultiProjectionModel
  7. Usage examples - Common patterns and workflows
  8. Testing strategy - Unit tests for each component

This serves as the specification for implementing the projection system in LangCalc.