# Copyright 2023 The LAST Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Recognition lattice."""
from collections.abc import Callable, Sequence
import functools
from typing import Any, Generic, Optional, Protocol, TypeVar
import flax.linen as nn
import jax
import jax.numpy as jnp
from last import alignments
from last import contexts
from last import semirings
from last import weight_fns
DType = Any
T = TypeVar('T')
[docs]
class RecognitionLattice(nn.Module, Generic[T]):
"""Recognition lattice in GNAT-style formulation and operations over it.
A RecognitionLattice provides operations used in training and inference, such
as computing the negative-log-probability loss, or finding the highest scoring
alignment path.
Following the formulation in GNAT, we combines the three modelling components
to define a RecognitionLattice:
- Context dependency: The finite automaton that models output history. See
last.contexts.ContextDependency for details.
- Alignment lattice: The finite automaton that models the alignment between
input frames and output labels. See
last.alignments.TimeSyncAlignmentLattice for details.
- Weight function: The neural network that produces arc weights from any
context state given an input frame. See last.weight_fns for details.
Given a sequence of `T` input frames, its recognition lattice is the following
finite automaton:
- States: The states are pairs of an alignment state and a context state.
For any alignment state (t, a) (t is the index of the current frame) and
context state c, there is a state (t, a, c) in the recognition lattice.
- Start state: (0, s_a, s_c), where s_a is the start state in the
frame-local alignment lattice (see
last.contexts.TimeSyncAlignmentLattice), and s_c is the start state in the
context dependency.
- Final states: (T, s_a, c) for any context state c.
- Arcs:
- Blank arcs: For any blank arc `(t, a) -> (t', a')` in the alignment
lattice, and any context state c, there is an arc
`(t, a, c) --blank-> (t', a', c)` in the recognition lattice.
- Lexical arcs: For any lexical arc `(t, a) -> (t', a')` in the
alignment lattice, and any arc `c --y-> c'` in the context dependency,
there is an arc `(t, a, c) --y-> (t', a', c')` in the recognition
lattice.
- Arc weights: For each state (t, a, c), the weight function receives the
t-th frame and the context state c, and produces arc weights for blank
and lexical arcs. Notably while in principle weight functions can also
depend on the alignment state, in practice we haven't yet encountered a
compelling reason for such weight functions. Thus for the sake of
simplicity, weight functions currently only depend on the frame and the
context state.
The arc weights can be used to model the conditional distribution of the paths
in the recognition lattice, especially those that lead to the reference label
sequence. A RecognitionLattice can be either a locally normalized model, or
a globally normalized model,
- A locally normalized model uses last.weight_fns.LocallyNormalizedWeightFn,
where the arc weights from the same recognition lattice state add up to 1
after taking an exponential. The probability of
P(alignment labels | frames) is simply the product of arc weights on the
alignment path after exponential.
- A globally normalized model uses a WeightFn that is not a subclass of
last.weight_fns.LocallyNormalizedWeightFn. To obtain a probabilistic
interpretation, we normalize the path weights with the sum of the
exponentiated weights of all possible paths in the recognition lattice.
Globally normalized models are more expensive to train, but they have various
advantages. See the GNAT paper for more details.
Attributes:
context: Context dependency.
alignment: Alignment lattice.
weight_fn_cacher_factory: Callable that builds a WeightFnCacher given the
context dependency.
weight_fn_factory: Callable that builds a WeightFn given the context
dependency.
"""
context: contexts.ContextDependency
alignment: alignments.TimeSyncAlignmentLattice
# We use factories instead of fields of nn.Module because of restrictions from
# nn.custom_vjp (needed for _forward_backward). The important thing though is
# that weight_fn and weight_fn_cacher have to be child modules, so that the
# variable ownership is clear for nn.custom_vjp.
weight_fn_cacher_factory: Callable[[contexts.ContextDependency],
weight_fns.WeightFnCacher[T]]
weight_fn_factory: Callable[[contexts.ContextDependency],
weight_fns.WeightFn[T]]
[docs]
def setup(self):
self.weight_fn_cacher = self.weight_fn_cacher_factory(self.context)
self.weight_fn = self.weight_fn_factory(self.context)
[docs]
def build_cache(self) -> T:
"""Builds the weight function cache.
Weight functions are implemented as a pair of WeightFn and WeightFnCacher to
avoid unnecessary recomputation (see last.weight_fns for more details).
build_cache() builds the cached static data that can be used in other public
methods.
Returns:
Cached data.
"""
return self.weight_fn_cacher()
# Should a public method take an optional cache argument?
#
# - If the method operates on a sequence: cache is optional. The method
# should build the cache itself if necessary.
# - If the method operates on a single frame: cache should be required. The
# method should NOT build the cache itself. Methods operating on a single
# frame are expected to be called repeatedly. We want to prevent the users
# from making the mistake of recomputing the cache repeatedly.
def __call__(self,
frames: jnp.ndarray,
num_frames: jnp.ndarray,
labels: jnp.ndarray,
num_labels: jnp.ndarray,
cache: Optional[T] = None) -> jnp.ndarray:
"""Compute the negative sequence log-probability loss.
There can be multiple alignment paths from the input frames to the output
labels. The conditional probability P(labels | frames) is thus the sum
of probabilities P(alignment labels | frames) for all possible alignments
that produce the given label sequence. Interpreting the arc weights as
(possibly unnormalized) log-probabilities, this function computes
-log P(labels | frames) for both locally and globally normalized models.
Args:
frames: [batch_dims..., max_num_frames, feature_size] padded frame
sequences.
num_frames: [batch_dims...] number of frames.
labels: [batch_dims..., max_num_labels] padded label sequences.
num_labels: [batch_dims...] number of labels.
cache: Optional weight function cache data.
Returns:
[batch_dims...] negative sequence log-prob loss.
"""
batch_dims = num_frames.shape
if frames.shape[:-2] != batch_dims:
raise ValueError('frames and num_frames have different batch_dims: '
f'{frames.shape[:-2]} vs {batch_dims}')
if labels.shape[:-1] != batch_dims:
raise ValueError('labels and num_frames have different batch_dims: '
f'{labels.shape[:-1]} vs {batch_dims}')
if num_labels.shape != batch_dims:
raise ValueError('num_labels and num_frames have different batch_dims: '
f'{num_labels.shape} vs {batch_dims}')
semiring = semirings.Log
if cache is None:
cache = self.weight_fn_cacher()
numerator = self._string_forward(
cache=cache,
frames=frames,
num_frames=num_frames,
labels=labels,
num_labels=num_labels,
semiring=semiring)
if isinstance(self.weight_fn, weight_fns.LocallyNormalizedWeightFn):
return -numerator
denominator = self._forward_backward(
cache=cache, frames=frames, num_frames=num_frames)
return denominator - numerator
[docs]
def shortest_path(
self,
frames: jnp.ndarray,
num_frames: jnp.ndarray,
cache: Optional[T] = None
) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""Computes the shortest path in the recognition lattice.
The shortest path is the path with the highest score, in other words, the
"shoretst" path under the max-tropical semiring.
Args:
frames: [batch_dims..., max_num_frames, feature_size] padded frame
sequences.
num_frames: [batch_dims...] number of frames.
cache: Optional weight function cache data.
Returns:
(alignment_labels, num_alignment_labels, path_weights) tuple,
- alignment_labels: [batch_dims..., max_num_alignment_labels] padded
alignment labels, either blank (0) or lexical (1 to vocab_size).
- num_alignment_labels: [batch_dims...] number of alignment labels.
- path_weights: [batch_dims...] path weights.
"""
batch_dims = num_frames.shape
if frames.shape[:-2] != batch_dims:
raise ValueError('frames and num_frames have different batch_dims: '
f'{frames.shape[:-2]} vs {batch_dims}')
max_num_frames = frames.shape[-2]
num_alignment_states = self.alignment.num_states()
if cache is None:
cache = self.weight_fn_cacher()
# Find shortest path by differentiating shortest distance under the tropical
# semiring.
def forward(lattice: 'RecognitionLattice', lexical_mask: jnp.ndarray):
path_weights, _ = lattice._forward( # pylint: disable=protected-access
cache=cache,
frames=frames,
num_frames=num_frames,
semiring=semirings.MaxTropical,
# _forward expects lexical_mask to be a length num_alignment_states
# Sequence whose elements are broadcastable to
# [batch_dims..., max_num_frames, num_context_states, vocab_size].
lexical_mask=[
lexical_mask[..., i, jnp.newaxis, :]
for i in range(num_alignment_states)
])
return path_weights
_, vocab_size = self.context.shape()
lexical_mask = jnp.zeros(
[*batch_dims, max_num_frames, num_alignment_states, vocab_size])
path_weights, vjp_fn = nn.vjp(
forward, self, lexical_mask, vjp_variables=False)
_, viterbi_lexical_mask = vjp_fn(jnp.ones_like(path_weights))
is_blank = jnp.all(viterbi_lexical_mask == 0, axis=-1)
alignment_labels = jnp.where(is_blank, 0,
1 + jnp.argmax(viterbi_lexical_mask, axis=-1))
# Flatten into [batch_dims..., max_num_frames * num_alignment_states]
alignment_labels = alignment_labels.reshape([*batch_dims, -1])
num_alignment_labels = num_alignment_states * num_frames
return alignment_labels, num_alignment_labels, path_weights
# Private methods.
# TODO(wuke): Add support for lexical_mask for forced decoding.
# TODO(wuke): Add support for composite semirings (e.g. expectation).
def _string_forward(self, cache: T, frames: jnp.ndarray,
num_frames: jnp.ndarray, labels: jnp.ndarray,
num_labels: jnp.ndarray,
semiring: semirings.Semiring[jnp.ndarray]) -> jnp.ndarray:
"""Shortest distance on the intersection of the recognition lattice and an output string (the label sequence) computed using the forward algorithm.
Args:
cache: Weight function cache data.
frames: [batch_dims..., max_num_frames, feature_size] padded frame
sequences.
num_frames: [batch_dims...] number of frames.
labels: [batch_dims..., max_num_labels] padded label sequence.
num_labels: [batch_dims...] number of labels.
semiring: Semiring to use for shortest distance computation.
Returns:
[batch_dims...] shortest distance.
"""
batch_dims = num_frames.shape
if frames.shape[:-2] != batch_dims:
raise ValueError('frames and num_frames have different batch_dims: '
f'{frames.shape[:-2]} vs {batch_dims}')
if labels.shape[:-1] != batch_dims:
raise ValueError('labels and num_frames have different batch_dims: '
f'{labels.shape[:-1]} vs {batch_dims}')
if num_labels.shape != batch_dims:
raise ValueError('num_labels and num_frames have different batch_dims: '
f'{num_labels.shape} vs {batch_dims}')
# Calculate arc weights for all visited context states.
#
# We can't fit into memory all
# O(batch_size * max_num_frames * (max_num_labels + 1) * (vocab_size + 1))
# arcs, thus we use a scan loop over the (max_num_labels + 1) axis to
# produce just the O(batch_size * max_num_frames * (max_num_labels + 1))
# arcs actually needed later. This is better than scanning over the
# max_num_frames axis because weight_fn can be vectorized over multiple
# frames for the same state (weight_fn with states often involves
# gathering).
# Before vmaps
# - frame is [batch_dims..., hidden_size]
# - state is [batch_dims...]
# Results are ([batch_dims...], [batch_dims..., vocab_size]).
compute_weights = (
lambda weight_fn, frame, state: weight_fn(cache, frame, state))
# Add time dimension on frame
# - frame is [batch_dims..., max_num_frames, hidden_size]
# - state is [batch_dims...]
# Results are ([batch_dims..., max_num_frames],
# [batch_dims..., max_num_frames, vocab_size]).
compute_weights = nn.vmap(
compute_weights,
variable_axes={'params': None},
split_rngs={'params': False},
in_axes=(-2, None),
out_axes=(-1, -2))
def gather_weight(weights, y):
# weights: [batch_dims..., max_num_frames, vocab_size]
# y: [batch_dims..., max_num_frames]
# weights are for labels [1, vocab_size], so y-1 are the corresponding
# indicies. one_hot(-1) is safe (all zeros).
mask = nn.one_hot(y - 1, weights.shape[-1])
return jnp.einsum('...TV,...V->...T', weights, mask)
def weight_step(weight_fn, carry, inputs):
del carry
state, next_label = inputs
blank_weight, lexical_weights = compute_weights(weight_fn, frames, state)
lexical_weight = gather_weight(lexical_weights, next_label)
return None, (blank_weight, lexical_weight)
# prevent_cse is not needed in loops. Turning it off allows the compiler to
# better optimize the loop step.
weight_step = nn.remat(weight_step, prevent_cse=False)
# [batch_dims..., max_num_labels + 1]
context_states = self.context.walk_states(labels)
context_next_labels = jnp.concatenate(
[labels, jnp.ones_like(labels[..., :1])], axis=-1)
# [batch_dims..., max_num_frames, max_num_labels+1]
_, (blank_weight, lexical_weight) = nn.scan(
weight_step,
variable_broadcast='params',
split_rngs={'params': False},
in_axes=len(batch_dims),
out_axes=len(batch_dims) + 1)(self.weight_fn, None,
(context_states, context_next_labels))
# Dynamic program for summing up all alignment paths. Actual work is done by
# alignment.string_forward(). This function mostly takes care of padding
# frames.
def shortest_distance_step(carry, inputs):
# alpha: [batch_dims..., max_num_labels + 1]
t, alpha = carry
# blank, lexical: [batch_dims..., max_num_labels + 1]
blank, lexical = inputs
# We current only support alignment-state invariant weights.
blank = [blank for _ in range(self.alignment.num_states())]
lexical = [lexical for _ in range(self.alignment.num_states())]
next_alpha = self.alignment.string_forward(
alpha=alpha, blank=blank, lexical=lexical, semiring=semiring)
is_padding = (t >= num_frames)[..., jnp.newaxis]
next_alpha = jnp.where(is_padding, alpha, next_alpha)
return (t + 1, next_alpha), None
num_alpha_states = labels.shape[-1] + 1
init_alpha = _init_context_state_weights(
batch_dims=batch_dims,
dtype=lexical_weight.dtype,
num_states=num_alpha_states,
start=0,
semiring=semiring)
(_, alpha), _ = jax.lax.scan(
shortest_distance_step, (0, init_alpha),
jax.tree_util.tree_map(
functools.partial(_to_time_major, num_batch_dims=len(batch_dims)),
(blank_weight, lexical_weight)))
is_final = num_labels[..., jnp.newaxis] == jnp.arange(num_alpha_states)
return semiring.sum(
jnp.where(is_final, alpha, semiring.zeros([], alpha.dtype)), axis=-1)
# TODO(wuke): Find a way to create a public shortest_distance() method whose
# interface isn't overly complex.
def _forward(
self,
cache: T,
frames: jnp.ndarray,
num_frames: jnp.ndarray,
semiring: semirings.Semiring[jnp.ndarray],
blank_mask: Optional[Sequence[jnp.ndarray]] = None,
lexical_mask: Optional[Sequence[jnp.ndarray]] = None,
) -> tuple[jnp.ndarray, jnp.ndarray]:
"""Shortest distance on the recognition lattice computed using the forward algorithm.
It is often useful to differentiate through shortest distance with respect
to arc weights. For example, under the log semiring, that gives us arc
marginals; whereas under the tropical semiring, that gives us shortest path.
To make that possible while the arc weights are being computed on the fly,
we can pass in zero-valued masks. The masks are added onto arc weights, and
because d f(x + y) / dy | y=0 = d f(x) / dx, we can get gradients with
respect to arc weights (i.e. x) by differentiating over the masks (i.e. y).
Args:
cache: Weight function cache data.
frames: [batch_dims..., max_num_frames, feature_size] padded frame
sequences.
num_frames: [batch_dims...] number of frames.
semiring: Semiring to use for shortest distance computation.
blank_mask: Optional length num_alignment_states sequence whose elements
are broadcastable to [batch_dims..., max_num_frames,
num_context_states].
lexical_mask: Optional length num_alignment_states sequence whose elements
are broadcastable to [batch_dims..., max_num_frames, num_context_states,
vocab_size].
Returns:
(shortest_distance, alpha_0_to_T_minus_1) tuple,
- shortest_distance: [batch_dims...] shortest distance.
- alpha_0_to_T_minus_1: [batch_dims..., max_num_frames,
num_context_states] forward weights after observing 0 to T-1 frames
(T is the number of frames in the sequence).
"""
batch_dims = num_frames.shape
if frames.shape[:-2] != batch_dims:
raise ValueError('frames and num_frames have different batch_dims: '
f'{frames.shape[:-2]} vs {batch_dims}')
if blank_mask is not None and len(
blank_mask) != self.alignment.num_states():
raise ValueError(
'The length of blank_mask should be equal to '
f'{self.alignment.num_states()} (the number of alignment states), '
f'but is {len(blank_mask)}')
if lexical_mask is not None and len(
lexical_mask) != self.alignment.num_states():
raise ValueError(
'The length of lexical_mask should be equal to '
f'{self.alignment.num_states()} (the number of alignment states), '
f'but is {len(lexical_mask)}')
# Dynamic program for summing up all alignment paths.
def step(weight_fn, carry, inputs):
# alpha: [batch_dims..., num_context_states]
t, alpha = carry
# frame: [batch_dims..., hidden_size]
# blank_mask: None or [batch_dims...]
# lexical_mask: None or broadcastable to
# [batch_dims..., num_alignment_states, vocab_size]
frame, blank_mask, lexical_mask = inputs
# blank: [batch_dims..., num_context_states]
# lexical: [batch_dims..., num_context_states, vocab_size]
blank, lexical = weight_fn(cache, frame)
# We currently only support alignment-state-invariant weights.
blank = [blank for _ in range(self.alignment.num_states())]
lexical = [lexical for _ in range(self.alignment.num_states())]
if blank_mask is not None:
blank = [b + m for b, m in zip(blank, blank_mask)]
if lexical_mask is not None:
lexical = [l + m for l, m in zip(lexical, lexical_mask)]
next_alpha = self.alignment.forward(
alpha=alpha,
blank=blank,
lexical=lexical,
context=self.context,
semiring=semiring)
is_padding = (t >= num_frames)[..., jnp.newaxis]
next_alpha = jnp.where(is_padding, alpha, next_alpha)
return (t + 1, next_alpha), alpha
# Reduce memory footprint when using autodiff.
#
# For the log semiring, this is not as fast or memory efficient as forward-
# backward, but still better than the defaults (i.e. no remat or no saving
# intermediates at all).
#
# For the tropical semiring, this should be equivalent to no remat.
def save_small(prim, *args, **params):
y, _ = prim.abstract_eval(*args, **params)
greater_than_1_dims = len([None for i in y.shape if i > 1])
save = greater_than_1_dims <= (len(batch_dims) + 1)
return save
# prevent_cse is not needed in loops. Turning it off allows the compiler to
# better optimize the loop step.
step = nn.remat(step, prevent_cse=False, policy=save_small)
init_t = jnp.array(0)
init_alpha = _init_context_state_weights(
batch_dims=batch_dims,
# TODO(wuke): Find a way to do this with jax.eval_shape.
dtype=self.weight_fn(cache, frames[..., 0, :])[0].dtype,
num_states=self.context.shape()[0],
start=self.context.start(),
semiring=semiring)
init_carry = (init_t, init_alpha)
inputs = (frames, blank_mask, lexical_mask)
(_, alpha_T), alpha_0_to_T_minus_1 = nn.scan( # pylint: disable=invalid-name
step,
variable_broadcast='params',
split_rngs={'params': False},
in_axes=len(batch_dims),
out_axes=len(batch_dims))(self.weight_fn, init_carry, inputs)
return semiring.sum(alpha_T, axis=-1), alpha_0_to_T_minus_1
[docs]
class BackwardStepCallback(Protocol):
"""Callback signature used in the backward algorithm loop."""
# Type names.
# pylint: disable=invalid-name
Blank = jnp.ndarray
Lexical = jnp.ndarray
ParamsGrad = Any
CacheGrad = Any
FrameGrad = jnp.ndarray
Carry = TypeVar('Carry')
Output = Any
# pylint: enable=invalid-name
def __call__(self, weight_vjp_fn: Callable[[Blank, Lexical],
tuple[ParamsGrad, CacheGrad,
FrameGrad]], carry: Carry,
blank_marginal: Blank,
lexical_marginals: Lexical) -> tuple[Carry, Output]:
"""Callback used in the backward algorithm loop.
Standard backward algorithm simply computes the arc marginals and backward
weights. Through a custom callback, we can perform on the fly processing
beyond these without having to store all the arc marginals. An example is
accumulating the gradients with respect to weight function parameters (see
_forward_backward).
Args:
weight_vjp_fn: VJP function of the weight function. Callable of the
signature (blank_grad, lexical_grad) -> (params_grad, cache_grad,
frame_grad).
carry: PyTree of custom callback carry data.
blank_marginal: [batch_dims..., num_context_states] marginal probability
of blank arcs.
lexical_marginals: [batch_dims..., num_context_states, vocab_size]
marginal probability of lexical arcs.
Returns:
next_carry and step outputs.
"""
raise NotImplementedError
def _backward(
self,
cache: T,
frames: jnp.ndarray,
num_frames: jnp.ndarray,
log_z: jnp.ndarray,
alpha_0_to_T_minus_1: jnp.ndarray, # pylint: disable=invalid-name
init_callback_carry: ...,
callback: BackwardStepCallback) -> tuple[Any, Any]:
"""Computes arc marginals under the log semiring using the backward algorithm.
Under the log semiring, arc weights can be viewed as unnormalized log
probabilities, and a conditional distribution over paths can be defined by
normalizing with respect to the exponentiated shortest distance (i.e. sum of
unnormalized path probabilities). The marginal probability of each arc can
then be computed with the backward algorithm.
Mathematically, under the log semiring, arc marginals are equal to the
gradients of shortest distance with respect to arc weights. The backward
algorithm offers a slightly more efficient method for computing these
gradients than reverse mode automatic differentiation with gradient
rematerialization:
- Both methods compute the arc weights twice: once in the forward pass,
once in the backward pass.
- Both methods carry out the "backward-broadcast" operation, i.e.
broadcasting the backward weights from a destination state to all source
states, once in the backward pass.
- Autodiff carries out the "forward-reduce" operation, i.e. summing up
path weights to the same destination state, twice: once in the forward
pass, once in the backward pass.
- Forward-backward only carries out the "forward-reduce" operation once,
in the forward pass.
In other words, forward-backward saves one "forward-reduce" operation. The
savings can be significant when the "forward-reduce" call is often
expensive, which is the main justification for all this added complexity.
Args:
cache: Weight function cache data.
frames: [batch_dims..., max_num_frames, feature_size] padded frame
sequences.
num_frames: [batch_dims...] number of frames.
log_z: [batch_dims...] shortest distance from _forward(). Under the log
semiring, the shortest distance is the log-normalizer, thus the name.
alpha_0_to_T_minus_1: [batch_dims..., max_num_frames, num_context_states]
forward weights from _forward().
init_callback_carry: PyTree of initial carry value for the callback.
callback: Callback used in the backward algorithm loop.
Returns:
(final_callback_carry, callback_outputs) tuple.
"""
batch_dims = num_frames.shape
if frames.shape[:-2] != batch_dims:
raise ValueError('frames and num_frames have different batch_dims: '
f'{frames.shape[:-2]} vs {batch_dims}')
if log_z.shape != batch_dims:
raise ValueError('log_z and num_frames have different batch_dims: '
f'{log_z.shape} vs {batch_dims}')
if alpha_0_to_T_minus_1.shape[:-2] != batch_dims:
raise ValueError(
'alpha_0_to_T_minus_1 and num_frames have different '
f'batch_dims: {alpha_0_to_T_minus_1.shape[:-2]} vs {batch_dims}')
def step(lattice, carry, inputs):
# beta: [batch_dims..., num_context_states]
t, beta, callback_carry = carry
# alpha: [batch_dims..., num_context_states]
# frame: [batch_dims..., hidden_size]
alpha, frame = inputs
# blank: [batch_dims..., num_context_states]
# lexical: [batch_dims..., num_context_states, vocab_size]
(blank, lexical), weight_vjp_fn = nn.vjp(
lambda lattice, cache, frame: lattice.weight_fn(cache, frame),
lattice, cache, frame)
# We currently only support alignment-state-invariant weights.
blank = [blank for _ in range(self.alignment.num_states())]
lexical = [lexical for _ in range(self.alignment.num_states())]
next_beta, blank_marginal, lexical_marginals = self.alignment.backward(
alpha=alpha,
blank=blank,
lexical=lexical,
beta=beta,
log_z=log_z,
context=self.context)
# We currently only support alignment-state-invariant weights.
blank_marginal = jnp.sum(jnp.stack(blank_marginal), axis=0)
lexical_marginals = jnp.sum(jnp.stack(lexical_marginals), axis=0)
# Mask out marginals on padding positions.
is_padding = (t >= num_frames)[..., jnp.newaxis]
next_beta = jnp.where(is_padding, beta, next_beta)
blank_marginal = jnp.where(is_padding, 0, blank_marginal)
lexical_marginals = jnp.where(is_padding[..., jnp.newaxis], 0,
lexical_marginals)
next_callback_carry, callback_outputs = callback(
weight_vjp_fn=weight_vjp_fn,
carry=callback_carry,
blank_marginal=blank_marginal,
lexical_marginals=lexical_marginals)
return (t - 1, next_beta, next_callback_carry), callback_outputs
num_context_states, _ = self.context.shape()
init_beta = semirings.Log.ones([*batch_dims, num_context_states],
log_z.dtype)
init_t = jnp.array(frames.shape[-2] - 1)
init_carry = (init_t, init_beta, init_callback_carry)
inputs = (alpha_0_to_T_minus_1, frames)
(_, _, final_callback_carry), callback_outputs = nn.scan(
step,
variable_broadcast='params',
split_rngs={'params': False},
in_axes=len(batch_dims),
out_axes=len(batch_dims),
reverse=True)(self, init_carry, inputs)
return final_callback_carry, callback_outputs
def _forward_backward(self, cache: T, frames: jnp.ndarray,
num_frames: jnp.ndarray) -> jnp.ndarray:
"""Shortest distance under the log semiring with gradients computed using the backward algorithm.
Args:
cache: Weight function cache data.
frames: [batch_dims..., max_num_frames, feature_size] padded frame
sequences.
num_frames: [batch_dims...] number of frames.
Returns:
[batch_dims...] shortest distance.
"""
semiring = semirings.Log
# This function is mostly flax wizardry to make custom_vjp work.
#
# The high level idea is to call _forward() in fwd() and _backward() in
# bwd().
#
# The tricky part is that flax disallows returning a linen Module as part of
# the residuals in fwd() for fear of the Module being mutated in bwd(). To
# get around this, we obtain an immutable copy of model parameters in fwd(),
# and then call Module.apply with that in bwd, to guarantee that no mutation
# to the Module object will occur.
def f(lattice: 'RecognitionLattice', cache: T, frames: jnp.ndarray):
"""Normal function evaluation."""
log_z, _ = lattice._forward( # pylint: disable=protected-access
cache=cache,
frames=frames,
num_frames=num_frames,
semiring=semiring)
return log_z
def fwd(lattice: 'RecognitionLattice', cache: T, frames: jnp.ndarray):
"""Forward pass function evaluation."""
log_z, alpha_0_to_T_minus_1 = lattice._forward( # pylint: disable=invalid-name,protected-access
cache=cache,
frames=frames,
num_frames=num_frames,
semiring=semiring)
# jax.tree_util.Partial makes the function an "empty" PyTree so that we
# can pass this function as part of res. All this function really does is
# keeping a reference to the `lattice` object.
@jax.tree_util.Partial
def apply_backward(params, **kwargs):
# Module.apply is functional so it's safe to call it in the backward
# pass.
return lattice.apply(params, **kwargs, method=lattice._backward) # pylint: disable=protected-access
# Obtain the values of weight function parameters. All parameters have
# been created by the time the _forward() call above returns.
params = {'params': lattice.variables['params']}
res = (apply_backward, params, cache, frames, log_z, alpha_0_to_T_minus_1)
return log_z, res
def bwd(res, g: jnp.ndarray):
apply_backward, params, cache, frames, log_z, alpha_0_to_T_minus_1 = res # pylint: disable=invalid-name
def vjp_callback(weight_vjp_fn, carry, blank_marginal, lexical_marginals):
params_grad, cache_grad = carry
params_grad_t, cache_grad_t, frame_grad_t = weight_vjp_fn(
(jnp.expand_dims(g, -1) * blank_marginal,
jnp.expand_dims(g, (-1, -2)) * lexical_marginals))
next_carry = jax.tree_util.tree_map(jnp.add, (params_grad, cache_grad),
(params_grad_t, cache_grad_t))
outputs = frame_grad_t
return next_carry, outputs
# Zero accumulators of params_grad/cache_grad.
init_params_grad = jax.tree_util.tree_map(jnp.zeros_like, params)
init_cache_grad = jax.tree_util.tree_map(jnp.zeros_like, cache)
init_callback_carry = init_params_grad, init_cache_grad
(params_grad, cache_grad), frames_grad = apply_backward(
params,
cache=cache,
frames=frames,
num_frames=num_frames,
log_z=log_z,
alpha_0_to_T_minus_1=alpha_0_to_T_minus_1,
init_callback_carry=init_callback_carry,
callback=vjp_callback)
return params_grad, cache_grad, frames_grad
return nn.custom_vjp(f, fwd, bwd)(self, cache, frames)
def _init_context_state_weights(
batch_dims: Sequence[int], dtype: DType, num_states: int, start: int,
semiring: semirings.Semiring[jnp.ndarray]) -> jnp.ndarray:
is_start = jnp.arange(num_states) == start
weights = jnp.where(is_start, semiring.ones([], dtype),
semiring.zeros([], dtype))
return jnp.broadcast_to(weights, (*batch_dims, num_states))
def _to_time_major(x: jnp.ndarray, num_batch_dims: int) -> jnp.ndarray:
# [batch_dims..., time, ...] -> [time, batch_dims..., ...]
axes = [
num_batch_dims,
*range(num_batch_dims),
*range(num_batch_dims + 1, x.ndim),
]
return jnp.transpose(x, axes)
def _to_batch_major(x: jnp.ndarray, num_batch_dims: int) -> jnp.ndarray:
# [time, batch_dims..., ...] -> [batch_dims..., time, ...]
axes = [
*range(1, num_batch_dims + 1),
0,
*range(num_batch_dims + 1, x.ndim),
]
return jnp.transpose(x, axes)