This commit is contained in:
2025-12-01 19:36:26 +05:30
parent 4a6630384b
commit bd9d75e2a2
4 changed files with 67 additions and 18 deletions

View File

@ -34,7 +34,7 @@ namespace IACore
if (!workerCount) if (!workerCount)
workerCount = std::max((UINT32) 2, std::thread::hardware_concurrency() - 2); workerCount = std::max((UINT32) 2, std::thread::hardware_concurrency() - 2);
for (UINT32 i = 0; i < workerCount; i++) for (UINT32 i = 0; i < workerCount; i++)
s_scheduleWorkers.emplace_back(AsyncOps::ScheduleWorkerLoop); s_scheduleWorkers.emplace_back(AsyncOps::ScheduleWorkerLoop, i + 1);
} }
VOID AsyncOps::TerminateScheduler() VOID AsyncOps::TerminateScheduler()
@ -57,7 +57,8 @@ namespace IACore
s_scheduleWorkers.clear(); s_scheduleWorkers.clear();
} }
VOID AsyncOps::ScheduleTask(IN Function<VOID()> task, IN Schedule *schedule, IN Priority priority) VOID AsyncOps::ScheduleTask(IN Function<VOID(IN WorkerID workerID)> task, IN TaskTag tag, IN Schedule *schedule,
IN Priority priority)
{ {
IA_ASSERT(s_scheduleWorkers.size() && "Scheduler must be initialized before calling this function"); IA_ASSERT(s_scheduleWorkers.size() && "Scheduler must be initialized before calling this function");
@ -65,13 +66,36 @@ namespace IACore
{ {
ScopedLock lock(s_queueMutex); ScopedLock lock(s_queueMutex);
if (priority == Priority::High) if (priority == Priority::High)
s_highPriorityQueue.emplace_back(ScheduledTask{IA_MOVE(task), schedule}); s_highPriorityQueue.emplace_back(ScheduledTask{tag, schedule, IA_MOVE(task)});
else else
s_normalPriorityQueue.emplace_back(ScheduledTask{IA_MOVE(task), schedule}); s_normalPriorityQueue.emplace_back(ScheduledTask{tag, schedule, IA_MOVE(task)});
} }
s_wakeCondition.notify_one(); s_wakeCondition.notify_one();
} }
VOID AsyncOps::CancelTasksOfTag(IN TaskTag tag)
{
ScopedLock lock(s_queueMutex);
auto cancelFromQueue = [&](Deque<ScheduledTask> &queue) {
for (auto it = queue.begin(); it != queue.end(); /* no increment here */)
{
if (it->Tag == tag)
{
if (it->ScheduleHandle->Counter.fetch_sub(1) == 1)
it->ScheduleHandle->Counter.notify_all();
it = queue.erase(it);
}
else
++it;
}
};
cancelFromQueue(s_highPriorityQueue);
cancelFromQueue(s_normalPriorityQueue);
}
VOID AsyncOps::WaitForScheduleCompletion(IN Schedule *schedule) VOID AsyncOps::WaitForScheduleCompletion(IN Schedule *schedule)
{ {
IA_ASSERT(s_scheduleWorkers.size() && "Scheduler must be initialized before calling this function"); IA_ASSERT(s_scheduleWorkers.size() && "Scheduler must be initialized before calling this function");
@ -97,7 +121,7 @@ namespace IACore
} }
if (foundTask) if (foundTask)
{ {
task.Task(); task.Task(MainThreadWorkerID);
if (task.ScheduleHandle->Counter.fetch_sub(1) == 1) if (task.ScheduleHandle->Counter.fetch_sub(1) == 1)
task.ScheduleHandle->Counter.notify_all(); task.ScheduleHandle->Counter.notify_all();
} }
@ -110,7 +134,12 @@ namespace IACore
} }
} }
VOID AsyncOps::ScheduleWorkerLoop(IN StopToken stopToken) AsyncOps::WorkerID AsyncOps::GetWorkerCount()
{
return static_cast<WorkerID>(s_scheduleWorkers.size() + 1); // +1 for MainThread (Work Stealing)
}
VOID AsyncOps::ScheduleWorkerLoop(IN StopToken stopToken, IN WorkerID workerID)
{ {
while (!stopToken.stop_requested()) while (!stopToken.stop_requested())
{ {
@ -141,7 +170,7 @@ namespace IACore
} }
if (foundTask) if (foundTask)
{ {
task.Task(); task.Task(workerID);
if (task.ScheduleHandle->Counter.fetch_sub(1) == 1) if (task.ScheduleHandle->Counter.fetch_sub(1) == 1)
task.ScheduleHandle->Counter.notify_all(); task.ScheduleHandle->Counter.notify_all();
} }

View File

@ -27,7 +27,7 @@ namespace IACore
for (const auto &h : headers) for (const auto &h : headers)
{ {
std::string key = HttpClient::HeaderTypeToString(h.first); // Your existing helper std::string key = HttpClient::HeaderTypeToString(h.first);
out.emplace(key, h.second); out.emplace(key, h.second);
if (h.first == HttpClient::EHeaderType::CONTENT_TYPE) if (h.first == HttpClient::EHeaderType::CONTENT_TYPE)
@ -53,8 +53,8 @@ namespace IACore
{ {
auto httpHeaders = BuildHeaders(headers, defaultContentType); auto httpHeaders = BuildHeaders(headers, defaultContentType);
static_cast<httplib::Client*>(m_client)->enable_server_certificate_verification(false); static_cast<httplib::Client *>(m_client)->enable_server_certificate_verification(false);
auto res = static_cast<httplib::Client*>(m_client)->Get(path.c_str(), httpHeaders); auto res = static_cast<httplib::Client *>(m_client)->Get(path.c_str(), httpHeaders);
if (res) if (res)
{ {
@ -62,7 +62,7 @@ namespace IACore
if (res->status >= 200 && res->status < 300) if (res->status >= 200 && res->status < 300)
return res->body; return res->body;
else else
return MakeUnexpected(std::format("HTTP Error {}", res->status)); return MakeUnexpected(std::format("HTTP Error {} : {}", res->status, res->body));
} }
return MakeUnexpected(std::format("Network Error: {}", httplib::to_string(res.error()))); return MakeUnexpected(std::format("Network Error: {}", httplib::to_string(res.error())));
@ -76,11 +76,13 @@ namespace IACore
String contentType = defaultContentType; String contentType = defaultContentType;
if (httpHeaders.count("Content-Type")) if (httpHeaders.count("Content-Type"))
{ {
contentType = httpHeaders.find("Content-Type")->second; const auto t = httpHeaders.find("Content-Type");
contentType = t->second;
httpHeaders.erase(t);
} }
static_cast<httplib::Client*>(m_client)->enable_server_certificate_verification(false); static_cast<httplib::Client *>(m_client)->enable_server_certificate_verification(false);
auto res = static_cast<httplib::Client*>(m_client)->Post(path.c_str(), httpHeaders, body, contentType.c_str()); auto res = static_cast<httplib::Client *>(m_client)->Post(path.c_str(), httpHeaders, body, contentType.c_str());
if (res) if (res)
{ {
@ -88,7 +90,7 @@ namespace IACore
if (res->status >= 200 && res->status < 300) if (res->status >= 200 && res->status < 300)
return res->body; return res->body;
else else
return MakeUnexpected(std::format("HTTP Error {}", res->status)); return MakeUnexpected(std::format("HTTP Error {} : {}", res->status, res->body));
} }
return MakeUnexpected(std::format("Network Error: {}", httplib::to_string(res.error()))); return MakeUnexpected(std::format("Network Error: {}", httplib::to_string(res.error())));
@ -202,4 +204,9 @@ namespace IACore
return ""; return "";
} }
} }
BOOL HttpClient::IsSuccessResponseCode(IN EResponseCode code)
{
return (INT32) code >= 200 && (INT32) code < 300;
}
} // namespace IACore } // namespace IACore

View File

@ -23,6 +23,11 @@ namespace IACore
class AsyncOps class AsyncOps
{ {
public: public:
using TaskTag = UINT64;
using WorkerID = UINT16;
STATIC CONSTEXPR WorkerID MainThreadWorkerID = 0;
enum class Priority : UINT8 enum class Priority : UINT8
{ {
High, High,
@ -38,20 +43,26 @@ namespace IACore
STATIC VOID InitializeScheduler(IN UINT8 workerCount = 0); STATIC VOID InitializeScheduler(IN UINT8 workerCount = 0);
STATIC VOID TerminateScheduler(); STATIC VOID TerminateScheduler();
STATIC VOID ScheduleTask(IN Function<VOID()> task, IN Schedule *schedule, STATIC VOID ScheduleTask(IN Function<VOID(IN WorkerID workerID)> task, IN TaskTag tag, IN Schedule *schedule,
IN Priority priority = Priority::Normal); IN Priority priority = Priority::Normal);
STATIC VOID CancelTasksOfTag(IN TaskTag tag);
STATIC VOID WaitForScheduleCompletion(IN Schedule *schedule); STATIC VOID WaitForScheduleCompletion(IN Schedule *schedule);
STATIC VOID RunTask(IN Function<VOID()> task); STATIC VOID RunTask(IN Function<VOID()> task);
STATIC WorkerID GetWorkerCount();
private: private:
struct ScheduledTask struct ScheduledTask
{ {
Function<VOID()> Task{}; TaskTag Tag{};
Schedule *ScheduleHandle{}; Schedule *ScheduleHandle{};
Function<VOID(IN WorkerID workerID)> Task{};
}; };
STATIC VOID ScheduleWorkerLoop(IN StopToken stopToken); STATIC VOID ScheduleWorkerLoop(IN StopToken stopToken, IN WorkerID workerID);
private: private:
STATIC Mutex s_queueMutex; STATIC Mutex s_queueMutex;

View File

@ -158,6 +158,8 @@ namespace IACore
STATIC String HeaderTypeToString(IN EHeaderType type); STATIC String HeaderTypeToString(IN EHeaderType type);
STATIC Header CreateHeader(IN EHeaderType key, IN CONST String &value); STATIC Header CreateHeader(IN EHeaderType key, IN CONST String &value);
STATIC BOOL IsSuccessResponseCode(IN EResponseCode code);
public: public:
EResponseCode LastResponseCode() EResponseCode LastResponseCode()
{ {