use hex::FromHexError;
use sha2::Digest;
use std::{
borrow::{Borrow, Cow},
fmt,
};
pub type Hash = [u8; 32];
#[derive(Clone, Hash, Ord, PartialOrd, Eq, PartialEq)]
pub struct Label<Storage: AsRef<[u8]>>(Storage);
impl<Storage: AsRef<[u8]>> Label<Storage> {
pub fn from_bytes<'a>(v: &'a [u8]) -> Self
where
&'a [u8]: Into<Storage>,
{
Self(v.into())
}
pub fn from_label<StorageB: AsRef<[u8]> + Into<Storage>>(s: Label<StorageB>) -> Self {
Self(s.0.into())
}
pub fn as_bytes(&self) -> &[u8] {
self.0.as_ref()
}
fn hex_len(&self) -> usize {
self.as_bytes().len() * 2
}
fn write_hex(&self, f: &mut impl fmt::Write) -> fmt::Result {
self.as_bytes()
.iter()
.try_for_each(|b| write!(f, "{:02X}", b))
}
}
impl<Storage: AsRef<[u8]>> From<Storage> for Label<Storage> {
fn from(s: Storage) -> Self {
Self(s)
}
}
impl<const N: usize> From<[u8; N]> for Label<Vec<u8>> {
fn from(s: [u8; N]) -> Self {
Self(s.into())
}
}
impl<'a, const N: usize> From<&'a [u8; N]> for Label<Vec<u8>> {
fn from(s: &'a [u8; N]) -> Self {
Self(s.as_slice().into())
}
}
impl<'a> From<&'a [u8]> for Label<Vec<u8>> {
fn from(s: &'a [u8]) -> Self {
Self(s.into())
}
}
impl<'a> From<&'a str> for Label<Vec<u8>> {
fn from(s: &'a str) -> Self {
Self(s.as_bytes().into())
}
}
impl From<String> for Label<Vec<u8>> {
fn from(s: String) -> Self {
Self(s.into())
}
}
impl<'a, const N: usize> From<&'a [u8; N]> for Label<&'a [u8]> {
fn from(s: &'a [u8; N]) -> Self {
Self(s.as_slice())
}
}
impl<'a> From<&'a str> for Label<&'a [u8]> {
fn from(s: &'a str) -> Self {
Self(s.as_bytes())
}
}
impl<'a, const N: usize> From<&'a [u8; N]> for Label<Cow<'a, [u8]>> {
fn from(s: &'a [u8; N]) -> Self {
Self(s.as_slice().into())
}
}
impl<'a> From<&'a [u8]> for Label<Cow<'a, [u8]>> {
fn from(s: &'a [u8]) -> Self {
Self(s.into())
}
}
impl<'a> From<&'a str> for Label<Cow<'a, [u8]>> {
fn from(s: &'a str) -> Self {
Self(s.as_bytes().into())
}
}
impl<Storage: AsRef<[u8]>> fmt::Display for Label<Storage> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
use fmt::Write;
match std::str::from_utf8(self.as_bytes()) {
Ok(s) if s.chars().all(|c| c.is_ascii_graphic()) => {
f.write_char('"')?;
f.write_str(s)?;
f.write_char('"')
}
_ => {
write!(f, "0x")?;
fmt::Debug::fmt(self, f)
}
}
}
}
impl<Storage: AsRef<[u8]>> fmt::Debug for Label<Storage> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.write_hex(f)
}
}
impl<Storage: AsRef<[u8]>> AsRef<[u8]> for Label<Storage> {
fn as_ref(&self) -> &[u8] {
self.0.as_ref()
}
}
#[derive(Clone, PartialEq, Eq, Debug)]
pub enum LookupResult<'tree> {
Absent,
Unknown,
Found(&'tree [u8]),
Error,
}
#[derive(Clone, PartialEq, Eq, Debug)]
pub enum SubtreeLookupResult<Storage: AsRef<[u8]>> {
Absent,
Unknown,
Found(HashTree<Storage>),
}
#[derive(Clone, PartialEq, Eq)]
pub struct HashTree<Storage: AsRef<[u8]>> {
pub(crate) root: HashTreeNode<Storage>,
}
impl<Storage: AsRef<[u8]>> fmt::Debug for HashTree<Storage> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("HashTree")
.field("root", &self.root)
.finish()
}
}
#[allow(dead_code)]
impl<Storage: AsRef<[u8]>> HashTree<Storage> {
#[inline]
pub fn digest(&self) -> Hash {
self.root.digest()
}
pub fn lookup_path<P>(&self, path: P) -> LookupResult<'_>
where
P: IntoIterator,
P::Item: AsRef<[u8]>,
{
self.root.lookup_path(&mut path.into_iter())
}
}
impl<Storage: Clone + AsRef<[u8]>> HashTree<Storage> {
pub fn lookup_subtree<'p, P, I>(&self, path: P) -> SubtreeLookupResult<Storage>
where
P: IntoIterator<Item = &'p I>,
I: ?Sized + AsRef<[u8]> + 'p,
{
self.root
.lookup_subtree(&mut path.into_iter().map(|v| v.borrow()))
}
pub fn list_paths(&self) -> Vec<Vec<Label<Storage>>> {
self.root.list_paths(&vec![])
}
}
impl<Storage: AsRef<[u8]>> AsRef<HashTreeNode<Storage>> for HashTree<Storage> {
fn as_ref(&self) -> &HashTreeNode<Storage> {
&self.root
}
}
impl<Storage: AsRef<[u8]>> From<HashTree<Storage>> for HashTreeNode<Storage> {
fn from(tree: HashTree<Storage>) -> HashTreeNode<Storage> {
tree.root
}
}
#[inline]
pub fn empty<Storage: AsRef<[u8]>>() -> HashTree<Storage> {
HashTree {
root: HashTreeNode::Empty(),
}
}
#[inline]
pub fn fork<Storage: AsRef<[u8]>>(
left: HashTree<Storage>,
right: HashTree<Storage>,
) -> HashTree<Storage> {
HashTree {
root: HashTreeNode::Fork(Box::new((left.root, right.root))),
}
}
#[inline]
pub fn label<Storage: AsRef<[u8]>, L: Into<Label<Storage>>, N: Into<HashTree<Storage>>>(
label: L,
node: N,
) -> HashTree<Storage> {
HashTree {
root: HashTreeNode::Labeled(label.into(), Box::new(node.into().root)),
}
}
#[inline]
pub fn leaf<Storage: AsRef<[u8]>, L: Into<Storage>>(leaf: L) -> HashTree<Storage> {
HashTree {
root: HashTreeNode::Leaf(leaf.into()),
}
}
#[inline]
pub fn pruned<Storage: AsRef<[u8]>, C: Into<Hash>>(content: C) -> HashTree<Storage> {
HashTree {
root: HashTreeNode::Pruned(content.into()),
}
}
#[inline]
pub fn pruned_from_hex<Storage: AsRef<[u8]>, C: AsRef<str>>(
content: C,
) -> Result<HashTree<Storage>, FromHexError> {
let mut decode: Hash = [0; 32];
hex::decode_to_slice(content.as_ref(), &mut decode)?;
Ok(pruned(decode))
}
#[derive(Debug)]
enum LookupLabelResult<'node, Storage: AsRef<[u8]>> {
Absent,
Unknown,
Less,
Greater,
Found(&'node HashTreeNode<Storage>),
}
#[derive(Clone, PartialEq, Eq)]
pub enum HashTreeNode<Storage: AsRef<[u8]>> {
Empty(),
Fork(Box<(Self, Self)>),
Labeled(Label<Storage>, Box<Self>),
Leaf(Storage),
Pruned(Hash),
}
impl<Storage: AsRef<[u8]>> fmt::Debug for HashTreeNode<Storage> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fn readable_print(f: &mut fmt::Formatter<'_>, v: &[u8]) -> fmt::Result {
match std::str::from_utf8(v) {
Ok(s) if s.chars().all(|c| c.is_ascii_graphic()) => {
f.write_str("\"")?;
f.write_str(s)?;
f.write_str("\"")
}
_ if v.len() <= 32 => {
f.write_str("0x")?;
f.write_str(&hex::encode(v))
}
_ => {
write!(f, "{} bytes", v.len())
}
}
}
match self {
HashTreeNode::Empty() => f.write_str("Empty"),
HashTreeNode::Fork(nodes) => f
.debug_tuple("Fork")
.field(&nodes.0)
.field(&nodes.1)
.finish(),
HashTreeNode::Leaf(v) => {
f.write_str("Leaf(")?;
readable_print(f, v.as_ref())?;
f.write_str(")")
}
HashTreeNode::Labeled(l, node) => {
f.write_str("Label(")?;
readable_print(f, l.as_bytes())?;
f.write_str(", ")?;
node.fmt(f)?;
f.write_str(")")
}
HashTreeNode::Pruned(digest) => write!(f, "Pruned({})", hex::encode(digest.as_ref())),
}
}
}
impl<Storage: AsRef<[u8]>> HashTreeNode<Storage> {
#[inline]
fn domain_sep(&self, hasher: &mut sha2::Sha256) {
let domain_sep = match self {
HashTreeNode::Empty() => "ic-hashtree-empty",
HashTreeNode::Fork(_) => "ic-hashtree-fork",
HashTreeNode::Labeled(_, _) => "ic-hashtree-labeled",
HashTreeNode::Leaf(_) => "ic-hashtree-leaf",
HashTreeNode::Pruned(_) => return,
};
hasher.update([domain_sep.len() as u8]);
hasher.update(domain_sep.as_bytes());
}
#[inline]
pub fn digest(&self) -> Hash {
let mut hasher = sha2::Sha256::new();
self.domain_sep(&mut hasher);
match self {
HashTreeNode::Empty() => {}
HashTreeNode::Fork(nodes) => {
hasher.update(nodes.0.digest());
hasher.update(nodes.1.digest());
}
HashTreeNode::Labeled(label, node) => {
hasher.update(label.as_bytes());
hasher.update(node.digest());
}
HashTreeNode::Leaf(bytes) => {
hasher.update(bytes.as_ref());
}
HashTreeNode::Pruned(digest) => {
return *digest;
}
}
hasher.finalize().into()
}
fn lookup_label(&self, label: &[u8]) -> LookupLabelResult<Storage> {
match self {
HashTreeNode::Labeled(l, node) => match label.cmp(l.as_bytes()) {
std::cmp::Ordering::Greater => LookupLabelResult::Greater,
std::cmp::Ordering::Equal => LookupLabelResult::Found(node.as_ref()),
std::cmp::Ordering::Less => LookupLabelResult::Less,
},
HashTreeNode::Fork(nodes) => {
let left_label = nodes.0.lookup_label(label);
match left_label {
LookupLabelResult::Greater => {
let right_label = nodes.1.lookup_label(label);
match right_label {
LookupLabelResult::Less => LookupLabelResult::Absent,
result => result,
}
}
LookupLabelResult::Unknown => {
let right_label = nodes.1.lookup_label(label);
match right_label {
LookupLabelResult::Less => LookupLabelResult::Unknown,
result => result,
}
}
result => result,
}
}
HashTreeNode::Pruned(_) => LookupLabelResult::Unknown,
_ => LookupLabelResult::Absent,
}
}
fn lookup_path(&self, path: &mut dyn Iterator<Item = impl AsRef<[u8]>>) -> LookupResult<'_> {
use HashTreeNode::*;
use LookupLabelResult as LLR;
use LookupResult::*;
match (
path.next()
.map(|segment| self.lookup_label(segment.as_ref())),
self,
) {
(Some(LLR::Found(node)), _) => node.lookup_path(path),
(None, Leaf(v)) => Found(v.as_ref()),
(None, Empty()) => Absent,
(None, Pruned(_)) => Unknown,
(None, Labeled(_, _) | Fork(_)) => Error,
(Some(LLR::Unknown), _) => Unknown,
(Some(LLR::Absent | LLR::Greater | LLR::Less), _) => Absent,
}
}
}
impl<Storage: Clone + AsRef<[u8]>> HashTreeNode<Storage> {
fn lookup_subtree(
&self,
path: &mut dyn Iterator<Item = impl AsRef<[u8]>>,
) -> SubtreeLookupResult<Storage> {
use LookupLabelResult as LLR;
use SubtreeLookupResult::*;
match path
.next()
.map(|segment| self.lookup_label(segment.as_ref()))
{
Some(LLR::Found(node)) => node.lookup_subtree(path),
Some(LLR::Unknown) => Unknown,
Some(LLR::Absent | LLR::Greater | LLR::Less) => Absent,
None => Found(HashTree {
root: self.to_owned(),
}),
}
}
fn list_paths(&self, path: &Vec<Label<Storage>>) -> Vec<Vec<Label<Storage>>> {
match self {
HashTreeNode::Empty() => vec![],
HashTreeNode::Fork(nodes) => {
[nodes.0.list_paths(path), nodes.1.list_paths(path)].concat()
}
HashTreeNode::Leaf(_) => vec![path.clone()],
HashTreeNode::Labeled(l, node) => {
let mut path = path.clone();
path.push(l.clone());
node.list_paths(&path)
}
HashTreeNode::Pruned(_) => vec![],
}
}
}
#[cfg(feature = "serde")]
mod serde_impl {
use std::{borrow::Cow, fmt, marker::PhantomData};
use crate::serde_impl::{CowStorage, SliceStorage, Storage, VecStorage};
use super::{HashTree, HashTreeNode, Label};
use serde::{
de::{self, SeqAccess, Visitor},
ser::SerializeSeq,
Deserialize, Deserializer, Serialize, Serializer,
};
use serde_bytes::Bytes;
impl<Storage: AsRef<[u8]>> Serialize for Label<Storage> {
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
if serializer.is_human_readable() {
let mut s = String::with_capacity(self.hex_len());
self.write_hex(&mut s).unwrap();
s.serialize(serializer)
} else {
serializer.serialize_bytes(self.0.as_ref())
}
}
}
impl<'de, Storage: AsRef<[u8]>> Deserialize<'de> for Label<Storage>
where
Storage: serde_bytes::Deserialize<'de>,
{
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
serde_bytes::deserialize(deserializer).map(Self)
}
}
impl<Storage: AsRef<[u8]>> Serialize for HashTreeNode<Storage> {
fn serialize<S>(
&self,
serializer: S,
) -> Result<<S as Serializer>::Ok, <S as Serializer>::Error>
where
S: Serializer,
{
match self {
HashTreeNode::Empty() => {
let mut seq = serializer.serialize_seq(Some(1))?;
seq.serialize_element(&0u8)?;
seq.end()
}
HashTreeNode::Fork(tree) => {
let mut seq = serializer.serialize_seq(Some(3))?;
seq.serialize_element(&1u8)?;
seq.serialize_element(&tree.0)?;
seq.serialize_element(&tree.1)?;
seq.end()
}
HashTreeNode::Labeled(label, tree) => {
let mut seq = serializer.serialize_seq(Some(3))?;
seq.serialize_element(&2u8)?;
seq.serialize_element(Bytes::new(label.as_bytes()))?;
seq.serialize_element(&tree)?;
seq.end()
}
HashTreeNode::Leaf(leaf_bytes) => {
let mut seq = serializer.serialize_seq(Some(2))?;
seq.serialize_element(&3u8)?;
seq.serialize_element(Bytes::new(leaf_bytes.as_ref()))?;
seq.end()
}
HashTreeNode::Pruned(digest) => {
let mut seq = serializer.serialize_seq(Some(2))?;
seq.serialize_element(&4u8)?;
seq.serialize_element(Bytes::new(digest))?;
seq.end()
}
}
}
}
struct HashTreeNodeVisitor<S>(PhantomData<S>);
impl<'de, S: Storage> Visitor<'de> for HashTreeNodeVisitor<S>
where
HashTreeNode<S::Value<'de>>: Deserialize<'de>,
{
type Value = HashTreeNode<S::Value<'de>>;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str(
"HashTree encoded as a sequence of the form \
hash-tree ::= [0] | [1 hash-tree hash-tree] | [2 bytes hash-tree] | [3 bytes] | [4 hash]",
)
}
fn visit_seq<V>(self, mut seq: V) -> Result<Self::Value, V::Error>
where
V: SeqAccess<'de>,
{
let tag: u8 = seq
.next_element()?
.ok_or_else(|| de::Error::invalid_length(0, &self))?;
match tag {
0 => {
if let Some(de::IgnoredAny) = seq.next_element()? {
return Err(de::Error::invalid_length(2, &self));
}
Ok(HashTreeNode::Empty())
}
1 => {
let left = seq
.next_element()?
.ok_or_else(|| de::Error::invalid_length(1, &self))?;
let right = seq
.next_element()?
.ok_or_else(|| de::Error::invalid_length(2, &self))?;
if let Some(de::IgnoredAny) = seq.next_element()? {
return Err(de::Error::invalid_length(4, &self));
}
Ok(HashTreeNode::Fork(Box::new((left, right))))
}
2 => {
let label = seq
.next_element()?
.ok_or_else(|| de::Error::invalid_length(1, &self))?;
let subtree = seq
.next_element()?
.ok_or_else(|| de::Error::invalid_length(2, &self))?;
if let Some(de::IgnoredAny) = seq.next_element()? {
return Err(de::Error::invalid_length(4, &self));
}
Ok(HashTreeNode::Labeled(
Label(S::convert(label)),
Box::new(subtree),
))
}
3 => {
let bytes = seq
.next_element()?
.ok_or_else(|| de::Error::invalid_length(1, &self))?;
if let Some(de::IgnoredAny) = seq.next_element()? {
return Err(de::Error::invalid_length(3, &self));
}
Ok(HashTreeNode::Leaf(S::convert(bytes)))
}
4 => {
let digest_bytes: &serde_bytes::Bytes = seq
.next_element()?
.ok_or_else(|| de::Error::invalid_length(1, &self))?;
if let Some(de::IgnoredAny) = seq.next_element()? {
return Err(de::Error::invalid_length(3, &self));
}
let digest =
std::convert::TryFrom::try_from(digest_bytes.as_ref()).map_err(|_| {
de::Error::invalid_length(digest_bytes.len(), &"Expected digest blob")
})?;
Ok(HashTreeNode::Pruned(digest))
}
_ => Err(de::Error::custom(format!(
"Unknown tag: {}, expected the tag to be one of {{0, 1, 2, 3, 4}}",
tag
))),
}
}
}
impl<'de> Deserialize<'de> for HashTreeNode<Vec<u8>> {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
deserializer.deserialize_seq(HashTreeNodeVisitor::<VecStorage>(PhantomData))
}
}
impl<'de> Deserialize<'de> for HashTreeNode<&'de [u8]> {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
deserializer.deserialize_seq(HashTreeNodeVisitor::<SliceStorage>(PhantomData))
}
}
impl<'de> Deserialize<'de> for HashTreeNode<Cow<'de, [u8]>> {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
deserializer.deserialize_seq(HashTreeNodeVisitor::<CowStorage>(PhantomData))
}
}
impl<Storage: AsRef<[u8]>> serde::Serialize for HashTree<Storage> {
fn serialize<S>(
&self,
serializer: S,
) -> Result<<S as serde::Serializer>::Ok, <S as serde::Serializer>::Error>
where
S: serde::Serializer,
{
self.root.serialize(serializer)
}
}
impl<'de, Storage: AsRef<[u8]>> serde::Deserialize<'de> for HashTree<Storage>
where
HashTreeNode<Storage>: Deserialize<'de>,
{
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
Ok(HashTree {
root: HashTreeNode::deserialize(deserializer)?,
})
}
}
}
pub fn fork_hash(l: &Hash, r: &Hash) -> Hash {
let mut h = domain_sep("ic-hashtree-fork");
h.update(&l[..]);
h.update(&r[..]);
h.finalize().into()
}
pub fn leaf_hash(data: &[u8]) -> Hash {
let mut h = domain_sep("ic-hashtree-leaf");
h.update(data);
h.finalize().into()
}
pub fn labeled_hash(label: &[u8], content_hash: &Hash) -> Hash {
let mut h = domain_sep("ic-hashtree-labeled");
h.update(label);
h.update(&content_hash[..]);
h.finalize().into()
}
fn domain_sep(s: &str) -> sha2::Sha256 {
let buf: [u8; 1] = [s.len() as u8];
let mut h = sha2::Sha256::new();
h.update(&buf[..]);
h.update(s.as_bytes());
h
}
#[cfg(test)]
mod tests;