Prerequisites: Introduction to Oracle Tribuo, knowledge of Java.
As we saw in Introduction to Supervised Learning, one of the basic machine learning tasks is classification, where er need to map a set of inputs (usually referred to as ‘features’) into two or more categories (or ‘classes’). While typically data for such tasks needs to be collected and prepared, there are many publicly available datasets that can se used for demonstration and benchmarking, and we will use one of them for our first Tribuo example.
Wine Quality Classification Dataset
There are actually two separate datasets, one for red wine and the other—for white, both share the same structure:
- Each row represents one variant of the Portuguese “Vinho Verde” wine.
- Each row contains 11 numeric features representing attributes of the wine variant, such as acidity, density and pH.
- The output variable is the 12th number, an integer representing the quality of the wine variant (score between 0 and 10).
Note that although the output variable divides the wine variants to different classes, and therefore the dataset can be considered a classification task, it can also be used as a regression task, as the output value goes up with the wine quality.
Our sample code resides in https://github.com/ai4java/tribuo-examples, which is part of the Ai4Java Github repository.
The code is setup as a Maven project, and the pom.xml file contains the dependency for Tribuo:
The main file used for this example is the class WineQualityClassification; lets take a close look at this class.
<Label> and <Regressor> Types
As Tribuo is strongly typed, we need to use the type <Label> when handling a classification task. This is demonstrated in out class variables:
protected Trainer<Label> trainer;
protected Dataset<Label> trainSet;
protected Dataset<Label> testSet;
For regression tasks, the equivalent type to use is <Regressor>.
Preparing the Datasets
The dataset preparation is handled by our class’ createDatasets() method. To read the local CSV file containing the red wine dataset, we start by creating a LabelFactory and using it to initialize a CVSLoader. As the CVS file is delimited by ‘;’ characters rather than commas, we use a constructor that enables to define the separator. We then load the file into a ListDataSource, while marking the ‘quality’ column as the output variable:
LabelFactory labelFactory = new LabelFactory();
CSVLoader<Label> csvLoader = new CSVLoader<>(';', CSVIterator.QUOTE, labelFactory);
ListDataSource<Label> dataSource = csvLoader.loadDataSource(Paths.get(DATASET_PATH), "quality");
Note that CVSLoader expects the first row of the file to contain the names of the columns, otherwise we will need to provide these names to its constructor. For more complex operations on the CSV file, such as ignoring metadata columns or marking multiple output columns, you can use the class CSVDataSource instead of CSVLoader.
We now use the TrainTestSplitter to split the dataset into a train set, used to train the model, and a test set—used to test its predictive performance on samples it was not trained with. In our split, the train set is comprised of 70% of the 1599 rows contained in the original red-wine dataset:
TrainTestSplitter<Label> dataSplitter = new TrainTestSplitter<>(dataSource,0.7,1L);
trainSet = new MutableDataset<>(dataSplitter.getTrain());
log.info(String.format("Train set size = %d, num of features = %d, classes = %s",
trainSet.size(), trainSet.getFeatureMap().size(), trainSet.getOutputInfo().getDomain()));
testSet = new MutableDataset<>(dataSplitter.getTest());
log.info(String.format("Test set size = %d, num of features = %d, classes = %s",
testSet.size(), testSet.getFeatureMap().size(), testSet.getOutputInfo().getDomain()));
We use some of the available methods of the Dataset class to print information pertaining to the two sets; we will see the resulting output later on.
Training the Model
The model is created and trained using a class that implements the Trainer interface. There are many implementations available, representing various algorithms. For our case, I chose the XGBoostClassificationTrainer class, which is initialized at the createTrainer() method, with a numTrees parameter of value 50:
trainer = new XGBoostClassificationTrainer(50);
Note that the createTrainer() method can be overridden by classes extending our WineQualityClassification class, providing other trainers (or other trainer parameters) to be used.
Next, the trainAndEvaluate() method performs supervised training using the train set, and evaluates the results using the evaluate() method:
Model<Label> model = trainer.train(trainSet);
evaluate(model, "trainSet", trainSet);
Next, the same evaluate() method is used to test the trained model on the test set, whose samples were not used for the training process:
evaluate(model, "testSet", testSet);
Evaluating the Results
The evaluate() method utilized Tribuo’s LabelEvaluator to evaluate the model against a given dataset, and in turn creates a LabelEvaluation instance containing the results:
LabelEvaluator evaluator = new LabelEvaluator();
LabelEvaluation evaluation = evaluator.evaluate(model,dataset);
The LabelEvaluation object can now be used to extract various measures of the model’s performance, here we print out the accuracy of the model and the confusion matrix:
log.info("Accuracy: " + evaluation.accuracy());
log.info("Confusion Matrix: \n" + evaluation.getConfusionMatrix());
LabelEvaluation offers various other measurements as well; a summary of all measurements included in this class can be printed out using the toString() method of the LabelEvaluation object.
Running the Code
When running the code, the first few lines output the dataset information. An interesting observation is that out of the 11 possible classes (0..10), only six are actually represented in the dataset:
Train set size = 1119, num of features = 11, classes = [3, 4, 5, 6, 7, 8]
Test set size = 480, num of features = 11, classes = [3, 4, 5, 6, 7, 8]
Next, we see the model being trained and the evaluation of the training being printed. The measures show that at the end of the training, the classifier was able to correctly classify all the samples provided. This is evident from the accuracy value of 100%, as well as from the perfect confusion matrix, whose all non-zero values reside along the diagonal. For example, all 446 samples of the the class ‘6’ were correctly identified as class ‘6’, yielding a value of 446 at the corresponding element of the matrix:
Results for trainSet---------------------
3 4 5 6 7 8
3 6 0 0 0 0 0
4 0 42 0 0 0 0
5 0 0 473 0 0 0
6 0 0 0 446 0 0
7 0 0 0 0 142 0
8 0 0 0 0 0 10
However, a perfect score for the train set does not necessarily mean the same for the test set. In the final portion of the program’s output we can see that the accuracy of the predictions for the test set is much lower, and stands at 0.66875, which means that only about 67% of the samples were classified correctly. The confusion matrix provides more insight: The first row, for example, indicates that there were four samples of the class ‘3’, and they were all classified incorrectly: three of them were classified as ‘5’, and the one remaining was classified as ‘6’.
Results for testSet---------------------
3 4 5 6 7 8
3 0 0 3 1 0 0
4 0 2 3 6 0 0
5 0 2 161 42 3 0
6 0 3 35 129 25 0
7 0 0 3 24 29 1
8 0 0 0 1 7 0
Although classification accuracy of 67% is typically not sufficient for real-life applications, this result is on par with other experiments that were done with the same dataset. To improve the results, we could experiment, for example, with changing the parameter values of the model/trainer, use other types of trainers, or apply feature engineering. These topics will be covered in future posts.
In future posts we will discover more aspects of the Tribuo library and its integration with other frameworks.