Notice: This page requires JavaScript to function properly.
Please enable JavaScript in your browser settings or update your browser.
GridSearchCV | Modeling
ML Introduction with scikit-learn
course content

Contenido del Curso

ML Introduction with scikit-learn

ML Introduction with scikit-learn

1. Machine Learning Concepts
2. Preprocessing Data with Scikit-learn
3. Pipelines
4. Modeling

book
GridSearchCV

Now it is time to try improving the model's performance! This is done by finding the best hyperparameters fitting our task.

This process is called hyperparameter tuning. The default approach is to try different hyperparameter values and calculate a cross-validation score for them. Then just choose the value that results in the best score.

This process can be done using the GridSearchCV class of the sklearn.model_selection module.

While constructing a GridSearchCV object, we need to pass the model and the parameters grid (and optionally scoring and the number of folds). The parameters grid (param_grid) is a dictionary containing all the hyperparameters configurations we want to try. For example, param_grid={'n_neighbors': [1, 3, 5, 7]} will try values 1, 3, 5, and 7 as the number of neighbors.

Next, train the model using the .fit(X, y) method. Once training is complete, you can identify the model that performed the best — the one with the highest cross-validation score, typically based on accuracy — by accessing the .best_estimator_ attribute.

To review this top-performing model’s cross-validation score, refer to the .best_score_ attribute.

123456789101112131415
import pandas as pd from sklearn.neighbors import KNeighborsClassifier from sklearn.model_selection import GridSearchCV df = pd.read_csv('https://codefinity-content-media.s3.eu-west-1.amazonaws.com/a65bbc96-309e-4df9-a790-a1eb8c815a1c/penguins_pipelined.csv') # Assign X, y variables (X is already preprocessed and y is already encoded) X, y = df.drop('species', axis=1), df['species'] # Create the param_grid and initialize GridSearchCV object param_grid = {'n_neighbors': [1,3,5,7,9]} grid_search = GridSearchCV(KNeighborsClassifier(), param_grid) # Train the GridSearchCV object. During training it finds the best parameters grid_search.fit(X, y) # Print the best estimator and its cross-validation score print(grid_search.best_estimator_) print(grid_search.best_score_)
copy

The next step would be to take the best_estimator_ and train it on the whole dataset since we already know it has the best parameters (out of those we tried), and we know its score. This step is so obvious that GridSearchCV does it by default.

So the object (grid_search in our example) becomes a trained model with the best parameters.
We can now use this object for predicting or evaluating. That's why GridSearchCV has the .predict() and .score() methods.

123456789101112131415
import pandas as pd from sklearn.neighbors import KNeighborsClassifier from sklearn.model_selection import GridSearchCV df = pd.read_csv('https://codefinity-content-media.s3.eu-west-1.amazonaws.com/a65bbc96-309e-4df9-a790-a1eb8c815a1c/penguins_pipelined.csv') # Assign X, y variables (X is already preprocessed and y is already encoded) X, y = df.drop('species', axis=1), df['species'] # Create the param_grid and initialize GridSearchCV object param_grid = {'n_neighbors': [1,3,5,7,9]} grid_search = GridSearchCV(KNeighborsClassifier(), param_grid) # Train the GridSearchCV object. During training it finds the best parameters grid_search.fit(X, y) # Evaluate the grid_search on the training set # It is done only to show that .score() method works, evaluating on training set is not reliable. print(grid_search.score(X, y))
copy
Once you trained a `GridSearchCV` object, you can use it to make predictions using the `.predict()` method. Is it correct?

Once you trained a GridSearchCV object, you can use it to make predictions using the .predict() method. Is it correct?

Selecciona la respuesta correcta

¿Todo estuvo claro?

¿Cómo podemos mejorarlo?

¡Gracias por tus comentarios!

Sección 4. Capítulo 6
We're sorry to hear that something went wrong. What happened?
some-alt