# Copyright 2023 The LAST Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Alignment lattices."""
import abc
from collections.abc import Sequence
from typing import Optional
import flax.struct
import jax.numpy as jnp
from last import contexts
from last import semirings
# TODO(wuke): Improve documentation.
[docs]
class TimeSyncAlignmentLattice(abc.ABC):
r"""Interface for time synchronous alignment lattices.
Frame-dependent and k-constrained label-frame-dependent alignment lattices are
examples of time synchronous alignment lattices. See Sections 3 and 4 of the
GNAT paper for more details.
The alignment lattice is intersected with the context dependency to form the
topology of a recognition lattice. class last.RecognitionLattice carries out
this intersection on the fly with the help of methods of
TimeSyncAlignmentLattice.
To describe time synchronous alignment lattices, we first introduce the notion
of a frame-local alignment lattice. A frame-local alignment lattice is an
acyclic deterministic finite automaton with 2 input labels, "lexical" and
"blank" with a single final state f.
Let Q be the set of states in the frame-local alignment and E be the set of
arcs, a time synchronous alignment lattice is then the same frame-local
alignment lattice repeated num_frames times,
- The states are {(t, a) | 0 <= t < num_frames, a \in Q - {f}} U
{(num_frames, s)};
- The start state is (0, s);
- The final state is (T, s);
- For any arc (a, y, b), b != f, in E, there is an arc ((t, a), y, (t, b));
- For any arc (a, y, f) in E, there is an arc ((t, a), y, (t + 1, s)).
"""
[docs]
@abc.abstractmethod
def num_states(self) -> int:
"""Number of non-final frame-local alignment states, `num_alignment_states`.
"""
[docs]
@abc.abstractmethod
def start(self) -> int:
"""Start state of the frame-local alignment lattice."""
[docs]
@abc.abstractmethod
def blank_next(self, state: int) -> Optional[int]:
"""Next alignment state id when taking the blank arc.
Args:
state: A state id in the range [0, num_alignment_states).
Returns:
None if there is no blank arc leaving `state`.
The start state id if the blank arc leads to the final state.
Otherwise, an ordinary state id in the range [0, num_alignment_states).
"""
[docs]
@abc.abstractmethod
def lexical_next(self, state: int) -> Optional[int]:
"""Next alignment state id when taking the lexical arc.
Args:
state: A state id in the range [0, num_alignment_states).
Returns:
None if there is no blank arc leaving `state`.
The start state id if the blank arc leads to the final state.
Otherwise, an ordinary state id in the range [0, num_alignment_states).
"""
[docs]
@abc.abstractmethod
def topological_visit(self) -> list[int]:
"""Produces non-final frame-local alignment state ids in a topological order.
"""
[docs]
@abc.abstractmethod
def forward(self, alpha: jnp.ndarray, blank: Sequence[jnp.ndarray],
lexical: Sequence[jnp.ndarray],
context: contexts.ContextDependency,
semiring: semirings.Semiring[jnp.ndarray]) -> jnp.ndarray:
"""Processes one frame in the recognition lattice forward algorithm.
On a recognition lattice, the forward algorithm computes the shortest
distance, i.e. the sum of path weights reaching all final states.
The shortest distance can be computed frame by frame with the help of
TimeSyncAlignmentLattice.forward:
1. At time t, after observing the previous frame (t-1), we know the forward
weights of states (t, s, c) for all context states c (`alpha`).
2. We also know all the arc weights from states (t, a, c), for all non-
final frame-local alignment states a and all context states c (`blank`
and `lexical`).
3. We also know the context dependency.
4. With the information above, TimeSyncAlignmentLattice.forward computes
the forward weights of states (t+1, s, c) for all context states c.
Args:
alpha: [batch_dims..., num_context_states] forward weights after observing
the previous frame.
blank: length num_alignment_states Sequence of [batch_dims...,
num_context_states] blank weights for the current frame, one for each
frame-local alignment state.
lexical: length num_alignment_states Sequence of [batch_dims...,
num_context_states, vocab_size] lexical weights for the current frame,
one for each frame-local alignment state.
context: Context dependency.
semiring: Semiring.
Returns:
[batch_dims..., num_context_states] forward weights after observing the
current frame.
"""
[docs]
@abc.abstractmethod
def backward(
self, alpha: jnp.ndarray, blank: Sequence[jnp.ndarray],
lexical: Sequence[jnp.ndarray], beta: jnp.ndarray, log_z: jnp.ndarray,
context: jnp.ndarray
) -> tuple[jnp.ndarray, list[jnp.ndarray], list[jnp.ndarray]]:
"""Processes one frame in the recognition lattice backward algorithm.
On a recognition lattice, the backward algorithm computes the arc marginals
under last.semirings.Log (the marginal probability of taking each lexical or
blank arc), as wells the sum of path weights from a recognition lattice
state to all final states (backward weights). The marginals and backward
weights can be computed frame by frame with the help of
TimeSyncAlignmentLattice.backward:
1. At time t, from the forward algorithm's results, we know the forward
weights of states (t, s, c) for all context states c (`alpha`) and the
overall shortest distance (sum of weights of all accepting paths,
`log_z`).
2. We also know all the arc weights from states (t, a, c), for all non-
final frame-local alignment states a and all context states c (`blank`
and `lexical`).
3. At time t, after observing frame (t+1), we know the backward weights of
states (t + 1, s, c) for all context states c (`beta`).
3. We also know the context dependency.
4. With the information above, TimeSyncAlignmentLattice.forward computes
the backward weights of states (t, s, c) for all context states c.
Args:
alpha: [batch_dims..., num_context_states] forward weights after observing
the previous frame.
blank: length num_alignment_states Sequence of [batch_dims...,
num_context_states] blank weights for the current frame, one for each
frame-local alignment state.
lexical: length num_alignment_states Sequence of [batch_dims...,
num_context_states, vocab_size] lexical weights for the current frame,
one for each frame-local alignment state.
beta: [batch_dims..., num_context_states] backward weights after observing
the next frame.
log_z: [batch_dims...] denominator, i.e. the sum of weights of all
accepting paths.
context: Context dependency.
Returns:
(next_beta, blank_marginal, lexical_marginal):
- next_beta: [batch_dims..., num_context_states] backward weights after
observing the current frame.
- blank_marginal: length num_alignment_states list of [batch_dims...,
num_context_states] marginals of blank arcs, one for each frame-local
alignment state.
- lexical_marginal: length num_alignment_states list of [batch_dims...,
num_context_states, vocab_size] marginals of lexical arcs, one for
each frame-local alignment state.
"""
[docs]
@abc.abstractmethod
def string_forward(self, alpha: jnp.ndarray, blank: Sequence[jnp.ndarray],
lexical: Sequence[jnp.ndarray],
semiring: semirings.Semiring[jnp.ndarray]) -> jnp.ndarray:
"""Processes one frame in the recognition lattice forward algorithm after the intersection with an output string.
Because the recognition lattice's topology is the intersection of an
alignment lattice and the context dependency, the intersection between
the recognition lattice and an output string is thus equivalent to
first intersecting the context dependency with the output string (i.e.
last.contexts.ContextDependency.walk_states), and then intersecting the
alignment lattice with the result.
On this intersected lattice, the forward algorithm computes the shortest
distance, i.e. the sum of path weights reaching all final states.
The shortest distance can be computed frame by frame with the help of
TimeSyncAlignmentLattice.string_forward:
1. At time t, after observing the previous frame (t-1), we know the forward
weights of states (t, s, c) for all context states c after the first
intersection. This is passed to TimeSyncAlignmentLattice.forward as
`alpha`. Note the first intersection results in a string acceptor, with
only `output_length + 1` context states.
2. We also know all the arc weights from states (t, a, c), for all non-
final frame-local alignment states a and all context states c after the
first intersection. These are passed to TimeSyncAlignmentLattice.forward
as `blank` and `lexical`.
3. We no longer need to know the context dependency because the result of
the first intersection is a simple chain.
4. With the information above, TimeSyncAlignmentLattice.string_forward
computes the forward weights of state (t+1, s, c) for all context states
c after the first intersection.
Args:
alpha: [batch_dims..., output_length + 1] forward weights after observing
the previous frame.
blank: length num_alignment_states Sequence of [batch_dims...,
output_length + 1] blank weights for the current frame, one for each
frame-local alignment state.
lexical: length num_alignment_states Sequence of [batch_dims...,
output_length + 1] lexical weights for the current frame, one for each
frame local alignment state.
semiring: Semiring.
Returns:
[batch_dims..., output_length + 1] forward weights after observing the
current frame.
"""
[docs]
def shift_down(x: jnp.ndarray,
semiring: semirings.Semiring[jnp.ndarray]) -> jnp.ndarray:
"""Shifts values down by 1 position.
This is a useful helper function for implementing string_forward().
Args:
x: [batch_dims..., N] input values.
semiring: Semiring to use for filling in zero values.
Returns:
[batch_dims..., N] output values, where output[..., i + 1] = x[..., i] and
output[..., 0] = semiring zero.
"""
return jnp.concatenate(
[semiring.zeros((*x.shape[:-1], 1), x.dtype), x[..., :-1]], axis=-1)
[docs]
def check_num_weights(alignment: TimeSyncAlignmentLattice,
blank: Sequence[jnp.ndarray],
lexical: Sequence[jnp.ndarray]):
"""Ensures that there are correct numbers of weight arrays."""
num_states = alignment.num_states()
if len(blank) != num_states:
raise ValueError(
f'blank should be a length {num_states} sequence of ndarrays, '
f'but got length {len(blank)}')
if len(lexical) != num_states:
raise ValueError(
f'lexical should be a length {num_states} sequence of ndarrays, '
f'but got length {len(lexical)}')
[docs]
class FrameDependent(TimeSyncAlignmentLattice):
"""Frame dependent alignment lattice.
Each frame is aligned to either a lexical label or a blank label.
"""
[docs]
def num_states(self) -> int:
return 1
[docs]
def start(self) -> int:
return 0
[docs]
def blank_next(self, state: int) -> Optional[int]:
return 0
[docs]
def lexical_next(self, state: int) -> Optional[int]:
return 0
[docs]
def topological_visit(self) -> list[int]:
return [0]
[docs]
def forward(self, alpha: jnp.ndarray, blank: Sequence[jnp.ndarray],
lexical: Sequence[jnp.ndarray],
context: contexts.ContextDependency,
semiring: semirings.Semiring[jnp.ndarray]) -> jnp.ndarray:
check_num_weights(self, blank, lexical)
# alpha: [batch_dims..., num_context_states]
# blank[0]: [batch_dims..., num_context_states]
# lexical[0]: [batch_dims..., num_context_states, vocab_size]
return semiring.plus(
semiring.times(alpha, blank[0]),
context.forward_reduce(
semiring.times(alpha[..., jnp.newaxis], lexical[0]), semiring))
[docs]
def backward(
self, alpha: jnp.ndarray, blank: Sequence[jnp.ndarray],
lexical: Sequence[jnp.ndarray], beta: jnp.ndarray, log_z: jnp.ndarray,
context: contexts.ContextDependency
) -> tuple[jnp.ndarray, list[jnp.ndarray], list[jnp.ndarray]]:
check_num_weights(self, blank, lexical)
# alpha: [batch_dims..., num_context_states]
# blank: [batch_dims..., num_context_states]
# lexical: [batch_dims..., num_context_states, vocab_size]
# beta: [batch_dims..., num_context_states]
# log_z: [batch_dims...]
blank_beta = blank[0] + beta
lexical_beta = lexical[0] + context.backward_broadcast(beta)
log_scale = alpha - log_z[..., jnp.newaxis]
blank_marginal = jnp.exp(blank_beta + log_scale)
lexical_marginal = jnp.exp(lexical_beta + log_scale[..., jnp.newaxis])
next_beta = semirings.Log.plus(blank_beta,
semirings.Log.sum(lexical_beta, axis=-1))
return next_beta, [blank_marginal], [lexical_marginal]
[docs]
def string_forward(self, alpha: jnp.ndarray, blank: Sequence[jnp.ndarray],
lexical: Sequence[jnp.ndarray],
semiring: semirings.Semiring[jnp.ndarray]) -> jnp.ndarray:
check_num_weights(self, blank, lexical)
# alpha: [batch_dims..., output_length + 1]
# blank: [batch_dims..., output_length + 1]
# lexical: [batch_dims..., output_length + 1]
return semiring.plus(
semiring.times(alpha, blank[0]),
shift_down(semiring.times(alpha, lexical[0]), semiring))
[docs]
@flax.struct.dataclass
class FrameLabelDependent(TimeSyncAlignmentLattice):
"""k-constrained frame-label-dependent alignment lattice.
Each frame is aligned to up to k lexical labels followed by a blank label.
Attributes:
max_expansions: The maximum number of lexical labels allowed per frame.
"""
max_expansions: int
[docs]
def num_states(self) -> int:
return self.max_expansions + 1
[docs]
def start(self) -> int:
return 0
[docs]
def blank_next(self, state: int) -> Optional[int]:
return 0
[docs]
def lexical_next(self, state: int) -> Optional[int]:
next_state = state + 1
if next_state <= self.max_expansions:
return next_state
else:
return None
[docs]
def topological_visit(self) -> list[int]:
return list(range(self.max_expansions + 1))
[docs]
def forward(self, alpha: jnp.ndarray, blank: Sequence[jnp.ndarray],
lexical: Sequence[jnp.ndarray],
context: contexts.ContextDependency,
semiring: semirings.Semiring[jnp.ndarray]) -> jnp.ndarray:
check_num_weights(self, blank, lexical)
# alpha: [batch_dims..., num_context_states]
# blank[i]: [batch_dims..., num_context_states]
# lexical[i]: [batch_dims..., num_context_states, vocab_size]
terminated = [semiring.times(alpha, blank[0])]
last = alpha
for i in range(self.max_expansions):
last = context.forward_reduce(
semiring.times(last[..., jnp.newaxis], lexical[i]), semiring)
terminated.append(semiring.times(last, blank[i + 1]))
return semiring.sum(jnp.stack(terminated), axis=0)
[docs]
def backward(
self, alpha: jnp.ndarray, blank: Sequence[jnp.ndarray],
lexical: Sequence[jnp.ndarray], beta: jnp.ndarray, log_z: jnp.ndarray,
context: contexts.ContextDependency
) -> tuple[jnp.ndarray, list[jnp.ndarray], list[jnp.ndarray]]:
check_num_weights(self, blank, lexical)
# alpha: [batch_dims..., num_context_states]
# blank[i]: [batch_dims..., num_context_states]
# lexical[i]: [batch_dims..., num_context_states, vocab_size]
# beta: [batch_dims..., num_context_states]
# log_z: [batch_dims...]
lexical_alphas = [alpha]
last = alpha
for i in range(self.max_expansions):
last = context.forward_reduce(last[..., jnp.newaxis] + lexical[i],
semirings.Log)
lexical_alphas.append(last)
# Corresponding backward paths.
blank_marginals = []
blank_log_scale = beta - log_z[..., jnp.newaxis]
for i in range(self.max_expansions + 1):
blank_marginals.append(
jnp.exp(lexical_alphas[i] + blank[i] + blank_log_scale))
next_beta = blank[self.max_expansions] + beta
lexical_marginals = []
for i in range(self.max_expansions):
j = self.max_expansions - 1 - i
lexical_beta = (lexical[j] + context.backward_broadcast(next_beta))
log_scale = (lexical_alphas[j] - log_z[..., jnp.newaxis])
lexical_marginals.append(
jnp.exp(lexical_beta + log_scale[..., jnp.newaxis]))
next_beta = semirings.Log.plus(blank[j] + beta,
semirings.Log.sum(lexical_beta, axis=-1))
lexical_marginals.reverse()
lexical_marginals.append(jnp.zeros_like(lexical[self.max_expansions]))
return next_beta, blank_marginals, lexical_marginals
[docs]
def string_forward(self, alpha: jnp.ndarray, blank: Sequence[jnp.ndarray],
lexical: Sequence[jnp.ndarray],
semiring: semirings.Semiring[jnp.ndarray]) -> jnp.ndarray:
check_num_weights(self, blank, lexical)
# alpha: [batch_dims..., output_length + 1]
# blank[i]: [batch_dims..., output_length + 1]
# lexical[i]: [batch_dims..., output_length + 1]
terminated = [semiring.times(alpha, blank[0])]
last = alpha
for i in range(self.max_expansions):
last = shift_down(semiring.times(last, lexical[i]), semiring)
terminated.append(semiring.times(last, blank[i + 1]))
return semiring.sum(jnp.stack(terminated), axis=0)