Source code for mlqtl.train

import numpy as np
import multiprocessing
from typing import List, Tuple
from scipy.stats import pearsonr, NearConstantInputWarning, ConstantInputWarning
from sklearn import preprocessing
from sklearn.model_selection import train_test_split
from sklearn.inspection import permutation_importance
from inspect import signature
from sklearn.base import RegressorMixin
from pandas import DataFrame
import warnings
import click

from .data import Dataset
from .nda_typing import (
    MatrixInt8,
    TensorFloat64,
    VectorFloat64,
    MatrixFloat64,
    VectorStr,
)

np.random.seed(42)


[docs] class MLMetrics(object): def __init__(self): self.corr = None self.p_value = None
[docs] def update(self, y, y_hat): with warnings.catch_warnings(): warnings.simplefilter("ignore", NearConstantInputWarning) warnings.simplefilter("ignore", ConstantInputWarning) with np.errstate(divide="ignore", invalid="ignore"): corr, p_value = pearsonr(y, y_hat) self.corr = corr self.p_value = p_value
[docs] def train_batch( X: MatrixInt8 | TensorFloat64, y: VectorFloat64, onehot: bool, models: List[RegressorMixin], importance: bool = False, ) -> MatrixFloat64: """ Train a batch of models on the given data Parameters ---------- X : MatrixInt8 | TensorFloat64 The encoded SNP data y : VectorFloat64 The trait values onehot : bool Whether the SNP data is one-hot encoded models : List[RegressorMixin] The list of models to train importance : bool Whether to calculate feature importance Returns ------- MatrixFloat64 - importance == False: a matrix of shape (n_models, 2) containing the correlation and p-value for each model - importance == True, a matrix of shape (n_models, n_features) containing the feature importance for each model """ if onehot: x_shape_ori: Tuple[int, int, int] = X.shape X: MatrixFloat64 = X.reshape(y.shape[0], -1) else: min_max_scaler = preprocessing.MinMaxScaler() X = min_max_scaler.fit_transform(X) instances = [] for model in models: if not issubclass(model, RegressorMixin): raise TypeError(f"Model {model} is not a valid regression model") try: init_params = signature(model).parameters if "random_state" in init_params: instances.append(model(random_state=42)) else: instances.append(model()) except TypeError: raise TypeError(f"Model {model} should be a class") mets, imp_matrix = [], [] for model in instances: met = MLMetrics() X_train, X_test, y_train, y_test = train_test_split( X, y, test_size=0.2, random_state=42 ) model.fit(X_train, y_train) y_pred = model.predict(X_test) met.update(y_test, y_pred) mets.append([met.corr, met.p_value]) if importance: imp_matrix.append( permutation_importance( model, X_train, y_train, n_repeats=10, random_state=42 ).importances_mean ) if importance: imp_matrix = np.array(imp_matrix) if onehot: imp_matrix = imp_matrix.reshape(-1, x_shape_ori[1], x_shape_ori[2]) imp_matrix = np.mean(imp_matrix, axis=2) return imp_matrix return np.array(mets)
[docs] def init_worker(dataset: Dataset) -> None: """ Initialize the worker with the dataset for multiprocessing """ global g_dataset g_dataset = dataset
def _task( genes: VectorStr, trait: str, onehot: bool, models: RegressorMixin ) -> List[MatrixFloat64 | None]: """ task for each chunk of genes """ result = [] for gene in genes: try: snps = g_dataset.get(gene) trait_value, not_nan_idx = g_dataset.trait.get(trait) X = g_dataset.snp.encode(snps, onehot, filter=not_nan_idx) y = trait_value m1 = train_batch(X, y, onehot, models, False) result.append(m1) except ValueError: result.append(None) return result
[docs] def train( trait: str, models: List[RegressorMixin], dataset: Dataset, max_workers: int = 8, onehot: bool = False, ) -> List[List[MatrixFloat64 | None]]: """ Train models on the given dataset using multiprocessing Parameters ---------- trait : str The trait to train onehot : bool Whether the SNP data is one-hot encoded models : List[RegressorMixin] The list of models to train dataset : Dataset The dataset to train max_workers : int The number of workers to use for multiprocessing Returns ------- List[List[MatrixFloat64 | None]] A list of lists containing the correlation and p-value for each model for each gene, shape of MatrixFloat64 is (n_models, 2) """ with multiprocessing.Pool( processes=max_workers, initializer=init_worker, initargs=(dataset,) ) as pool: results = pool.starmap( _task, [ ( dataset.gene.name[chunk], trait, onehot, models, ) for chunk in dataset.gene.chunks(max_workers) ], ) return results
[docs] def train_single( gene: str, trait: str, models: List[RegressorMixin], dataset: Dataset, onehot: bool = False, ) -> MatrixFloat64: """ Train a single gene and trait Parameters ---------- gene : str The gene to train trait : str The trait to train dataset : Dataset The dataset to train models : List[RegressorMixin] The list of models to train onehot : bool Whether the SNP data is one-hot encoded Returns ------- MatrixFloat64 a matrix of shape (n_models, 2) containing the correlation and p-value for each model """ snps = dataset.get(gene) trait_value, not_nan_idx = dataset.trait.get(trait) X = dataset.snp.encode(snps, onehot, filter=not_nan_idx) y = trait_value res = train_batch(X, y, onehot, models) return res
[docs] def feature_importance( gene: str, trait: str, models: List[RegressorMixin], dataset: Dataset, onehot: bool = False, ) -> DataFrame: """ Train a single gene and trait Parameters ---------- gene : str The gene to train trait : str The trait to train dataset : Dataset The dataset to train models : List[RegressorMixin] The list of models to train onehot : bool Whether the SNP data is one-hot encoded Returns ------- DataFrame a pandas DataFrame containing the feature importance for each model with the SNP markers as columns and the model names as rows """ snps = dataset.get(gene) trait_value, not_nan_idx = dataset.trait.get(trait) X = dataset.snp.encode(snps, onehot, filter=not_nan_idx) y = trait_value importance = train_batch(X, y, onehot, models, True) feature = np.array([dataset.snp.idx2marker(i[0]) for i in snps], dtype=np.str_) models_name = np.array([model.__name__ for model in models], dtype=np.str_) res = DataFrame( importance, index=models_name, columns=feature, ) return res
# The following defines the train function with a progress bar and related helper functions used in the CLI def _progress_bar_manager(queue: multiprocessing, total_genes: int, trait: str): """ Manages the click.progressbar by listening to a queue. Receives 'None' to terminate. """ with click.progressbar( length=total_genes, label=f"{trait}", show_pos=True, show_percent=True, show_eta=True, ) as bar: processed_count = 0 while processed_count < total_genes: message = queue.get() if message is None: break if message == "TICK": bar.update(1) processed_count += 1 bar.finish() def _task_progressbar( genes: VectorStr, trait: str, onehot: bool, models: RegressorMixin, progress_queue: multiprocessing.Queue, ) -> List[MatrixFloat64 | None]: """ task for each chunk of genes """ result = [] for gene in genes: try: snps = g_dataset.get(gene) trait_value, not_nan_idx = g_dataset.trait.get(trait) X = g_dataset.snp.encode(snps, onehot, filter=not_nan_idx) y = trait_value m1 = train_batch(X, y, onehot, models, False) result.append(m1) except ValueError: result.append(None) finally: progress_queue.put("TICK") return result
[docs] def train_with_progressbar( trait: str, models: List[RegressorMixin], dataset: Dataset, max_workers: int = 8, onehot: bool = False, ) -> List[List[MatrixFloat64 | None]]: """ Train models on the given dataset using multiprocessing Parameters ---------- trait : str The trait to train onehot : bool Whether the SNP data is one-hot encoded models : List[RegressorMixin] The list of models to train dataset : Dataset The dataset to train max_workers : int The number of workers to use for multiprocessing Returns ------- List[List[MatrixFloat64 | None]] A list of lists containing the correlation and p-value for each model for each gene, shape of MatrixFloat64 is (n_models, 2) """ with multiprocessing.Manager() as manager: progress_queue = manager.Queue() progress_manager_proc = multiprocessing.Process( target=_progress_bar_manager, args=(progress_queue, len(dataset.gene.name), trait), ) progress_manager_proc.start() with multiprocessing.Pool( processes=max_workers, initializer=init_worker, initargs=(dataset,) ) as pool: results = pool.starmap( _task_progressbar, [ ( dataset.gene.name[chunk], trait, onehot, models, progress_queue, ) for chunk in dataset.gene.chunks(max_workers) ], ) progress_queue.put(None) progress_manager_proc.join() return results