Add SpinLock to guard NRT cache

nix
Owen Green 5 years ago
parent d0d155913c
commit a2847d96db

@ -53,17 +53,80 @@ namespace impl {
return !weak.owner_before(WeakCacheEntryPointer{}) && !WeakCacheEntryPointer{}.owner_before(weak); return !weak.owner_before(WeakCacheEntryPointer{}) && !WeakCacheEntryPointer{}.owner_before(weak);
} }
// https://rigtorp.se/spinlock/
struct Spinlock {
std::atomic<bool> lock_ = {0};
void lock() noexcept {
for (;;) {
// Optimistically assume the lock is free on the first try
if (!lock_.exchange(true, std::memory_order_acquire)) {
return;
}
// Wait for lock to be released without generating cache misses
while (lock_.load(std::memory_order_relaxed)) {
// Issue X86 PAUSE or ARM YIELD instruction to reduce contention between
// hyper-threads
//__builtin_ia32_pause();
}
}
}
public: bool tryLock() noexcept {
static WeakCacheEntryPointer get(index id) // First do a relaxed load to check if lock is free in order to prevent
// unnecessary cache misses if someone does while(!try_lock())
return !lock_.load(std::memory_order_relaxed) &&
!lock_.exchange(true, std::memory_order_acquire);
}
void unlock() noexcept {
lock_.store(false, std::memory_order_release);
}
};
//RAII for above
struct ScopedSpinlock
{
ScopedSpinlock(Spinlock& _l) noexcept: mLock{_l}
{
mLock.lock();
}
~ScopedSpinlock() { mLock.unlock(); }
private:
Spinlock& mLock;
};
static Spinlock mSpinlock;
// shouldn't be called without at least *thinking* about getting spin lock first
static inline WeakCacheEntryPointer unsafeGet(index id)
{ {
auto lookup = mCache.find(id); auto lookup = mCache.find(id);
return lookup == mCache.end() ? WeakCacheEntryPointer() : lookup->second; return lookup == mCache.end() ? WeakCacheEntryPointer() : lookup->second;
} }
public:
static WeakCacheEntryPointer get(index id)
{
ScopedSpinlock{mSpinlock};
return unsafeGet(id);
}
static WeakCacheEntryPointer tryGet(index id)
{
if(mSpinlock.tryLock())
{
auto ret = unsafeGet(id);
mSpinlock.unlock();
return ret;
}
return WeakCacheEntryPointer{};
}
static WeakCacheEntryPointer add(index id, const Params& params) static WeakCacheEntryPointer add(index id, const Params& params)
{ {
ScopedSpinlock{mSpinlock};
if(isNull(get(id))) if(isNull(get(id)))
{ {
auto result = mCache.emplace(id, auto result = mCache.emplace(id,
@ -80,6 +143,7 @@ namespace impl {
static void remove(index id) static void remove(index id)
{ {
ScopedSpinlock{mSpinlock};
mCache.erase(id); mCache.erase(id);
} }
@ -260,6 +324,7 @@ namespace impl {
if(auto ptr = get(NRTCommand::mID).lock()) if(auto ptr = get(NRTCommand::mID).lock())
{ {
Result r; Result r;
mRecord = ptr;
auto& client = ptr->mClient; auto& client = ptr->mClient;
ProcessState s = client.checkProgress(r); ProcessState s = client.checkProgress(r);
if (s == ProcessState::kDone || s == ProcessState::kDoneStillProcessing) if (s == ProcessState::kDone || s == ProcessState::kDoneStillProcessing)
@ -293,7 +358,7 @@ namespace impl {
bool stage3(World* world) bool stage3(World* world)
{ {
if(auto ptr = get(NRTCommand::mID).lock()) if(auto ptr = mRecord.lock())
{ {
auto& params = ptr->mParams; auto& params = ptr->mParams;
params.template forEachParamType<BufferT, AssignBuffer>(world); params.template forEachParamType<BufferT, AssignBuffer>(world);
@ -319,6 +384,7 @@ namespace impl {
} }
bool mSuccess; bool mSuccess;
WeakCacheEntryPointer mRecord;
}; };
@ -338,7 +404,6 @@ namespace impl {
auto launchCompletionFromNRT = [](FifoMsg* inmsg) auto launchCompletionFromNRT = [](FifoMsg* inmsg)
{ {
auto runCompletion = [](FifoMsg* msg){ auto runCompletion = [](FifoMsg* msg){
// std::cout << "In FIFOMsg\n";
Context* c = static_cast<Context*>(msg->mData); Context* c = static_cast<Context*>(msg->mData);
World* world = c->mWorld; World* world = c->mWorld;
index id = c->mID; index id = c->mID;
@ -401,7 +466,8 @@ namespace impl {
bool stage2(World* world) bool stage2(World* world)
{ {
if(auto ptr = get(NRTCommand::mID).lock()) mRecord = get(NRTCommand::mID);
if(auto ptr = mRecord.lock())
{ {
auto& params = ptr->mParams; auto& params = ptr->mParams;
@ -458,7 +524,7 @@ namespace impl {
//Only for blocking execution //Only for blocking execution
bool stage3(World* world) //rt bool stage3(World* world) //rt
{ {
if(auto ptr = get(NRTCommand::mID).lock()) if(auto ptr = mRecord.lock())
{ {
ptr->mParams.template forEachParamType<BufferT, AssignBuffer>(world); ptr->mParams.template forEachParamType<BufferT, AssignBuffer>(world);
// NRTCommand::sendReply(world, name(), mResult.ok()); // NRTCommand::sendReply(world, name(), mResult.ok());
@ -502,6 +568,7 @@ namespace impl {
char* mCompletionMessage{nullptr}; char* mCompletionMessage{nullptr};
Params mParams; Params mParams;
bool mOverwriteParams{false}; bool mOverwriteParams{false};
WeakCacheEntryPointer mRecord;
}; };
struct CommandProcessNew: public NRTCommand struct CommandProcessNew: public NRTCommand
@ -713,13 +780,18 @@ namespace impl {
if (0 == mCounter++) if (0 == mCounter++)
{ {
index id = static_cast<index>(mInBuf[0][0]); index id = static_cast<index>(mInBuf[0][0]);
if(auto ptr = get(id).lock()) if(auto ptr = tryGet(id).lock())
{ {
mInit = true;
if(ptr->mClient.done()) mDone = 1; if(ptr->mClient.done()) mDone = 1;
out0(0) = static_cast<float>(ptr->mClient.progress()); out0(0) = static_cast<float>(ptr->mClient.progress());
} }
else else
{
if(!mInit)
std::cout << "WARNING: No " << Wrapper::getName() << " with ID " << id << std::endl; std::cout << "WARNING: No " << Wrapper::getName() << " with ID " << id << std::endl;
else mDone = 1;
}
} }
mCounter %= mInterval; mCounter %= mInterval;
} }
@ -727,6 +799,7 @@ namespace impl {
private: private:
index mInterval; index mInterval;
index mCounter{0}; index mCounter{0};
bool mInit{false};
}; };
@ -760,18 +833,18 @@ namespace impl {
if(mID == -1) mID = count(); if(mID == -1) mID = count();
auto cmd = NonRealTime::rtalloc<CommandNew>(mWorld,mID,mWorld, mControlsIterator, this); auto cmd = NonRealTime::rtalloc<CommandNew>(mWorld,mID,mWorld, mControlsIterator, this);
runAsyncCommand(mWorld, cmd, nullptr, 0, nullptr); runAsyncCommand(mWorld, cmd, nullptr, 0, nullptr);
mInst = get(mID); // mInst = get(mID);
set_calc_function<NRTTriggerUnit, &NRTTriggerUnit::next>(); set_calc_function<NRTTriggerUnit, &NRTTriggerUnit::next>();
Wrapper::getInterfaceTable()->fClearUnitOutputs(this, 1); Wrapper::getInterfaceTable()->fClearUnitOutputs(this, 1);
} }
~NRTTriggerUnit() ~NRTTriggerUnit()
{ {
if(auto ptr = mInst.lock()) // if(auto ptr = mInst.lock())
{ // {
auto cmd = NonRealTime::rtalloc<CommandFree>(mWorld,mID); auto cmd = NonRealTime::rtalloc<CommandFree>(mWorld,mID);
runAsyncCommand(mWorld, cmd, nullptr, 0, nullptr); runAsyncCommand(mWorld, cmd, nullptr, 0, nullptr);
} // }
} }
void next(int) void next(int)
@ -801,12 +874,13 @@ namespace impl {
} }
else else
{ {
if(auto ptr = get(mID).lock()) if(auto ptr = tryGet(mID).lock())
{ {
mInit = true;
auto& client = ptr->mClient; auto& client = ptr->mClient;
mDone = ptr->mDone; mDone = ptr->mDone;
out0(0) = mDone ? 1 : static_cast<float>(client.progress()); out0(0) = mDone ? 1 : static_cast<float>(client.progress());
} } else mDone = mInit;
} }
// } // }
// else printNotFound(id); // else printNotFound(id);
@ -821,6 +895,7 @@ namespace impl {
index mRunCount{0}; index mRunCount{0};
WeakCacheEntryPointer mInst; WeakCacheEntryPointer mInst;
Params mParams; Params mParams;
bool mInit{false};
}; };
struct NRTModelQueryUnit: SCUnit struct NRTModelQueryUnit: SCUnit
@ -845,20 +920,40 @@ namespace impl {
//Offset controls by 1 to account for ID //Offset controls by 1 to account for ID
: mControls{mInBuf + ControlOffset(),ControlSize()} : mControls{mInBuf + ControlOffset(),ControlSize()}
{ {
index id = static_cast<index>(in0(1)); mID = static_cast<index>(in0(1));
mInst = get(id); init();
// mInst = get(id);
// if(auto ptr = mInst.lock())
// {
// auto& client = ptr->mClient;
// mDelegate.init(*this,client,mControls);
set_calc_function<NRTModelQueryUnit, &NRTModelQueryUnit::next>();
Wrapper::getInterfaceTable()->fClearUnitOutputs(this, 1);
// }else printNotFound(mID);
}
void init()
{
if(mSpinlock.tryLock())
{
mInit = false;
mInst = unsafeGet(mID);
if(auto ptr = mInst.lock()) if(auto ptr = mInst.lock())
{ {
auto& client = ptr->mClient; auto& client = ptr->mClient;
mDelegate.init(*this,client,mControls); mDelegate.init(*this,client,mControls);
set_calc_function<NRTModelQueryUnit, &NRTModelQueryUnit::next>(); mInit = true;
Wrapper::getInterfaceTable()->fClearUnitOutputs(this, 1); }//else printNotFound(mID);
}else printNotFound(id); mSpinlock.unlock();
}
} }
void next(int) void next(int)
{ {
index id = static_cast<index>(in0(1)); index id = static_cast<index>(in0(1));
if(mID != id) init();
if(!mInit) return;
if(auto ptr = mInst.lock()) if(auto ptr = mInst.lock())
{ {
auto& client = ptr->mClient; auto& client = ptr->mClient;
@ -873,6 +968,7 @@ namespace impl {
FloatControlsIter mControls; FloatControlsIter mControls;
index mID; index mID;
WeakCacheEntryPointer mInst; WeakCacheEntryPointer mInst;
bool mInit{false};
}; };
@ -988,6 +1084,10 @@ namespace impl {
template<typename Client, typename Wrapper> template<typename Client, typename Wrapper>
typename NonRealTime<Client, Wrapper>::Cache NonRealTime<Client,Wrapper>::mCache{}; typename NonRealTime<Client, Wrapper>::Cache NonRealTime<Client,Wrapper>::mCache{};
template<typename Client, typename Wrapper>
typename NonRealTime<Client, Wrapper>::Spinlock NonRealTime<Client,Wrapper>::mSpinlock{};
} }
} }
} }

Loading…
Cancel
Save