Tutorial for Wasserstein Wormhole

[1]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]="0"
[2]:
import umap
import numpy as np
import pandas as pd
import scipy.stats

import matplotlib
import matplotlib.pyplot as plt
import colorcet
import seaborn as sns


import sklearn.metrics
import sklearn.neural_network
import sklearn.neighbors
import skimage.filters
/home/havivd/miniconda3/envs/WormholeUpdate/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm
2024-04-25 16:24:37.098371: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-04-25 16:24:37.148586: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-04-25 16:24:48.467709: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT
[3]:
from wassersteinwormhole import Wormhole
2024-04-25 16:25:16.095512: W external/xla/xla/service/gpu/nvptx_compiler.cc:718] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.4.131). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.

Load MNIST data

[4]:
import tensorflow.keras.datasets
mnist = tensorflow.keras.datasets.mnist

(image_train, label_train), (image_test, label_test) = mnist.load_data()



[5]:
pc_train, weight_train = [np.stack(np.where(image>0), axis = 1) for image in image_train], [image[image>0]/image[image>0].sum() for image in image_train]
pc_test, weight_test = [np.stack(np.where(image>0), axis = 1) for image in image_test], [image[image>0]/image[image>0].sum() for image in image_test]

Show some examples

[6]:
thresh = skimage.filters.threshold_otsu(image_train.reshape([-1]))
fig = plt.figure(figsize = (25,25))
for _ in range(16):
    i = np.random.choice(np.arange(image_train.shape[0]))

    plt.subplot(4,4,1+_)
    plt.scatter(pc_train[i][:, 1], -pc_train[i][:, 0], c = weight_train[i], s = 80)
    plt.axis('equal')
    plt.xticks([])
    plt.yticks([])
plt.show()
../_images/tutorial_Wormhole_Tutorial_8_0.png

Or load your own set of point-clouds

[7]:
import pickle

with open('/path/to/data/pc_train.pickle', 'rb') as f:
    pc_train = pickle.load(f)

with open('/path/to/data/weight_train.pickle', 'rb') as f:
    weight_train = pickle.load(f)

with open('/path/to/data/pc_test.pickle', 'rb') as f:
    pc_test = pickle.load(f)

with open('/path/to/data/weight_test.pickle', 'rb') as f:
    weight_test = pickle.load(f)

label_train = np.load('/path/to/data/label_train.npy')
label_test = np.load('/path/to/data/label_test.npy')

Wormhole Model

[8]:
WormholeModel = Wormhole(point_clouds = pc_train, weights = weight_train, point_clouds_test = pc_test, weights_test = weight_test)
Using Calculated Min Max Scaling Values

Train model

[9]:
WormholeModel.train()
 enc_loss: 2.434e-05 dec_loss: 6.003e-04 enc_corr: 0.986: 100%|██████████| 10000/10000 [06:48<00:00, 24.48it/s]

Save parameters

[10]:
with open('/path/to/savedmodel/WormholeModel.pickle', 'wb') as f:
    pickle.dump(WormholeModel.params, f)

Or Load trained model

[11]:
with open('/path/to/savedmodel/WormholeModel.pickle',, 'rb') as f:
    WormholeModel.params = pickle.load(f)

Generate encodings

[12]:
train_encodings = WormholeModel.encode(WormholeModel.point_clouds, WormholeModel.weights)
test_encodings = WormholeModel.encode(WormholeModel.point_clouds_test, WormholeModel.weights_test)

Compare to true Wasserstein distance

Train

[15]:
train_rand_ind = np.random.choice(np.arange(train_encodings.shape[0]), 128,replace = False)
upper_traingular_ind = np.stack(np.triu_indices(train_rand_ind.shape[0], 1), axis = 1)

true_w2 = np.asarray(WormholeModel.jit_dist_enc([WormholeModel.point_clouds[train_rand_ind][upper_traingular_ind[:, 0]],
                                                 WormholeModel.weights[train_rand_ind][upper_traingular_ind[:, 0]]],
                                                [WormholeModel.point_clouds[train_rand_ind][upper_traingular_ind[:, 1]],
                                                 WormholeModel.weights[train_rand_ind][upper_traingular_ind[:, 1]]],
                                                 WormholeModel.eps_enc, WormholeModel.lse_enc))

enc_L2 = np.mean(np.square(train_encodings[train_rand_ind][upper_traingular_ind[:, 0]] - train_encodings[train_rand_ind][upper_traingular_ind[:, 1]]), axis = -1)
[16]:
plt.figure(figsize = (5,5))
plt.scatter(true_w2, enc_L2, s = 0.1)
plt.axis('square')
plt.xlabel('True W2')
plt.ylabel('Transformer Encoder L2')
plt.title('Corr: ' + str(np.round(scipy.stats.pearsonr(true_w2, enc_L2)[0],3)))
plt.show()
../_images/tutorial_Wormhole_Tutorial_24_0.png

Test

[17]:
test_rand_ind = np.random.choice(np.arange(test_encodings.shape[0]), 128,replace = False)
upper_traingular_ind = np.stack(np.triu_indices(test_rand_ind.shape[0], 1), axis = 1)

true_w2 = np.asarray(WormholeModel.jit_dist_enc([WormholeModel.point_clouds_test[test_rand_ind][upper_traingular_ind[:, 0]],
                                                 WormholeModel.weights_test[test_rand_ind][upper_traingular_ind[:, 0]]],
                                                [WormholeModel.point_clouds_test[test_rand_ind][upper_traingular_ind[:, 1]],
                                                 WormholeModel.weights_test[test_rand_ind][upper_traingular_ind[:, 1]]],
                                                 WormholeModel.eps_enc, WormholeModel.lse_enc))

enc_L2 = np.mean(np.square(test_encodings[test_rand_ind][upper_traingular_ind[:, 0]] - test_encodings[test_rand_ind][upper_traingular_ind[:, 1]]), axis = -1)
[18]:
plt.figure(figsize = (5,5))
plt.scatter(true_w2, enc_L2, s = 0.1)
plt.axis('square')
plt.xlabel('True W2')
plt.ylabel('Transformer Encoder L2')
plt.title('Corr: ' + str(np.round(scipy.stats.pearsonr(true_w2, enc_L2)[0],3)))
plt.show()
../_images/tutorial_Wormhole_Tutorial_27_0.png

Decodings

[19]:
train_decodings = WormholeModel.decode(train_encodings)
test_decodings = WormholeModel.decode(test_encodings)

Plot train

[20]:
fig = plt.figure(figsize = (25,25))
for _ in range(25):
    ind = np.random.choice(np.arange(len(train_decodings)))

    plt.subplot(5,5,1+_)
    plt.scatter(train_decodings[ind][:, 1], -train_decodings[ind][:, 0], s = 80)
    plt.xticks([])
    plt.yticks([])
    plt.axis('equal')
plt.show()
../_images/tutorial_Wormhole_Tutorial_31_0.png

Plot test

[21]:
fig = plt.figure(figsize = (25,25))
for _ in range(25):
    ind = np.random.choice(np.arange(len(test_decodings)))

    plt.subplot(5,5,1+_)
    plt.scatter(test_decodings[ind][:, 1], -test_decodings[ind][:, 0], s = 80)
    plt.xticks([])
    plt.yticks([])
    plt.axis('equal')
plt.show()
../_images/tutorial_Wormhole_Tutorial_33_0.png

Barycenters

Calculate per-class encoding mean

[22]:
class_encoding_mean = np.asarray([np.mean(train_encodings[label_train == digit], axis = 0) for digit in list(set(label_train))])

Decode for Wasserstein barycenters

[23]:
class_barycenters = WormholeModel.decode(class_encoding_mean)
[24]:
fig = plt.figure(figsize=(25, 10))
for i in list(set(label_train)):

    plt.subplot(2,5,1+i)
    plt.scatter(class_barycenters[i][:, 1], -class_barycenters[i][:, 0])
    plt.axis('equal')
    plt.xticks([])
    plt.yticks([])
    plt.title(i)
plt.show()

../_images/tutorial_Wormhole_Tutorial_39_0.png

Point Cloud Interpolation

Chooce start and end encodings

[25]:
start_ind = np.random.randint(train_encodings.shape[0])
end_ind = np.random.randint(train_encodings.shape[0])

start_enc = train_encodings[start_ind]
end_enc = train_encodings[end_ind]

Linearly interpolate between them

[26]:
interpolated_encodings = np.linspace(start_enc, end_enc, 8)

Decode interpolations

[27]:
interpolated_decodings = WormholeModel.decode(interpolated_encodings)
[28]:
fig = plt.figure(figsize=(interpolated_decodings.shape[0]*5, 5))


for i in range(interpolated_decodings.shape[0]):

    plt.subplot(1,interpolated_decodings.shape[0],1+i)
    plt.scatter(interpolated_decodings[i][:, 1], -interpolated_decodings[i][:, 0])
    plt.axis('equal')
    plt.xticks([])
    plt.yticks([])
plt.show()
../_images/tutorial_Wormhole_Tutorial_47_0.png

Draw UMAPs

[29]:
fit = umap.UMAP(
    n_neighbors = 30,
    min_dist = 0.01,
    n_components = 2,
)

umap = fit.fit_transform(np.concatenate([train_encodings, test_encodings], axis = 0))
umap_train =  umap[:train_encodings.shape[0]]
umap_test =  umap[train_encodings.shape[0]:]
[30]:
cat_cmap = matplotlib.cm.get_cmap('cet_glasbey')
label_palette = {ctype:cat_cmap(i) for i,ctype in enumerate(list(set(label_train)))}
/scratch/lsftmp/5659470.tmpdir/ipykernel_63395/2627742441.py:1: MatplotlibDeprecationWarning: The get_cmap function was deprecated in Matplotlib 3.7 and will be removed two minor releases later. Use ``matplotlib.colormaps[name]`` or ``matplotlib.colormaps.get_cmap(obj)`` instead.
  cat_cmap = matplotlib.cm.get_cmap('cet_glasbey')
[31]:
plt.figure(figsize=(10,5))

plt.subplot(121)
sns.scatterplot(x = umap_train[:, 0],
                y = umap_train[:, 1], legend = True,
                hue = label_train,
                s = 10, palette = label_palette)
plt.axis('off')
plt.title("Train")

plt.subplot(122 )
sns.scatterplot(x = umap_test[:, 0],
                y = umap_test[:, 1], legend = False,
                hue = label_test,
                s = 10, palette = label_palette)
plt.axis('off')
plt.title("Test")
plt.show()
../_images/tutorial_Wormhole_Tutorial_51_0.png

Classify Digits

[32]:
classifier = sklearn.neural_network.MLPClassifier(max_iter=1000).fit(train_encodings, label_train)
test_predictions = classifier.predict(test_encodings)
[33]:
cm = pd.DataFrame(sklearn.metrics.confusion_matrix(label_test, test_predictions, labels=classifier.classes_, normalize = 'true'),
                  index = classifier.classes_, columns = classifier.classes_,)
sns.heatmap(cm, cmap = 'viridis')
plt.ylabel("True Labels")
plt.xlabel("Pred Labels")
plt.title("Accuracy: " + str(round(np.mean(label_test == test_predictions), 3)))
plt.show()
../_images/tutorial_Wormhole_Tutorial_54_0.png