# Copyright 2023 The LAST Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Semirings."""
from collections.abc import Sequence
import dataclasses
import functools
from typing import Any, Callable, Generic, Optional, TypeVar
import jax
import jax.numpy as jnp
# Types for documentation purposes.
DType = Any
PyTree = Any
# Type variables for semiring values.
T = TypeVar('T')
S = TypeVar('S')
[docs]
def value_shape(x: PyTree) -> tuple[int, ...]:
"""Obtains the shape of a semiring value.
A semiring value is a PyTree of one or more identically shaped ndarrays.
The shape of a semiring value is thus the common of shape of its leaves.
Args:
x: Some semiring value.
Returns:
The common shape of the leaves of x.
Raises:
ValueError: If the leaves of x do not have a common shape.
"""
shapes = [i.shape for i in jax.tree_util.tree_leaves(x)]
if not shapes:
raise ValueError(
f'No common shape can be derived for an empty PyTree: {x!r}'
)
result = shapes[0]
for i in shapes[1:]:
if i != result:
raise ValueError(
'A semiring value must consist of ndarrays of a common shape. '
f'Got inconsistent shapes {result} vs {i} for PyTree: {x!r}'
)
return result
[docs]
def value_dtype(x: PyTree) -> DType:
"""Obtains the dtypes of a semiring value.
Different leaves of a semiring value may have different dtypes. Methods
such as Semiring.{zeros,ones} can take a PyTree of dtypes in the same
structure as the corresponding semiring values. This function can be used
to extract such a dtype PyTree from a semiring value.
Args:
x: Some semiring value.
Returns:
dtypes in the same structure as x.
"""
return jax.tree_util.tree_map(lambda x_: x_.dtype, x)
[docs]
class Semiring(Generic[T]):
"""Base Semiring interface.
See https://en.wikipedia.org/wiki/Semiring for what a semiring is. A Semiring
object holds methods that implement the semiring operations. To simplify
non-semiring operations on the semiring values, the semiring values are not
typed: for most basic semirings, each value is a single ndarray; for some more
complex semirings (e.g. Expectation or Cartesian), the values can be a tuple
of ndarrays.
In general, a semiring value under some semiring is represented as a PyTree
of identically shaped ndarrays, with possibly different dtypes. The shape
and dtypes of a semiring value can be obtained with methods
`last.semirings.value_shape()` and `last.semirings.value_dtype()`.
Semiring is not an abstract base class because we allow operations to be
unimplemented (e.g. `prod`, is not commonly used).
Implementation tips:
* Binary operations (times & plus) should support broadcasting. If a binary
operation requires a custom vjp implementation, it may not be straight-
forward to write one that handles input broadcasting properly. Instead,
write a custom vjp function that requires inputs with identical shapes,
then use `jnp.broadcast_arrays()` to preprocess the operands in the
implementation of the semiring operations. See `{_Log,_MaxTropical}.plus`
for examples.
* Reductions (prod & sum) can be tricky to implement correctly, here are
two important things to watch out for:
* `axis` can be in the range [-rank, rank).
* The input can have 0-sized dimensions.
"""
[docs]
def zeros(self, shape: Sequence[int], dtype: Optional[DType] = None) -> T:
"""Semiring zeros in the given shape and dtype.
Args:
shape: Desired output shape.
dtype: Optional PyTree of dtypes.
Returns:
If dtype is None, semiring zero values in the specified shape with
reasonable default dtypes. Otherwise, semiring zero values in the
specified shape with the specified dtypes.
"""
raise NotImplementedError
[docs]
def ones(self, shape: Sequence[int], dtype: Optional[DType] = None) -> T:
"""Semiring ones in the given shape and dtype.
Args:
shape: Desired output shape.
dtype: Optional PyTree of dtypes.
Returns:
If dtype is None, semiring one values in the specified shape with
reasonable default dtypes. Otherwise, semiring one values in the
specified shape with the specified dtypes.
"""
raise NotImplementedError
[docs]
def times(self, a: T, b: T) -> T:
"""Semiring multiplication between two values."""
raise NotImplementedError
[docs]
def plus(self, a: T, b: T) -> T:
"""Semiring addition between two values."""
raise NotImplementedError
[docs]
def prod(self, a: T, axis: int) -> T:
"""Semiring multiplication along a single axis."""
raise NotImplementedError
[docs]
def sum(self, a: T, axis: int) -> T:
"""Semiring addition along a single axis."""
raise NotImplementedError
class _Real(Semiring[jnp.ndarray]):
"""Real semiring."""
@staticmethod
def zeros(
shape: Sequence[int], dtype: Optional[DType] = None
) -> jnp.ndarray:
return jnp.zeros(shape, dtype)
@staticmethod
def ones(shape: Sequence[int], dtype: Optional[DType] = None) -> jnp.ndarray:
return jnp.ones(shape, dtype)
@staticmethod
def times(a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray:
return a * b
@staticmethod
def plus(a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray:
return a + b
@staticmethod
def prod(a: jnp.ndarray, axis: int) -> jnp.ndarray:
return jnp.prod(a, axis=axis)
@staticmethod
def sum(a: jnp.ndarray, axis: int) -> jnp.ndarray:
return jnp.sum(a, axis=axis)
Real = _Real()
def _check_axis(a: jnp.ndarray, axis: int) -> None:
if not isinstance(axis, int):
raise ValueError(f'Only int axis is supported, got axis={axis!r}')
if not -a.ndim <= axis < a.ndim:
raise ValueError(
f'Invalid reduction axis={axis!r} for input shape {a.shape}')
class _Log(Semiring[jnp.ndarray]):
"""Log semiring."""
@staticmethod
def zeros(
shape: Sequence[int], dtype: Optional[DType] = None
) -> jnp.ndarray:
return jnp.full(shape, -jnp.inf, dtype)
@staticmethod
def ones(shape: Sequence[int], dtype: Optional[DType] = None) -> jnp.ndarray:
return jnp.zeros(shape, dtype)
@staticmethod
def times(a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray:
return a + b
@staticmethod
def plus(a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray:
a, b = jnp.broadcast_arrays(a, b)
return _logaddexp(a, b)
@staticmethod
def prod(a: jnp.ndarray, axis: int) -> jnp.ndarray:
return jnp.sum(a, axis=axis)
@classmethod
def sum(cls, a: jnp.ndarray, axis: int) -> jnp.ndarray:
_check_axis(a, axis)
# Special handling is needed because jnp.max (used in _logsumexp) doesn't
# support reduction on 0-sized dimensions.
if a.size > 0:
return _logsumexp(a, axis=axis)
# Summing empty input should result in zeros.
if axis < 0:
axis += a.ndim
result_shape = a.shape[:axis] + a.shape[axis + 1:]
return cls.zeros(result_shape, a.dtype)
# Specialized log{add,sum}exp with safe gradients.
#
# Scenarios:
# - All operands are finite: As expected.
# - All operands are -inf: Sum should be -inf. Gradient should be 0.
# - All operands are +inf: Sum should be +inf. Gradient should be NaN.
# - Mixed finite & -inf operands: Sum as expected. Gradient should be 0 for
# -inf; non-0 for others.
# - Mixed finite & +inf operands: Sum should +inf. Gradient should be NaN for
# +inf; 0 for others.
# - Mixed -inf & +inf operands: Sum should be +inf. Gradient should be NaN for
# +inf; 0 for -inf.
# - Mixed finite, -inf & +inf operands: Sum should be +inf. Gradient should be
# NaN for +inf; 0 for others.
#
# The different treatment of -inf & +inf comes from their different sources.
# - +inf is an indicator of a true error, e.g. an overflow somewhere. It's
# thus desirabled to not silence such issues.
# - -inf often arises from perfectly legitimate computations such as
# `logaddexp(-inf, -inf + x)`, where `x` should not receive a NaN gradient.
@jax.custom_vjp
def _logaddexp(a, b):
return _logaddexp_fwd(a, b)[0]
def _logaddexp_fwd(a, b):
c = jnp.maximum(a, b)
safe = jnp.isfinite(c)
c = jnp.where(safe, c, 0)
ea = jnp.exp(a - c)
eb = jnp.exp(b - c)
z = ea + eb
return c + jnp.log(z), (ea, eb, z)
def _logaddexp_bwd(res, g):
ea, eb, z = res
safe = z != 0
z = jnp.where(safe, z, 1)
scale = g / z
return scale * ea, scale * eb
_logaddexp.defvjp(_logaddexp_fwd, _logaddexp_bwd)
@functools.partial(jax.custom_vjp, nondiff_argnums=(1,))
def _logsumexp(a, axis):
return _logsumexp_fwd(a, axis)[0]
def _logsumexp_fwd(a, axis):
c = jnp.max(a, axis=axis, keepdims=True)
safe = jnp.isfinite(c)
c = jnp.where(safe, c, 0)
e = jnp.exp(a - c)
z = jnp.sum(e, axis=axis, keepdims=True)
r = jnp.squeeze(c, axis=axis) + jnp.log(jnp.squeeze(z, axis=axis))
return r, (e, z)
def _logsumexp_bwd(axis, res, g):
e, z = res
safe = z != 0
z = jnp.where(safe, z, 1)
g = jnp.expand_dims(g, axis=axis)
# g & z are smaller than e, doing the division between g & z instead e & z is
# thus faster.
return (g / z * e,)
_logsumexp.defvjp(_logsumexp_fwd, _logsumexp_bwd)
Log = _Log()
class _MaxTropical(Semiring):
"""Max tropical semiring.
The gradients of `plus` and `sum` is guaranteed to be non-zero on exactly 1
input element, even in the event of a tie.
"""
@staticmethod
def zeros(
shape: Sequence[int], dtype: Optional[DType] = None
) -> jnp.ndarray:
return jnp.full(shape, -jnp.inf, dtype)
@staticmethod
def ones(shape: Sequence[int], dtype: Optional[DType] = None) -> jnp.ndarray:
return jnp.zeros(shape, dtype)
@staticmethod
def times(a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray:
return a + b
@staticmethod
def plus(a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray:
a, b = jnp.broadcast_arrays(a, b)
return _maximum(a, b)
@staticmethod
def prod(a: jnp.ndarray, axis: int) -> jnp.ndarray:
return jnp.sum(a, axis=axis)
@classmethod
def sum(cls, a: jnp.ndarray, axis: int) -> jnp.ndarray:
_check_axis(a, axis)
# Special handling is needed because jnp.max doesn't support reduction on
# 0-sized dimensions.
if a.size > 0:
return _max(a, axis=axis)
# Summing empty input should result in zeros.
if axis < 0:
axis += a.ndim
result_shape = a.shape[:axis] + a.shape[axis + 1:]
return cls.zeros(result_shape, a.dtype)
MaxTropical = _MaxTropical()
@jax.custom_vjp
def _maximum(a, b):
return _maximum_fwd(a, b)[0]
def _maximum_fwd(a, b):
return jnp.maximum(a, b), a >= b
def _maximum_bwd(res, g):
choose_a = res
return g * choose_a, g * (1 - choose_a)
_maximum.defvjp(_maximum_fwd, _maximum_bwd)
@functools.partial(jax.custom_vjp, nondiff_argnums=(1,))
def _max(a, axis):
return _max_fwd(a, axis)[0]
def _max_fwd(a, axis):
argmax = jnp.argmax(a, axis)
width = a.shape[axis]
return jnp.max(a, axis), (argmax, width)
def _max_bwd(axis, res, g):
argmax, width = res
mask = jax.nn.one_hot(argmax, width, axis=axis, dtype=g.dtype)
g = jnp.expand_dims(g, axis=axis)
return (g * mask,)
_max.defvjp(_max_fwd, _max_bwd)
[docs]
@dataclasses.dataclass(frozen=True)
class Expectation(Generic[T, S], Semiring[tuple[T, S]]):
"""Jason Eisner's expectation semiring.
In most cases, use LogLogExpectation below directly.
See https://www.cs.jhu.edu/~jason/papers/eisner.fsmnlp01.pdf for reference.
Each semiring value is a tuple (w, x):
- w: The weight of this tuple, expressed in the self.w semiring.
- x: The weighted sum for some corresponding weighted values, expressed in
the self.x semiring.
To create a semiring value from a weight-value pair, use
`self.weighted()`. See `ExpectationTest.test_entropy` for an
example of using the expectation semiring to compute the entropy of
probability distributions.
Attributes:
w: Semiring for representing weights.
x: Semiring for representing weighted sums.
w_to_x: Function to convert a value from semiring `w` to semiring `x`.
"""
w: Semiring[T]
x: Semiring[S]
w_to_x: Callable[[T], S]
def weighted(self, w: T, v: S) -> tuple[T, S]:
# When w is zero in semiring self.w, self.w_to_x(w) is zero in semiring
# self.x. We stipulate that the weighted value should always be zero in
# semiring self.x. This is useful for avoiding NaNs when both semirings are
# Log and w is -inf and v is +inf (e.g. computing 0 log 0 under Log).
w_is_zero = w == self.w.zeros([], w.dtype)
safe_v = jnp.where(w_is_zero, 0, v)
return w, self.x.times(self.w_to_x(w), safe_v)
[docs]
def zeros(
self, shape: Sequence[int], dtype: Optional[DType] = None
) -> tuple[T, S]:
if dtype is None:
dtype_w = dtype_x = None
else:
dtype_w, dtype_x = dtype
return self.w.zeros(shape, dtype_w), self.x.zeros(shape, dtype_x)
[docs]
def ones(
self, shape: Sequence[int], dtype: Optional[DType] = None
) -> tuple[T, S]:
if dtype is None:
dtype_w = dtype_x = None
else:
dtype_w, dtype_x = dtype
return self.w.ones(shape, dtype_w), self.x.zeros(shape, dtype_x)
[docs]
def times(self, a: tuple[T, S], b: tuple[T, S]) -> tuple[T, S]:
w_a, x_a = a
w_b, x_b = b
w = self.w.times(w_a, w_b)
x = self.x.plus(
self.x.times(self.w_to_x(w_a), x_b),
self.x.times(self.w_to_x(w_b), x_a))
return w, x
[docs]
def plus(self, a: tuple[T, S], b: tuple[T, S]) -> tuple[T, S]:
w_a, x_a = a
w_b, x_b = b
w = self.w.plus(w_a, w_b)
x = self.x.plus(x_a, x_b)
return w, x
[docs]
def sum(self, a: tuple[T, S], axis: int) -> tuple[T, S]:
w, x = a
w = self.w.sum(w, axis)
x = self.x.sum(x, axis)
return w, x
# Expectation semiring with weight and weighted sum represented both using the
# Log semiring. Therefore only summation on non-negative value is allowed.
LogLogExpectation = Expectation(w=Log, x=Log, w_to_x=lambda x: x)
[docs]
@dataclasses.dataclass(frozen=True)
class Cartesian(Generic[T, S], Semiring[tuple[T, S]]):
"""Cartesian product of 2 semirings.
Attributes:
x: The first semiring.
y: The second semiring.
"""
x: Semiring[T]
y: Semiring[S]
[docs]
def zeros(
self, shape: Sequence[int], dtype: Optional[DType] = None
) -> tuple[T, S]:
if dtype is None:
dtype_x = dtype_y = None
else:
dtype_x, dtype_y = dtype
return self.x.zeros(shape, dtype_x), self.y.zeros(shape, dtype_y)
[docs]
def ones(
self, shape: Sequence[int], dtype: Optional[DType] = None
) -> tuple[T, S]:
if dtype is None:
dtype_x = dtype_y = None
else:
dtype_x, dtype_y = dtype
return self.x.ones(shape, dtype_x), self.y.ones(shape, dtype_y)
[docs]
def times(self, a: tuple[T, S], b: tuple[T, S]) -> tuple[T, S]:
a_x, a_y = a
b_x, b_y = b
return self.x.times(a_x, b_x), self.y.times(a_y, b_y)
[docs]
def plus(self, a: tuple[T, S], b: tuple[T, S]) -> tuple[T, S]:
a_x, a_y = a
b_x, b_y = b
return self.x.plus(a_x, b_x), self.y.plus(a_y, b_y)
[docs]
def sum(self, a: tuple[T, S], axis: int) -> tuple[T, S]:
a_x, a_y = a
return self.x.sum(a_x, axis), self.y.sum(a_y, axis)
[docs]
def prod(self, a: tuple[T, S], axis: int) -> tuple[T, S]:
a_x, a_y = a
return self.x.prod(a_x, axis), self.y.prod(a_y, axis)