Source code for last.contexts

# Copyright 2023 The LAST Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Context dependencies."""

import abc
import dataclasses

import jax
import jax.numpy as jnp

from last import scatter_reduce
from last import semirings


[docs] class ContextDependency(abc.ABC): r"""Interface for context dependencies. A context dependency is a deterministic finite automaton (DFA) that accepts $\Sigma^*$ ($\Sigma$ is the lexical output vocabulary). The state ids in [0, num_states) of a context dependency encodes the output history. See Sections 3 and 4 of the GNAT paper for more details. Note: we assume all context dependency states to be final. Subclasses should implement the following methods: - shape - start - next_state - forward_reduce - backward_broadcast """
[docs] @abc.abstractmethod def shape(self) -> tuple[int, int]: r"""Shape of a context dependency. Returns: (num_states, vocab_size) tuple: - num_states: The number of states in the context dependency DFA. - vocab_size: The size of the output vocabulary, $|\Sigma|$. """
[docs] @abc.abstractmethod def start(self) -> int: """The start state id."""
[docs] @abc.abstractmethod def next_state(self, state: jnp.ndarray, label: jnp.ndarray) -> jnp.ndarray: """Takes a transition in the DFA. Note: because 0 is the epsilon label, it is normally not fed to `next_state`. For consistency, we require that `next_state` should return `state[i]` when `label[i] == 0`. Args: state: [batch_dims...] int32 source state ids. label: [batch_dims...] int32 output labels in the range [0, vocab_size]. Returns: [batch_dims...] next state ids. """
[docs] @abc.abstractmethod def forward_reduce(self, weights: jnp.ndarray, semiring: semirings.Semiring[jnp.ndarray]) -> jnp.ndarray: """The reduction used in the forward algorithm. For each state q, we sum over all source states p and labels y that lead to state q, i.e. result[..., q] = sum_{p-y->q} weights[..., p, y] Args: weights: [batch_dims..., num_states, vocab_size] weights. semiring: The semiring for carrying out the summation. Returns: [batch_dims..., num_states] reduced weights. """
[docs] @abc.abstractmethod def backward_broadcast(self, weights: jnp.ndarray) -> jnp.ndarray: """The broadcast used in the backward algorithm. For each state q, we broadcast its weight to all the (source state p, label y) pairs leading to state q, i.e. result[..., p, y] = weights[..., q] Args: weights: [batch_dims..., num_states] weights. Returns: [batch_dims..., num_states, vocab_size] broadcasted weights. """
# Methods below are implemented using the basic operations above.
[docs] def walk_states(self, labels: jnp.ndarray) -> jnp.ndarray: """Walks a context dependency following label sequences. Args: labels: [batch_dims..., num_labels] int32 label sequences. Each element is in the range [0, vocab_size]. Returns: [batch_dims..., num_labels + 1] int32 context states. states[..., 0] equals to the start state of the context dependency; states[..., i] for i > 0 is the state after observing labels[..., i - 1]. """ batch_dims = labels.shape[:-1] start = jnp.broadcast_to(self.start(), batch_dims) def step(state, label): next_state = self.next_state(state, label) return next_state, next_state time_major_labels = jnp.transpose( labels, [len(batch_dims), *range(len(batch_dims))]) _, time_major_states = jax.lax.scan(step, start, time_major_labels) states = jnp.transpose(time_major_states, [*range(1, labels.ndim), 0]) return jnp.concatenate([jnp.expand_dims(start, axis=-1), states], axis=-1)
[docs] @dataclasses.dataclass(frozen=True) class FullNGram(ContextDependency): """Full n-gram context dependency as described in Section 4.1 of the GNAT paper. For a given vocab_size > 0, context_size >= 0, - The set of states represents the set of all possible n-grams from length 0 to length context_size for an output vocabulary of vocab_size. - Each n-gram is assigned their lexicographic order as the id. The empty n-gram is state 0, followed by the vocab_size unigrams as states 1 to vocab_size, and so on. - The start state is 0 (the empty n-gram). - All states are final. - From each n-gram state, there is an arc for each label in the vocabulary leading to the n-gram with the label appended to the end, with the length of the n-gram capped at context_size. Attributes: vocab_size: Lexical output vocabulary size. context_size: Maximum n-gram context size. """ vocab_size: int context_size: int def __post_init__(self): if self.vocab_size <= 0: raise ValueError('vocab_size should be > 0, but got ' f'vocab_size={self.vocab_size}') if self.context_size < 0: raise ValueError('context_size should be >= 0, but got ' f'context_size={self.context_size}') def num_states(self) -> int: # int() is just here to please pytype. return sum(int(self.vocab_size**i) for i in range(self.context_size + 1))
[docs] def shape(self) -> tuple[int, int]: return self.num_states(), self.vocab_size
[docs] def start(self) -> int: return 0
[docs] def next_state(self, state: jnp.ndarray, label: jnp.ndarray) -> jnp.ndarray: # Note: lexical labels are in the range [1, vocab_size]. num_ascending_states = sum( self.vocab_size**i for i in range(self.context_size)) ascend_nextstate = state * self.vocab_size + label if self.context_size == 0: full_nextstate = jnp.zeros_like(ascend_nextstate) else: full_nextstate = ( (state - num_ascending_states) % (self.vocab_size**(self.context_size - 1)) * self.vocab_size + num_ascending_states + label - 1) nextstate = jnp.where(state < num_ascending_states, ascend_nextstate, full_nextstate) # Remain where we were for epsilons. nextstate = jnp.where(label == 0, state, nextstate) return nextstate
[docs] def forward_reduce(self, weights: jnp.ndarray, semiring: semirings.Semiring[jnp.ndarray]) -> jnp.ndarray: batch_dims = weights.shape[:-2] if weights.shape[-2:] != self.shape(): raise ValueError(f'weights.shape[-2:] should be {self.shape()} but got' f' {weights.shape[-2:]}') # weights can be partitioned into two blocks, those leading to # ascending states, and those leading to the full context_size order states. next_accum_parts = [] if self.context_size > 0: next_accum_parts.append(semiring.zeros(batch_dims + (1,), weights.dtype)) num_states_going_into_ascending_states = sum( self.vocab_size**i for i in range(0, self.context_size - 1)) next_accum_parts.append(weights[ ..., :num_states_going_into_ascending_states, :].reshape(batch_dims + (-1,))) next_accum_parts.append( semiring.sum( weights[..., num_states_going_into_ascending_states:, :].reshape( batch_dims + (-1, self.vocab_size**self.context_size)), axis=-2)) return jnp.concatenate(next_accum_parts, axis=-1)
[docs] def backward_broadcast(self, weights: jnp.ndarray) -> jnp.ndarray: batch_dims = weights.shape[:-1] num_states = weights.shape[-1] if num_states != self.num_states(): raise ValueError(f'weights.shape[-1] should be {self.num_states()} but ' f'got {num_states}') if self.context_size == 0: return jnp.broadcast_to(weights[..., jnp.newaxis], weights.shape + (self.vocab_size,)) # Non-start ascending states have a unique incoming arc, thus a unique # incoming state. num_ascending_states = sum( self.vocab_size**i for i in range(self.context_size)) part_a = weights[..., 1:num_ascending_states].reshape(batch_dims + (-1, self.vocab_size)) # States with arcs into full context_size order states. part_b = jnp.broadcast_to( weights[..., jnp.newaxis, num_ascending_states:], batch_dims + (1 + self.vocab_size, self.vocab_size** self.context_size)).reshape(batch_dims + (-1, self.vocab_size)) return jnp.concatenate([part_a, part_b], axis=-2)
[docs] def next_state_table(self) -> jnp.ndarray: """Generates the next state table (see NextStateTable).""" num_states, vocab_size = self.shape() return self.next_state( jnp.arange(num_states)[:, jnp.newaxis], jnp.arange(vocab_size)[jnp.newaxis, :] + 1)
[docs] @dataclasses.dataclass(frozen=True) class NextStateTable(ContextDependency): """Context dependency described as a transition lookup table. Attributes: next_state_table: [num_states, vocab_size] int32 array. next_state_table[p, y - 1] is the state reached from p with label y. """ next_state_table: jnp.ndarray def __post_init__(self): if self.next_state_table.ndim != 2: raise ValueError( 'next_state_table should have shape [num_states, vocab_size], but' f'got shape {self.next_state_table.shape}') if self.next_state_table.size == 0: raise ValueError('next_state_table should have a non-zero size, but ' f'got shape {self.next_state_table.shape}') if self.next_state_table.dtype != jnp.int32: raise ValueError('next_state_table should be an int32 ndarray, but ' f'got dtype {self.next_state_table.dtype}')
[docs] def shape(self) -> tuple[int, int]: return self.next_state_table.shape
[docs] def start(self) -> int: return 0
[docs] def next_state(self, state: jnp.ndarray, label: jnp.ndarray) -> jnp.ndarray: # Note: lexical labels are in the range [1, vocab_size]. is_epsilon = label == 0 zero_based_label = jnp.where(is_epsilon, 0, label - 1) nextstate = self.next_state_table[state, zero_based_label] # Remain where we were for epsilons. nextstate = jnp.where(is_epsilon, state, nextstate) return nextstate
[docs] def forward_reduce(self, weights: jnp.ndarray, semiring: semirings.Semiring[jnp.ndarray]) -> jnp.ndarray: batch_dims = weights.shape[:-2] if weights.shape[-2:] != self.shape(): raise ValueError(f'weights.shape[-2:] should be {self.shape()} but got' f' {weights.shape[-2:]}') num_states, _ = self.shape() # Build the scatter operation. operand = semiring.zeros(batch_dims + (num_states,), weights.dtype) updates = weights scatter_indices = jnp.expand_dims(self.next_state_table, axis=-1) update_window_dims = tuple(range(len(batch_dims))) inserted_window_dims = (len(batch_dims),) scatter_dims_to_operand_dims = (len(batch_dims),) return scatter_reduce.scatter_reduce( operand=operand, scatter_indices=scatter_indices, updates=updates, computation=semiring.plus, dimension_numbers=jax.lax.ScatterDimensionNumbers( update_window_dims=update_window_dims, inserted_window_dims=inserted_window_dims, scatter_dims_to_operand_dims=scatter_dims_to_operand_dims))
[docs] def backward_broadcast(self, weights: jnp.ndarray) -> jnp.ndarray: num_states = weights.shape[-1] if num_states != self.shape()[0]: raise ValueError(f'weights.shape[-1] should be {self.shape()[0]} but ' f'got {num_states}') return weights[..., self.next_state_table]