wassersteinwormhole.Wormhole module¶
- class wassersteinwormhole.Wormhole.Wormhole(point_clouds, weights=None, point_clouds_test=None, weights_test=None, config=<class 'wassersteinwormhole.DefaultConfig.DefaultConfig'>)¶
Bases:
objectInitializes Wormhole model and processes input point clouds
- Parameters:
point_clouds – (list of np.array) list of train-set point clouds to train Wormhole on
weights – (list of np.array) list of per point weight for each train-set point cloud (default None, indicating uniform weights)
point_clouds_test – (list of np.array) list of test-set point clouds (default None)
weights_test – (list of np.array) list of per point weight for each test-set point cloud (default None, indicating uniform weights)
config – (flax struct.dataclass) object with parameters for Wormhole such as OT metric choice, emedding dimention, etc. See docs for ‘DefaultConfig.py’ and tutorial details.
- Returns:
initialized Wormhole model
- decode(enc, max_batch=256)¶
Decode embedding back into point clouds using Wormhole decoder
- Parameters:
enc – (np.array) embeddings to decode
max_batch – (int) maximum size of batch during inference calls to Wormhole (default 256)
- Return dec:
decoded point clouds from embeddings
- encode(pc, weights, max_batch=256)¶
Encode point clouds with trained Wormhole model
- Parameters:
pc – (np.array) array of point clouds to encode
weights – (np.array) point weigts for input point clouds. Wormhole calculates padding for train and test-set point clouds.
max_batch – (int) maximum size of batch during inference calls to Wormhole (default 256)
- Return enc:
per point cloud embeddings
- train(training_steps=10000, batch_size=16, verbose=8, init_lr=0.0001, decay_steps=2000, key=Array((), dtype=key<fry>) overlaying: [0 0])¶
Set up optimization parameters and train the ENVI moodel
- Parameters:
training_steps – (int) number of gradient descent steps to train ENVI (default 10000)
batch_size – (int) size of train-set point clouds sampled for each training step (default 16)
verbose – (int) amount of steps between each loss print statement (default 8)
init_lr – (float) initial learning rate for ADAM optimizer with exponential decay (default 1e-4)
decay_steps – (int) number of steps before each learning rate decay (default 2000)
key – (jax.random.key) random seed (default jax.random.key(0))
- Returns:
nothing
wassersteinwormhole.DefaultConfig module¶
- class wassersteinwormhole.DefaultConfig.DefaultConfig(dtype: ~typing.Any = <class 'jax.numpy.float32'>, dist_func_enc: str = 'S2', dist_func_dec: str = 'S2', eps_enc: float = 0.1, eps_dec: float = 0.01, lse_enc: bool = False, lse_dec: bool = True, coeff_dec: float = 1, scale: str = 'min_max_total', factor: float = 1.0, emb_dim: int = 128, num_heads: int = 4, num_layers: int = 3, qkv_dim: int = 512, mlp_dim: int = 512, attention_dropout_rate: float = 0.1, kernel_init: ~typing.Callable = <function variance_scaling.<locals>.init>, bias_init: ~typing.Callable = <function zeros>)¶
Bases:
objectObject with configuration parameters for Wormhole
- Parameters:
dtype – (data type) float point precision for Wormhole model (default jnp.float32)
dist_func_enc – (str) OT metric used for embedding space (default ‘S2’, could be ‘W1’, ‘S1’, ‘W2’, ‘S2’, ‘GW’ and ‘GS’)
dist_func_dec – (str) OT metric used for Wormhole decoder loss (default ‘S2’, could be ‘W1’, ‘S1’, ‘W2’, ‘S2’, ‘GW’ and ‘GS’)
eps_enc – (float) entropic regularization for embedding OT (default 0.1)
eps_dec – (float) entropic regularization for Wormhole decoder loss (default 0.1)
lse_enc – (bool) whether to use log-sum-exp mode or kernel mode for embedding OT (default False)
lse_dec – (bool) whether to use log-sum-exp mode or kernel mode for decoder OT (default True)
coeff_dec – (float) coefficient for decoder loss (default 1)
scale – (str) how to scale input point clouds (‘min_max_total’ and scales all point clouds so values are between -1 and 1)
factor – (float) multiplicative factor applied on point cloud coordinates after scaling (default 1)
emb_dim – (int) Wormhole embedding dimention (defulat 128)
num_heads – (int) number of heads in multi-head attention (default 4)
num_layers – (int) number of layers of multi-head attention for Wormhole encoder and decoder (default 3)
qkv_dim – (int) dimention of query, key and value attributes in attention (default 512)
mlp_dim – (int) dimention of hidden layer for fully-connected network after every multi-head attention layer
attention_dropout_rate – (float) dropout rate for attention matrices during training (default 0.1)
kernel_init – (Callable) initializer of kernel weights (default nn.initializers.glorot_uniform())
bias_init – ((Callable) initializer of bias weights (default nn.initializers.zeros_init())
- attention_dropout_rate: float = 0.1¶
- bias_init(shape: ~collections.abc.Sequence[int | ~typing.Any], dtype: ~typing.Any = <class 'jax.numpy.float64'>) Array¶
An initializer that returns a constant array full of zeros.
The
keyargument is ignored.>>> import jax, jax.numpy as jnp >>> jax.nn.initializers.zeros(jax.random.key(42), (2, 3), jnp.float32) Array([[0., 0., 0.], [0., 0., 0.]], dtype=float32)
- coeff_dec: float = 1¶
- dist_func_dec: str = 'S2'¶
- dist_func_enc: str = 'S2'¶
- dtype¶
alias of
float32
- emb_dim: int = 128¶
- eps_dec: float = 0.01¶
- eps_enc: float = 0.1¶
- factor: float = 1.0¶
- kernel_init(shape: ~collections.abc.Sequence[int | ~typing.Any], dtype: ~typing.Any = <class 'jax.numpy.float64'>) Array¶
- lse_dec: bool = True¶
- lse_enc: bool = False¶
- mlp_dim: int = 512¶
- num_heads: int = 4¶
- num_layers: int = 3¶
- qkv_dim: int = 512¶
- replace(**updates)¶
“Returns a new object replacing the specified fields with new values.
- scale: str = 'min_max_total'¶
wassersteinwormhole.utils_OT module¶
- wassersteinwormhole.utils_OT.GS(x, y, eps, lse_mode=False)¶
Calculate Gromov-Wasserstein based Sinkhorn Divergence (GS) distance between two weighted point clouds
- Parameters:
x – (list) list with two elements, the first (x[0]) being the point-cloud coordinates and the second (x[1]) being each points weight)
y – (list) list with two elements, the first (y[0]) being the point-cloud coordinates and the second (y[1]) being each points weight)
eps – (float) coefficient of entropic regularization
lse_mode – (bool) whether to use log-sum-exp mode (if True, more stable for smaller eps, but slower) or kernel mode (default False)
- Return GS:
Gromov-Wasserstein Sinkhorn Divergence between x and y
- wassersteinwormhole.utils_OT.GS_scale(x, y, eps, lse_mode=False)¶
Calculate Gromov-Wasserstein based Sinkhorn Divergence (GS) distance between two weighted point clouds
- Parameters:
x – (list) list with two elements, the first (x[0]) being the point-cloud coordinates and the second (x[1]) being each points weight)
y – (list) list with two elements, the first (y[0]) being the point-cloud coordinates and the second (y[1]) being each points weight)
eps – (float) coefficient of entropic regularization
lse_mode – (bool) whether to use log-sum-exp mode (if True, more stable for smaller eps, but slower) or kernel mode (default False)
- Return GS:
Gromov-Wasserstein Sinkhorn Divergence between x and y
- wassersteinwormhole.utils_OT.GW(x, y, eps, lse_mode=False)¶
Calculate Gromov-Wasserstein distance between two weighted point clouds
- Parameters:
x – (list) list with two elements, the first (x[0]) being the point-cloud coordinates and the second (x[1]) being each points weight)
y – (list) list with two elements, the first (y[0]) being the point-cloud coordinates and the second (y[1]) being each points weight)
eps – (float) coefficient of entropic regularization
lse_mode – (bool) whether to use log-sum-exp mode (if True, more stable for smaller eps, but slower) or kernel mode (default False)
- Return GW:
GW distance between x and y
- wassersteinwormhole.utils_OT.S1(x, y, eps, lse_mode=False)¶
Calculate EMD Sinkhorn divergence between two weighted point clouds
- Parameters:
x – (list) list with two elements, the first (x[0]) being the point-cloud coordinates and the second (x[1]) being each points weight)
y – (list) list with two elements, the first (y[0]) being the point-cloud coordinates and the second (y[1]) being each points weight)
eps – (float) coefficient of entropic regularization
lse_mode – (bool) whether to use log-sum-exp mode (if True, more stable for smaller eps, but slower) or kernel mode (default False)
- Return S1:
L1 Sinkhorn divergence between x and y
- wassersteinwormhole.utils_OT.S2(x, y, eps, lse_mode=False)¶
Calculate Sinkhorn Divergnece (S2) between two weighted point clouds
- Parameters:
x – (list) list with two elements, the first (x[0]) being the point-cloud coordinates and the second (x[1]) being each points weight)
y – (list) list with two elements, the first (y[0]) being the point-cloud coordinates and the second (y[1]) being each points weight)
eps – (float) coefficient of entropic regularization
lse_mode – (bool) whether to use log-sum-exp mode (if True, more stable for smaller eps, but slower) or kernel mode (default False)
- Return S2:
Sinkhorn Divergnece between x and y
- wassersteinwormhole.utils_OT.W1(x, y, eps, lse_mode=False)¶
Calculate W1 (EMD) between two weighted point clouds
- Parameters:
x – (list) list with two elements, the first (x[0]) being the point-cloud coordinates and the second (x[1]) being each points weight)
y – (list) list with two elements, the first (y[0]) being the point-cloud coordinates and the second (y[1]) being each points weight)
eps – (float) coefficient of entropic regularization
lse_mode – (bool) whether to use log-sum-exp mode (if True, more stable for smaller eps, but slower) or kernel mode (default False)
- Return W1:
EMD distance between x and y
- wassersteinwormhole.utils_OT.W2(x, y, eps, lse_mode=False)¶
Calculate W2 between two weighted point clouds
- Parameters:
x – (list) list with two elements, the first (x[0]) being the point-cloud coordinates and the second (x[1]) being each points weight)
y – (list) list with two elements, the first (y[0]) being the point-cloud coordinates and the second (y[1]) being each points weight)
eps – (float) coefficient of entropic regularization
lse_mode – (bool) whether to use log-sum-exp mode (if True, more stable for smaller eps, but slower) or kernel mode (default False)
- Return W2:
Wasserstien distance between x and y
- wassersteinwormhole.utils_OT.Zeros(x, y, eps, lse_mode=False)¶
Automatically returns 0, used when Wormhole is trained to only embed and to avoid computational overhead of the decoder
- Parameters:
x – (list) list with two elements, the first (x[0]) being the point-cloud coordinates and the second (x[1]) being each points weight)
y – (list) list with two elements, the first (y[0]) being the point-cloud coordinates and the second (y[1]) being each points weight)
eps – (float) coefficient of entropic regularization
lse_mode – (bool) whether to use log-sum-exp mode (if True, more stable for smaller eps, but slower) or kernel mode (default False)
- Return zeros:
array of zeros