# pylint: disable=missing-type-doc
"""Thread safe datastore."""
from contextlib import contextmanager
import threading
from pymodbus.datastore.store import BaseModbusDataBlock
class ContextWrapper:
"""This is a simple wrapper around enter and exit functions
that conforms to the python context manager protocol:
with ContextWrapper(enter, leave):
do_something()
"""
def __init__(self, enter=None, leave=None, factory=None):
"""Initialize."""
self._enter = enter
self._leave = leave
self._factory = factory
def __enter__(self):
"""Do on enter."""
if self.enter: # pylint: disable=no-member
self._enter()
return self if not self._factory else self._factory()
def __exit__(self, *args):
"""Do on exit."""
if self._leave:
self._leave()
class ReadWriteLock:
"""This reader writer lock guarantees write order,
but not read order and is generally biased towards allowing writes
if they are available to prevent starvation.
TODO:
* allow user to choose between read/write/random biasing
- currently write biased
- read biased allow N readers in queue
- random is 50/50 choice of next
"""
def __init__(self):
"""Initialize a new instance of the ReadWriteLock"""
self.queue = [] # the current writer queue
self.lock = threading.Lock() # the underlying condition lock
self.read_condition = threading.Condition(
self.lock
) # the single reader condition
self.readers = 0 # the number of current readers
self.writer = False # is there a current writer
def __is_pending_writer(self):
"""Check is pending writer."""
return self.writer or ( # if there is a current writer
self.queue # or if there is a waiting writer
and (self.queue[0] != self.read_condition)
)
def acquire_reader(self):
"""Notify the lock that a new reader is requesting the underlying resource."""
with self.lock:
if self.__is_pending_writer(): # if there are existing writers waiting
if (
self.read_condition not in self.queue
): # do not pollute the queue with readers
self.queue.append(
self.read_condition
) # add the readers in line for the queue
while (
self.__is_pending_writer()
): # until the current writer is finished
self.read_condition.wait(1) # wait on our condition
if self.queue and self.read_condition == self.queue[0]:
self.queue.pop(0) # then go ahead and remove it
self.readers += 1 # update the current number of readers
def acquire_writer(self):
"""Notify the lock that a new writer is requesting the underlying resource."""
with self.lock:
if self.writer or self.readers:
condition = threading.Condition(self.lock)
# create a condition just for this writer
self.queue.append(condition) # and put it on the waiting queue
while self.writer or self.readers: # until the write lock is free
condition.wait(1)
self.queue.pop(0)
self.writer = True # stop other writers from operating
def release_reader(self):
"""Notify the lock that an existing reader is finished with the underlying resource."""
with self.lock:
self.readers = max(0, self.readers - 1) # readers should never go below 0
if not self.readers and self.queue: # if there are no active readers
self.queue[0].notify_all() # then notify any waiting writers
def release_writer(self):
"""Notify the lock that an existing writer is finished with the underlying resource."""
with self.lock:
self.writer = False # give up current writing handle
if self.queue: # if someone is waiting in the queue
self.queue[0].notify_all() # wake them up first
else:
self.read_condition.notify_all() # otherwise wake up all possible readers
@contextmanager
def get_reader_lock(self):
"""Wrap some code with a reader lock using the python context manager protocol::
with rwlock.get_reader_lock():
do_read_operation()
"""
try:
self.acquire_reader()
yield self
finally:
self.release_reader()
@contextmanager
def get_writer_lock(self):
"""Wrap some code with a writer lock using the python context manager protocol::
with rwlock.get_writer_lock():
do_read_operation()
"""
try:
self.acquire_writer()
yield self
finally:
self.release_writer()
class ThreadSafeDataBlock(BaseModbusDataBlock):
"""This is a simple decorator for a data block.
This allows a user to inject an existing data block which can then be
safely operated on from multiple cocurrent threads.
It should be noted that the choice was made to lock around the
datablock instead of the manager as there is less source of
contention (writes can occur to slave 0x01 while reads can
occur to slave 0x02).
"""
def __init__(self, block):
"""Initialize a new thread safe decorator
:param block: The block to decorate
"""
self.rwlock = ReadWriteLock()
self.block = block
def validate(self, address, count=1):
"""Check to see if the request is in range
:param address: The starting address
:param count: The number of values to test for
:returns: True if the request in within range, False otherwise
"""
with self.rwlock.get_reader_lock():
return self.block.validate(address, count)
def getValues(self, address, count=1):
"""Return the requested values of the datastore
:param address: The starting address
:param count: The number of values to retrieve
:returns: The requested values from a:a+c
"""
with self.rwlock.get_reader_lock():
return self.block.getValues(address, count)
def setValues(self, address, values):
"""Set the requested values of the datastore
:param address: The starting address
:param values: The new values to be set
"""
with self.rwlock.get_writer_lock():
return self.block.setValues(address, values)
if __name__ == "__main__": # pylint: disable=too-complex
class AtomicCounter:
"""Atomic counter."""
def __init__(self, **kwargs):
"""Init."""
self.counter = kwargs.get("start", 0)
self.finish = kwargs.get("finish", 1000)
self.lock = threading.Lock()
def increment(self, count=1):
"""Increment."""
with self.lock:
self.counter += count
def is_running(self):
"""Is running."""
return self.counter <= self.finish
locker = ReadWriteLock()
readers, writers = AtomicCounter(), AtomicCounter()
def read():
"""Read."""
while writers.is_running() and readers.is_running():
with locker.get_reader_lock():
readers.increment()
def write():
"""Write."""
while writers.is_running() and readers.is_running():
with locker.get_writer_lock():
writers.increment()
rthreads = [threading.Thread(target=read) for i in range(50)]
wthreads = [threading.Thread(target=write) for i in range(2)]
for t in rthreads + wthreads:
t.start()
for t in rthreads + wthreads:
t.join()
print(f"readers[{readers.counter}] writers[{writers.counter}]")