Source code for matchclot.embedding.catalyst_tools

import torch
from catalyst import runners, metrics, dl

from matchclot.embedding.models import symmetric_npair_loss


[docs]class scRNARunner(runners.Runner):
[docs] def handle_batch(self, batch): # model train/valid step # unpack the batch features_first = batch["features_first"].float() features_second = batch["features_second"].float() # run model forward pass logits, embeddings_first, embeddings_second = self.model( features_first, features_second ) targets = torch.arange(logits.shape[0]).to(logits.device) # compute the loss loss = symmetric_npair_loss(logits, targets) # log metrics batch_temperature = self.model.logit_scale.exp().item() self.batch_metrics.update({"loss": loss}) self.batch_metrics.update({"T": batch_temperature}) self.batch = { "features_first": features_first, "features_second": features_second, "embeddings_first": embeddings_first, "embeddings_second": embeddings_second, "scores": logits, "targets": targets, "temperature": batch_temperature, } self.input = { "features_first": features_first, "features_second": features_second, } self.output = { "scores": logits, "embeddings_first": embeddings_first, "embeddings_second": embeddings_second, } # run model backward pass if self.is_train_loader: self.engine.backward(loss) self.optimizer.step() self.optimizer.zero_grad()
[docs] def get_loggers(self): return { "console": dl.ConsoleLogger(), }
[docs]class CustomMetric(metrics.ICallbackLoaderMetric): """Top1, Top5 accuracy metrics and competition score, without applying the matching algorithm.""" def __init__( self, compute_on_call: bool = True, prefix: str = None, suffix: str = None ): super().__init__(compute_on_call=compute_on_call) self.prefix = prefix or "" self.suffix = suffix or "" self.embeddings_list_first = [] self.embeddings_list_second = [] self.batch_size = 256 # For batched computation of metrics self.extended_statistics = False
[docs] def reset(self, num_batches: int, num_samples: int) -> None: self.embeddings_list_first = [] self.embeddings_list_second = [] torch.cuda.empty_cache()
[docs] def update(self, *args, **kwargs) -> None: embeddings_first = kwargs["embeddings_first"] embeddings_second = kwargs["embeddings_second"] temperature = kwargs["temperature"] self.embeddings_list_first.append(temperature * embeddings_first) self.embeddings_list_second.append(embeddings_second)
[docs] def compute(self): raise NotImplementedError("This method is not supported")
[docs] def compute_key_value(self): all_embeddings_first = torch.cat(self.embeddings_list_first).detach().cpu() all_embeddings_second = torch.cat(self.embeddings_list_second).detach().cpu() N = all_embeddings_first.shape[0] # print("Calculating metrics") embeddings_first = all_embeddings_first embeddings_second = all_embeddings_second logits = embeddings_first @ embeddings_second.T del embeddings_first del embeddings_second labels = torch.arange(logits.shape[0]) forward_accuracy = 0 for i in range(0, N, self.batch_size): curr_batch_size = min(self.batch_size, N - i) logits_batch = logits[i : i + curr_batch_size, :] # row batch forward_accuracy += ( torch.argmax(logits_batch, dim=1) + i == labels[i : i + curr_batch_size] ).float().mean().item() / curr_batch_size del logits_batch backward_accuracy = 0 for i in range(0, N, self.batch_size): curr_batch_size = min(self.batch_size, N - i) logits_batch = logits[:, i : i + curr_batch_size] # column batch backward_accuracy += ( torch.argmax(logits_batch, dim=0) + i == labels[i : i + curr_batch_size] ).float().mean().item() / curr_batch_size del logits_batch avg_accuracy = 0.5 * (forward_accuracy + backward_accuracy) if self.extended_statistics: top1_competition_metric = 0 for i in range(0, N, self.batch_size): curr_batch_size = min(self.batch_size, N - i) logits_batch = logits[i : i + curr_batch_size, :] # row batch logits_row_sums = logits_batch.clip(min=0).sum(dim=1) top1_competition_metric += ( logits_batch.clip(min=0) .diagonal(offset=i) .div(logits_row_sums) .mean() .item() / curr_batch_size ) _, top_indexes_forward = logits.topk(5, dim=1) _, top_indexes_backward = logits.topk(5, dim=0) l_forward = labels.expand(5, logits.shape[0]).T del logits l_backward = l_forward.T top5_forward_accuracy = ( torch.any(top_indexes_forward == l_forward, 1).float().mean().item() ) top5_backward_accuracy = ( torch.any(top_indexes_backward == l_backward, 0).float().mean().item() ) top5_avg_accuracy = 0.5 * (top5_forward_accuracy + top5_backward_accuracy) loader_metrics = { "forward_acc": forward_accuracy, "backward_acc": backward_accuracy, "avg_acc": avg_accuracy, } if self.extended_statistics: loader_metrics.update( { "top1_competition_metric": top1_competition_metric, "top5_forward_acc": top5_forward_accuracy, "top5_backward_acc": top5_backward_accuracy, "top5_avg_acc": top5_avg_accuracy, } ) return loader_metrics