Source code for flex.gp.jax_primitives

from .primitives import PrimitiveParams
import jax.numpy as jnp
import dctkit as dt


[docs] def protectedDiv(left, right): try: return jnp.divide(left, right) except ZeroDivisionError: return jnp.nan
[docs] def protectedLog(x): try: return jnp.log(x) except ValueError: return jnp.nan
[docs] def protectedSqrt(x): try: return jnp.sqrt(x) except ValueError: return jnp.nan
[docs] def inv_float(x): return protectedDiv(1.0, x)
[docs] def square_mod(x): return jnp.square(x).astype(dt.float_dtype)
jax_primitives = { # scalar operations (JAX backend) "AddF": PrimitiveParams(jnp.add, [float, float], float), "SubF": PrimitiveParams(jnp.subtract, [float, float], float), "MulF": PrimitiveParams(jnp.multiply, [float, float], float), "Div": PrimitiveParams(protectedDiv, [float, float], float), "SinF": PrimitiveParams(jnp.sin, [float], float), "ArcsinF": PrimitiveParams(jnp.arcsin, [float], float), "CosF": PrimitiveParams(jnp.cos, [float], float), "ArccosF": PrimitiveParams(jnp.arccos, [float], float), "ExpF": PrimitiveParams(jnp.exp, [float], float), "LogF": PrimitiveParams(protectedLog, [float], float), "SqrtF": PrimitiveParams(protectedSqrt, [float], float), "SquareF": PrimitiveParams(jnp.square, [float], float), "InvF": PrimitiveParams(inv_float, [float], float), }