Skip to content

Models Registry

Model registry and factory function.

pitch_sequencing.models

Model registry for all pitch prediction models.

MODEL_REGISTRY = {'logistic_regression': LogisticRegressionModel, 'random_forest': RandomForestModel, 'hmm': HMMModel, 'autogluon': AutoGluonModel, 'lstm': LSTMModel, 'cnn1d': CNN1DModel, 'transformer': TransformerModel} module-attribute

get_model(name, config=None)

Instantiate a model by registry name.

Parameters:

Name Type Description Default
name

Key in MODEL_REGISTRY (e.g. 'lstm', 'random_forest').

required
config

Optional dict of hyperparameters.

None

Returns:

Type Description

Instance of the model class.

Source code in src/pitch_sequencing/models/__init__.py
def get_model(name, config=None):
    """Instantiate a model by registry name.

    Args:
        name: Key in MODEL_REGISTRY (e.g. 'lstm', 'random_forest').
        config: Optional dict of hyperparameters.

    Returns:
        Instance of the model class.
    """
    if name not in MODEL_REGISTRY:
        raise ValueError(f"Unknown model '{name}'. Available: {list(MODEL_REGISTRY.keys())}")
    return MODEL_REGISTRY[name](config)