In [None]:
import jax
import jax.numpy as jnp
import numpy as np
from jax import random

jax.config.update("jax_debug_nans", True)

In [None]:
t = jnp.array([1., 2., 3.])


def melu(x):
    return jnp.sum(jnp.exp(x))


g_melu = jax.grad(melu)
g_melu = jax.jit(g_melu)
g_melu(t)  # Warm up

a_np = np.array([1., 2., 3.])
a_jnp = jnp.array(a_np)
a_np[2:] = 10.

np.testing.assert_allclose(a_np, a_jnp.at[2:].set(10.))

print(jax.grad(jax.grad(melu))(3.))
print(jax.grad(jax.grad(jax.grad(melu)))(3.))

In [None]:
# TODO: demo jvp, vjp

In [None]:
# TODO: demo vmap

In [None]:
def dense(params, inputs):
    w, b = params["w_mat"], params["b"]
    return jnp.dot(inputs, w) + b


def explu(inputs):
    return jnp.where(inputs < 0., inputs, jnp.exp(inputs))


in_dim = int(6)
out_dim = int(6)

key = random.PRNGKey(42)
key, *sk = random.split(key, 4)

nn_params = {
    "w_mat": random.normal(sk.pop(), shape=(in_dim, out_dim)),
    "b": random.normal(sk.pop(), shape=(out_dim, ))
}
nn_features = 2. * random.normal(sk.pop(), shape=(in_dim, ))

target = nn_features**2

print(explu(dense(nn_params, nn_features)))

signal_response = lambda nn_params: explu(dense(nn_params, nn_features))
loss = lambda nn_params: jnp.mean((target - signal_response(nn_params))**2)

print(loss(nn_params))
nn_grad = jax.grad(loss)
print(nn_grad(nn_params))


def update(params, learning_rate):
    grad_step = jax.grad(loss)(params)
    return {k: params[k] - learning_rate * grad_step[k] for k in params.keys()}


learning_rate = 1e-4
for _ in range(1000):
    nn_params = update(nn_params, learning_rate)

In [None]:
from matplotlib import pyplot as plt

plt.scatter(nn_features, target, label="Truth", alpha=0.7)
plt.scatter(
    nn_features, signal_response(nn_params), label="Reconstruction", alpha=0.7
)
plt.legend()
plt.show()

In [None]:
update = jax.jit(update)

key, *sk = random.split(key, 3)
nn_params = {
    "w_mat": random.normal(sk.pop(), shape=(in_dim, out_dim)),
    "b": random.normal(sk.pop(), shape=(out_dim, ))
}

for _ in range(1000):
    nn_params = update(nn_params, learning_rate)

plt.scatter(nn_features, target, label="Truth", alpha=0.7)
plt.scatter(
    nn_features, signal_response(nn_params), label="Reconstruction", alpha=0.7
)
plt.legend()
plt.show()

In [None]:
# Let's switch to stax

In [None]:
from jax.example_libraries import optimizers, stax

print(stax.Dense(1024))

In [None]:
init_params, predict = stax.serial(
    stax.Dense(1024), stax.Relu, stax.Dense(1024), stax.Relu, stax.Dense(10),
    stax.LogSoftmax
)
key, sk = random.split(key, 2)
_, params = init_params(sk, (-1, 28 * 28))


def loss(params, batch):
    inputs, targets = batch
    p = predict(params, inputs)
    return -jnp.mean(jnp.sum(p * targets, axis=1))


def accuracy(params, batch):
    inputs, targets = batch
    target_class = jnp.argmax(targets, axis=1)
    predicted_class = jnp.argmax(predict(params, inputs), axis=1)
    return jnp.mean(predicted_class == target_class)

In [None]:
!curl -LO https://raw.githubusercontent.com/google/jax/main/examples/datasets.py
from datasets import mnist

train_images, train_labels, test_images, test_labels = mnist()
plt.imshow(train_images[0].reshape(28, 28))

In [None]:
step_size = 0.001
num_epochs = 10
batch_size = 128
momentum_mass = 0.9

num_train = train_images.shape[0]
num_complete_batches = num_train // batch_size
num_batches = num_complete_batches + bool(num_train % batch_size)


def data_stream(key):
    while True:
        key, sk = random.split(key, 2)
        perm = random.permutation(sk, num_train)
        for i in range(num_batches):
            batch_idx = perm[i * batch_size:(i + 1) * batch_size]
            yield train_images[batch_idx], train_labels[batch_idx]


key, sk = random.split(key, 2)
batches = data_stream(sk)

opt_init, opt_update, get_params = optimizers.momentum(
    step_size, mass=momentum_mass
)


@jax.jit
def update(i, opt_state, batch):
    params = get_params(opt_state)
    return opt_update(i, jax.grad(loss)(params, batch), opt_state)


opt_state = opt_init(params)

import itertools

In [None]:
import time

itercount = itertools.count()

print("\nStarting training...")
for epoch in range(num_epochs):
    start_time = time.time()
    for _ in range(num_batches):
        opt_state = update(next(itercount), opt_state, next(batches))
    epoch_time = time.time() - start_time

    params = get_params(opt_state)
    train_acc = accuracy(params, (train_images, train_labels))
    test_acc = accuracy(params, (test_images, test_labels))
    print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
    print("Training set accuracy {}".format(train_acc))
    print("Test set accuracy {}".format(test_acc))