This is the third post in a series of four posts.
In the previous post, First Tribuo Example, we saw how to implement a wine quality classifier with Tribuo, a Java-based Machine Learning library. This time, we will use the same dataset to demonstrate how to implement a regressor, and along the way discover a few more capabilities of Tribuo.
Regression vs. Classification
As we saw in Introduction to Supervised Learning, models used for regression tasks typically have a single output, providing a continuous value; given the input values (or ‘features’), the model is expected to predict the value of the output. This is in contrast to classification tasks, where the model maps the features into two or more categories.
Interestingly, the Wine Quality Dataset we used for classification in the previous post can also be used for regression: although the output variable divides the wine variants to different classes, 0, 1, 2, …, 10, that value goes up with the wine quality, and therefore we can also think of the quality as continuous value between 0 and 10.
The Code
The main file used for this example is the class WineQualityRegression; it has many similarities to the class WineQualityClassification which we described in the previous post; in this post, we will concentrate on the differences between these classes.
<Regressor> Instead of <Label>
The type <Regressor> is replacing <Label> throughout the code, to denote our use case of regression, as demonstrated in our class variables:
41 42 43 44 |
protected Model<Regressor> model; protected Trainer<Regressor> trainer; protected Dataset<Regressor> trainSet; protected Dataset<Regressor> testSet; |
There are several other common types used in Tribuo, such as <Event>, used in anomaly detection, and <ClusterID>—used in clustering.
Random Forest Trainer
For this task, we are going to demonstrate the use of a Random Forest model. Just like the XGBoost model we used in the previous post, this type of model can be used for both classification and regression. During training, the Random Forest algorithm creates a collection (or ensemble) of decision trees, each created for a randomly selected subspace of the data set. The outputs of all these ‘subsampling’ trees are then averaged to create the prediction of the trained regressor.
To make that happen, the method createTrainer() starts by creating a CART (Classification And Regression Tree) trainer that will be used to create the ‘subsampling’ trees—the building blocks of the Random Forest model. The constructor of this trainer is given a list of parameter values that will be used when creating these trees:
58 59 60 61 62 63 64 |
CARTRegressionTrainer subsamplingTree = new CARTRegressionTrainer( Integer.MAX_VALUE, // maxDepth - the maximum depth of the tree AbstractCARTTrainer.MIN_EXAMPLES, // minChildWeight - the minimum node weight to consider it for a split 0.9f, // fractionFeaturesInSplit - the fraction of features available in each split new MeanSquaredError(), // impurity - the impurity function to use to determine split quality Trainer.DEFAULT_SEED // seed - the Random Number Generator seed ); |
Next we create the actual Random Forest trainer, that receives the CART trainer we just defined as its first parameter:
66 67 68 69 70 |
trainer = new RandomForestTrainer<>( subsamplingTree, // trainer - the tree trainer new AveragingCombiner(), // combiner - the combining function for the ensemble 10 // numMembers - the number of ensemble members to train ); |
When running the program, the output of the training process looks as follows:
1 2 3 4 5 6 7 8 |
Training model... Feb 06, 2021 9:42:51 PM org.tribuo.ensemble.BaggingTrainer train INFO: Building model 0 Feb 06, 2021 9:42:52 PM org.tribuo.ensemble.BaggingTrainer train INFO: Building model 1 ... Feb 06, 2021 9:42:52 PM org.tribuo.ensemble.BaggingTrainer train INFO: Building model 9 |
Tribuo’s BaggingTrainer class is responsible for creating the ensemble of decision trees for the Random Forest regressor. We can see 10 rounds of training, matching the value of numMembers we initialized the trainer with.
Evaluating the Regression Results
To evaluate the results, the evaluate() method utilizes Tribuo’s RegressionEvaluator and RegressionEvaluation, in a similar manner to what we saw in the case of classification. However, in order to extract the results we need, an ‘auxiliary’ regressor needs to be created first as follows:
112 |
Regressor dimension0 = new Regressor("DIM-0",Double.NaN); |
The reason for that is that Tribuo supports multidimensional regression: imagine, for example, that our dataset had two output columns, one for the wine’s quality and the other—for consistency. A single, two-dimensional regressor could be used for this dataset, with the first dimension predicting the quality, and the second dimension—the consistency. Each of these dimensions would be evaluated separately. In our case, there is a single dimension (‘dimension 0’) and we extract its evaluation as follows:
114 115 116 |
log.info("MAE: " + evaluation.mae(dimension0)); log.info("RMSE: " + evaluation.rmse(dimension0)); log.info("R^2: " + evaluation.r2(dimension0)); |
The measures extracted from the evaluation represent three different ways to evaluate the results of a regressor:
- MAE (Mean Absolute Error): represents the average absolute distance between the predicted results and the actual data
- RMSE (Root Mean Squared Error): represents the squared root of the average squared distance between the predicted results and the actual data, emphasizing large errors.
- R^2 (R-Squared): provides a measure, between 0 and 1, of how well the predicted results fit compared to the actual data.
When running the program, the output of the evaluation looks as follows:
1 2 3 4 5 6 7 8 9 |
Results for trainSet--------------------- MAE: 0.1821105153410776 RMSE: 0.2628796435261567 R^2: 0.8932756394863142 Testing model... Results for testSet--------------------- MAE: 0.47935416666666664 RMSE: 0.6523233317662848 R^2: 0.35668129055965125 |
As evident from these results, the performance for the test set is considerably worse comparing to what we get for the train set—the two distance related measures are larger, and the R-Squared is closer to 0 than to 1. This is consistent with the results we got for the same dataset in the case of classification.
Provenance
As we mentioned in our Introduction to Oracle Tribuo post, Tribuo’s Provenance is a unique feature integrated into all classes that represent Models, Datasets and Evaluations. Provenance provides information about the parameters, transformations, and files were used to create them; it allows each model and experiment to be recreated from scratch.
To illustrate the wealth of information that is stored in the model’s provenance, we added at the end of the trainAndEvaluate() method a few lines that print out of the provenance for the dataset and the trainer used for our experiment :
100 101 102 103 |
log.info("Dataset Provenance: --------------------"); log.info(ProvenanceUtil.formattedProvenanceString(model.getProvenance().getDatasetProvenance())); log.info("Trainer Provenance: --------------------"); log.info(ProvenanceUtil.formattedProvenanceString(model.getProvenance().getTrainerProvenance())); |
The resulting output is a very detailed structure, containing all the pertaining information. Below is a small portion of it:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 |
Dataset Provenance: -------------------- MutableDataset( class-name = org.tribuo.MutableDataset datasource = TrainTestSplitter( class-name = org.tribuo.evaluation.TrainTestSplitter source = CSVLoader( class-name = org.tribuo.data.csv.CSVLoader ... Trainer Provenance: -------------------- RandomForestTrainer( class-name = org.tribuo.common.tree.RandomForestTrainer seed = 12345 innerTrainer = CARTRegressionTrainer( class-name = org.tribuo.regression.rtree.CARTRegressionTrainer maxDepth = 2147483647 ... |
You can learn more about Provenance and how it can be used in conjunction with the Configuration Manager in the official Tribuo Configuration Tutorial.
Serializing the Model
We can now save the trained model into a file, so we can recreate it later and use it to make predictions. This is done by the code in the method saveModel():
120 121 122 123 |
File modelFile = new File(MODEL_PATH); try (ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(modelFile))) { oos.writeObject(model); } |
As seen above, all we need to do is write the Model object into a file.
What’s Next?
In our next post, we will read the model back from the file, and use it to create an application that provides wine quality predictions by responding to REST queries.