EasySegmentSupervisedTrain
Support |
|
Required licenses |
EasySegment |
Recommended images |
N/A |
Location |
Deep Learning Inspection\EasySegmentSupervisedTrain \ |
Purpose
This sample program demonstrates how to:
□ | Train an EasySegment Supervised tool in console mode. |
Code highlights
By default, the sample program trains a tool for the dataset Coffee in the Deep Learning Additional Resources.
NOTE: | In the code, you need to replace DEEP_LEARNING_ADDITIONAL_RESOURCES by the path of the Deep Learning Additional Resources. |
1. | Select the first available GPU for the training. |
//initialize a segmenter that will be trained
ESupervisedSegmenter newSupervisedSegmenter;
// make sure that we have a GPU to use (note: GPU's can only be detected in x64 configuration)
std::vector<EDeepLearningDevice> devices = newClassifier.GetAvailableDevices();
bool foundGPU = false;
for (size_t i = 0; i < devices.size(); i++)
{
if (devices[i].GetDeviceType() == EDeepLearningDeviceType_GPU)
{
newClassifier.SetActiveDevice(devices[i]);
foundGPU = true;
break;
}
}
2. | Load a dataset from a Deep Learning Studio project. |
const EClassificationDataset* dataset;
EDeepLearningProject project;
project.Load(PROJECT_PATH);
dataset = &project.GetDataset();
3. | Alternatively, load a dataset directly using a .edldataset file. |
const EClassificationDataset* dataset;
EClassificationDataset datasetHolder;
datasetHolder.Load(DATASET_PATH);
dataset = &datasetHolder;
4. | Split the dataset in a training part and a validation part. |
NOTE: | Be careful that the method for splitting a dataset is different for each type of tool. |
// Put 80% of images into the training dataset and 20% in the validation dataset
EClassificationDataset trainingDataset;
EClassificationDataset validationDataset;
dataset->SplitDatasetForSegmentation(trainingDataset, validationDataset, 0.8f);
5. | Train the EasySegment Supervised tool and wait for the training to end. |
newSupervisedSegmenter.Train(trainingDataset, validationDataset, 5);
newSupervisedSegmenter.WaitForTrainingCompletion();
6. | Check the validation IoU (intersection over union) of the tool we just trained. |
float newSupervisedSegmenterIoU = newSupervisedSegmenter.GetValidationMetrics(newSupervisedSegmenter.GetBestIteration()).GetWeightedIntersectionOverUnion();
std::cout << "New segmenter IoU: " << newSupervisedSegmenterIoU << std::endl;