diff --git a/CMakeLists.txt b/CMakeLists.txt index b2f51cd0..9780ab47 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -52,12 +52,17 @@ ADD_DEFINITIONS(-DPROJECT_PREFIX="${PROJECT_PREFIX}") ADD_DEFINITIONS(-DPROJECT_VERSION="${PROJECT_VERSION}") ADD_DEFINITIONS(-DPROJECT_NAME="${PROJECT_NAME}") -CONFIGURE_FILE(snap/snapcraft.yaml.in ${PROJECT_SOURCE_DIR}/snap/snapcraft.yaml) -message(STATUS "\"snap/snapcraft.yaml\" has been generated with version \"${PROJECT_VERSION}\". To create a snap, do \"snapcraft\" in find-object's root directory.") - ####### DEPENDENCIES ####### FIND_PACKAGE(OpenCV REQUIRED) # tested on 2.3.1 +# For SuperPoint +SET(TORCH 0) +FIND_PACKAGE(Torch QUIET) +IF(TORCH_FOUND) + MESSAGE(STATUS "Found Torch: ${TORCH_INCLUDE_DIRS}") + SET(TORCH 1) +ENDIF(TORCH_FOUND) + IF(OpenCV_VERSION_MAJOR EQUAL 4) IF(NOT MSVC) include(CheckCXXCompilerFlag) @@ -284,9 +289,14 @@ IF(NOT CATKIN_BUILD) ENDIF() ELSE() IF(OPENCV_XFEATURES2D_FOUND) - MESSAGE(STATUS " With OpenCV ${OpenCV_VERSION_MAJOR} xfeatures2d module (SIFT/SURF/BRIEF/FREAK) = YES") + MESSAGE(STATUS " With OpenCV ${OpenCV_VERSION_MAJOR} xfeatures2d module (BRIEF/FREAK/KAZE) = YES") ELSE() - MESSAGE(STATUS " With OpenCV ${OpenCV_VERSION_MAJOR} xfeatures2d module (SIFT/SURF/BRIEF/FREAK) = NO (not found)") + MESSAGE(STATUS " With OpenCV ${OpenCV_VERSION_MAJOR} xfeatures2d module (BRIEF/FREAK/KAZE) = NO (not found)") + ENDIF() + IF(NONFREE) + MESSAGE(STATUS " With OpenCV ${OpenCV_VERSION_MAJOR} nonfree module (SIFT/SURF) = YES") + ELSE() + MESSAGE(STATUS " With OpenCV ${OpenCV_VERSION_MAJOR} nonfree module (SIFT/SURF) = NO") ENDIF() ENDIF() @@ -303,6 +313,12 @@ IF(NOT CATKIN_BUILD) ELSE() MESSAGE(STATUS " With tcmalloc = NO (tcmalloc not found)") ENDIF(Tcmalloc_FOUND) + + IF(TORCH_FOUND) + MESSAGE(STATUS " With Torch = YES") + ELSE() + MESSAGE(STATUS " With Torch = NO (libtorch not found)") + ENDIF(TORCH_FOUND) IF(APPLE) MESSAGE(STATUS " BUILD_AS_BUNDLE = ${BUILD_AS_BUNDLE}") diff --git a/Version.h.in b/Version.h.in index 6f813509..31bfe1de 100644 --- a/Version.h.in +++ b/Version.h.in @@ -38,6 +38,7 @@ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #define FINDOBJECT_VERSION_COMPARE(major, minor, patch) (major>=@PROJECT_VERSION_MAJOR@ || (major==@PROJECT_VERSION_MAJOR@ && minor>=@PROJECT_VERSION_MINOR@) || (major==@PROJECT_VERSION_MAJOR@ && minor==@PROJECT_VERSION_MINOR@ && patch >=@PROJECT_VERSION_PATCH@)) #define FINDOBJECT_NONFREE @NONFREE@ +#define FINDOBJECT_TORCH @TORCH@ #endif /* VERSION_H_ */ diff --git a/include/find_object/Settings.h b/include/find_object/Settings.h index 6f21dc0c..e2126701 100644 --- a/include/find_object/Settings.h +++ b/include/find_object/Settings.h @@ -113,8 +113,8 @@ class FINDOBJECT_EXP Settings //List format : [Index:item0;item1;item3;...] - PARAMETER_COND(Feature2D, 1Detector, QString, FINDOBJECT_NONFREE, "7:Dense;Fast;GFTT;MSER;ORB;SIFT;Star;SURF;BRISK;AGAST;KAZE;AKAZE" , "4:Dense;Fast;GFTT;MSER;ORB;SIFT;Star;SURF;BRISK;AGAST;KAZE;AKAZE", "Keypoint detector."); - PARAMETER_COND(Feature2D, 2Descriptor, QString, FINDOBJECT_NONFREE, "3:Brief;ORB;SIFT;SURF;BRISK;FREAK;KAZE;AKAZE;LUCID;LATCH;DAISY", "1:Brief;ORB;SIFT;SURF;BRISK;FREAK;KAZE;AKAZE;LUCID;LATCH;DAISY", "Keypoint descriptor."); + PARAMETER_COND(Feature2D, 1Detector, QString, FINDOBJECT_NONFREE, "7:Dense;Fast;GFTT;MSER;ORB;SIFT;Star;SURF;BRISK;AGAST;KAZE;AKAZE;SuperPointTorch" , "4:Dense;Fast;GFTT;MSER;ORB;SIFT;Star;SURF;BRISK;AGAST;KAZE;AKAZE;SuperPointTorch", "Keypoint detector."); + PARAMETER_COND(Feature2D, 2Descriptor, QString, FINDOBJECT_NONFREE, "3:Brief;ORB;SIFT;SURF;BRISK;FREAK;KAZE;AKAZE;LUCID;LATCH;DAISY;SuperPointTorch", "1:Brief;ORB;SIFT;SURF;BRISK;FREAK;KAZE;AKAZE;LUCID;LATCH;DAISY;SuperPointTorch", "Keypoint descriptor."); PARAMETER(Feature2D, 3MaxFeatures, int, 0, "Maximum features per image. If the number of features extracted is over this threshold, only X features with the highest response are kept. 0 means all features are kept."); PARAMETER(Feature2D, 4Affine, bool, false, "(ASIFT) Extract features on multiple affine transformations of the image."); PARAMETER(Feature2D, 5AffineCount, int, 6, "(ASIFT) Higher the value, more affine transformations will be done."); @@ -228,6 +228,12 @@ class FINDOBJECT_EXP Settings PARAMETER(Feature2D, DAISY_interpolation, bool, true, "Switch to disable interpolation for speed improvement at minor quality loss."); PARAMETER(Feature2D, DAISY_use_orientation, bool, false, "Sample patterns using keypoints orientation, disabled by default."); + PARAMETER(Feature2D, SuperPointTorch_modelPath, QString, "", "[Required] Path to pre-trained weights Torch file of SuperPoint (*.pt)."); + PARAMETER(Feature2D, SuperPointTorch_threshold, float, 0.2, "Detector response threshold to accept keypoint."); + PARAMETER(Feature2D, SuperPointTorch_NMS, bool, true, "If true, non-maximum suppression is applied to detected keypoints."); + PARAMETER(Feature2D, SuperPointTorch_NMS_radius, int, 4, "[%s=true] Minimum distance (pixels) between keypoints"); + PARAMETER(Feature2D, SuperPointTorch_cuda, bool, false, "Use Cuda device for Torch, otherwise CPU device is used by default."); + PARAMETER_COND(NearestNeighbor, 1Strategy, QString, FINDOBJECT_NONFREE, "1:Linear;KDTree;KMeans;Composite;Autotuned;Lsh;BruteForce", "6:Linear;KDTree;KMeans;Composite;Autotuned;Lsh;BruteForce", "Nearest neighbor strategy."); PARAMETER_COND(NearestNeighbor, 2Distance_type, QString, FINDOBJECT_NONFREE, "0:EUCLIDEAN_L2;MANHATTAN_L1;MINKOWSKI;MAX;HIST_INTERSECT;HELLINGER;CHI_SQUARE_CS;KULLBACK_LEIBLER_KL;HAMMING", "1:EUCLIDEAN_L2;MANHATTAN_L1;MINKOWSKI;MAX;HIST_INTERSECT;HELLINGER;CHI_SQUARE_CS;KULLBACK_LEIBLER_KL;HAMMING", "Distance type."); PARAMETER(NearestNeighbor, 3nndrRatioUsed, bool, true, "Nearest neighbor distance ratio approach to accept the best match."); diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 2793d6fe..0100af99 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -119,6 +119,24 @@ IF(CATKIN_BUILD) ) ENDIF(CATKIN_BUILD) +IF(TORCH_FOUND) + SET(LIBRARIES + ${LIBRARIES} + ${TORCH_LIBRARIES} + ) + SET(SRC_FILES + ${SRC_FILES} + superpoint_torch/SuperPoint.cc + ) + SET(INCLUDE_DIRS + ${TORCH_INCLUDE_DIRS} + ${CMAKE_CURRENT_SOURCE_DIR}/superpoint_torch + ${INCLUDE_DIRS} + ) + ADD_DEFINITIONS("-DWITH_TORCH") +ENDIF(TORCH_FOUND) + + #include files INCLUDE_DIRECTORIES(${INCLUDE_DIRS}) diff --git a/src/ParametersToolBox.cpp b/src/ParametersToolBox.cpp index 71ac78f5..3f2fb5ad 100644 --- a/src/ParametersToolBox.cpp +++ b/src/ParametersToolBox.cpp @@ -358,6 +358,9 @@ void ParametersToolBox::addParameter(QVBoxLayout * layout, #ifndef HAVE_OPENCV_XFEATURES2D widget->setItemData(6, 0, Qt::UserRole - 1); // disable Star #endif +#endif +#if FINDOBJECT_TORCH == 0 + widget->setItemData(12, 0, Qt::UserRole - 1); // disable SuperPointTorch #endif } if(key.compare(Settings::kFeature2D_2Descriptor()) == 0) @@ -381,6 +384,9 @@ void ParametersToolBox::addParameter(QVBoxLayout * layout, widget->setItemData(9, 0, Qt::UserRole - 1); // disable LATCH widget->setItemData(10, 0, Qt::UserRole - 1); // disable DAISY #endif +#endif +#if FINDOBJECT_TORCH == 0 + widget->setItemData(11, 0, Qt::UserRole - 1); // disable SuperPointTorch #endif } if(key.compare(Settings::kNearestNeighbor_1Strategy()) == 0) @@ -628,7 +634,7 @@ void ParametersToolBox::changeParameter(QObject * sender, int value) { QStringList tmp = Settings::getFeature2D_2Descriptor().split(':'); UASSERT(tmp.size() == 2); - QString newTmp = QString('0'+index)+":"+tmp.back(); + QString newTmp = QString::number(index)+":"+tmp.back(); Settings::setFeature2D_2Descriptor(newTmp); descriptorBox->blockSignals(true); this->updateParameter(Settings::kFeature2D_2Descriptor()); diff --git a/src/Settings.cpp b/src/Settings.cpp index 05244b32..95696d3b 100644 --- a/src/Settings.cpp +++ b/src/Settings.cpp @@ -58,6 +58,10 @@ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include #endif +#if FINDOBJECT_TORCH == 1 +#include "superpoint_torch/SuperPoint.h" +#endif + namespace find_object { ParametersMap Settings::defaultParameters_; @@ -130,10 +134,10 @@ ParametersMap Settings::loadSettings(const QString & fileName) value = QVariant(str); UINFO("Updated list of parameter \"%s\"", key.toStdString().c_str()); } -#if FINDOBJECT_NONFREE == 0 int index = str.split(':').first().toInt(); if(key.compare(Settings::kFeature2D_1Detector()) == 0) { +#if FINDOBJECT_NONFREE == 0 if(index == 5 || index == 7) { index = Settings::defaultFeature2D_1Detector().split(':').first().toInt(); @@ -142,9 +146,21 @@ ParametersMap Settings::loadSettings(const QString & fileName) Settings::kFeature2D_1Detector().toStdString().c_str(), Settings::defaultFeature2D_1Detector().split(':').last().split(";").at(index).toStdString().c_str()); } +#endif +#if FINDOBJECT_TORCH == 0 + if(index == 12) + { + index = Settings::defaultFeature2D_1Detector().split(':').first().toInt(); + UWARN("Trying to set \"%s\" to SuperPointTorch but Find-Object isn't built " + "with the Torch. Keeping default combo value: %s.", + Settings::kFeature2D_1Detector().toStdString().c_str(), + Settings::defaultFeature2D_1Detector().split(':').last().split(";").at(index).toStdString().c_str()); + } +#endif } else if(key.compare(Settings::kFeature2D_2Descriptor()) == 0) { +#if FINDOBJECT_NONFREE == 0 if(index == 2 || index == 3) { index = Settings::defaultFeature2D_2Descriptor().split(':').first().toInt(); @@ -153,23 +169,21 @@ ParametersMap Settings::loadSettings(const QString & fileName) Settings::kFeature2D_2Descriptor().toStdString().c_str(), Settings::defaultFeature2D_2Descriptor().split(':').last().split(";").at(index).toStdString().c_str()); } - } - else if(key.compare(Settings::kNearestNeighbor_1Strategy()) == 0) - { - if(index <= 4) +#endif +#if FINDOBJECT_TORCH == 0 + if(index == 11) { - index = Settings::defaultNearestNeighbor_1Strategy().split(':').first().toInt(); - UWARN("Trying to set \"%s\" to one FLANN approach but Find-Object isn't built " - "with the nonfree module from OpenCV and FLANN cannot be used " - "with binary descriptors. Keeping default combo value: %s.", - Settings::kNearestNeighbor_1Strategy().toStdString().c_str(), - Settings::defaultNearestNeighbor_1Strategy().split(':').last().split(";").at(index).toStdString().c_str()); + index = Settings::defaultFeature2D_2Descriptor().split(':').first().toInt(); + UWARN("Trying to set \"%s\" to SuperPointTorch but Find-Object isn't built " + "with the Torch. Keeping default combo value: %s.", + Settings::kFeature2D_1Detector().toStdString().c_str(), + Settings::defaultFeature2D_1Detector().split(':').last().split(";").at(index).toStdString().c_str()); } +#endif } str = getParameter(key).toString(); str = QString::number(index)+":"+ str.split(':').back(); value = QVariant(str); -#endif } loadedParameters.insert(key, value); setParameter(key, value); @@ -403,7 +417,7 @@ public: : fast_(CVCUDA::FastFeatureDetector::create( threshold, nonmaxSuppression, - CVCUDA::FastFeatureDetector::TYPE_9_16, + cv::FastFeatureDetector::TYPE_9_16, max_npoints)) #endif #endif @@ -614,10 +628,54 @@ private: #endif }; +#if FINDOBJECT_TORCH == 1 +class SuperPointTorch : public Feature2D +{ +public: + SuperPointTorch( + const QString & modelPath, + float threshold = Settings::defaultFeature2D_SuperPointTorch_threshold(), + bool nms = Settings::defaultFeature2D_SuperPointTorch_NMS(), + int nmsRadius = Settings::defaultFeature2D_SuperPointTorch_NMS_radius(), + bool cuda = Settings::defaultFeature2D_SuperPointTorch_cuda()) + { + superPoint_ = cv::Ptr(new SPDetector(modelPath.toStdString(), threshold, nms, nmsRadius, cuda)); + } + + virtual ~SuperPointTorch() {} + + virtual void detect(const cv::Mat & image, + std::vector & keypoints, + const cv::Mat & mask = cv::Mat()) + { + keypoints = superPoint_->detect(image); + } + + virtual void compute( const cv::Mat& image, + std::vector& keypoints, + cv::Mat& descriptors) + { + descriptors = superPoint_->compute(keypoints); + } + + virtual void detectAndCompute( const cv::Mat& image, + std::vector& keypoints, + cv::Mat& descriptors, + const cv::Mat & mask = cv::Mat()) + { + keypoints = superPoint_->detect(image); + descriptors = superPoint_->compute(keypoints); + } +private: + cv::Ptr superPoint_; +}; +#endif + Feature2D * Settings::createKeypointDetector() { Feature2D * feature2D = 0; QString str = getFeature2D_1Detector(); + UDEBUG("Type=%s", str.toStdString().c_str()); QStringList split = str.split(':'); if(split.size()==2) { @@ -643,6 +701,18 @@ Feature2D * Settings::createKeypointDetector() } #endif +#if FINDOBJECT_TORCH == 0 + //check for nonfree stuff + if(strategies.at(index).compare("SuperPointTorch") == 0) + { + index = Settings::defaultFeature2D_1Detector().split(':').first().toInt(); + UERROR("Find-Object is not built with Torch so " + "SuperPointTorch cannot be used! Using default \"%s\" instead.", + strategies.at(index).toStdString().c_str()); + + } +#endif + #if CV_MAJOR_VERSION < 3 if(strategies.at(index).compare("AGAST") == 0 || strategies.at(index).compare("KAZE") == 0 || @@ -947,6 +1017,18 @@ Feature2D * Settings::createKeypointDetector() UDEBUG("type=%s", strategies.at(index).toStdString().c_str()); } } +#endif +#if FINDOBJECT_TORCH == 1 + else if(strategies.at(index).compare("SuperPointTorch") == 0) + { + feature2D = new SuperPointTorch( + getFeature2D_SuperPointTorch_modelPath(), + getFeature2D_SuperPointTorch_threshold(), + getFeature2D_SuperPointTorch_NMS(), + getFeature2D_SuperPointTorch_NMS_radius(), + getFeature2D_SuperPointTorch_cuda()); + UDEBUG("type=%s", strategies.at(index).toStdString().c_str()); + } #endif } } @@ -959,6 +1041,7 @@ Feature2D * Settings::createDescriptorExtractor() { Feature2D * feature2D = 0; QString str = getFeature2D_2Descriptor(); + UDEBUG("Type=%s", str.toStdString().c_str()); QStringList split = str.split(':'); if(split.size()==2) { @@ -983,6 +1066,18 @@ Feature2D * Settings::createDescriptorExtractor() } #endif +#if FINDOBJECT_TORCH == 0 + //check for nonfree stuff + if(strategies.at(index).compare("SuperPointTorch") == 0) + { + index = Settings::defaultFeature2D_2Descriptor().split(':').first().toInt(); + UERROR("Find-Object is not built with Torch so " + "SuperPointTorch cannot be used! Using default \"%s\" instead.", + strategies.at(index).toStdString().c_str()); + + } +#endif + #if CV_MAJOR_VERSION < 3 if(strategies.at(index).compare("KAZE") == 0 || strategies.at(index).compare("AKAZE") == 0) @@ -1227,6 +1322,18 @@ Feature2D * Settings::createDescriptorExtractor() UDEBUG("type=%s", strategies.at(index).toStdString().c_str()); } } +#endif +#if FINDOBJECT_TORCH == 1 + else if(strategies.at(index).compare("SuperPointTorch") == 0) + { + feature2D = new SuperPointTorch( + getFeature2D_SuperPointTorch_modelPath(), + getFeature2D_SuperPointTorch_threshold(), + getFeature2D_SuperPointTorch_NMS(), + getFeature2D_SuperPointTorch_NMS_radius(), + getFeature2D_SuperPointTorch_cuda()); + UDEBUG("type=%s", strategies.at(index).toStdString().c_str()); + } #endif } } diff --git a/src/Vocabulary.cpp b/src/Vocabulary.cpp index aeac1572..803b5551 100644 --- a/src/Vocabulary.cpp +++ b/src/Vocabulary.cpp @@ -450,6 +450,7 @@ void Vocabulary::search(const cv::Mat & descriptorsIn, cv::Mat & results, cv::Ma if(Settings::isBruteForceNearestNeighbor()) { std::vector > matches; + bool isL2NotSqr = false; if(Settings::getNearestNeighbor_BruteForce_gpu() && CVCUDA::getCudaEnabledDeviceCount()) { CVCUDA::GpuMat newDescriptorsGpu(descriptors); @@ -464,6 +465,7 @@ void Vocabulary::search(const cv::Mat & descriptorsIn, cv::Mat & results, cv::Ma { CVCUDA::BruteForceMatcher_GPU > gpuMatcher; gpuMatcher.knnMatch(newDescriptorsGpu, lastDescriptorsGpu, matches, k); + isL2NotSqr = true; } #else #ifdef HAVE_OPENCV_CUDAFEATURES2D @@ -477,6 +479,7 @@ void Vocabulary::search(const cv::Mat & descriptorsIn, cv::Mat & results, cv::Ma { gpuMatcher = cv::cuda::DescriptorMatcher::createBFMatcher(cv::NORM_L2); gpuMatcher->knnMatch(newDescriptorsGpu, lastDescriptorsGpu, matches, k); + isL2NotSqr = true; } #else UERROR("OpenCV3 is not built with CUDAFEATURES2D module, cannot do brute force matching on GPU!"); @@ -485,7 +488,7 @@ void Vocabulary::search(const cv::Mat & descriptorsIn, cv::Mat & results, cv::Ma } else { - cv::BFMatcher matcher(indexedDescriptors_.type()==CV_8U?cv::NORM_HAMMING:cv::NORM_L2); + cv::BFMatcher matcher(indexedDescriptors_.type()==CV_8U?cv::NORM_HAMMING:cv::NORM_L2SQR); matcher.knnMatch(descriptors, indexedDescriptors_, matches, k); } @@ -497,7 +500,16 @@ void Vocabulary::search(const cv::Mat & descriptorsIn, cv::Mat & results, cv::Ma for(int j=0; j(i, j) = matches[i].at(j).trainIdx; - dists.at(i, j) = matches[i].at(j).distance; + + if(isL2NotSqr) + { + // Make sure we use L2SQR to match FLANN + dists.at(i, j) = matches[i].at(j).distance*matches[i].at(j).distance; + } + else + { + dists.at(i, j) = matches[i].at(j).distance; + } } } } diff --git a/src/superpoint_torch/SuperPoint.cc b/src/superpoint_torch/SuperPoint.cc new file mode 100644 index 00000000..107bd5bb --- /dev/null +++ b/src/superpoint_torch/SuperPoint.cc @@ -0,0 +1,355 @@ +/** + * Original code from https://github.com/KinglittleQ/SuperPoint_SLAM + */ + +#include +#include + + +namespace find_object +{ + +const int c1 = 64; +const int c2 = 64; +const int c3 = 128; +const int c4 = 128; +const int c5 = 256; +const int d1 = 256; + + + +SuperPoint::SuperPoint() + : conv1a(torch::nn::Conv2dOptions( 1, c1, 3).stride(1).padding(1)), + conv1b(torch::nn::Conv2dOptions(c1, c1, 3).stride(1).padding(1)), + + conv2a(torch::nn::Conv2dOptions(c1, c2, 3).stride(1).padding(1)), + conv2b(torch::nn::Conv2dOptions(c2, c2, 3).stride(1).padding(1)), + + conv3a(torch::nn::Conv2dOptions(c2, c3, 3).stride(1).padding(1)), + conv3b(torch::nn::Conv2dOptions(c3, c3, 3).stride(1).padding(1)), + + conv4a(torch::nn::Conv2dOptions(c3, c4, 3).stride(1).padding(1)), + conv4b(torch::nn::Conv2dOptions(c4, c4, 3).stride(1).padding(1)), + + convPa(torch::nn::Conv2dOptions(c4, c5, 3).stride(1).padding(1)), + convPb(torch::nn::Conv2dOptions(c5, 65, 1).stride(1).padding(0)), + + convDa(torch::nn::Conv2dOptions(c4, c5, 3).stride(1).padding(1)), + convDb(torch::nn::Conv2dOptions(c5, d1, 1).stride(1).padding(0)) + + { + register_module("conv1a", conv1a); + register_module("conv1b", conv1b); + + register_module("conv2a", conv2a); + register_module("conv2b", conv2b); + + register_module("conv3a", conv3a); + register_module("conv3b", conv3b); + + register_module("conv4a", conv4a); + register_module("conv4b", conv4b); + + register_module("convPa", convPa); + register_module("convPb", convPb); + + register_module("convDa", convDa); + register_module("convDb", convDb); + } + + +std::vector SuperPoint::forward(torch::Tensor x) { + + x = torch::relu(conv1a->forward(x)); + x = torch::relu(conv1b->forward(x)); + x = torch::max_pool2d(x, 2, 2); + + x = torch::relu(conv2a->forward(x)); + x = torch::relu(conv2b->forward(x)); + x = torch::max_pool2d(x, 2, 2); + + x = torch::relu(conv3a->forward(x)); + x = torch::relu(conv3b->forward(x)); + x = torch::max_pool2d(x, 2, 2); + + x = torch::relu(conv4a->forward(x)); + x = torch::relu(conv4b->forward(x)); + + auto cPa = torch::relu(convPa->forward(x)); + auto semi = convPb->forward(cPa); // [B, 65, H/8, W/8] + + auto cDa = torch::relu(convDa->forward(x)); + auto desc = convDb->forward(cDa); // [B, d1, H/8, W/8] + + auto dn = torch::norm(desc, 2, 1); + desc = desc.div(torch::unsqueeze(dn, 1)); + + semi = torch::softmax(semi, 1); + semi = semi.slice(1, 0, 64); + semi = semi.permute({0, 2, 3, 1}); // [B, H/8, W/8, 64] + + + int Hc = semi.size(1); + int Wc = semi.size(2); + semi = semi.contiguous().view({-1, Hc, Wc, 8, 8}); + semi = semi.permute({0, 1, 3, 2, 4}); + semi = semi.contiguous().view({-1, Hc * 8, Wc * 8}); // [B, H, W] + + + std::vector ret; + ret.push_back(semi); + ret.push_back(desc); + + return ret; + } + +void NMS(const std::vector & ptsIn, + const cv::Mat & conf, + const cv::Mat & descriptorsIn, + std::vector & ptsOut, + cv::Mat & descriptorsOut, + int border, int dist_thresh, int img_width, int img_height); + +SPDetector::SPDetector(const std::string & modelPath, float threshold, bool nms, int minDistance, bool cuda) : + threshold_(threshold), + nms_(nms), + minDistance_(minDistance), + detected_(false) +{ + UDEBUG("modelPath=%s thr=%f nms=%d cuda=%d", modelPath.c_str(), threshold, nms?1:0, cuda?1:0); + if(modelPath.empty()) + { + return; + } + model_ = std::make_shared(); + torch::load(model_, modelPath); + + if(cuda && !torch::cuda::is_available()) + { + UWARN("Cuda option is enabled but torch doesn't have cuda support on this platform, using CPU instead."); + } + cuda_ = cuda && torch::cuda::is_available(); + torch::Device device(cuda_?torch::kCUDA:torch::kCPU); + model_->to(device); +} + +SPDetector::~SPDetector() +{ +} + +std::vector SPDetector::detect(const cv::Mat &img) +{ + detected_ = false; + if(model_) + { + torch::NoGradGuard no_grad_guard; + auto x = torch::from_blob(img.data, {1, 1, img.rows, img.cols}, torch::kByte); + x = x.to(torch::kFloat) / 255; + + torch::Device device(cuda_?torch::kCUDA:torch::kCPU); + x = x.set_requires_grad(false); + auto out = model_->forward(x.to(device)); + + prob_ = out[0].squeeze(0); // [H, W] + desc_ = out[1]; // [1, 256, H/8, W/8] + + auto kpts = (prob_ > threshold_); + kpts = torch::nonzero(kpts); // [n_keypoints, 2] (y, x) + + std::vector keypoints_no_nms; + for (int i = 0; i < kpts.size(0); i++) { + float response = prob_[kpts[i][0]][kpts[i][1]].item(); + keypoints_no_nms.push_back(cv::KeyPoint(kpts[i][1].item(), kpts[i][0].item(), 8, -1, response)); + } + + detected_ = true; + if (nms_ && !keypoints_no_nms.empty()) { + cv::Mat conf(keypoints_no_nms.size(), 1, CV_32F); + for (size_t i = 0; i < keypoints_no_nms.size(); i++) { + int x = keypoints_no_nms[i].pt.x; + int y = keypoints_no_nms[i].pt.y; + conf.at(i, 0) = prob_[y][x].item(); + } + + int border = 0; + int dist_thresh = minDistance_; + int height = img.rows; + int width = img.cols; + + std::vector keypoints; + cv::Mat descEmpty; + NMS(keypoints_no_nms, conf, descEmpty, keypoints, descEmpty, border, dist_thresh, width, height); + return keypoints; + } + else { + return keypoints_no_nms; + } + } + else + { + UERROR("No model is loaded!"); + return std::vector(); + } +} + +cv::Mat SPDetector::compute(const std::vector &keypoints) +{ + if(!detected_) + { + UERROR("SPDetector has been reset before extracting the descriptors! detect() should be called before compute()."); + return cv::Mat(); + } + if(model_.get()) + { + cv::Mat kpt_mat(keypoints.size(), 2, CV_32F); // [n_keypoints, 2] (y, x) + + for (size_t i = 0; i < keypoints.size(); i++) { + kpt_mat.at(i, 0) = (float)keypoints[i].pt.y; + kpt_mat.at(i, 1) = (float)keypoints[i].pt.x; + } + + auto fkpts = torch::from_blob(kpt_mat.data, {(long int)keypoints.size(), 2}, torch::kFloat); + + torch::Device device(cuda_?torch::kCUDA:torch::kCPU); + auto grid = torch::zeros({1, 1, fkpts.size(0), 2}).to(device); // [1, 1, n_keypoints, 2] + grid[0][0].slice(1, 0, 1) = 2.0 * fkpts.slice(1, 1, 2) / prob_.size(1) - 1; // x + grid[0][0].slice(1, 1, 2) = 2.0 * fkpts.slice(1, 0, 1) / prob_.size(0) - 1; // y + + auto desc = torch::grid_sampler(desc_, grid, 0, 0, true); // [1, 256, 1, n_keypoints] + desc = desc.squeeze(0).squeeze(1); // [256, n_keypoints] + + // normalize to 1 + auto dn = torch::norm(desc, 2, 1); + desc = desc.div(torch::unsqueeze(dn, 1)); + + desc = desc.transpose(0, 1).contiguous(); // [n_keypoints, 256] + if(cuda_) + desc = desc.to(torch::kCPU); + + cv::Mat desc_mat(cv::Size(desc.size(1), desc.size(0)), CV_32FC1, desc.data()); + + return desc_mat.clone(); + } + else + { + UERROR("No model is loaded!"); + return cv::Mat(); + } +} + +void NMS(const std::vector & ptsIn, + const cv::Mat & conf, + const cv::Mat & descriptorsIn, + std::vector & ptsOut, + cv::Mat & descriptorsOut, + int border, int dist_thresh, int img_width, int img_height) +{ + + std::vector pts_raw; + + for (size_t i = 0; i < ptsIn.size(); i++) + { + int u = (int) ptsIn[i].pt.x; + int v = (int) ptsIn[i].pt.y; + + pts_raw.push_back(cv::Point2f(u, v)); + } + + //Grid Value Legend: + // 255 : Kept. + // 0 : Empty or suppressed. + // 100 : To be processed (converted to either kept or suppressed). + cv::Mat grid = cv::Mat(cv::Size(img_width, img_height), CV_8UC1); + cv::Mat inds = cv::Mat(cv::Size(img_width, img_height), CV_16UC1); + + cv::Mat confidence = cv::Mat(cv::Size(img_width, img_height), CV_32FC1); + + grid.setTo(0); + inds.setTo(0); + confidence.setTo(0); + + for (size_t i = 0; i < pts_raw.size(); i++) + { + int uu = (int) pts_raw[i].x; + int vv = (int) pts_raw[i].y; + + grid.at(vv, uu) = 100; + inds.at(vv, uu) = i; + + confidence.at(vv, uu) = conf.at(i, 0); + } + + // debug + //cv::Mat confidenceVis = confidence.clone() * 255; + //confidenceVis.convertTo(confidenceVis, CV_8UC1); + //cv::imwrite("confidence.bmp", confidenceVis); + //cv::imwrite("grid_in.bmp", grid); + + cv::copyMakeBorder(grid, grid, dist_thresh, dist_thresh, dist_thresh, dist_thresh, cv::BORDER_CONSTANT, 0); + + for (size_t i = 0; i < pts_raw.size(); i++) + { + // account for top left padding + int uu = (int) pts_raw[i].x + dist_thresh; + int vv = (int) pts_raw[i].y + dist_thresh; + float c = confidence.at(vv-dist_thresh, uu-dist_thresh); + + if (grid.at(vv, uu) == 100) // If not yet suppressed. + { + for(int k = -dist_thresh; k < (dist_thresh+1); k++) + { + for(int j = -dist_thresh; j < (dist_thresh+1); j++) + { + if(j==0 && k==0) + continue; + + if ( confidence.at(vv + k - dist_thresh, uu + j - dist_thresh) <= c ) + { + grid.at(vv + k, uu + j) = 0; + } + } + } + grid.at(vv, uu) = 255; + } + } + + size_t valid_cnt = 0; + std::vector select_indice; + + grid = cv::Mat(grid, cv::Rect(dist_thresh, dist_thresh, img_width, img_height)); + + //debug + //cv::imwrite("grid_nms.bmp", grid); + + for (int v = 0; v < img_height; v++) + { + for (int u = 0; u < img_width; u++) + { + if (grid.at(v,u) == 255) + { + int select_ind = (int) inds.at(v, u); + float response = conf.at(select_ind, 0); + ptsOut.push_back(cv::KeyPoint(pts_raw[select_ind], 8.0f, -1, response)); + + select_indice.push_back(select_ind); + valid_cnt++; + } + } + } + + if(!descriptorsIn.empty()) + { + UASSERT(descriptorsIn.rows == (int)ptsIn.size()); + descriptorsOut.create(select_indice.size(), 256, CV_32F); + + for (size_t i=0; i(i, j) = descriptorsIn.at(select_indice[i], j); + } + } + } +} + +} diff --git a/src/superpoint_torch/SuperPoint.h b/src/superpoint_torch/SuperPoint.h new file mode 100644 index 00000000..fa66d134 --- /dev/null +++ b/src/superpoint_torch/SuperPoint.h @@ -0,0 +1,75 @@ +/** + * Original code from https://github.com/KinglittleQ/SuperPoint_SLAM + */ + +#ifndef SUPERPOINT_H +#define SUPERPOINT_H + + +#include +#include + +#include + +#ifdef EIGEN_MPL2_ONLY +#undef EIGEN_MPL2_ONLY +#endif + + +namespace find_object +{ + +struct SuperPoint : torch::nn::Module { + SuperPoint(); + + std::vector forward(torch::Tensor x); + + + torch::nn::Conv2d conv1a; + torch::nn::Conv2d conv1b; + + torch::nn::Conv2d conv2a; + torch::nn::Conv2d conv2b; + + torch::nn::Conv2d conv3a; + torch::nn::Conv2d conv3b; + + torch::nn::Conv2d conv4a; + torch::nn::Conv2d conv4b; + + torch::nn::Conv2d convPa; + torch::nn::Conv2d convPb; + + // descriptor + torch::nn::Conv2d convDa; + torch::nn::Conv2d convDb; + +}; + +class SPDetector { +public: + SPDetector(const std::string & modelPath, float threshold = 0.2f, bool nms = true, int minDistance = 4, bool cuda = false); + virtual ~SPDetector(); + std::vector detect(const cv::Mat &img); + cv::Mat compute(const std::vector &keypoints); + + void setThreshold(float threshold) {threshold_ = threshold;} + void SetNMS(bool enabled) {nms_ = enabled;} + void setMinDistance(float minDistance) {minDistance_ = minDistance;} + +private: + std::shared_ptr model_; + torch::Tensor prob_; + torch::Tensor desc_; + + float threshold_; + bool nms_; + int minDistance_; + bool cuda_; + + bool detected_; +}; + +} + +#endif