import torch
import torch.nn.functional as F
from torch import nn
[docs]class BatchSwapNoise(nn.Module):
"""Swap Noise module"""
def __init__(self, p):
super().__init__()
self.p = p
[docs] def forward(self, x):
if self.training:
mask = torch.rand(x.size()) > (1 - self.p)
idx = torch.add(
torch.arange(x.nelement()),
(
torch.floor(torch.rand(x.size()) * x.size(0)).type(torch.LongTensor)
* (mask.type(torch.LongTensor) * x.size(1))
).view(-1),
)
idx[idx >= x.nelement()] = idx[idx >= x.nelement()] - x.nelement()
return x.view(-1)[idx].view(x.size())
else:
return x
[docs]class Encoder(nn.Module):
"""Single modality encoder MLP with dropout and stochastic feature augmentation (SFA)
https://openaccess.thecvf.com/content/ICCV2021/papers
/Li_A_Simple_Feature_Augmentation_for_Domain_Generalization_ICCV_2021_paper.pdf"""
def __init__(
self,
n_input,
embedding_size,
dropout_rates,
dims_layers,
swap_noise_ratio,
noise_amount=0.0,
):
super(Encoder, self).__init__()
dropout = []
layers = [nn.Linear(n_input, dims_layers[0])]
for i in range(len(dims_layers) - 1):
layers.append(nn.Linear(dims_layers[i], dims_layers[i + 1]))
for i in range(len(dropout_rates)):
dropout.append(nn.Dropout(p=dropout_rates[i]))
layers.append(nn.Linear(dims_layers[-1], embedding_size))
self.fc_list = nn.ModuleList(layers)
print("dropout list", dropout)
self.dropout_list = nn.ModuleList(dropout)
self.noise_amount = noise_amount
print("SFA with noise:", noise_amount)
[docs] def forward(self, x):
for i in range(len(self.fc_list) - 1):
if i > 0 and self.training and self.noise_amount > 0:
x = torch.mul(
x,
torch.ones_like(x)
+ self.noise_amount * torch.randn_like(x, device=x.device),
)
x += self.noise_amount * torch.randn_like(x, device=x.device)
x = F.elu(self.fc_list[i](x))
if i < len(self.dropout_list):
x = self.dropout_list[i](x)
x = self.fc_list[-1](x)
return x
[docs]class Modality_CLIP(nn.Module):
"""CLIP-inspired architecture"""
def __init__(
self,
Encoder,
layers_dims,
dropout_rates,
dim_mod1,
dim_mod2,
output_dim,
T,
swap_rate_1=0.0,
swap_rate_2=0.0,
noise_amount=0.0,
):
super(Modality_CLIP, self).__init__()
self.encoder_modality1 = Encoder(
dim_mod1,
output_dim,
dropout_rates[0],
layers_dims[0],
swap_rate_1,
noise_amount=noise_amount,
)
self.encoder_modality2 = Encoder(
dim_mod2,
output_dim,
dropout_rates[1],
layers_dims[1],
swap_rate_2,
noise_amount=noise_amount,
)
self.logit_scale = nn.Parameter(torch.ones([]) * T)
[docs] def forward(self, features_first, features_second):
features_mod1 = self.encoder_modality1(features_first)
features_mod2 = self.encoder_modality2(features_second)
features_mod1 = features_mod1 / torch.norm(
features_mod1, p=2, dim=-1, keepdim=True
)
features_mod2 = features_mod2 / torch.norm(
features_mod2, p=2, dim=-1, keepdim=True
)
logit_scale = self.logit_scale.exp()
logits = logit_scale * features_mod1 @ features_mod2.T
return logits, features_mod1, features_mod2
[docs]def symmetric_npair_loss(logits, targets):
"""CLIP loss"""
loss = 0.5 * (F.cross_entropy(logits, targets) + F.cross_entropy(logits.T, targets))
return loss