Source code for matchclot.run.run

import argparse
import os
import pickle
import sys

import anndata as ad
import numpy as np
import pandas as pd
import scipy.sparse
import torch
from distutils.util import strtobool

from catalyst.utils import set_global_seed

import matchclot
from matchclot.embedding.models import Modality_CLIP, Encoder
from matchclot.matching.matching import OT_matching, MWB_matching
from matchclot.preprocessing.preprocess import harmony
from matchclot.run.evaluate import evaluate
from matchclot.utils.dataloaders import ModalityMatchingDataset
from matchclot.utils.hyperparameters import (
    defaults_common,
    defaults_GEX2ADT,
    defaults_GEX2ATAC,
    baseline_GEX2ATAC,
    baseline_GEX2ADT,
)


# sys.path.append(".")
[docs]def main(argv=sys.argv[1:]): os.environ["CUDA_LAUNCH_BLOCKING"] = "1" # setting device on GPU if available, else CPU device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print("Using device:", device) # Define argument parsers parser = argparse.ArgumentParser() subparsers = parser.add_subparsers(dest="TASK") # Common args for key, value in defaults_common.items(): parser.add_argument( "--" + key, default=value, type=(lambda x: bool(strtobool(x))) if type(value) == bool else type(value), ) # GEX2ADT args parser_GEX2ADT = subparsers.add_parser("GEX2ADT", help="train GEX2ADT model") for key, value in defaults_GEX2ADT.items(): parser_GEX2ADT.add_argument("--" + key, default=value, type=type(value)) # GEX2ATAC args parser_GEX2ATAC = subparsers.add_parser("GEX2ATAC", help="train GEX2ATAC model") for key, value in defaults_GEX2ATAC.items(): parser_GEX2ATAC.add_argument("--" + key, default=value, type=type(value)) # Parse args args, unknown_args = parser.parse_known_args(argv) # Set global random seed set_global_seed(args.SEED) # Define file paths if args.TASK == "GEX2ADT": dataset_path = os.path.join( args.DATASETS_PATH, "openproblems_bmmc_cite_phase2_rna/" "openproblems_bmmc_cite_phase2_rna.censor_dataset.output_", ) pretrain_path = os.path.join(args.PRETRAIN_PATH, "GEX2ADT") is_multiome = False elif args.TASK == "GEX2ATAC": dataset_path = os.path.join( args.DATASETS_PATH, "openproblems_bmmc_multiome_phase2_rna/" "openproblems_bmmc_multiome_phase2_rna.censor_dataset.output_", ) pretrain_path = os.path.join(args.PRETRAIN_PATH, "GEX2ATAC") is_multiome = True else: raise ValueError("Unknown task: " + args.TASK) if args.CUSTOM_DATASET_PATH != "": dataset_path = args.CUSTOM_DATASET_PATH assert args.TRANSDUCTIVE is False par = { "input_train_mod1": f"{dataset_path}train_mod1.h5ad", "input_train_mod2": f"{dataset_path}train_mod2.h5ad", "input_test_mod1": f"{dataset_path}test_mod1.h5ad", "input_test_mod2": f"{dataset_path}test_mod2.h5ad", "input_pretrain": pretrain_path, "output": os.path.join(args.PRETRAIN_PATH, args.OUT_NAME + args.TASK), } # Overwrite configurations for ablation study if args.HYPERPARAMS is False: if is_multiome: for hyperparam, baseline_value in baseline_GEX2ATAC.items(): setattr(args, hyperparam, baseline_value) else: for hyperparam, baseline_value in baseline_GEX2ADT.items(): setattr(args, hyperparam, baseline_value) print("args:", args, "unknown_args:", unknown_args) # Load data if args.TRANSDUCTIVE: input_train_mod1 = ad.read_h5ad(par["input_train_mod1"]) input_train_mod2 = ad.read_h5ad(par["input_train_mod2"]) input_test_mod1 = ad.read_h5ad(par["input_test_mod1"]) input_test_mod2 = ad.read_h5ad(par["input_test_mod2"]) N_cells = input_test_mod1.shape[0] print( "mod1 cells:", input_test_mod1.shape[0], "mod2 cells:", input_test_mod2.shape[0] ) assert input_test_mod2.shape[0] == N_cells # Load and apply LSI transformation with open(par["input_pretrain"] + "/lsi_GEX_transformer.pickle", "rb") as f: try: lsi_transformer_gex = pickle.load(f) except ModuleNotFoundError: sys.modules["resources"] = matchclot if is_multiome: with open(par["input_pretrain"] + "/lsi_ATAC_transformer.pickle", "rb") as f: try: lsi_transformer_atac = pickle.load(f) except ModuleNotFoundError: sys.modules["resources"] = matchclot if args.TRANSDUCTIVE: gex_train = lsi_transformer_gex.transform(input_train_mod1) mod2_train = lsi_transformer_atac.transform(input_train_mod2) gex_test = lsi_transformer_gex.transform(input_test_mod1) mod2_test = lsi_transformer_atac.transform(input_test_mod2) else: if args.TRANSDUCTIVE: gex_train = lsi_transformer_gex.transform(input_train_mod1) mod2_train = input_train_mod2.to_df() gex_test = lsi_transformer_gex.transform(input_test_mod1) mod2_test = input_test_mod2.to_df() if args.HARMONY: # Apply Harmony batch effect correction gex_test["batch"] = input_test_mod1.obs.batch mod2_test["batch"] = input_test_mod2.obs.batch if args.TRANSDUCTIVE: # Transductive setting gex_train["batch"] = input_train_mod1.obs.batch ( gex_train, gex_test, ) = harmony([gex_train, gex_test]) mod2_train["batch"] = input_train_mod2.obs.batch ( mod2_train, mod2_test, ) = harmony([mod2_train, mod2_test]) else: (gex_test,) = harmony([gex_test]) (mod2_test,) = harmony([mod2_test]) # Load pretrained models and ensemble predictions sim_matrix = np.zeros((N_cells, N_cells)) for fold in range(0, 9): weight_file = par["input_pretrain"] + "/" + str(fold) + "/model.best.pth" if os.path.exists(weight_file): print("Loading weights from " + weight_file) weight = torch.load(weight_file, map_location="cpu") # Define modality encoders if is_multiome: model = Modality_CLIP( Encoder=Encoder, layers_dims=( [args.LAYERS_DIM_ATAC], [args.LAYERS_DIM_GEX0, args.LAYERS_DIM_GEX1], ), dropout_rates=( [args.DROPOUT_RATES_ATAC], [args.DROPOUT_RATES_GEX0, args.DROPOUT_RATES_GEX1], ), dim_mod1=args.N_LSI_COMPONENTS_ATAC, dim_mod2=args.N_LSI_COMPONENTS_GEX, output_dim=args.EMBEDDING_DIM, T=args.LOG_T, noise_amount=args.SFA_NOISE, ).to(device) else: model = Modality_CLIP( Encoder=Encoder, layers_dims=( [args.LAYERS_DIM_ADT0, args.LAYERS_DIM_ADT1], [args.LAYERS_DIM_GEX0, args.LAYERS_DIM_GEX1], ), dropout_rates=( [args.DROPOUT_RATES_ADT0, args.DROPOUT_RATES_ADT1], [args.DROPOUT_RATES_GEX0, args.DROPOUT_RATES_GEX1], ), dim_mod1=args.N_LSI_COMPONENTS_ADT, dim_mod2=args.N_LSI_COMPONENTS_GEX, output_dim=args.EMBEDDING_DIM, T=args.LOG_T, noise_amount=args.SFA_NOISE, ).to(device) # Load pretrained weights model.load_state_dict(weight) # Load torch datasets dataset_test = ModalityMatchingDataset( pd.DataFrame(gex_test), pd.DataFrame(mod2_test) ) data_test = torch.utils.data.DataLoader(dataset_test, 32, shuffle=False) # Predict on test set all_emb_mod1 = [] all_emb_mod2 = [] indexes = [] model.eval() for batch in data_test: x1 = batch["features_first"].float() x2 = batch["features_second"].float() # The model applies the GEX encoder to the second argument, here x1 logits, features_mod2, features_mod1 = model( x2.to(device), x1.to(device) ) all_emb_mod1.append(features_mod1.detach().cpu()) all_emb_mod2.append(features_mod2.detach().cpu()) all_emb_mod1 = torch.cat(all_emb_mod1) all_emb_mod2 = torch.cat(all_emb_mod2) # Save the embeddings concatenated according to the true order and predicted order if args.SAVE_EMBEDDINGS: # Assumes that the two modalities have the cells in the same order all_emb_mod12_true = torch.cat((all_emb_mod1, all_emb_mod2), dim=1) file = par["output"] + "emb_mod12_fold" + str(fold) + "_truematch.pt" torch.save(all_emb_mod12_true, file) print("Prediction saved to", file) predicted_match = ad.read_h5ad(par["output"] + ".h5ad") Xsol = torch.tensor(predicted_match.X) all_emb_mod12_pred = torch.cat( (all_emb_mod1, all_emb_mod2[Xsol.argmax(1)]), dim=1 ) file = par["output"] + "emb_mod12_fold" + str(fold) + "_predmatch.pt" torch.save(all_emb_mod12_pred, file) print("Prediction saved to", file) # Calculate the cosine similarity matrix and add it to the ensemble sim_matrix += (all_emb_mod1 @ all_emb_mod2.T).detach().cpu().numpy() # save the full similarity matrix # np.save("similarity_matrix.npy", sim_matrix) if args.BATCH_LABEL_MATCHING: # Split matching by batch label mod1_splits = set(input_test_mod1.obs["batch"]) mod2_splits = set(input_test_mod2.obs["batch"]) splits = mod1_splits | mod2_splits matching_matrices, mod1_obs_names, mod2_obs_names = [], [], [] mod1_obs_index = input_test_mod1.obs.index mod2_obs_index = input_test_mod2.obs.index print("batch label splits", splits) for split in splits: print("matching split", split) mod1_split = input_test_mod1[input_test_mod1.obs["batch"] == split] mod2_split = input_test_mod2[input_test_mod2.obs["batch"] == split] mod1_obs_names.append(mod1_split.obs_names) mod2_obs_names.append(mod2_split.obs_names) mod1_indexes = mod1_obs_index.get_indexer(mod1_split.obs_names) mod2_indexes = mod2_obs_index.get_indexer(mod2_split.obs_names) sim_matrix_split = sim_matrix[np.ix_(mod1_indexes, mod2_indexes)] # save the split similarity matrix # np.save("test_similarity_matrix"+str(split)+".npy", sim_matrix_split) if args.OT_MATCHING: # Compute OT matching matching_matrices.append( OT_matching(sim_matrix_split, entropy_reg=args.OT_ENTROPY) ) else: # Max-weight bipartite matching matching_matrices.append(MWB_matching(sim_matrix_split)) # Assemble the matching matrices and reorder according to the original order of cell profiles matching_matrix = scipy.sparse.block_diag(matching_matrices, format="csc") mod1_obs_names = pd.Index(np.concatenate(mod1_obs_names)) mod2_obs_names = pd.Index(np.concatenate(mod2_obs_names)) matching_matrix = matching_matrix[ mod1_obs_names.get_indexer(mod1_obs_index), : ][:, mod2_obs_names.get_indexer(mod2_obs_index)] else: if args.OT_MATCHING: # Compute OT matching matching_matrix = OT_matching(sim_matrix, entropy_reg=args.OT_ENTROPY) else: # Max-weight bipartite matching matching_matrix = MWB_matching(sim_matrix) matching_matrix = pd.DataFrame(matching_matrix) out = ad.AnnData( X=matching_matrix, uns={ # "dataset_id": input_test_mod1.uns["dataset_id"], "method_id": "MatchCLOT", }, ) # Save the matching matrix out.write_h5ad(par["output"] + ".h5ad", compression="gzip") print("Prediction saved to", par["output"] + ".h5ad") # Load the solution for evaluation if args.CUSTOM_DATASET_PATH == "": if is_multiome: sol_path = os.path.join( args.DATASETS_PATH, "openproblems_bmmc_multiome_phase2_rna" "/openproblems_bmmc_multiome_phase2_rna.censor_dataset.output_", ) else: sol_path = os.path.join( args.DATASETS_PATH, "openproblems_bmmc_cite_phase2_rna/openproblems_bmmc_cite_phase2_rna" ".censor_dataset.output_", ) sol = ad.read_h5ad(sol_path + "test_sol.h5ad") else: print("For evaluation assuming cell order is the same in the two modalities") # sol.X.toarray() will be called sol = ad.AnnData(X=scipy.sparse.eye(matching_matrix.shape[0])) # Score the prediction and save the results scores_path = os.path.join("scores", args.OUT_NAME + args.TASK + ".txt") evaluate(out, sol, scores_path)
if __name__ == "__main__": main()