#include "user_data_manager.h" #include #include #include #include #include #include #include #include "json11.hpp" #include "hid_handler.h" UserDataManager userManager; void UserDataManager::ResetData() { userId_.clear(); testingToken_.clear(); userTrainedModels_.clear(); hid_handler.log << "User data has been reset" << std::endl; } std::string UserDataManager::SanitizeString(const std::string& input) { if (input.empty()) { return {}; } std::string result; result.reserve(input.size()); for (char ch : input) { if (ch != '"' && ch != '\'') { result.push_back(ch); } } // trim whitespace auto isSpace = [](unsigned char c) { return std::isspace(c) != 0; }; auto begin = std::find_if_not(result.begin(), result.end(), isSpace); auto end = std::find_if_not(result.rbegin(), result.rend(), isSpace).base(); if (begin >= end) { return {}; } return std::string(begin, end); } std::string UserDataManager::TrimTrailingSlash(const std::string& url) { if (url.empty()) { return url; } size_t end = url.size(); while (end > 0 && (url[end - 1] == '/' || url[end - 1] == '\\')) { --end; } return url.substr(0, end); } std::string UserDataManager::GetExecutableDirectory() { wchar_t buffer[MAX_PATH]{}; DWORD length = GetModuleFileNameW(nullptr, buffer, MAX_PATH); if (length == 0 || length == MAX_PATH) { return "."; } std::wstring fullPath(buffer, length); size_t pos = fullPath.find_last_of(L"\\/"); std::wstring dir = (pos == std::wstring::npos) ? fullPath : fullPath.substr(0, pos); int requiredSize = WideCharToMultiByte(CP_UTF8, 0, dir.c_str(), static_cast(dir.size()), nullptr, 0, nullptr, nullptr); std::string utf8(requiredSize, '\0'); WideCharToMultiByte(CP_UTF8, 0, dir.c_str(), static_cast(dir.size()), utf8.data(), requiredSize, nullptr, nullptr); return utf8; } void UserDataManager::SetTestingToken(const std::string& value) { if (testingToken_ == value) { return; } ResetData(); std::string sanitized = SanitizeString(value); testingToken_ = std::move(sanitized); hid_handler.log << "Testing token updated." << std::endl; if (!testingToken_.empty()) { FetchUserIdFromApi(); FetchUserTrainedModels(); } } void UserDataManager::FetchUserIdFromApi() { std::thread([this]() { if (testingToken_.empty()) { hid_handler.log << "Cannot fetch user ID: Testing token is empty" << std::endl; return; } const std::string baseUrl = TrimTrailingSlash(serverApiAddress_); const std::string endpoint = baseUrl + "/get_user_id"; std::string token = testingToken_; // Trim CR/LF like the C# version token.erase(std::remove(token.begin(), token.end(), '\r'), token.end()); token.erase(std::remove(token.begin(), token.end(), '\n'), token.end()); cpr::Response r = cpr::Get( cpr::Url{ endpoint }, cpr::Timeout{ 30000 }, cpr::Header{ {"Authorization", "Bearer " + token} }); if (r.error) { hid_handler.log <<"HTTP error in FetchUserIdFromApi: " << std::wstring(r.error.message.begin(), r.error.message.end()) << std::endl; if (OnAuthenticationResponse) { OnAuthenticationResponse(0); } return; } if (r.status_code >= 200 && r.status_code < 300) { std::string content = r.text; // Trim whitespace auto isSpace = [](unsigned char c) { return std::isspace(c) != 0; }; auto begin = std::find_if_not(content.begin(), content.end(), isSpace); auto end = std::find_if_not(content.rbegin(), content.rend(), isSpace).base(); if (begin < end) { content.assign(begin, end); } else { content.clear(); } std::string userId = content; // If the response appears to be JSON, parse with json11 and try to get "user_id" if (!content.empty() && (content.front() == '{' || content.front() == '[')) { std::string err; auto json = json11::Json::parse(content, err); if (err.empty()) { if (json.is_object()) { auto obj = json.object_items(); auto it = obj.find("user_id"); if (it != obj.end() && it->second.is_string()) { userId = it->second.string_value(); } } else if (json.is_array()) { // If it's an array, look for the first object with a "user_id" field. for (const auto& el : json.array_items()) { if (el.is_object()) { auto obj = el.object_items(); auto it = obj.find("user_id"); if (it != obj.end() && it->second.is_string()) { userId = it->second.string_value(); break; } } } } } } std::string sanitized = SanitizeString(userId); if (userId_ != sanitized) { userId_ = std::move(sanitized); hid_handler.log << "User ID updated: " << std::wstring(userId_.begin(), userId_.end()) << std::endl; if (OnUserIdChanged) { OnUserIdChanged(userId_); } } if (OnAuthenticationResponse) { OnAuthenticationResponse(2); } return; } hid_handler.log <<"API request failed in FetchUserIdFromApi. Status code: " << r.status_code << std::endl; hid_handler.log <<"Error response: " << std::wstring(r.text.begin(), r.text.end()) << std::endl; if (OnAuthenticationResponse) { if (r.text.find("expired") != std::string::npos) { OnAuthenticationResponse(1); } else { OnAuthenticationResponse(0); } } }).detach(); } void UserDataManager::FetchUserTrainedModels() { if (testingToken_.empty()) { return; } std::thread([this]() { const std::string baseUrl = TrimTrailingSlash(serverApiAddress_); const std::string endpoint = baseUrl + "/get_user_files"; std::string token = testingToken_; token.erase(std::remove(token.begin(), token.end(), '\r'), token.end()); token.erase(std::remove(token.begin(), token.end(), '\n'), token.end()); cpr::Response r = cpr::Get( cpr::Url{ endpoint }, cpr::Timeout{ 30000 }, cpr::Header{ {"Authorization", "Bearer " + token} }); receivedModelsResponse_ = true; if (r.error) { hid_handler.log <<"HTTP error in FetchUserTrainedModels: " << std::wstring(r.error.message.begin(), r.error.message.end()) << std::endl; receivedModelsResponse_ = false; if (OnAuthenticationResponse) { OnAuthenticationResponse(0); } return; } if (r.status_code >= 200 && r.status_code < 300) { std::vector models; const std::string& content = r.text; auto endsWithModel = [](const std::string& s) { if (s.size() < 6) return false; std::string tail = s.substr(s.size() - 6); std::transform(tail.begin(), tail.end(), tail.begin(), [](unsigned char c) { return static_cast(std::tolower(c)); }); return tail == ".model"; }; std::string err; auto json = json11::Json::parse(content, err); if (err.empty()) { if (json.is_array()) { for (const auto& el : json.array_items()) { if (el.is_string()) { const std::string val = el.string_value(); if (endsWithModel(val)) { models.push_back(val); } } else if (el.is_object()) { auto obj = el.object_items(); auto it = obj.find("name"); if (it != obj.end() && it->second.is_string()) { const std::string val = it->second.string_value(); if (endsWithModel(val)) { models.push_back(val); } } } } } else if (json.is_object()) { // If the response is an object, look for any array property. for (const auto& kv : json.object_items()) { if (kv.second.is_array()) { for (const auto& el : kv.second.array_items()) { if (el.is_string()) { const std::string val = el.string_value(); if (endsWithModel(val)) { models.push_back(val); } } else if (el.is_object()) { auto obj = el.object_items(); auto it = obj.find("name"); if (it != obj.end() && it->second.is_string()) { const std::string val = it->second.string_value(); if (endsWithModel(val)) { models.push_back(val); } } } } } } } } userTrainedModels_ = std::move(models); if (OnUserTrainedModelsChanged) { OnUserTrainedModelsChanged(userTrainedModels_); } if (OnAuthenticationResponse) { OnAuthenticationResponse(2); } return; } hid_handler.log <<"API request failed in FetchUserTrainedModels. Status code: " << r.status_code << std::endl; hid_handler.log <<"Error response: " << std::wstring(r.text.begin(), r.text.end()) << std::endl; if (OnAuthenticationResponse) { if (r.text.find("expired") != std::string::npos) { OnAuthenticationResponse(1); } else { OnAuthenticationResponse(0); } } }).detach(); } bool UserDataManager::DownloadModel(const std::string& modelKey, const std::string& outputPath) { if (testingToken_.empty()) { hid_handler.log << "Cannot download model: Testing token is empty" << std::endl; return false; } const std::string baseUrl = TrimTrailingSlash(serverApiAddress_); const std::string endpoint = baseUrl + "/get_model"; isLoadingModel_ = true; modelLoadingError_.clear(); std::string token = testingToken_; token.erase(std::remove(token.begin(), token.end(), '\r'), token.end()); token.erase(std::remove(token.begin(), token.end(), '\n'), token.end()); cpr::Response r = cpr::Get( cpr::Url{ endpoint }, cpr::Timeout{ 180000 }, cpr::Header{ {"Authorization", "Bearer " + token} }, cpr::Parameters{ {"model_s3_uri", modelKey} }); if (r.error) { hid_handler.log <<"HTTP error in DownloadModel: " << std::wstring(r.error.message.begin(), r.error.message.end()) << std::endl; isLoadingModel_ = false; modelLoadingError_ = "Failed to download model: network error."; return false; } if (r.status_code >= 200 && r.status_code < 300) { try { std::filesystem::path outPath(outputPath); std::filesystem::create_directories(outPath.parent_path()); std::ofstream out(outputPath, std::ios::binary); out.write(r.text.data(), static_cast(r.text.size())); out.close(); hid_handler.log << "Model downloaded and saved to: " << std::wstring(outputPath.begin(), outputPath.end()) << std::endl; isLoadingModel_ = false; return true; } catch (...) { isLoadingModel_ = false; modelLoadingError_ = "Failed to download model: could not write file."; return false; } } isLoadingModel_ = false; modelLoadingError_ = "Failed to download model: HTTP " + std::to_string(r.status_code); hid_handler.log <<"Failed to download model. Status code: " << r.status_code << std::endl; hid_handler.log <<"Error response: " << std::wstring(r.text.begin(), r.text.end()) << std::endl; return false; }