Зміст курсу
Ensemble Learning
Ensemble Learning
ExtraTrees
Extra Trees, short for Extremely Randomized Trees, is a bagging ensemble learning technique that builds upon the concept of decision trees to create a more robust and diverse model.
How does ExtraTrees algorithm work?
It is a variation of the Random Forest algorithm but introduces even more randomness into the tree-building process:
- The extra trees algorithm, like the random forests algorithm, creates many decision trees, but the sampling for each tree is random, without replacement;
- A specific number of features from the total set of features is also selected randomly for each tree;
- Extra trees' most important and unique characteristic is the random selection of a splitting value for a feature. Instead of calculating a locally optimal value using Gini or entropy to split the data, the algorithm randomly selects a split value. This makes the trees diversified and uncorrelated.
Note
We can also use
.feature_importances_
attribute to measure the features' impact on the model's result.
Example
We can use ExtraTrees in Python just like Random Forest using the ExtraTreesClassifier
or ExtraTreesRegressor
classes:
# Import necessary libraries from sklearn.model_selection import train_test_split from sklearn.ensemble import ExtraTreesRegressor from sklearn.metrics import mean_squared_error import numpy as np # Generate example data with a more complex relationship np.random.seed(42) X = np.random.rand(100, 2) # 100 samples with 2 features y = 3*X[:, 0]**2 + 5*X[:, 1]**3 + np.random.normal(0, 2, 100) # Complex relationship with noise # Split the data into training and testing sets X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) # Create and train the ExtraTrees Regressor regressor = ExtraTreesRegressor(n_estimators=100, random_state=42) regressor.fit(X_train, y_train) # Make predictions y_pred = regressor.predict(X_test) # Calculate Mean Squared Error (MSE) as the evaluation metric mse = mean_squared_error(y_test, y_pred) print(f'Mean Squared Error: {mse:.4f}') # Get feature importances feature_importances = regressor.feature_importances_ # Print feature importances print('Feature Importances:') for feature, importance in enumerate(feature_importances): print(f'Feature {feature}: {importance:.4f}')
Дякуємо за ваш відгук!