Public API: last package

LAST API.

last.RecognitionLattice

Recognition lattice in GNAT-style formulation and operations over it.

last.semirings

Semirings.

last.contexts

Context dependencies.

last.alignments

Alignment lattices.

last.weight_fns

Weight functions.

last

class last.RecognitionLattice(context: ~last.contexts.ContextDependency, alignment: ~last.alignments.TimeSyncAlignmentLattice, weight_fn_cacher_factory: ~collections.abc.Callable[[~last.contexts.ContextDependency], ~last.weight_fns.WeightFnCacher[~last.lattices.T]], weight_fn_factory: ~collections.abc.Callable[[~last.contexts.ContextDependency], ~last.weight_fns.WeightFn[~last.lattices.T]], parent: ~typing.Type[~flax.linen.module.Module] | ~typing.Type[~flax.core.scope.Scope] | ~typing.Type[~flax.linen.module._Sentinel] | None = <flax.linen.module._Sentinel object>, name: str | None = None)[source]

Recognition lattice in GNAT-style formulation and operations over it.

A RecognitionLattice provides operations used in training and inference, such as computing the negative-log-probability loss, or finding the highest scoring alignment path.

Following the formulation in GNAT, we combines the three modelling components to define a RecognitionLattice: - Context dependency: The finite automaton that models output history. See

last.contexts.ContextDependency for details.

  • Alignment lattice: The finite automaton that models the alignment between input frames and output labels. See last.alignments.TimeSyncAlignmentLattice for details.

  • Weight function: The neural network that produces arc weights from any context state given an input frame. See last.weight_fns for details.

Given a sequence of T input frames, its recognition lattice is the following finite automaton: - States: The states are pairs of an alignment state and a context state.

For any alignment state (t, a) (t is the index of the current frame) and context state c, there is a state (t, a, c) in the recognition lattice.

  • Start state: (0, s_a, s_c), where s_a is the start state in the frame-local alignment lattice (see last.contexts.TimeSyncAlignmentLattice), and s_c is the start state in the context dependency.

  • Final states: (T, s_a, c) for any context state c.

  • Arcs: - Blank arcs: For any blank arc (t, a) -> (t’, a’) in the alignment

    lattice, and any context state c, there is an arc (t, a, c) –blank-> (t’, a’, c) in the recognition lattice.

    • Lexical arcs: For any lexical arc (t, a) -> (t’, a’) in the alignment lattice, and any arc c –y-> c’ in the context dependency, there is an arc (t, a, c) –y-> (t’, a’, c’) in the recognition lattice.

  • Arc weights: For each state (t, a, c), the weight function receives the t-th frame and the context state c, and produces arc weights for blank and lexical arcs. Notably while in principle weight functions can also depend on the alignment state, in practice we haven’t yet encountered a compelling reason for such weight functions. Thus for the sake of simplicity, weight functions currently only depend on the frame and the context state.

The arc weights can be used to model the conditional distribution of the paths in the recognition lattice, especially those that lead to the reference label sequence. A RecognitionLattice can be either a locally normalized model, or a globally normalized model, - A locally normalized model uses last.weight_fns.LocallyNormalizedWeightFn,

where the arc weights from the same recognition lattice state add up to 1 after taking an exponential. The probability of P(alignment labels | frames) is simply the product of arc weights on the alignment path after exponential.

  • A globally normalized model uses a WeightFn that is not a subclass of last.weight_fns.LocallyNormalizedWeightFn. To obtain a probabilistic interpretation, we normalize the path weights with the sum of the exponentiated weights of all possible paths in the recognition lattice.

Globally normalized models are more expensive to train, but they have various advantages. See the GNAT paper for more details.

context

Context dependency.

Type:

last.contexts.ContextDependency

alignment

Alignment lattice.

Type:

last.alignments.TimeSyncAlignmentLattice

weight_fn_cacher_factory

Callable that builds a WeightFnCacher given the context dependency.

Type:

collections.abc.Callable[[last.contexts.ContextDependency], last.weight_fns.WeightFnCacher[last.lattices.T]]

weight_fn_factory

Callable that builds a WeightFn given the context dependency.

Type:

collections.abc.Callable[[last.contexts.ContextDependency], last.weight_fns.WeightFn[last.lattices.T]]

class BackwardStepCallback(*args, **kwargs)[source]

Callback signature used in the backward algorithm loop.

Blank

alias of Array

CacheGrad = typing.Any
FrameGrad

alias of Array

Lexical

alias of Array

Output = typing.Any
ParamsGrad = typing.Any
build_cache() T[source]

Builds the weight function cache.

Weight functions are implemented as a pair of WeightFn and WeightFnCacher to avoid unnecessary recomputation (see last.weight_fns for more details). build_cache() builds the cached static data that can be used in other public methods.

Returns:

Cached data.

setup()[source]

Initializes a Module lazily (similar to a lazy __init__).

setup is called once lazily on a module instance when a module is bound, immediately before any other methods like __call__ are invoked, or before a setup-defined attribute on self is accessed.

This can happen in three cases:

  1. Immediately when invoking apply(), init() or init_and_output().

  2. Once the module is given a name by being assigned to an attribute of another module inside the other module’s setup method (see __setattr__()):

    class MyModule(nn.Module):
      def setup(self):
        submodule = Conv(...)
    
        # Accessing `submodule` attributes does not yet work here.
    
        # The following line invokes `self.__setattr__`, which gives
        # `submodule` the name "conv1".
        self.conv1 = submodule
    
        # Accessing `submodule` attributes or methods is now safe and
        # either causes setup() to be called once.
    
  3. Once a module is constructed inside a method wrapped with compact(), immediately before another method is called or setup defined attribute is accessed.

shortest_path(frames: Array, num_frames: Array, cache: T | None = None) tuple[Array, Array, Array][source]

Computes the shortest path in the recognition lattice.

The shortest path is the path with the highest score, in other words, the “shoretst” path under the max-tropical semiring.

Parameters:
  • frames – [batch_dims…, max_num_frames, feature_size] padded frame sequences.

  • num_frames – [batch_dims…] number of frames.

  • cache – Optional weight function cache data.

Returns:

(alignment_labels, num_alignment_labels, path_weights) tuple, - alignment_labels: [batch_dims…, max_num_alignment_labels] padded

alignment labels, either blank (0) or lexical (1 to vocab_size).

  • num_alignment_labels: [batch_dims…] number of alignment labels.

  • path_weights: [batch_dims…] path weights.

last.semirings

Semirings.

class last.semirings.Cartesian(x: Semiring[T], y: Semiring[S])[source]

Cartesian product of 2 semirings.

x

The first semiring.

Type:

last.semirings.Semiring[last.semirings.T]

y

The second semiring.

Type:

last.semirings.Semiring[last.semirings.S]

ones(shape: Sequence[int], dtype: Any | None = None) tuple[T, S][source]

Semiring ones in the given shape and dtype.

Parameters:
  • shape – Desired output shape.

  • dtype – Optional PyTree of dtypes.

Returns:

If dtype is None, semiring one values in the specified shape with reasonable default dtypes. Otherwise, semiring one values in the specified shape with the specified dtypes.

plus(a: tuple[T, S], b: tuple[T, S]) tuple[T, S][source]

Semiring addition between two values.

prod(a: tuple[T, S], axis: int) tuple[T, S][source]

Semiring multiplication along a single axis.

sum(a: tuple[T, S], axis: int) tuple[T, S][source]

Semiring addition along a single axis.

times(a: tuple[T, S], b: tuple[T, S]) tuple[T, S][source]

Semiring multiplication between two values.

zeros(shape: Sequence[int], dtype: Any | None = None) tuple[T, S][source]

Semiring zeros in the given shape and dtype.

Parameters:
  • shape – Desired output shape.

  • dtype – Optional PyTree of dtypes.

Returns:

If dtype is None, semiring zero values in the specified shape with reasonable default dtypes. Otherwise, semiring zero values in the specified shape with the specified dtypes.

class last.semirings.Expectation(w: Semiring[T], x: Semiring[S], w_to_x: Callable[[T], S])[source]

Jason Eisner’s expectation semiring.

In most cases, use LogLogExpectation below directly.

See https://www.cs.jhu.edu/~jason/papers/eisner.fsmnlp01.pdf for reference.

Each semiring value is a tuple (w, x): - w: The weight of this tuple, expressed in the self.w semiring. - x: The weighted sum for some corresponding weighted values, expressed in

the self.x semiring.

To create a semiring value from a weight-value pair, use self.weighted(). See ExpectationTest.test_entropy for an example of using the expectation semiring to compute the entropy of probability distributions.

w

Semiring for representing weights.

Type:

last.semirings.Semiring[last.semirings.T]

x

Semiring for representing weighted sums.

Type:

last.semirings.Semiring[last.semirings.S]

w_to_x

Function to convert a value from semiring w to semiring x.

Type:

Callable[[last.semirings.T], last.semirings.S]

ones(shape: Sequence[int], dtype: Any | None = None) tuple[T, S][source]

Semiring ones in the given shape and dtype.

Parameters:
  • shape – Desired output shape.

  • dtype – Optional PyTree of dtypes.

Returns:

If dtype is None, semiring one values in the specified shape with reasonable default dtypes. Otherwise, semiring one values in the specified shape with the specified dtypes.

plus(a: tuple[T, S], b: tuple[T, S]) tuple[T, S][source]

Semiring addition between two values.

sum(a: tuple[T, S], axis: int) tuple[T, S][source]

Semiring addition along a single axis.

times(a: tuple[T, S], b: tuple[T, S]) tuple[T, S][source]

Semiring multiplication between two values.

zeros(shape: Sequence[int], dtype: Any | None = None) tuple[T, S][source]

Semiring zeros in the given shape and dtype.

Parameters:
  • shape – Desired output shape.

  • dtype – Optional PyTree of dtypes.

Returns:

If dtype is None, semiring zero values in the specified shape with reasonable default dtypes. Otherwise, semiring zero values in the specified shape with the specified dtypes.

class last.semirings.Semiring[source]

Base Semiring interface.

See https://en.wikipedia.org/wiki/Semiring for what a semiring is. A Semiring object holds methods that implement the semiring operations. To simplify non-semiring operations on the semiring values, the semiring values are not typed: for most basic semirings, each value is a single ndarray; for some more complex semirings (e.g. Expectation or Cartesian), the values can be a tuple of ndarrays.

In general, a semiring value under some semiring is represented as a PyTree of identically shaped ndarrays, with possibly different dtypes. The shape and dtypes of a semiring value can be obtained with methods last.semirings.value_shape() and last.semirings.value_dtype().

Semiring is not an abstract base class because we allow operations to be unimplemented (e.g. prod, is not commonly used).

Implementation tips: * Binary operations (times & plus) should support broadcasting. If a binary

operation requires a custom vjp implementation, it may not be straight- forward to write one that handles input broadcasting properly. Instead, write a custom vjp function that requires inputs with identical shapes, then use jnp.broadcast_arrays() to preprocess the operands in the implementation of the semiring operations. See {_Log,_MaxTropical}.plus for examples.

  • Reductions (prod & sum) can be tricky to implement correctly, here are two important things to watch out for: * axis can be in the range [-rank, rank). * The input can have 0-sized dimensions.

ones(shape: Sequence[int], dtype: Any | None = None) T[source]

Semiring ones in the given shape and dtype.

Parameters:
  • shape – Desired output shape.

  • dtype – Optional PyTree of dtypes.

Returns:

If dtype is None, semiring one values in the specified shape with reasonable default dtypes. Otherwise, semiring one values in the specified shape with the specified dtypes.

plus(a: T, b: T) T[source]

Semiring addition between two values.

prod(a: T, axis: int) T[source]

Semiring multiplication along a single axis.

sum(a: T, axis: int) T[source]

Semiring addition along a single axis.

times(a: T, b: T) T[source]

Semiring multiplication between two values.

zeros(shape: Sequence[int], dtype: Any | None = None) T[source]

Semiring zeros in the given shape and dtype.

Parameters:
  • shape – Desired output shape.

  • dtype – Optional PyTree of dtypes.

Returns:

If dtype is None, semiring zero values in the specified shape with reasonable default dtypes. Otherwise, semiring zero values in the specified shape with the specified dtypes.

last.semirings.value_dtype(x: Any) Any[source]

Obtains the dtypes of a semiring value.

Different leaves of a semiring value may have different dtypes. Methods such as Semiring.{zeros,ones} can take a PyTree of dtypes in the same structure as the corresponding semiring values. This function can be used to extract such a dtype PyTree from a semiring value.

Parameters:

x – Some semiring value.

Returns:

dtypes in the same structure as x.

last.semirings.value_shape(x: Any) tuple[int, ...][source]

Obtains the shape of a semiring value.

A semiring value is a PyTree of one or more identically shaped ndarrays. The shape of a semiring value is thus the common of shape of its leaves.

Parameters:

x – Some semiring value.

Returns:

The common shape of the leaves of x.

Raises:

ValueError – If the leaves of x do not have a common shape.

last.semirings.Real = <last.semirings._Real object>

Real semiring.

last.semirings.Log = <last.semirings._Log object>

Log semiring.

last.semirings.MaxTropical = <last.semirings._MaxTropical object>

Max tropical semiring.

The gradients of plus and sum is guaranteed to be non-zero on exactly 1 input element, even in the event of a tie.

last.semirings.LogLogExpectation = Expectation(w=<last.semirings._Log object>, x=<last.semirings._Log object>, w_to_x=<function <lambda>>)

Jason Eisner’s expectation semiring.

In most cases, use LogLogExpectation below directly.

See https://www.cs.jhu.edu/~jason/papers/eisner.fsmnlp01.pdf for reference.

Each semiring value is a tuple (w, x): - w: The weight of this tuple, expressed in the self.w semiring. - x: The weighted sum for some corresponding weighted values, expressed in

the self.x semiring.

To create a semiring value from a weight-value pair, use self.weighted(). See ExpectationTest.test_entropy for an example of using the expectation semiring to compute the entropy of probability distributions.

last.semirings.w

Semiring for representing weights.

last.semirings.x

Semiring for representing weighted sums.

last.semirings.w_to_x

Function to convert a value from semiring w to semiring x.

last.context

Context dependencies.

class last.contexts.ContextDependency[source]

Interface for context dependencies.

A context dependency is a deterministic finite automaton (DFA) that accepts $Sigma^*$ ($Sigma$ is the lexical output vocabulary). The state ids in [0, num_states) of a context dependency encodes the output history. See Sections 3 and 4 of the GNAT paper for more details.

Note: we assume all context dependency states to be final.

Subclasses should implement the following methods: - shape - start - next_state - forward_reduce - backward_broadcast

abstract backward_broadcast(weights: Array) Array[source]

The broadcast used in the backward algorithm.

For each state q, we broadcast its weight to all the (source state p, label y) pairs leading to state q, i.e.

result[…, p, y] = weights[…, q]

Parameters:

weights – [batch_dims…, num_states] weights.

Returns:

[batch_dims…, num_states, vocab_size] broadcasted weights.

abstract forward_reduce(weights: Array, semiring: Semiring[Array]) Array[source]

The reduction used in the forward algorithm.

For each state q, we sum over all source states p and labels y that lead to state q, i.e.

result[…, q] = sum_{p-y->q} weights[…, p, y]

Parameters:
  • weights – [batch_dims…, num_states, vocab_size] weights.

  • semiring – The semiring for carrying out the summation.

Returns:

[batch_dims…, num_states] reduced weights.

abstract next_state(state: Array, label: Array) Array[source]

Takes a transition in the DFA.

Note: because 0 is the epsilon label, it is normally not fed to next_state. For consistency, we require that next_state should return state[i] when label[i] == 0.

Parameters:
  • state – [batch_dims…] int32 source state ids.

  • label – [batch_dims…] int32 output labels in the range [0, vocab_size].

Returns:

[batch_dims…] next state ids.

abstract shape() tuple[int, int][source]

Shape of a context dependency.

Returns:

  • num_states: The number of states in the context dependency DFA.

  • vocab_size: The size of the output vocabulary, $|Sigma|$.

Return type:

(num_states, vocab_size) tuple

abstract start() int[source]

The start state id.

walk_states(labels: Array) Array[source]

Walks a context dependency following label sequences.

Parameters:

labels – [batch_dims…, num_labels] int32 label sequences. Each element is in the range [0, vocab_size].

Returns:

[batch_dims…, num_labels + 1] int32 context states. states[…, 0] equals to the start state of the context dependency; states[…, i] for i > 0 is the state after observing labels[…, i - 1].

class last.contexts.FullNGram(vocab_size: int, context_size: int)[source]

Full n-gram context dependency as described in Section 4.1 of the GNAT paper.

For a given vocab_size > 0, context_size >= 0, - The set of states represents the set of all possible n-grams from length 0

to length context_size for an output vocabulary of vocab_size.

  • Each n-gram is assigned their lexicographic order as the id. The empty n-gram is state 0, followed by the vocab_size unigrams as states 1 to vocab_size, and so on.

  • The start state is 0 (the empty n-gram).

  • All states are final.

  • From each n-gram state, there is an arc for each label in the vocabulary leading to the n-gram with the label appended to the end, with the length of the n-gram capped at context_size.

vocab_size

Lexical output vocabulary size.

Type:

int

context_size

Maximum n-gram context size.

Type:

int

backward_broadcast(weights: Array) Array[source]

The broadcast used in the backward algorithm.

For each state q, we broadcast its weight to all the (source state p, label y) pairs leading to state q, i.e.

result[…, p, y] = weights[…, q]

Parameters:

weights – [batch_dims…, num_states] weights.

Returns:

[batch_dims…, num_states, vocab_size] broadcasted weights.

forward_reduce(weights: Array, semiring: Semiring[Array]) Array[source]

The reduction used in the forward algorithm.

For each state q, we sum over all source states p and labels y that lead to state q, i.e.

result[…, q] = sum_{p-y->q} weights[…, p, y]

Parameters:
  • weights – [batch_dims…, num_states, vocab_size] weights.

  • semiring – The semiring for carrying out the summation.

Returns:

[batch_dims…, num_states] reduced weights.

next_state(state: Array, label: Array) Array[source]

Takes a transition in the DFA.

Note: because 0 is the epsilon label, it is normally not fed to next_state. For consistency, we require that next_state should return state[i] when label[i] == 0.

Parameters:
  • state – [batch_dims…] int32 source state ids.

  • label – [batch_dims…] int32 output labels in the range [0, vocab_size].

Returns:

[batch_dims…] next state ids.

next_state_table() Array[source]

Generates the next state table (see NextStateTable).

shape() tuple[int, int][source]

Shape of a context dependency.

Returns:

  • num_states: The number of states in the context dependency DFA.

  • vocab_size: The size of the output vocabulary, $|Sigma|$.

Return type:

(num_states, vocab_size) tuple

start() int[source]

The start state id.

class last.contexts.NextStateTable(next_state_table: Array)[source]

Context dependency described as a transition lookup table.

next_state_table

[num_states, vocab_size] int32 array. next_state_table[p, y - 1] is the state reached from p with label y.

Type:

jax.Array

backward_broadcast(weights: Array) Array[source]

The broadcast used in the backward algorithm.

For each state q, we broadcast its weight to all the (source state p, label y) pairs leading to state q, i.e.

result[…, p, y] = weights[…, q]

Parameters:

weights – [batch_dims…, num_states] weights.

Returns:

[batch_dims…, num_states, vocab_size] broadcasted weights.

forward_reduce(weights: Array, semiring: Semiring[Array]) Array[source]

The reduction used in the forward algorithm.

For each state q, we sum over all source states p and labels y that lead to state q, i.e.

result[…, q] = sum_{p-y->q} weights[…, p, y]

Parameters:
  • weights – [batch_dims…, num_states, vocab_size] weights.

  • semiring – The semiring for carrying out the summation.

Returns:

[batch_dims…, num_states] reduced weights.

next_state(state: Array, label: Array) Array[source]

Takes a transition in the DFA.

Note: because 0 is the epsilon label, it is normally not fed to next_state. For consistency, we require that next_state should return state[i] when label[i] == 0.

Parameters:
  • state – [batch_dims…] int32 source state ids.

  • label – [batch_dims…] int32 output labels in the range [0, vocab_size].

Returns:

[batch_dims…] next state ids.

shape() tuple[int, int][source]

Shape of a context dependency.

Returns:

  • num_states: The number of states in the context dependency DFA.

  • vocab_size: The size of the output vocabulary, $|Sigma|$.

Return type:

(num_states, vocab_size) tuple

start() int[source]

The start state id.

last.alignments

Alignment lattices.

class last.alignments.FrameDependent[source]

Frame dependent alignment lattice.

Each frame is aligned to either a lexical label or a blank label.

backward(alpha: Array, blank: Sequence[Array], lexical: Sequence[Array], beta: Array, log_z: Array, context: ContextDependency) tuple[Array, list[Array], list[Array]][source]

Processes one frame in the recognition lattice backward algorithm.

On a recognition lattice, the backward algorithm computes the arc marginals under last.semirings.Log (the marginal probability of taking each lexical or blank arc), as wells the sum of path weights from a recognition lattice state to all final states (backward weights). The marginals and backward weights can be computed frame by frame with the help of TimeSyncAlignmentLattice.backward: 1. At time t, from the forward algorithm’s results, we know the forward

weights of states (t, s, c) for all context states c (alpha) and the overall shortest distance (sum of weights of all accepting paths, log_z).

  1. We also know all the arc weights from states (t, a, c), for all non- final frame-local alignment states a and all context states c (blank and lexical).

  2. At time t, after observing frame (t+1), we know the backward weights of states (t + 1, s, c) for all context states c (beta).

  1. We also know the context dependency.

  2. With the information above, TimeSyncAlignmentLattice.forward computes the backward weights of states (t, s, c) for all context states c.

Parameters:
  • alpha – [batch_dims…, num_context_states] forward weights after observing the previous frame.

  • blank – length num_alignment_states Sequence of [batch_dims…, num_context_states] blank weights for the current frame, one for each frame-local alignment state.

  • lexical – length num_alignment_states Sequence of [batch_dims…, num_context_states, vocab_size] lexical weights for the current frame, one for each frame-local alignment state.

  • beta – [batch_dims…, num_context_states] backward weights after observing the next frame.

  • log_z – [batch_dims…] denominator, i.e. the sum of weights of all accepting paths.

  • context – Context dependency.

Returns:

  • next_beta: [batch_dims…, num_context_states] backward weights after observing the current frame.

  • blank_marginal: length num_alignment_states list of [batch_dims…, num_context_states] marginals of blank arcs, one for each frame-local alignment state.

  • lexical_marginal: length num_alignment_states list of [batch_dims…, num_context_states, vocab_size] marginals of lexical arcs, one for each frame-local alignment state.

Return type:

(next_beta, blank_marginal, lexical_marginal)

blank_next(state: int) int | None[source]

Next alignment state id when taking the blank arc.

Parameters:

state – A state id in the range [0, num_alignment_states).

Returns:

None if there is no blank arc leaving state. The start state id if the blank arc leads to the final state. Otherwise, an ordinary state id in the range [0, num_alignment_states).

forward(alpha: Array, blank: Sequence[Array], lexical: Sequence[Array], context: ContextDependency, semiring: Semiring[Array]) Array[source]

Processes one frame in the recognition lattice forward algorithm.

On a recognition lattice, the forward algorithm computes the shortest distance, i.e. the sum of path weights reaching all final states. The shortest distance can be computed frame by frame with the help of TimeSyncAlignmentLattice.forward: 1. At time t, after observing the previous frame (t-1), we know the forward

weights of states (t, s, c) for all context states c (alpha).

  1. We also know all the arc weights from states (t, a, c), for all non- final frame-local alignment states a and all context states c (blank and lexical).

  2. We also know the context dependency.

  3. With the information above, TimeSyncAlignmentLattice.forward computes the forward weights of states (t+1, s, c) for all context states c.

Parameters:
  • alpha – [batch_dims…, num_context_states] forward weights after observing the previous frame.

  • blank – length num_alignment_states Sequence of [batch_dims…, num_context_states] blank weights for the current frame, one for each frame-local alignment state.

  • lexical – length num_alignment_states Sequence of [batch_dims…, num_context_states, vocab_size] lexical weights for the current frame, one for each frame-local alignment state.

  • context – Context dependency.

  • semiring – Semiring.

Returns:

[batch_dims…, num_context_states] forward weights after observing the current frame.

lexical_next(state: int) int | None[source]

Next alignment state id when taking the lexical arc.

Parameters:

state – A state id in the range [0, num_alignment_states).

Returns:

None if there is no blank arc leaving state. The start state id if the blank arc leads to the final state. Otherwise, an ordinary state id in the range [0, num_alignment_states).

num_states() int[source]

Number of non-final frame-local alignment states, num_alignment_states.

start() int[source]

Start state of the frame-local alignment lattice.

string_forward(alpha: Array, blank: Sequence[Array], lexical: Sequence[Array], semiring: Semiring[Array]) Array[source]

Processes one frame in the recognition lattice forward algorithm after the intersection with an output string.

Because the recognition lattice’s topology is the intersection of an alignment lattice and the context dependency, the intersection between the recognition lattice and an output string is thus equivalent to first intersecting the context dependency with the output string (i.e. last.contexts.ContextDependency.walk_states), and then intersecting the alignment lattice with the result.

On this intersected lattice, the forward algorithm computes the shortest distance, i.e. the sum of path weights reaching all final states. The shortest distance can be computed frame by frame with the help of TimeSyncAlignmentLattice.string_forward: 1. At time t, after observing the previous frame (t-1), we know the forward

weights of states (t, s, c) for all context states c after the first intersection. This is passed to TimeSyncAlignmentLattice.forward as alpha. Note the first intersection results in a string acceptor, with only output_length + 1 context states.

  1. We also know all the arc weights from states (t, a, c), for all non- final frame-local alignment states a and all context states c after the first intersection. These are passed to TimeSyncAlignmentLattice.forward as blank and lexical.

  2. We no longer need to know the context dependency because the result of the first intersection is a simple chain.

  3. With the information above, TimeSyncAlignmentLattice.string_forward computes the forward weights of state (t+1, s, c) for all context states c after the first intersection.

Parameters:
  • alpha – [batch_dims…, output_length + 1] forward weights after observing the previous frame.

  • blank – length num_alignment_states Sequence of [batch_dims…, output_length + 1] blank weights for the current frame, one for each frame-local alignment state.

  • lexical – length num_alignment_states Sequence of [batch_dims…, output_length + 1] lexical weights for the current frame, one for each frame local alignment state.

  • semiring – Semiring.

Returns:

[batch_dims…, output_length + 1] forward weights after observing the current frame.

topological_visit() list[int][source]

Produces non-final frame-local alignment state ids in a topological order.

class last.alignments.FrameLabelDependent(max_expansions: int)[source]

k-constrained frame-label-dependent alignment lattice.

Each frame is aligned to up to k lexical labels followed by a blank label.

max_expansions

The maximum number of lexical labels allowed per frame.

Type:

int

backward(alpha: Array, blank: Sequence[Array], lexical: Sequence[Array], beta: Array, log_z: Array, context: ContextDependency) tuple[Array, list[Array], list[Array]][source]

Processes one frame in the recognition lattice backward algorithm.

On a recognition lattice, the backward algorithm computes the arc marginals under last.semirings.Log (the marginal probability of taking each lexical or blank arc), as wells the sum of path weights from a recognition lattice state to all final states (backward weights). The marginals and backward weights can be computed frame by frame with the help of TimeSyncAlignmentLattice.backward: 1. At time t, from the forward algorithm’s results, we know the forward

weights of states (t, s, c) for all context states c (alpha) and the overall shortest distance (sum of weights of all accepting paths, log_z).

  1. We also know all the arc weights from states (t, a, c), for all non- final frame-local alignment states a and all context states c (blank and lexical).

  2. At time t, after observing frame (t+1), we know the backward weights of states (t + 1, s, c) for all context states c (beta).

  1. We also know the context dependency.

  2. With the information above, TimeSyncAlignmentLattice.forward computes the backward weights of states (t, s, c) for all context states c.

Parameters:
  • alpha – [batch_dims…, num_context_states] forward weights after observing the previous frame.

  • blank – length num_alignment_states Sequence of [batch_dims…, num_context_states] blank weights for the current frame, one for each frame-local alignment state.

  • lexical – length num_alignment_states Sequence of [batch_dims…, num_context_states, vocab_size] lexical weights for the current frame, one for each frame-local alignment state.

  • beta – [batch_dims…, num_context_states] backward weights after observing the next frame.

  • log_z – [batch_dims…] denominator, i.e. the sum of weights of all accepting paths.

  • context – Context dependency.

Returns:

  • next_beta: [batch_dims…, num_context_states] backward weights after observing the current frame.

  • blank_marginal: length num_alignment_states list of [batch_dims…, num_context_states] marginals of blank arcs, one for each frame-local alignment state.

  • lexical_marginal: length num_alignment_states list of [batch_dims…, num_context_states, vocab_size] marginals of lexical arcs, one for each frame-local alignment state.

Return type:

(next_beta, blank_marginal, lexical_marginal)

blank_next(state: int) int | None[source]

Next alignment state id when taking the blank arc.

Parameters:

state – A state id in the range [0, num_alignment_states).

Returns:

None if there is no blank arc leaving state. The start state id if the blank arc leads to the final state. Otherwise, an ordinary state id in the range [0, num_alignment_states).

forward(alpha: Array, blank: Sequence[Array], lexical: Sequence[Array], context: ContextDependency, semiring: Semiring[Array]) Array[source]

Processes one frame in the recognition lattice forward algorithm.

On a recognition lattice, the forward algorithm computes the shortest distance, i.e. the sum of path weights reaching all final states. The shortest distance can be computed frame by frame with the help of TimeSyncAlignmentLattice.forward: 1. At time t, after observing the previous frame (t-1), we know the forward

weights of states (t, s, c) for all context states c (alpha).

  1. We also know all the arc weights from states (t, a, c), for all non- final frame-local alignment states a and all context states c (blank and lexical).

  2. We also know the context dependency.

  3. With the information above, TimeSyncAlignmentLattice.forward computes the forward weights of states (t+1, s, c) for all context states c.

Parameters:
  • alpha – [batch_dims…, num_context_states] forward weights after observing the previous frame.

  • blank – length num_alignment_states Sequence of [batch_dims…, num_context_states] blank weights for the current frame, one for each frame-local alignment state.

  • lexical – length num_alignment_states Sequence of [batch_dims…, num_context_states, vocab_size] lexical weights for the current frame, one for each frame-local alignment state.

  • context – Context dependency.

  • semiring – Semiring.

Returns:

[batch_dims…, num_context_states] forward weights after observing the current frame.

lexical_next(state: int) int | None[source]

Next alignment state id when taking the lexical arc.

Parameters:

state – A state id in the range [0, num_alignment_states).

Returns:

None if there is no blank arc leaving state. The start state id if the blank arc leads to the final state. Otherwise, an ordinary state id in the range [0, num_alignment_states).

num_states() int[source]

Number of non-final frame-local alignment states, num_alignment_states.

replace(**updates)

“Returns a new object replacing the specified fields with new values.

start() int[source]

Start state of the frame-local alignment lattice.

string_forward(alpha: Array, blank: Sequence[Array], lexical: Sequence[Array], semiring: Semiring[Array]) Array[source]

Processes one frame in the recognition lattice forward algorithm after the intersection with an output string.

Because the recognition lattice’s topology is the intersection of an alignment lattice and the context dependency, the intersection between the recognition lattice and an output string is thus equivalent to first intersecting the context dependency with the output string (i.e. last.contexts.ContextDependency.walk_states), and then intersecting the alignment lattice with the result.

On this intersected lattice, the forward algorithm computes the shortest distance, i.e. the sum of path weights reaching all final states. The shortest distance can be computed frame by frame with the help of TimeSyncAlignmentLattice.string_forward: 1. At time t, after observing the previous frame (t-1), we know the forward

weights of states (t, s, c) for all context states c after the first intersection. This is passed to TimeSyncAlignmentLattice.forward as alpha. Note the first intersection results in a string acceptor, with only output_length + 1 context states.

  1. We also know all the arc weights from states (t, a, c), for all non- final frame-local alignment states a and all context states c after the first intersection. These are passed to TimeSyncAlignmentLattice.forward as blank and lexical.

  2. We no longer need to know the context dependency because the result of the first intersection is a simple chain.

  3. With the information above, TimeSyncAlignmentLattice.string_forward computes the forward weights of state (t+1, s, c) for all context states c after the first intersection.

Parameters:
  • alpha – [batch_dims…, output_length + 1] forward weights after observing the previous frame.

  • blank – length num_alignment_states Sequence of [batch_dims…, output_length + 1] blank weights for the current frame, one for each frame-local alignment state.

  • lexical – length num_alignment_states Sequence of [batch_dims…, output_length + 1] lexical weights for the current frame, one for each frame local alignment state.

  • semiring – Semiring.

Returns:

[batch_dims…, output_length + 1] forward weights after observing the current frame.

topological_visit() list[int][source]

Produces non-final frame-local alignment state ids in a topological order.

class last.alignments.TimeSyncAlignmentLattice[source]

Interface for time synchronous alignment lattices.

Frame-dependent and k-constrained label-frame-dependent alignment lattices are examples of time synchronous alignment lattices. See Sections 3 and 4 of the GNAT paper for more details.

The alignment lattice is intersected with the context dependency to form the topology of a recognition lattice. class last.RecognitionLattice carries out this intersection on the fly with the help of methods of TimeSyncAlignmentLattice.

To describe time synchronous alignment lattices, we first introduce the notion of a frame-local alignment lattice. A frame-local alignment lattice is an acyclic deterministic finite automaton with 2 input labels, “lexical” and “blank” with a single final state f.

Let Q be the set of states in the frame-local alignment and E be the set of arcs, a time synchronous alignment lattice is then the same frame-local alignment lattice repeated num_frames times, - The states are {(t, a) | 0 <= t < num_frames, a in Q - {f}} U

{(num_frames, s)};

  • The start state is (0, s);

  • The final state is (T, s);

  • For any arc (a, y, b), b != f, in E, there is an arc ((t, a), y, (t, b));

  • For any arc (a, y, f) in E, there is an arc ((t, a), y, (t + 1, s)).

abstract backward(alpha: Array, blank: Sequence[Array], lexical: Sequence[Array], beta: Array, log_z: Array, context: Array) tuple[Array, list[Array], list[Array]][source]

Processes one frame in the recognition lattice backward algorithm.

On a recognition lattice, the backward algorithm computes the arc marginals under last.semirings.Log (the marginal probability of taking each lexical or blank arc), as wells the sum of path weights from a recognition lattice state to all final states (backward weights). The marginals and backward weights can be computed frame by frame with the help of TimeSyncAlignmentLattice.backward: 1. At time t, from the forward algorithm’s results, we know the forward

weights of states (t, s, c) for all context states c (alpha) and the overall shortest distance (sum of weights of all accepting paths, log_z).

  1. We also know all the arc weights from states (t, a, c), for all non- final frame-local alignment states a and all context states c (blank and lexical).

  2. At time t, after observing frame (t+1), we know the backward weights of states (t + 1, s, c) for all context states c (beta).

  1. We also know the context dependency.

  2. With the information above, TimeSyncAlignmentLattice.forward computes the backward weights of states (t, s, c) for all context states c.

Parameters:
  • alpha – [batch_dims…, num_context_states] forward weights after observing the previous frame.

  • blank – length num_alignment_states Sequence of [batch_dims…, num_context_states] blank weights for the current frame, one for each frame-local alignment state.

  • lexical – length num_alignment_states Sequence of [batch_dims…, num_context_states, vocab_size] lexical weights for the current frame, one for each frame-local alignment state.

  • beta – [batch_dims…, num_context_states] backward weights after observing the next frame.

  • log_z – [batch_dims…] denominator, i.e. the sum of weights of all accepting paths.

  • context – Context dependency.

Returns:

  • next_beta: [batch_dims…, num_context_states] backward weights after observing the current frame.

  • blank_marginal: length num_alignment_states list of [batch_dims…, num_context_states] marginals of blank arcs, one for each frame-local alignment state.

  • lexical_marginal: length num_alignment_states list of [batch_dims…, num_context_states, vocab_size] marginals of lexical arcs, one for each frame-local alignment state.

Return type:

(next_beta, blank_marginal, lexical_marginal)

abstract blank_next(state: int) int | None[source]

Next alignment state id when taking the blank arc.

Parameters:

state – A state id in the range [0, num_alignment_states).

Returns:

None if there is no blank arc leaving state. The start state id if the blank arc leads to the final state. Otherwise, an ordinary state id in the range [0, num_alignment_states).

abstract forward(alpha: Array, blank: Sequence[Array], lexical: Sequence[Array], context: ContextDependency, semiring: Semiring[Array]) Array[source]

Processes one frame in the recognition lattice forward algorithm.

On a recognition lattice, the forward algorithm computes the shortest distance, i.e. the sum of path weights reaching all final states. The shortest distance can be computed frame by frame with the help of TimeSyncAlignmentLattice.forward: 1. At time t, after observing the previous frame (t-1), we know the forward

weights of states (t, s, c) for all context states c (alpha).

  1. We also know all the arc weights from states (t, a, c), for all non- final frame-local alignment states a and all context states c (blank and lexical).

  2. We also know the context dependency.

  3. With the information above, TimeSyncAlignmentLattice.forward computes the forward weights of states (t+1, s, c) for all context states c.

Parameters:
  • alpha – [batch_dims…, num_context_states] forward weights after observing the previous frame.

  • blank – length num_alignment_states Sequence of [batch_dims…, num_context_states] blank weights for the current frame, one for each frame-local alignment state.

  • lexical – length num_alignment_states Sequence of [batch_dims…, num_context_states, vocab_size] lexical weights for the current frame, one for each frame-local alignment state.

  • context – Context dependency.

  • semiring – Semiring.

Returns:

[batch_dims…, num_context_states] forward weights after observing the current frame.

abstract lexical_next(state: int) int | None[source]

Next alignment state id when taking the lexical arc.

Parameters:

state – A state id in the range [0, num_alignment_states).

Returns:

None if there is no blank arc leaving state. The start state id if the blank arc leads to the final state. Otherwise, an ordinary state id in the range [0, num_alignment_states).

abstract num_states() int[source]

Number of non-final frame-local alignment states, num_alignment_states.

abstract start() int[source]

Start state of the frame-local alignment lattice.

abstract string_forward(alpha: Array, blank: Sequence[Array], lexical: Sequence[Array], semiring: Semiring[Array]) Array[source]

Processes one frame in the recognition lattice forward algorithm after the intersection with an output string.

Because the recognition lattice’s topology is the intersection of an alignment lattice and the context dependency, the intersection between the recognition lattice and an output string is thus equivalent to first intersecting the context dependency with the output string (i.e. last.contexts.ContextDependency.walk_states), and then intersecting the alignment lattice with the result.

On this intersected lattice, the forward algorithm computes the shortest distance, i.e. the sum of path weights reaching all final states. The shortest distance can be computed frame by frame with the help of TimeSyncAlignmentLattice.string_forward: 1. At time t, after observing the previous frame (t-1), we know the forward

weights of states (t, s, c) for all context states c after the first intersection. This is passed to TimeSyncAlignmentLattice.forward as alpha. Note the first intersection results in a string acceptor, with only output_length + 1 context states.

  1. We also know all the arc weights from states (t, a, c), for all non- final frame-local alignment states a and all context states c after the first intersection. These are passed to TimeSyncAlignmentLattice.forward as blank and lexical.

  2. We no longer need to know the context dependency because the result of the first intersection is a simple chain.

  3. With the information above, TimeSyncAlignmentLattice.string_forward computes the forward weights of state (t+1, s, c) for all context states c after the first intersection.

Parameters:
  • alpha – [batch_dims…, output_length + 1] forward weights after observing the previous frame.

  • blank – length num_alignment_states Sequence of [batch_dims…, output_length + 1] blank weights for the current frame, one for each frame-local alignment state.

  • lexical – length num_alignment_states Sequence of [batch_dims…, output_length + 1] lexical weights for the current frame, one for each frame local alignment state.

  • semiring – Semiring.

Returns:

[batch_dims…, output_length + 1] forward weights after observing the current frame.

abstract topological_visit() list[int][source]

Produces non-final frame-local alignment state ids in a topological order.

last.alignments.check_num_weights(alignment: TimeSyncAlignmentLattice, blank: Sequence[Array], lexical: Sequence[Array])[source]

Ensures that there are correct numbers of weight arrays.

last.alignments.shift_down(x: Array, semiring: Semiring[Array]) Array[source]

Shifts values down by 1 position.

This is a useful helper function for implementing string_forward().

Parameters:
  • x – [batch_dims…, N] input values.

  • semiring – Semiring to use for filling in zero values.

Returns:

[batch_dims…, N] output values, where output[…, i + 1] = x[…, i] and output[…, 0] = semiring zero.

last.weight_fns

Weight functions.

class last.weight_fns.JointWeightFn(vocab_size: int, hidden_size: int, parent: ~typing.Type[~flax.linen.module.Module] | ~typing.Type[~flax.core.scope.Scope] | ~typing.Type[~flax.linen.module._Sentinel] | None = <flax.linen.module._Sentinel object>, name: str | None = None)[source]

Common implementation of both the shared-emb and shared-rnn weight functions.

To use shared-emb weight functions, pair this with a SharedEmbCacher. To use shared-rnn weight functions, pair this with a SharedRNNCacher. More generally, this weight function works with any WeightFnCacher that produces a [num_context_states, embedding_size] context embedding table.

vocab_size

Size of the lexical output vocabulary (not including the blank), i.e. $|Sigma|$.

Type:

int

hidden_size

Hidden layer size.

Type:

int

class last.weight_fns.LocallyNormalizedWeightFn(weight_fn: ~last.weight_fns.WeightFn[~last.weight_fns.T], normalize: ~typing.Callable[[~jax.Array, ~jax.Array], tuple[~jax.Array, ~jax.Array]] = <function hat_normalize>, parent: ~typing.Type[~flax.linen.module.Module] | ~typing.Type[~flax.core.scope.Scope] | ~typing.Type[~flax.linen.module._Sentinel] | None = <flax.linen.module._Sentinel object>, name: str | None = None)[source]

Wrapper for turning any weight function into a locally normalized one.

This is the recommended way of obtaining a locally normalized weight function. Algorithms such as those that computes the sequence log-loss may rely on a weight function being of this type to eliminate unnecessary denominator computation.

It is thus also important for the normalize function to be mathematically correct: let (blank, lexical) be the pair of weights produced by the normalize function, then jnp.exp(blank) + jnp.sum(jnp.exp(lexical), axis=-1) should be approximately equal to 1.

weight_fn

Underlying weight function.

Type:

last.weight_fns.WeightFn[last.weight_fns.T]

normalize

Callable that produces normalized log-probabilities from (blank, lexical) weights, e.g. hat_normalize() or log_softmax_normalize().

Type:

Callable[[jax.Array, jax.Array], tuple[jax.Array, jax.Array]]

normalize(lexical: Array) tuple[Array, Array]

Local normalization used in the Hybrid Autoregressive Transducer (HAT) paper.

The sigmoid of the blank weight is directly interpreted as the probability of blank. The lexical probability is then normalized with a log-softmax.

Parameters:
  • blank – [batch_dims…] blank weight.

  • lexical – [batch_dims…, vocab_size] lexical weights.

Returns:

Normalized (blank, lexical) weights.

class last.weight_fns.NullCacher(parent: ~typing.Type[~flax.linen.module.Module] | ~typing.Type[~flax.core.scope.Scope] | ~typing.Type[~flax.linen.module._Sentinel] | None = <flax.linen.module._Sentinel object>, name: str | None = None)[source]

A cacher that simply returns None.

Mainly used with TableWeightFn for unit testing.

class last.weight_fns.SharedEmbCacher(num_context_states: int, embedding_size: int, parent: ~typing.Type[~flax.linen.module.Module] | ~typing.Type[~flax.core.scope.Scope] | ~typing.Type[~flax.linen.module._Sentinel] | None = <flax.linen.module._Sentinel object>, name: str | None = None)[source]

A randomly initialized, independent context embedding table.

The result context embedding table can be used with JointWeightFn.

class last.weight_fns.SharedRNNCacher(vocab_size: int, context_size: int, rnn_size: int, rnn_embedding_size: int, rnn_cell: ~flax.linen.recurrent.RNNCellBase | None = None, parent: ~typing.Type[~flax.linen.module.Module] | ~typing.Type[~flax.core.scope.Scope] | ~typing.Type[~flax.linen.module._Sentinel] | None = <flax.linen.module._Sentinel object>, name: str | None = None)[source]

Builds a context embedding table by running n-gram context labels through an RNN.

This is usually used with last.contexts.FullNGram, where num_context_states = sum(vocab_size**i for i in range(context_size + 1)). The result context embedding table can be used with JointWeightFn.

class last.weight_fns.TableWeightFn(table: ~jax.Array, parent: ~typing.Type[~flax.linen.module.Module] | ~typing.Type[~flax.core.scope.Scope] | ~typing.Type[~flax.linen.module._Sentinel] | None = <flax.linen.module._Sentinel object>, name: str | None = None)[source]

Weight function that looks up a fixed table, useful for testing.

table

[batch_dims…, input_vocab_size, num_context_states, 1 + vocab_size] arc weight table. For each input frame, we simply cast the 0-th element into an integer “input label” and look up the corresponding weights. The weights of blank arcs are stored at table[…, 0], and the weights of lexical arcs at table[…, 1:].

Type:

jax.Array

class last.weight_fns.WeightFn(parent: ~typing.Type[~flax.linen.module.Module] | ~typing.Type[~flax.core.scope.Scope] | ~typing.Type[~flax.linen.module._Sentinel] | None = <flax.linen.module._Sentinel object>, name: str | None = None)[source]

Interface for weight functions.

A weight function is a neural network that computes the arc weights from all or some context states for a given frame. A WeightFn is used in pair with a WeightFnCacher that produces the static data cache, e.g. JointWeightFn can be used with SharedEmbCacher or SharedRNNCacher.

class last.weight_fns.WeightFnCacher(parent: ~typing.Type[~flax.linen.module.Module] | ~typing.Type[~flax.core.scope.Scope] | ~typing.Type[~flax.linen.module._Sentinel] | None = <flax.linen.module._Sentinel object>, name: str | None = None)[source]

Interface for weight function cachers.

A weight function cacher prepares static data that may require expensive computational work. For example: the context state embeddings used by JointWeightFn can be from running an RNN on n-gram label sequences

last.weight_fns.hat_normalize(blank: Array, lexical: Array) tuple[Array, Array][source]

Local normalization used in the Hybrid Autoregressive Transducer (HAT) paper.

The sigmoid of the blank weight is directly interpreted as the probability of blank. The lexical probability is then normalized with a log-softmax.

Parameters:
  • blank – [batch_dims…] blank weight.

  • lexical – [batch_dims…, vocab_size] lexical weights.

Returns:

Normalized (blank, lexical) weights.

last.weight_fns.log_softmax_normalize(blank: Array, lexical: Array) tuple[Array, Array][source]

Standard log-softmax local normalization.

Weights are concatenated and then normalized together.

Parameters:
  • blank – [batch_dims…] blank weight.

  • lexical – [batch_dims…, vocab_size] lexical weights.

Returns:

Normalized (blank, lexical) weights.