wassersteinwormhole.Wormhole module

class wassersteinwormhole.Wormhole.Wormhole(point_clouds, weights=None, point_clouds_test=None, weights_test=None, config=<class 'wassersteinwormhole.DefaultConfig.DefaultConfig'>)

Bases: object

Initializes Wormhole model and processes input point clouds

Parameters:
  • point_clouds – (list of np.array) list of train-set point clouds to train Wormhole on

  • weights – (list of np.array) list of per point weight for each train-set point cloud (default None, indicating uniform weights)

  • point_clouds_test – (list of np.array) list of test-set point clouds (default None)

  • weights_test – (list of np.array) list of per point weight for each test-set point cloud (default None, indicating uniform weights)

  • config – (flax struct.dataclass) object with parameters for Wormhole such as OT metric choice, emedding dimention, etc. See docs for ‘DefaultConfig.py’ and tutorial details.

Returns:

initialized Wormhole model

decode(enc, max_batch=256)

Decode embedding back into point clouds using Wormhole decoder

Parameters:
  • enc – (np.array) embeddings to decode

  • max_batch – (int) maximum size of batch during inference calls to Wormhole (default 256)

Return dec:

decoded point clouds from embeddings

encode(pc, weights, max_batch=256)

Encode point clouds with trained Wormhole model

Parameters:
  • pc – (np.array) array of point clouds to encode

  • weights – (np.array) point weigts for input point clouds. Wormhole calculates padding for train and test-set point clouds.

  • max_batch – (int) maximum size of batch during inference calls to Wormhole (default 256)

Return enc:

per point cloud embeddings

train(training_steps=10000, batch_size=16, verbose=8, init_lr=0.0001, decay_steps=2000, key=Array((), dtype=key<fry>) overlaying: [0 0])

Set up optimization parameters and train the ENVI moodel

Parameters:
  • training_steps – (int) number of gradient descent steps to train ENVI (default 10000)

  • batch_size – (int) size of train-set point clouds sampled for each training step (default 16)

  • verbose – (int) amount of steps between each loss print statement (default 8)

  • init_lr – (float) initial learning rate for ADAM optimizer with exponential decay (default 1e-4)

  • decay_steps – (int) number of steps before each learning rate decay (default 2000)

  • key – (jax.random.key) random seed (default jax.random.key(0))

Returns:

nothing

wassersteinwormhole.DefaultConfig module

class wassersteinwormhole.DefaultConfig.DefaultConfig(dtype: ~typing.Any = <class 'jax.numpy.float32'>, dist_func_enc: str = 'S2', dist_func_dec: str = 'S2', eps_enc: float = 0.1, eps_dec: float = 0.01, lse_enc: bool = False, lse_dec: bool = True, coeff_dec: float = 1, scale: str = 'min_max_total', factor: float = 1.0, emb_dim: int = 128, num_heads: int = 4, num_layers: int = 3, qkv_dim: int = 512, mlp_dim: int = 512, attention_dropout_rate: float = 0.1, kernel_init: ~typing.Callable = <function variance_scaling.<locals>.init>, bias_init: ~typing.Callable = <function zeros>)

Bases: object

Object with configuration parameters for Wormhole

Parameters:
  • dtype – (data type) float point precision for Wormhole model (default jnp.float32)

  • dist_func_enc – (str) OT metric used for embedding space (default ‘S2’, could be ‘W1’, ‘S1’, ‘W2’, ‘S2’, ‘GW’ and ‘GS’)

  • dist_func_dec – (str) OT metric used for Wormhole decoder loss (default ‘S2’, could be ‘W1’, ‘S1’, ‘W2’, ‘S2’, ‘GW’ and ‘GS’)

  • eps_enc – (float) entropic regularization for embedding OT (default 0.1)

  • eps_dec – (float) entropic regularization for Wormhole decoder loss (default 0.1)

  • lse_enc – (bool) whether to use log-sum-exp mode or kernel mode for embedding OT (default False)

  • lse_dec – (bool) whether to use log-sum-exp mode or kernel mode for decoder OT (default True)

  • coeff_dec – (float) coefficient for decoder loss (default 1)

  • scale – (str) how to scale input point clouds (‘min_max_total’ and scales all point clouds so values are between -1 and 1)

  • factor – (float) multiplicative factor applied on point cloud coordinates after scaling (default 1)

  • emb_dim – (int) Wormhole embedding dimention (defulat 128)

  • num_heads – (int) number of heads in multi-head attention (default 4)

  • num_layers – (int) number of layers of multi-head attention for Wormhole encoder and decoder (default 3)

  • qkv_dim – (int) dimention of query, key and value attributes in attention (default 512)

  • mlp_dim – (int) dimention of hidden layer for fully-connected network after every multi-head attention layer

  • attention_dropout_rate – (float) dropout rate for attention matrices during training (default 0.1)

  • kernel_init – (Callable) initializer of kernel weights (default nn.initializers.glorot_uniform())

  • bias_init – ((Callable) initializer of bias weights (default nn.initializers.zeros_init())

attention_dropout_rate: float = 0.1
bias_init(shape: ~collections.abc.Sequence[int | ~typing.Any], dtype: ~typing.Any = <class 'jax.numpy.float64'>) Array

An initializer that returns a constant array full of zeros.

The key argument is ignored.

>>> import jax, jax.numpy as jnp
>>> jax.nn.initializers.zeros(jax.random.key(42), (2, 3), jnp.float32)
Array([[0., 0., 0.],
       [0., 0., 0.]], dtype=float32)
coeff_dec: float = 1
dist_func_dec: str = 'S2'
dist_func_enc: str = 'S2'
dtype

alias of float32

emb_dim: int = 128
eps_dec: float = 0.01
eps_enc: float = 0.1
factor: float = 1.0
kernel_init(shape: ~collections.abc.Sequence[int | ~typing.Any], dtype: ~typing.Any = <class 'jax.numpy.float64'>) Array
lse_dec: bool = True
lse_enc: bool = False
mlp_dim: int = 512
num_heads: int = 4
num_layers: int = 3
qkv_dim: int = 512
replace(**updates)

“Returns a new object replacing the specified fields with new values.

scale: str = 'min_max_total'

wassersteinwormhole.utils_OT module

wassersteinwormhole.utils_OT.GS(x, y, eps, lse_mode=False)

Calculate Gromov-Wasserstein based Sinkhorn Divergence (GS) distance between two weighted point clouds

Parameters:
  • x – (list) list with two elements, the first (x[0]) being the point-cloud coordinates and the second (x[1]) being each points weight)

  • y – (list) list with two elements, the first (y[0]) being the point-cloud coordinates and the second (y[1]) being each points weight)

  • eps – (float) coefficient of entropic regularization

  • lse_mode – (bool) whether to use log-sum-exp mode (if True, more stable for smaller eps, but slower) or kernel mode (default False)

Return GS:

Gromov-Wasserstein Sinkhorn Divergence between x and y

wassersteinwormhole.utils_OT.GS_scale(x, y, eps, lse_mode=False)

Calculate Gromov-Wasserstein based Sinkhorn Divergence (GS) distance between two weighted point clouds

Parameters:
  • x – (list) list with two elements, the first (x[0]) being the point-cloud coordinates and the second (x[1]) being each points weight)

  • y – (list) list with two elements, the first (y[0]) being the point-cloud coordinates and the second (y[1]) being each points weight)

  • eps – (float) coefficient of entropic regularization

  • lse_mode – (bool) whether to use log-sum-exp mode (if True, more stable for smaller eps, but slower) or kernel mode (default False)

Return GS:

Gromov-Wasserstein Sinkhorn Divergence between x and y

wassersteinwormhole.utils_OT.GW(x, y, eps, lse_mode=False)

Calculate Gromov-Wasserstein distance between two weighted point clouds

Parameters:
  • x – (list) list with two elements, the first (x[0]) being the point-cloud coordinates and the second (x[1]) being each points weight)

  • y – (list) list with two elements, the first (y[0]) being the point-cloud coordinates and the second (y[1]) being each points weight)

  • eps – (float) coefficient of entropic regularization

  • lse_mode – (bool) whether to use log-sum-exp mode (if True, more stable for smaller eps, but slower) or kernel mode (default False)

Return GW:

GW distance between x and y

wassersteinwormhole.utils_OT.S1(x, y, eps, lse_mode=False)

Calculate EMD Sinkhorn divergence between two weighted point clouds

Parameters:
  • x – (list) list with two elements, the first (x[0]) being the point-cloud coordinates and the second (x[1]) being each points weight)

  • y – (list) list with two elements, the first (y[0]) being the point-cloud coordinates and the second (y[1]) being each points weight)

  • eps – (float) coefficient of entropic regularization

  • lse_mode – (bool) whether to use log-sum-exp mode (if True, more stable for smaller eps, but slower) or kernel mode (default False)

Return S1:

L1 Sinkhorn divergence between x and y

wassersteinwormhole.utils_OT.S2(x, y, eps, lse_mode=False)

Calculate Sinkhorn Divergnece (S2) between two weighted point clouds

Parameters:
  • x – (list) list with two elements, the first (x[0]) being the point-cloud coordinates and the second (x[1]) being each points weight)

  • y – (list) list with two elements, the first (y[0]) being the point-cloud coordinates and the second (y[1]) being each points weight)

  • eps – (float) coefficient of entropic regularization

  • lse_mode – (bool) whether to use log-sum-exp mode (if True, more stable for smaller eps, but slower) or kernel mode (default False)

Return S2:

Sinkhorn Divergnece between x and y

wassersteinwormhole.utils_OT.W1(x, y, eps, lse_mode=False)

Calculate W1 (EMD) between two weighted point clouds

Parameters:
  • x – (list) list with two elements, the first (x[0]) being the point-cloud coordinates and the second (x[1]) being each points weight)

  • y – (list) list with two elements, the first (y[0]) being the point-cloud coordinates and the second (y[1]) being each points weight)

  • eps – (float) coefficient of entropic regularization

  • lse_mode – (bool) whether to use log-sum-exp mode (if True, more stable for smaller eps, but slower) or kernel mode (default False)

Return W1:

EMD distance between x and y

wassersteinwormhole.utils_OT.W2(x, y, eps, lse_mode=False)

Calculate W2 between two weighted point clouds

Parameters:
  • x – (list) list with two elements, the first (x[0]) being the point-cloud coordinates and the second (x[1]) being each points weight)

  • y – (list) list with two elements, the first (y[0]) being the point-cloud coordinates and the second (y[1]) being each points weight)

  • eps – (float) coefficient of entropic regularization

  • lse_mode – (bool) whether to use log-sum-exp mode (if True, more stable for smaller eps, but slower) or kernel mode (default False)

Return W2:

Wasserstien distance between x and y

wassersteinwormhole.utils_OT.Zeros(x, y, eps, lse_mode=False)

Automatically returns 0, used when Wormhole is trained to only embed and to avoid computational overhead of the decoder

Parameters:
  • x – (list) list with two elements, the first (x[0]) being the point-cloud coordinates and the second (x[1]) being each points weight)

  • y – (list) list with two elements, the first (y[0]) being the point-cloud coordinates and the second (y[1]) being each points weight)

  • eps – (float) coefficient of entropic regularization

  • lse_mode – (bool) whether to use log-sum-exp mode (if True, more stable for smaller eps, but slower) or kernel mode (default False)

Return zeros:

array of zeros