From a2847d96db8c8f8935ca3d5115421a0a68145c6f Mon Sep 17 00:00:00 2001 From: Owen Green Date: Wed, 5 May 2021 16:36:12 +0100 Subject: [PATCH] Add SpinLock to guard NRT cache --- include/wrapper/NonRealtime.hpp | 148 ++++++++++++++++++++++++++------ 1 file changed, 124 insertions(+), 24 deletions(-) diff --git a/include/wrapper/NonRealtime.hpp b/include/wrapper/NonRealtime.hpp index 9983c5a..bdcc8b8 100644 --- a/include/wrapper/NonRealtime.hpp +++ b/include/wrapper/NonRealtime.hpp @@ -53,17 +53,80 @@ namespace impl { return !weak.owner_before(WeakCacheEntryPointer{}) && !WeakCacheEntryPointer{}.owner_before(weak); } - + // https://rigtorp.se/spinlock/ + struct Spinlock { + std::atomic lock_ = {0}; + + void lock() noexcept { + for (;;) { + // Optimistically assume the lock is free on the first try + if (!lock_.exchange(true, std::memory_order_acquire)) { + return; + } + // Wait for lock to be released without generating cache misses + while (lock_.load(std::memory_order_relaxed)) { + // Issue X86 PAUSE or ARM YIELD instruction to reduce contention between + // hyper-threads + //__builtin_ia32_pause(); + } + } + } + + bool tryLock() noexcept { + // First do a relaxed load to check if lock is free in order to prevent + // unnecessary cache misses if someone does while(!try_lock()) + return !lock_.load(std::memory_order_relaxed) && + !lock_.exchange(true, std::memory_order_acquire); + } + + void unlock() noexcept { + lock_.store(false, std::memory_order_release); + } + }; + + //RAII for above + struct ScopedSpinlock + { + ScopedSpinlock(Spinlock& _l) noexcept: mLock{_l} + { + mLock.lock(); + } + ~ScopedSpinlock() { mLock.unlock(); } + private: + Spinlock& mLock; + }; + + static Spinlock mSpinlock; + + // shouldn't be called without at least *thinking* about getting spin lock first + static inline WeakCacheEntryPointer unsafeGet(index id) + { + auto lookup = mCache.find(id); + return lookup == mCache.end() ? WeakCacheEntryPointer() : lookup->second; + } public: static WeakCacheEntryPointer get(index id) { - auto lookup = mCache.find(id); - return lookup == mCache.end() ? WeakCacheEntryPointer() : lookup->second; + ScopedSpinlock{mSpinlock}; + return unsafeGet(id); } - + + static WeakCacheEntryPointer tryGet(index id) + { + if(mSpinlock.tryLock()) + { + auto ret = unsafeGet(id); + mSpinlock.unlock(); + return ret; + } + return WeakCacheEntryPointer{}; + } + + static WeakCacheEntryPointer add(index id, const Params& params) { + ScopedSpinlock{mSpinlock}; if(isNull(get(id))) { auto result = mCache.emplace(id, @@ -80,6 +143,7 @@ namespace impl { static void remove(index id) { + ScopedSpinlock{mSpinlock}; mCache.erase(id); } @@ -260,6 +324,7 @@ namespace impl { if(auto ptr = get(NRTCommand::mID).lock()) { Result r; + mRecord = ptr; auto& client = ptr->mClient; ProcessState s = client.checkProgress(r); if (s == ProcessState::kDone || s == ProcessState::kDoneStillProcessing) @@ -293,7 +358,7 @@ namespace impl { bool stage3(World* world) { - if(auto ptr = get(NRTCommand::mID).lock()) + if(auto ptr = mRecord.lock()) { auto& params = ptr->mParams; params.template forEachParamType(world); @@ -319,6 +384,7 @@ namespace impl { } bool mSuccess; + WeakCacheEntryPointer mRecord; }; @@ -338,7 +404,6 @@ namespace impl { auto launchCompletionFromNRT = [](FifoMsg* inmsg) { auto runCompletion = [](FifoMsg* msg){ - // std::cout << "In FIFOMsg\n"; Context* c = static_cast(msg->mData); World* world = c->mWorld; index id = c->mID; @@ -401,7 +466,8 @@ namespace impl { bool stage2(World* world) { - if(auto ptr = get(NRTCommand::mID).lock()) + mRecord = get(NRTCommand::mID); + if(auto ptr = mRecord.lock()) { auto& params = ptr->mParams; @@ -458,7 +524,7 @@ namespace impl { //Only for blocking execution bool stage3(World* world) //rt { - if(auto ptr = get(NRTCommand::mID).lock()) + if(auto ptr = mRecord.lock()) { ptr->mParams.template forEachParamType(world); // NRTCommand::sendReply(world, name(), mResult.ok()); @@ -502,6 +568,7 @@ namespace impl { char* mCompletionMessage{nullptr}; Params mParams; bool mOverwriteParams{false}; + WeakCacheEntryPointer mRecord; }; struct CommandProcessNew: public NRTCommand @@ -713,13 +780,18 @@ namespace impl { if (0 == mCounter++) { index id = static_cast(mInBuf[0][0]); - if(auto ptr = get(id).lock()) + if(auto ptr = tryGet(id).lock()) { + mInit = true; if(ptr->mClient.done()) mDone = 1; out0(0) = static_cast(ptr->mClient.progress()); } else - std::cout << "WARNING: No " << Wrapper::getName() << " with ID " << id << std::endl; + { + if(!mInit) + std::cout << "WARNING: No " << Wrapper::getName() << " with ID " << id << std::endl; + else mDone = 1; + } } mCounter %= mInterval; } @@ -727,6 +799,7 @@ namespace impl { private: index mInterval; index mCounter{0}; + bool mInit{false}; }; @@ -760,18 +833,18 @@ namespace impl { if(mID == -1) mID = count(); auto cmd = NonRealTime::rtalloc(mWorld,mID,mWorld, mControlsIterator, this); runAsyncCommand(mWorld, cmd, nullptr, 0, nullptr); - mInst = get(mID); +// mInst = get(mID); set_calc_function(); Wrapper::getInterfaceTable()->fClearUnitOutputs(this, 1); } ~NRTTriggerUnit() { - if(auto ptr = mInst.lock()) - { +// if(auto ptr = mInst.lock()) +// { auto cmd = NonRealTime::rtalloc(mWorld,mID); runAsyncCommand(mWorld, cmd, nullptr, 0, nullptr); - } +// } } void next(int) @@ -801,12 +874,13 @@ namespace impl { } else { - if(auto ptr = get(mID).lock()) + if(auto ptr = tryGet(mID).lock()) { + mInit = true; auto& client = ptr->mClient; mDone = ptr->mDone; out0(0) = mDone ? 1 : static_cast(client.progress()); - } + } else mDone = mInit; } // } // else printNotFound(id); @@ -821,6 +895,7 @@ namespace impl { index mRunCount{0}; WeakCacheEntryPointer mInst; Params mParams; + bool mInit{false}; }; struct NRTModelQueryUnit: SCUnit @@ -845,20 +920,40 @@ namespace impl { //Offset controls by 1 to account for ID : mControls{mInBuf + ControlOffset(),ControlSize()} { - index id = static_cast(in0(1)); - mInst = get(id); - if(auto ptr = mInst.lock()) - { - auto& client = ptr->mClient; - mDelegate.init(*this,client,mControls); + mID = static_cast(in0(1)); + init(); +// mInst = get(id); +// if(auto ptr = mInst.lock()) +// { +// auto& client = ptr->mClient; +// mDelegate.init(*this,client,mControls); set_calc_function(); Wrapper::getInterfaceTable()->fClearUnitOutputs(this, 1); - }else printNotFound(id); +// }else printNotFound(mID); + } + + void init() + { + if(mSpinlock.tryLock()) + { + mInit = false; + mInst = unsafeGet(mID); + if(auto ptr = mInst.lock()) + { + auto& client = ptr->mClient; + mDelegate.init(*this,client,mControls); + mInit = true; + }//else printNotFound(mID); + mSpinlock.unlock(); + } } void next(int) { + index id = static_cast(in0(1)); + if(mID != id) init(); + if(!mInit) return; if(auto ptr = mInst.lock()) { auto& client = ptr->mClient; @@ -873,6 +968,7 @@ namespace impl { FloatControlsIter mControls; index mID; WeakCacheEntryPointer mInst; + bool mInit{false}; }; @@ -987,7 +1083,11 @@ namespace impl { template typename NonRealTime::Cache NonRealTime::mCache{}; - + + template + typename NonRealTime::Spinlock NonRealTime::mSpinlock{}; + + } } }