def train_main():
"""Train a single pitch prediction model."""
parser = argparse.ArgumentParser(description="Train a single pitch prediction model")
parser.add_argument("--model", type=str, required=True, help="Model name (e.g. lstm, random_forest)")
parser.add_argument("--config", type=str, default=None, help="Path to model config YAML")
parser.add_argument("--data-config", type=str, default=None, help="Path to data config")
args = parser.parse_args()
import mlflow
import numpy as np
from .config import DataConfig, load_config
from .data.loader import load_pitch_data, create_sequences, load_hmm_sequences
from .data.preprocessing import encode_categoricals, normalize_numericals
from .models import get_model
from .evaluation.metrics import compute_metrics
data_config_path = args.data_config or str(get_default_config("data.yaml"))
data_cfg = DataConfig.from_yaml(data_config_path)
# Load model config
if args.config:
model_cfg = load_config(args.config)
else:
default_path = str(get_default_config(f"models/{args.model.replace('logistic_regression', 'logistic')}.yaml"))
if os.path.exists(default_path):
model_cfg = load_config(default_path)
else:
model_cfg = {}
# Load and prepare data
print(f"Loading data from {data_cfg.data_path}...")
df = load_pitch_data(data_cfg.data_path)
df, encoders = encode_categoricals(
df, [c for c in data_cfg.categorical_features if c in df.columns]
)
df, norm_stats = normalize_numericals(df, data_cfg.numerical_features)
model = get_model(args.model, model_cfg)
print(f"Training {model.name} (type={model.model_type})...")
if args.model == "hmm":
from sklearn.model_selection import train_test_split
hmm_flat, hmm_enc = load_hmm_sequences(data_cfg.hmm_data_path)
X_train, X_test = train_test_split(hmm_flat, test_size=data_cfg.test_size, random_state=data_cfg.random_state)
model.fit(X_train, X_train.flatten(), X_val=X_test, y_val=X_test.flatten())
y_pred = model.predict(X_test)
y_test = X_test.flatten()
elif model.model_type == "sequence":
X, y, _ = create_sequences(
df, window_size=data_cfg.window_size,
feature_cols=data_cfg.sequence_features,
target_col=f"{data_cfg.target_col}_enc",
)
split = int(len(X) * (1 - data_cfg.test_size))
X_train, X_test = X[:split], X[split:]
y_train, y_test = y[:split], y[split:]
model.fit(X_train, y_train, X_val=X_test, y_val=y_test)
y_pred = model.predict(X_test)
else:
tab_features = []
for col in data_cfg.tabular_features:
enc_col = f"{col}_enc"
if enc_col in df.columns:
tab_features.append(enc_col)
elif col in df.columns:
tab_features.append(col)
X = df[tab_features].values
y = df[f"{data_cfg.target_col}_enc"].values
split = int(len(X) * (1 - data_cfg.test_size))
X_train, X_test = X[:split], X[split:]
y_train, y_test = y[:split], y[split:]
model.fit(X_train, y_train)
y_pred = model.predict(X_test)
metrics = compute_metrics(y_test, y_pred)
print(f"\nResults for {model.name}:")
print(f" Accuracy: {metrics['accuracy']:.4f}")
print(f" Balanced Accuracy: {metrics['balanced_accuracy']:.4f}")
print(f" Macro F1: {metrics['macro_f1']:.4f}")
print(f" Macro Precision: {metrics['macro_precision']:.4f}")
print(f" Macro Recall: {metrics['macro_recall']:.4f}")
# Log to MLflow
mlflow.set_tracking_uri(f"file://{os.path.abspath('experiments')}")
mlflow.set_experiment(f"train_{args.model}")
with mlflow.start_run(run_name=f"{args.model}_single"):
mlflow.log_params(model.get_params())
for k, v in metrics.items():
if isinstance(v, (int, float)):
mlflow.log_metric(k, v)
print("\nLogged to MLflow.")