Serket#
serket
aims to be the most intuitive and easy-to-use neural network library injax
.serket
is fully transparent tojax
transformation (e.g.vmap
,grad
,jit
,âĻ).
đ ī¸ Installation#
Install from github:
pip install git+https://github.com/ASEM000/serket
đ Quick example#
import jax, jax.numpy as jnp
import serket as sk
x_train, y_train = ..., ...
k1, k2 = jax.random.split(jax.random.PRNGKey(0))
net = sk.tree_mask(sk.Sequential(
jnp.ravel,
sk.nn.Linear(28 * 28, 64, key=k1),
jax.nn.relu,
sk.nn.Linear(64, 10, key=k2),
))
@ft.partial(jax.grad, has_aux=True)
def loss_func(net, x, y):
logits = jax.vmap(sk.tree_unmask(net))(x)
onehot = jax.nn.one_hot(y, 10)
loss = jnp.mean(softmax_cross_entropy(logits, onehot))
return loss, (loss, logits)
@jax.jit
def train_step(net, x, y):
grads, (loss, logits) = loss_func(net, x, y)
net = jax.tree_map(lambda p, g: p - g * 1e-3, net, grads)
return net, (loss, logits)
for j, (xb, yb) in enumerate(zip(x_train, y_train)):
net, (loss, logits) = train_step(net, xb, yb)
accuracy = accuracy_func(logits, y_train)
Apache2.0 License.