""" Prediction de la survie d'un individu sur le Titanic """ import os from dotenv import load_dotenv import argparse from loguru import logger from joblib import dump import pathlib import pandas as pd from sklearn.model_selection import train_test_split from src.pipeline.build_pipeline import create_pipeline from src.models.train_evaluate import evaluate_model # ENVIRONMENT CONFIGURATION --------------------------- logger.add("recording.log", rotation="500 MB") load_dotenv() parser = argparse.ArgumentParser(description="Paramètres du random forest") parser.add_argument( "--n_trees", type=int, default=20, help="Nombre d'arbres" ) args = parser.parse_args() URL_RAW = "https://minio.lab.sspcloud.fr/lgaliana/ensae-reproductibilite/data/raw/data.csv" n_trees = args.n_trees jeton_api = os.environ.get("JETON_API", "") data_path = os.environ.get("data_path", URL_RAW) data_train_path = os.environ.get("train_path", "data/derived/train.parquet") data_test_path = os.environ.get("test_path", "data/derived/test.parquet") MAX_DEPTH = None MAX_FEATURES = "sqrt" if jeton_api.startswith("$"): logger.info("API token has been configured properly") else: logger.warning("API token has not been configured") # IMPORT ET STRUCTURATION DONNEES -------------------------------- p = pathlib.Path("data/derived/") p.mkdir(parents=True, exist_ok=True) TrainingData = pd.read_csv(data_path) y = TrainingData["Survived"] X = TrainingData.drop("Survived", axis="columns") X_train, X_test, y_train, y_test = train_test_split( X, y, test_size=0.1 ) pd.concat([X_train, y_train], axis = 1).to_parquet(data_train_path) pd.concat([X_test, y_test], axis = 1).to_parquet(data_test_path) # PIPELINE ---------------------------- # Create the pipeline pipe = create_pipeline( n_trees, max_depth=MAX_DEPTH, max_features=MAX_FEATURES ) # ESTIMATION ET EVALUATION ---------------------- pipe.fit(X_train, y_train) with open("model.joblib", "wb") as f: dump(pipe, f) # Evaluate the model score, matrix = evaluate_model(pipe, X_test, y_test) logger.success(f"{score:.1%} de bonnes réponses sur les données de test pour validation") logger.debug(20 * "-") logger.info("Matrice de confusion") logger.debug(matrix)