EasySegmentUnsupervisedTrain
Support |
|
Required licenses |
EasySegment |
Recommended images |
N/A |
Location |
Deep Learning Inspection\EasySegmentUnsupervisedTrain \ |
Purpose
This sample program demonstrates how to:
□ | Train an EasySegment Unsupervised tool in console mode. |
Code highlights
By default, the sample program trains a tool for the dataset Fabric in the Deep Learning Additional Resources.
NOTE: | In the code, you need to replace DEEP_LEARNING_ADDITIONAL_RESOURCES by the path of the Deep Learning Additional Resources. |
1. | Select the first available GPU for the training. |
// initialize a segmenter that will be trained
EUnsupervisedSegmenter newUnsupervisedSegmenter;
// make sure that we have a GPU to use (note: GPU's can only be detected in x64 configuration)
std::vector<EDeepLearningDevice> devices = newClassifier.GetAvailableDevices();
bool foundGPU = false;
for (size_t i = 0; i < devices.size(); i++)
{
if (devices[i].GetDeviceType() == EDeepLearningDeviceType_GPU)
{
newClassifier.SetActiveDevice(devices[i]);
foundGPU = true;
break;
}
}
2. | Load a dataset from a Deep Learning Studio project. |
const EClassificationDataset* dataset;
EDeepLearningProject project;
project.Load(PROJECT_PATH);
dataset = &project.GetDataset();
3. | Alternatively, load a dataset directly using a .edldataset file. |
const EClassificationDataset* dataset;
EClassificationDataset datasetHolder;
datasetHolder.Load(DATASET_PATH);
dataset = &datasetHolder;
4. | Split the dataset in a training part and a validation part. |
NOTE: | Be careful that the method for splitting a dataset is different for each type of tool. EasySegment Unsupervised uses the same method as EasyClassify to create a split. |
// Put 80% of images into the training dataset and 20% in the validation dataset
EClassificationDataset trainingDataset;
EClassificationDataset validationDataset;
dataset->SplitDataset(trainingDataset, validationDataset, 0.8f);
5. | Configure and train the EasySegment Unsupervised tool and wait for the training to end. |
newUnsupervisedSegmenter.SetCapacity(EUnsupervisedSegmenterCapacity_Small);
newUnsupervisedSegmenter.SetForceGrayscale(true);
newUnsupervisedSegmenter.SetGoodLabel("Good");
newUnsupervisedSegmenter.Train(trainingDataset, validationDataset, 50);
newUnsupervisedSegmenter.WaitForTrainingCompletion();
6. | Check the validation accuracy of the tool we just trained. |
float newUnsupervisedSegmenterAccuracy = newUnsupervisedSegmenter.GetValidationMetrics(newUnsupervisedSegmenter.GetBestIteration()).GetBestAccuracy();
std::cout << "New segmenter Accuracy:" << newUnsupervisedSegmenterAccuracy << std::endl;