#![warn(missing_docs)]
#![allow(clippy::mutex_atomic)]
mod cached;
mod thread_id;
mod unreachable;
#[allow(deprecated)]
pub use cached::{CachedIntoIter, CachedIterMut, CachedThreadLocal};
use std::cell::UnsafeCell;
use std::fmt;
use std::iter::FusedIterator;
use std::mem;
use std::mem::MaybeUninit;
use std::panic::UnwindSafe;
use std::ptr;
use std::sync::atomic::{AtomicBool, AtomicPtr, AtomicUsize, Ordering};
use std::sync::Mutex;
use thread_id::Thread;
use unreachable::UncheckedResultExt;
#[cfg(target_pointer_width = "16")]
const POINTER_WIDTH: u8 = 16;
#[cfg(target_pointer_width = "32")]
const POINTER_WIDTH: u8 = 32;
#[cfg(target_pointer_width = "64")]
const POINTER_WIDTH: u8 = 64;
const BUCKETS: usize = (POINTER_WIDTH + 1) as usize;
pub struct ThreadLocal<T: Send> {
buckets: [AtomicPtr<Entry<T>>; BUCKETS],
values: AtomicUsize,
lock: Mutex<()>,
}
struct Entry<T> {
present: AtomicBool,
value: UnsafeCell<MaybeUninit<T>>,
}
unsafe impl<T: Send> Sync for ThreadLocal<T> {}
impl<T: Send> Default for ThreadLocal<T> {
fn default() -> ThreadLocal<T> {
ThreadLocal::new()
}
}
impl<T: Send> Drop for ThreadLocal<T> {
fn drop(&mut self) {
let mut bucket_size = 1;
for (i, bucket) in self.buckets.iter_mut().enumerate() {
let bucket_ptr = *bucket.get_mut();
let this_bucket_size = bucket_size;
if i != 0 {
bucket_size <<= 1;
}
if bucket_ptr.is_null() {
continue;
}
unsafe { Box::from_raw(std::slice::from_raw_parts_mut(bucket_ptr, this_bucket_size)) };
}
}
}
impl<T: Send> ThreadLocal<T> {
pub fn new() -> ThreadLocal<T> {
Self::with_capacity(2)
}
pub fn with_capacity(capacity: usize) -> ThreadLocal<T> {
let allocated_buckets = capacity
.checked_sub(1)
.map(|c| usize::from(POINTER_WIDTH) - (c.leading_zeros() as usize) + 1)
.unwrap_or(0);
let mut buckets = [ptr::null_mut(); BUCKETS];
let mut bucket_size = 1;
for (i, bucket) in buckets[..allocated_buckets].iter_mut().enumerate() {
*bucket = allocate_bucket::<T>(bucket_size);
if i != 0 {
bucket_size <<= 1;
}
}
ThreadLocal {
buckets: unsafe { mem::transmute(buckets) },
values: AtomicUsize::new(0),
lock: Mutex::new(()),
}
}
pub fn get(&self) -> Option<&T> {
let thread = thread_id::get();
self.get_inner(thread)
}
pub fn get_or<F>(&self, create: F) -> &T
where
F: FnOnce() -> T,
{
unsafe {
self.get_or_try(|| Ok::<T, ()>(create()))
.unchecked_unwrap_ok()
}
}
pub fn get_or_try<F, E>(&self, create: F) -> Result<&T, E>
where
F: FnOnce() -> Result<T, E>,
{
let thread = thread_id::get();
match self.get_inner(thread) {
Some(x) => Ok(x),
None => Ok(self.insert(thread, create()?)),
}
}
fn get_inner(&self, thread: Thread) -> Option<&T> {
let bucket_ptr =
unsafe { self.buckets.get_unchecked(thread.bucket) }.load(Ordering::Acquire);
if bucket_ptr.is_null() {
return None;
}
unsafe {
let entry = &*bucket_ptr.add(thread.index);
if (&entry.present as *const _ as *const bool).read() {
Some(&*(&*entry.value.get()).as_ptr())
} else {
None
}
}
}
#[cold]
fn insert(&self, thread: Thread, data: T) -> &T {
let _guard = self.lock.lock().unwrap();
let bucket_atomic_ptr = unsafe { self.buckets.get_unchecked(thread.bucket) };
let bucket_ptr: *const _ = bucket_atomic_ptr.load(Ordering::Acquire);
let bucket_ptr = if bucket_ptr.is_null() {
let bucket_ptr = allocate_bucket(thread.bucket_size);
bucket_atomic_ptr.store(bucket_ptr, Ordering::Release);
bucket_ptr
} else {
bucket_ptr
};
drop(_guard);
let entry = unsafe { &*bucket_ptr.add(thread.index) };
let value_ptr = entry.value.get();
unsafe { value_ptr.write(MaybeUninit::new(data)) };
entry.present.store(true, Ordering::Release);
self.values.fetch_add(1, Ordering::Release);
unsafe { &*(&*value_ptr).as_ptr() }
}
pub fn iter(&self) -> Iter<'_, T>
where
T: Sync,
{
Iter {
thread_local: self,
raw: RawIter::new(),
}
}
pub fn iter_mut(&mut self) -> IterMut<T> {
IterMut {
thread_local: self,
raw: RawIter::new(),
}
}
pub fn clear(&mut self) {
*self = ThreadLocal::new();
}
}
impl<T: Send> IntoIterator for ThreadLocal<T> {
type Item = T;
type IntoIter = IntoIter<T>;
fn into_iter(self) -> IntoIter<T> {
IntoIter {
thread_local: self,
raw: RawIter::new(),
}
}
}
impl<'a, T: Send + Sync> IntoIterator for &'a ThreadLocal<T> {
type Item = &'a T;
type IntoIter = Iter<'a, T>;
fn into_iter(self) -> Self::IntoIter {
self.iter()
}
}
impl<'a, T: Send> IntoIterator for &'a mut ThreadLocal<T> {
type Item = &'a mut T;
type IntoIter = IterMut<'a, T>;
fn into_iter(self) -> IterMut<'a, T> {
self.iter_mut()
}
}
impl<T: Send + Default> ThreadLocal<T> {
pub fn get_or_default(&self) -> &T {
self.get_or(Default::default)
}
}
impl<T: Send + fmt::Debug> fmt::Debug for ThreadLocal<T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "ThreadLocal {{ local_data: {:?} }}", self.get())
}
}
impl<T: Send + UnwindSafe> UnwindSafe for ThreadLocal<T> {}
#[derive(Debug)]
struct RawIter {
yielded: usize,
bucket: usize,
bucket_size: usize,
index: usize,
}
impl RawIter {
#[inline]
fn new() -> Self {
Self {
yielded: 0,
bucket: 0,
bucket_size: 1,
index: 0,
}
}
fn next<'a, T: Send + Sync>(&mut self, thread_local: &'a ThreadLocal<T>) -> Option<&'a T> {
while self.bucket < BUCKETS {
let bucket = unsafe { thread_local.buckets.get_unchecked(self.bucket) };
let bucket = bucket.load(Ordering::Relaxed);
if !bucket.is_null() {
while self.index < self.bucket_size {
let entry = unsafe { &*bucket.add(self.index) };
self.index += 1;
if entry.present.load(Ordering::Acquire) {
self.yielded += 1;
return Some(unsafe { &*(&*entry.value.get()).as_ptr() });
}
}
}
self.next_bucket();
}
None
}
fn next_mut<'a, T: Send>(
&mut self,
thread_local: &'a mut ThreadLocal<T>,
) -> Option<&'a mut Entry<T>> {
if *thread_local.values.get_mut() == self.yielded {
return None;
}
loop {
let bucket = unsafe { thread_local.buckets.get_unchecked_mut(self.bucket) };
let bucket = *bucket.get_mut();
if !bucket.is_null() {
while self.index < self.bucket_size {
let entry = unsafe { &mut *bucket.add(self.index) };
self.index += 1;
if *entry.present.get_mut() {
self.yielded += 1;
return Some(entry);
}
}
}
self.next_bucket();
}
}
#[inline]
fn next_bucket(&mut self) {
if self.bucket != 0 {
self.bucket_size <<= 1;
}
self.bucket += 1;
self.index = 0;
}
fn size_hint<T: Send>(&self, thread_local: &ThreadLocal<T>) -> (usize, Option<usize>) {
let total = thread_local.values.load(Ordering::Acquire);
(total - self.yielded, None)
}
fn size_hint_frozen<T: Send>(&self, thread_local: &ThreadLocal<T>) -> (usize, Option<usize>) {
let total = unsafe { *(&thread_local.values as *const AtomicUsize as *const usize) };
let remaining = total - self.yielded;
(remaining, Some(remaining))
}
}
#[derive(Debug)]
pub struct Iter<'a, T: Send + Sync> {
thread_local: &'a ThreadLocal<T>,
raw: RawIter,
}
impl<'a, T: Send + Sync> Iterator for Iter<'a, T> {
type Item = &'a T;
fn next(&mut self) -> Option<Self::Item> {
self.raw.next(self.thread_local)
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.raw.size_hint(self.thread_local)
}
}
impl<T: Send + Sync> FusedIterator for Iter<'_, T> {}
pub struct IterMut<'a, T: Send> {
thread_local: &'a mut ThreadLocal<T>,
raw: RawIter,
}
impl<'a, T: Send> Iterator for IterMut<'a, T> {
type Item = &'a mut T;
fn next(&mut self) -> Option<&'a mut T> {
self.raw
.next_mut(self.thread_local)
.map(|entry| unsafe { &mut *(&mut *entry.value.get()).as_mut_ptr() })
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.raw.size_hint_frozen(self.thread_local)
}
}
impl<T: Send> ExactSizeIterator for IterMut<'_, T> {}
impl<T: Send> FusedIterator for IterMut<'_, T> {}
impl<'a, T: Send + fmt::Debug> fmt::Debug for IterMut<'a, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("IterMut").field("raw", &self.raw).finish()
}
}
#[derive(Debug)]
pub struct IntoIter<T: Send> {
thread_local: ThreadLocal<T>,
raw: RawIter,
}
impl<T: Send> Iterator for IntoIter<T> {
type Item = T;
fn next(&mut self) -> Option<T> {
self.raw.next_mut(&mut self.thread_local).map(|entry| {
*entry.present.get_mut() = false;
unsafe {
std::mem::replace(&mut *entry.value.get(), MaybeUninit::uninit()).assume_init()
}
})
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.raw.size_hint_frozen(&self.thread_local)
}
}
impl<T: Send> ExactSizeIterator for IntoIter<T> {}
impl<T: Send> FusedIterator for IntoIter<T> {}
fn allocate_bucket<T>(size: usize) -> *mut Entry<T> {
Box::into_raw(
(0..size)
.map(|_| Entry::<T> {
present: AtomicBool::new(false),
value: UnsafeCell::new(MaybeUninit::uninit()),
})
.collect(),
) as *mut _
}
#[cfg(test)]
mod tests {
use super::ThreadLocal;
use std::cell::RefCell;
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering::Relaxed;
use std::sync::Arc;
use std::thread;
fn make_create() -> Arc<dyn Fn() -> usize + Send + Sync> {
let count = AtomicUsize::new(0);
Arc::new(move || count.fetch_add(1, Relaxed))
}
#[test]
fn same_thread() {
let create = make_create();
let mut tls = ThreadLocal::new();
assert_eq!(None, tls.get());
assert_eq!("ThreadLocal { local_data: None }", format!("{:?}", &tls));
assert_eq!(0, *tls.get_or(|| create()));
assert_eq!(Some(&0), tls.get());
assert_eq!(0, *tls.get_or(|| create()));
assert_eq!(Some(&0), tls.get());
assert_eq!(0, *tls.get_or(|| create()));
assert_eq!(Some(&0), tls.get());
assert_eq!("ThreadLocal { local_data: Some(0) }", format!("{:?}", &tls));
tls.clear();
assert_eq!(None, tls.get());
}
#[test]
fn different_thread() {
let create = make_create();
let tls = Arc::new(ThreadLocal::new());
assert_eq!(None, tls.get());
assert_eq!(0, *tls.get_or(|| create()));
assert_eq!(Some(&0), tls.get());
let tls2 = tls.clone();
let create2 = create.clone();
thread::spawn(move || {
assert_eq!(None, tls2.get());
assert_eq!(1, *tls2.get_or(|| create2()));
assert_eq!(Some(&1), tls2.get());
})
.join()
.unwrap();
assert_eq!(Some(&0), tls.get());
assert_eq!(0, *tls.get_or(|| create()));
}
#[test]
fn iter() {
let tls = Arc::new(ThreadLocal::new());
tls.get_or(|| Box::new(1));
let tls2 = tls.clone();
thread::spawn(move || {
tls2.get_or(|| Box::new(2));
let tls3 = tls2.clone();
thread::spawn(move || {
tls3.get_or(|| Box::new(3));
})
.join()
.unwrap();
drop(tls2);
})
.join()
.unwrap();
let mut tls = Arc::try_unwrap(tls).unwrap();
let mut v = tls.iter().map(|x| **x).collect::<Vec<i32>>();
v.sort_unstable();
assert_eq!(vec![1, 2, 3], v);
let mut v = tls.iter_mut().map(|x| **x).collect::<Vec<i32>>();
v.sort_unstable();
assert_eq!(vec![1, 2, 3], v);
let mut v = tls.into_iter().map(|x| *x).collect::<Vec<i32>>();
v.sort_unstable();
assert_eq!(vec![1, 2, 3], v);
}
#[test]
fn is_sync() {
fn foo<T: Sync>() {}
foo::<ThreadLocal<String>>();
foo::<ThreadLocal<RefCell<String>>>();
}
}