1
Fork 0

First public release

This commit is contained in:
Tommi Virtanen 2025-03-27 13:45:43 -06:00 committed by Tommi Virtanen
parent da7aba533f
commit 5e1ec711c4
457 changed files with 27082 additions and 0 deletions

6
.cargo/mutants.toml Normal file
View file

@ -0,0 +1,6 @@
exclude_re = [
# test tool, not part of system under test
"<impl arbitrary::Arbitrary<",
# we do not test debug output, at least for now
"<impl std::fmt::Debug ",
]

20
.config/nextest.toml Normal file
View file

@ -0,0 +1,20 @@
[profile.default]
slow-timeout = { period = "5s", terminate-after = 2 }
# We're likely primarily I/O bound, even in tests.
# This might change if we add durability cheating features for tests, or move to primarily testing against storage abstractions.
test-threads = 20
[[profile.default.overrides]]
# The SQLite sqllogictest test suite has tens of thousands of queries in a single "test", give them more time to complete.
# Expected completion time is <30 seconds per test, but leave a generous margin.
filter = "package(=kanto) and kind(=test) and binary(=sqllogictest_sqlite)"
threads-required = 4
slow-timeout = { period = "1m", terminate-after = 5 }
[profile.valgrind]
slow-timeout = { period = "5m", terminate-after = 2 }
[profile.default-miri]
slow-timeout = { period = "5m", terminate-after = 2 }
# Miri cannot handle arbitrary C/C++ code.
default-filter = "not rdeps(rocky)"

5
.gitignore vendored Normal file
View file

@ -0,0 +1,5 @@
/target/
/result
/dev-db/
# valgrind artifacts
vgcore.*

4588
Cargo.lock generated Normal file

File diff suppressed because it is too large Load diff

227
Cargo.toml Normal file
View file

@ -0,0 +1,227 @@
[workspace]
resolver = "3"
members = ["crates/*", "crates/*/fuzz", "task/*"]
default-members = ["crates/*"]
[workspace.package]
# explicitly set a dummy version so nix build tasks using the workspace don't need to each set it
version = "0.0.0"
authors = ["Tommi Virtanen <tv@eagain.net>"]
license = "Apache-2.0"
homepage = "https://kantodb.com"
repository = "https://git.kantodb.com/kantodb/kantodb"
edition = "2024"
rust-version = "1.85.0"
[workspace.dependencies]
anyhow = "1.0.89"
arbitrary = { version = "1.4.1", features = [
"derive",
] } # sync major version against [`libfuzzer-sys`] re-export or fuzzing will break.
arbitrary-arrow = { version = "0.1.0", path = "crates/arbitrary-arrow" }
arrow = "54.2.0"
async-trait = "0.1.83"
clap = { version = "4.5.19", features = ["derive", "env"] }
dashmap = "6.1.0"
datafusion = "46.0.1"
duplicate = "2.0.0"
futures = "0.3.30"
futures-lite = "2.3.0"
gat-lending-iterator = "0.1.6"
hex = "0.4.3"
ignore = "0.4.23"
indexmap = "2.7.0" # sync major version against [`rkyv`] feature `indexmap-2`
kanto = { version = "0.1.0", path = "crates/kanto" }
kanto-backend-rocksdb = { version = "0.1.0", path = "crates/backend-rocksdb" }
kanto-index-format-v1 = { version = "0.1.0", path = "crates/index-format-v1" }
kanto-key-format-v1 = { version = "0.1.0", path = "crates/key-format-v1" }
kanto-meta-format-v1 = { version = "0.1.0", path = "crates/meta-format-v1" }
kanto-protocol-postgres = { version = "0.1.0", path = "crates/protocol-postgres" }
kanto-record-format-v1 = { version = "0.1.0", path = "crates/record-format-v1" }
kanto-testutil = { version = "0.1.0", path = "crates/testutil" }
kanto-tunables = { version = "0.1.0", path = "crates/tunables" }
libc = "0.2.153"
libfuzzer-sys = "0.4.9"
librocksdb-sys = { version = "0.17.1", features = ["lz4", "static", "zstd"] }
libtest-mimic = "0.8.1"
maybe-tracing = { version = "0.1.0", path = "crates/maybe-tracing" }
parking_lot = "0.12.3"
paste = "1.0.15"
pgwire = "0.28.0"
regex = "1.10.6"
rkyv = { version = "0.8.10", features = ["unaligned", "indexmap-2"] }
rkyv_util = { version = "0.1.0-alpha.1" }
rocky = { version = "0.1.0", path = "crates/rocky" }
smallvec = { version = "1.13.2", features = ["union"] }
sqllogictest = "0.28.0"
tempfile = "3.10.1"
test-log = { version = "0.2.16", default-features = false, features = [
"trace",
] }
thiserror = "2.0.6"
tokio = { version = "1.40.0", features = ["macros", "rt-multi-thread"] }
tracing = "0.1.40"
tracing-subscriber = { version = "0.3.18", default-features = false, features = [
"env-filter",
"fmt",
"json",
] }
tree-sitter = "0.23.0"
tree-sitter-md = { version = "0.3.2", features = ["parser"] }
tree-sitter-rust = "0.23.0"
unindent = "0.2.3"
walkdir = "2.5.0"
[workspace.lints.rust]
absolute_paths_not_starting_with_crate = "warn"
ambiguous_negative_literals = "warn"
closure_returning_async_block = "warn"
dead_code = "warn"
deprecated_safe_2024 = "deny"
elided_lifetimes_in_paths = "warn"
explicit_outlives_requirements = "warn"
future_incompatible = { level = "warn", priority = -1 }
if_let_rescope = "warn"
impl_trait_redundant_captures = "warn"
keyword_idents_2024 = "warn"
let_underscore_drop = "warn"
macro_use_extern_crate = "deny"
meta_variable_misuse = "warn"
missing_abi = "deny"
missing_copy_implementations = "allow" # TODO switch this to warn #urgency/low
missing_debug_implementations = "allow" # TODO switch this to warn #urgency/low
missing_docs = "warn"
missing_unsafe_on_extern = "deny"
non_ascii_idents = "forbid"
redundant_imports = "warn"
redundant_lifetimes = "warn"
rust_2018_idioms = { level = "warn", priority = -1 }
trivial_casts = "warn"
trivial_numeric_casts = "warn"
unexpected_cfgs = { level = "allow", check-cfg = ["cfg(kani)"] }
unit_bindings = "warn"
unnameable_types = "warn"
unreachable_pub = "warn"
unsafe_attr_outside_unsafe = "deny"
unsafe_code = "deny"
unsafe_op_in_unsafe_fn = "deny"
unused_crate_dependencies = "allow" # TODO false positives for packages with bin/lib/integration tests; desirable because this catches things `cargo-shear` does not <https://github.com/rust-lang/rust/issues/95513> #waiting #ecosystem/rust #severity/low #urgency/low
unused_extern_crates = "warn"
unused_import_braces = "warn"
unused_lifetimes = "warn"
unused_macro_rules = "warn"
unused_qualifications = "warn"
unused_results = "warn"
variant_size_differences = "warn"
[workspace.lints.clippy]
all = { level = "warn", priority = -2 }
allow_attributes = "warn"
arithmetic_side_effects = "warn"
as_conversions = "warn"
as_underscore = "warn"
assertions_on_result_states = "deny"
cargo = { level = "warn", priority = -1 }
cast_lossless = "warn"
cast_possible_truncation = "warn"
cfg_not_test = "warn"
complexity = { level = "warn", priority = -1 }
dbg_macro = "warn"
default_numeric_fallback = "warn"
disallowed_script_idents = "deny"
else_if_without_else = "warn"
empty_drop = "deny"
empty_enum_variants_with_brackets = "warn"
empty_structs_with_brackets = "warn"
error_impl_error = "warn"
exhaustive_enums = "warn"
exhaustive_structs = "warn"
exit = "deny"
expect_used = "allow" # TODO switch this to warn
field_scoped_visibility_modifiers = "allow" # TODO switch this to warn #urgency/medium
float_cmp_const = "warn"
fn_to_numeric_cast_any = "deny"
format_push_string = "warn"
host_endian_bytes = "deny"
if_not_else = "allow"
if_then_some_else_none = "warn"
indexing_slicing = "warn"
infinite_loop = "warn"
integer_division = "warn"
iter_over_hash_type = "warn"
large_include_file = "warn"
let_underscore_must_use = "warn"
let_underscore_untyped = "warn"
lossy_float_literal = "deny"
manual_is_power_of_two = "warn"
map_err_ignore = "warn"
map_unwrap_or = "allow"
match_like_matches_macro = "allow"
mem_forget = "deny"
missing_asserts_for_indexing = "warn"
missing_const_for_fn = "warn"
missing_errors_doc = "allow" # TODO write docstrings, remove this #urgency/medium
missing_inline_in_public_items = "allow" # TODO consider warn #urgency/low #performance
missing_panics_doc = "allow" # TODO write docstrings, remove this #urgency/medium
mixed_read_write_in_expression = "deny"
module_name_repetitions = "warn"
modulo_arithmetic = "warn"
multiple_crate_versions = "allow" # TODO i wish there was a way to fix this <https://github.com/rust-lang/rust-clippy/issues/9756> #urgency/low
multiple_unsafe_ops_per_block = "deny"
mutex_atomic = "warn"
non_ascii_literal = "deny"
non_zero_suggestions = "warn"
panic = "warn"
panic_in_result_fn = "warn" # note, this allows `debug_assert` but not `assert`
partial_pub_fields = "warn"
pathbuf_init_then_push = "warn"
pedantic = { level = "warn", priority = -1 }
perf = { level = "warn", priority = -1 }
print_stderr = "warn"
print_stdout = "warn"
pub_without_shorthand = "warn"
rc_buffer = "warn"
rc_mutex = "deny"
redundant_type_annotations = "warn"
ref_patterns = "warn"
renamed_function_params = "warn"
rest_pat_in_fully_bound_structs = "warn"
same_name_method = "deny"
semicolon_inside_block = "warn"
separated_literal_suffix = "warn"
set_contains_or_insert = "warn"
significant_drop_in_scrutinee = "warn"
similar_names = "allow" # too eager
str_to_string = "warn"
string_add = "warn"
string_lit_chars_any = "warn"
string_slice = "warn"
string_to_string = "warn"
style = { level = "warn", priority = -1 }
suspicious = { level = "warn", priority = -1 }
suspicious_xor_used_as_pow = "warn"
tests_outside_test_module = "allow" # TODO to switch this to warn <https://github.com/rust-lang/rust-clippy/issues/11024> #waiting #urgency/low #severity/low
todo = "warn"
try_err = "warn"
undocumented_unsafe_blocks = "allow" # TODO write safety docs, switch this to warn #urgency/low
unimplemented = "warn"
unnecessary_safety_comment = "warn"
unnecessary_safety_doc = "warn"
unnecessary_self_imports = "deny"
unreachable = "deny"
unused_result_ok = "warn"
unused_trait_names = "warn"
unwrap_in_result = "warn"
unwrap_used = "warn"
use_debug = "warn"
verbose_file_reads = "warn"
wildcard_imports = "warn"
[workspace.metadata.spellcheck]
config = "cargo-spellcheck.toml"
[workspace.metadata.crane]
name = "kantodb"
[profile.release]
lto = true

202
LICENSE Normal file
View file

@ -0,0 +1,202 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

24
cargo-spellcheck.toml Normal file
View file

@ -0,0 +1,24 @@
# TODO this is too painful with commented-out code, but commented-out code is a code smell of its own
#dev_comments = true
[Hunspell]
# TODO setting config overrides defaults <https://github.com/drahnr/cargo-spellcheck/issues/293#issuecomment-2397728638> #ecosystem/cargo-spellcheck #dev #severity/low
use_builtin = true
# Use built-in dictionary only, for reproducibility.
skip_os_lookups = true
search_dirs = ["."]
extra_dictionaries = ["hunspell.dict"]
# TODO "C++" is tokenized wrong <https://github.com/drahnr/cargo-spellcheck/issues/272> #ecosystem/cargo-spellcheck #dev #severity/low
tokenization_splitchars = "\",;:!?#(){}[]|/_-+='`&@§¶…"
[Hunspell.quirks]
transform_regex = [
# 2025Q1 style references to dates.
"^(20\\d\\d)Q([1234])$",
# version numbers; tokenization splits off everything at the dot so be conservative
"^v[012]$",
]

25
clippy.toml Normal file
View file

@ -0,0 +1,25 @@
allow-dbg-in-tests = true
allow-expect-in-tests = true
allow-indexing-slicing-in-tests = true
allow-panic-in-tests = true
allow-print-in-tests = true
allow-unwrap-in-tests = true
avoid-breaking-exported-api = false # TODO let clippy suggest more radical changes; change when approaching stable #urgency/low #severity/medium
disallowed-methods = [
{ path = "std::iter::Iterator::for_each", reason = "prefer `for` for side-effects" },
{ path = "std::iter::Iterator::try_for_each", reason = "prefer `for` for side-effects" },
{ path = "std::option::Option::map_or", reason = "prefer `map(..).unwrap_or(..)` for legibility" },
{ path = "std::option::Option::map_or_else", reason = "prefer `map(..).unwrap_or_else(..)` for legibility" },
{ path = "std::result::Result::map_or", reason = "prefer `map(..).unwrap_or(..)` for legibility" },
{ path = "std::result::Result::map_or_else", reason = "prefer `map(..).unwrap_or_else(..)` for legibility" },
]
doc-valid-idents = [
"..",
"CapnProto",
"DataFusion",
"KantoDB",
"RocksDB",
"SQLite",
]
semicolon-inside-block-ignore-singleline = true
warn-on-all-wildcard-imports = true

View file

@ -0,0 +1,28 @@
[package]
name = "arbitrary-arrow"
version = "0.1.0"
description = "Fuzzing support for Apache Arrow"
keywords = ["arrow", "fuzzing", "testing"]
categories = ["development-tools::testing"]
authors.workspace = true
license.workspace = true
repository.workspace = true
publish = false # TODO merge upstream or publish crate #severity/high #urgency/medium
edition.workspace = true
rust-version.workspace = true
[package.metadata.cargo-shear]
ignored = ["test-log"]
[dependencies]
arbitrary = { workspace = true }
arrow = { workspace = true }
tracing = { workspace = true }
[dev-dependencies]
getrandom = "0.3.1"
maybe-tracing = { workspace = true }
test-log = { workspace = true, features = ["trace"] }
[lints]
workspace = true

View file

@ -0,0 +1,4 @@
/target/
/corpus/
/artifacts/
/coverage/

View file

@ -0,0 +1,24 @@
[package]
name = "arbitrary-arrow-fuzz"
version = "0.0.0"
repository.workspace = true
authors.workspace = true
license.workspace = true
publish = false
edition.workspace = true
rust-version.workspace = true
[package.metadata]
cargo-fuzz = true
[dependencies]
arbitrary-arrow = { workspace = true }
arrow = { workspace = true, features = ["prettyprint"] }
libfuzzer-sys = { workspace = true }
[[bin]]
name = "record_batch_pretty"
path = "fuzz_targets/fuzz_record_batch_pretty.rs"
test = false
doc = false
bench = false

View file

@ -0,0 +1,33 @@
#![no_main]
use arbitrary_arrow::arbitrary;
use libfuzzer_sys::fuzz_target;
#[derive(Debug, arbitrary::Arbitrary)]
struct Input {
record_batches: Box<[arbitrary_arrow::RecordBatchBuilder]>,
}
fn is_all_printable(s: &str) -> bool {
const PRINTABLE: std::ops::RangeInclusive<u8> = 0x20..=0x7e;
s.bytes().all(|b| b == b'\n' || PRINTABLE.contains(&b))
}
fuzz_target!(|input: Input| {
let Input { record_batches } = input;
let record_batches = record_batches
.into_iter()
.map(|builder| builder.build())
.collect::<Box<_>>();
let display = arrow::util::pretty::pretty_format_batches(&record_batches).unwrap();
let output = format!("{display}");
if !is_all_printable(&output) {
for line in output.lines() {
println!("{line:#?}");
}
// TODO this would be desirable, but it's not true yet
// panic!("non-printable output: {b:#?}");
}
});

View file

@ -0,0 +1,498 @@
use std::sync::Arc;
use arbitrary::Arbitrary as _;
use arbitrary::Unstructured;
use arrow::array::Array;
use arrow::array::BooleanBufferBuilder;
use arrow::array::FixedSizeBinaryArray;
use arrow::array::GenericByteBuilder;
use arrow::array::GenericByteViewBuilder;
use arrow::array::PrimitiveArray;
use arrow::buffer::Buffer;
use arrow::buffer::MutableBuffer;
use arrow::buffer::NullBuffer;
use arrow::datatypes::ArrowNativeType as _;
use arrow::datatypes::ArrowTimestampType;
use arrow::datatypes::BinaryViewType;
use arrow::datatypes::ByteArrayType;
use arrow::datatypes::ByteViewType;
use arrow::datatypes::DataType;
use arrow::datatypes::Field;
use arrow::datatypes::GenericBinaryType;
use arrow::datatypes::GenericStringType;
use arrow::datatypes::IntervalUnit;
use arrow::datatypes::StringViewType;
use arrow::datatypes::TimeUnit;
/// Make an arbitrary [`PrimitiveArray<T>`].
#[tracing::instrument(skip(unstructured), ret)]
pub fn arbitrary_primitive_array<'a, T>(
unstructured: &mut Unstructured<'a>,
nullable: bool,
row_count: usize,
) -> Result<PrimitiveArray<T>, arbitrary::Error>
where
T: arrow::datatypes::ArrowPrimitiveType,
{
// Arrow primitive arrays are by definition safe to fill with random bytes.
let width = T::Native::get_byte_width();
let values = {
let capacity = width
.checked_mul(row_count)
.ok_or(arbitrary::Error::IncorrectFormat)?;
let mut buf = MutableBuffer::from_len_zeroed(capacity);
debug_assert_eq!(buf.as_slice().len(), capacity);
unstructured.fill_buffer(buf.as_slice_mut())?;
buf.into()
};
let nulls = if nullable {
let mut builder = BooleanBufferBuilder::new(row_count);
builder.resize(row_count);
debug_assert_eq!(builder.as_slice().len(), row_count.div_ceil(8));
unstructured.fill_buffer(builder.as_slice_mut())?;
let buffer = NullBuffer::new(builder.finish());
Some(buffer)
} else {
None
};
let builder = PrimitiveArray::<T>::new(values, nulls);
Ok(builder)
}
#[tracing::instrument(skip(unstructured), ret)]
fn arbitrary_primitive_array_dyn<'a, T>(
unstructured: &mut Unstructured<'a>,
nullable: bool,
row_count: usize,
) -> Result<Arc<dyn Array>, arbitrary::Error>
where
T: arrow::datatypes::ArrowPrimitiveType,
{
let array: PrimitiveArray<T> = arbitrary_primitive_array(unstructured, nullable, row_count)?;
Ok(Arc::new(array))
}
/// Make an arbitrary [`PrimitiveArray`] with any [`ArrowTimestampType`], in the given time zone.
#[tracing::instrument(skip(unstructured), ret)]
pub fn arbitrary_timestamp_array<'a, T>(
unstructured: &mut Unstructured<'a>,
nullable: bool,
row_count: usize,
time_zone: Option<&Arc<str>>,
) -> Result<PrimitiveArray<T>, arbitrary::Error>
where
T: ArrowTimestampType,
{
let array = arbitrary_primitive_array::<T>(unstructured, nullable, row_count)?
.with_timezone_opt(time_zone.cloned());
Ok(array)
}
#[tracing::instrument(skip(unstructured), ret)]
fn arbitrary_timestamp_array_dyn<'a, T>(
unstructured: &mut Unstructured<'a>,
field: &Field,
row_count: usize,
time_zone: Option<&Arc<str>>,
) -> Result<Arc<dyn Array>, arbitrary::Error>
where
T: ArrowTimestampType,
{
let array =
arbitrary_timestamp_array::<T>(unstructured, field.is_nullable(), row_count, time_zone)?;
Ok(Arc::new(array))
}
/// Make an arbitrary [`GenericByteArray`](arrow::array::GenericByteArray), such as [`StringArray`](arrow::array::StringArray) or [`BinaryArray`](arrow::array::BinaryArray).
///
/// # Examples
///
/// ```rust
/// use arbitrary_arrow::arbitrary::Arbitrary as _;
/// use arbitrary_arrow::arbitrary_byte_array;
/// use arrow::datatypes::GenericStringType;
/// # let unstructured = &mut arbitrary_arrow::arbitrary::Unstructured::new(b"fake entropy");
/// let row_count = 10;
/// let array = arbitrary_byte_array::<GenericStringType<i32>, _>(row_count, || {
/// Option::<String>::arbitrary(unstructured)
/// })?;
/// # Ok::<(), arbitrary_arrow::arbitrary::Error>(())
/// ```
#[tracing::instrument(skip(generate), ret)]
pub fn arbitrary_byte_array<'a, T, V>(
row_count: usize,
generate: impl FnMut() -> Result<Option<V>, arbitrary::Error>,
) -> Result<Arc<dyn Array>, arbitrary::Error>
where
T: ByteArrayType,
V: AsRef<T::Native>,
{
let mut builder = GenericByteBuilder::<T>::with_capacity(row_count, 0);
for result in std::iter::repeat_with(generate).take(row_count) {
let value = result?;
builder.append_option(value);
}
let array = builder.finish();
Ok(Arc::new(array))
}
/// Make an arbitrary [`GenericByteViewArray`](arrow::array::GenericByteViewArray), that is either a [`StringViewArray`](arrow::array::StringViewArray) or a [`BinaryViewArray`](arrow::array::BinaryViewArray).
///
/// # Examples
///
/// ```rust
/// use arbitrary_arrow::arbitrary_byte_view_array;
/// use arrow::datatypes::StringViewType;
///
/// use crate::arbitrary_arrow::arbitrary::Arbitrary as _;
/// # let unstructured = &mut arbitrary_arrow::arbitrary::Unstructured::new(b"fake entropy");
/// let row_count = 10;
/// let array = arbitrary_byte_view_array::<StringViewType, _>(row_count, || {
/// String::arbitrary(unstructured).map(|s| Some(s))
/// })?;
/// # Ok::<(), arbitrary_arrow::arbitrary::Error>(())
/// ```
#[tracing::instrument(skip(generate), ret)]
pub fn arbitrary_byte_view_array<'a, T, V>(
row_count: usize,
generate: impl FnMut() -> Result<Option<V>, arbitrary::Error>,
) -> Result<Arc<dyn Array>, arbitrary::Error>
where
T: ByteViewType,
V: AsRef<T::Native>,
{
let mut builder = GenericByteViewBuilder::<T>::with_capacity(row_count);
for result in std::iter::repeat_with(generate).take(row_count) {
let value = result?;
builder.append_option(value);
}
let array = builder.finish();
Ok(Arc::new(array))
}
/// Make an [arbitrary] array of the type described by a [`Field`] of an Arrow [`Schema`](arrow::datatypes::Schema).
///
/// Implemented as a helper function, not via the [`Arbitrary`](arbitrary::Arbitrary) trait, because we need to construct arrays with specific data type, nullability, and row count.
#[tracing::instrument(skip(unstructured), ret)]
pub fn arbitrary_array<'a>(
unstructured: &mut Unstructured<'a>,
field: &Field,
row_count: usize,
) -> Result<Arc<dyn Array>, arbitrary::Error> {
let array: Arc<dyn Array> = match field.data_type() {
DataType::Null => Arc::new(arrow::array::NullArray::new(row_count)),
DataType::Boolean => {
if field.is_nullable() {
let vec = std::iter::repeat_with(|| unstructured.arbitrary::<Option<bool>>())
.take(row_count)
.collect::<Result<Vec<_>, _>>()?;
Arc::new(arrow::array::BooleanArray::from(vec))
} else {
let vec = std::iter::repeat_with(|| unstructured.arbitrary::<bool>())
.take(row_count)
.collect::<Result<Vec<_>, _>>()?;
Arc::new(arrow::array::BooleanArray::from(vec))
}
}
DataType::Int8 => arbitrary_primitive_array_dyn::<arrow::datatypes::Int8Type>(
unstructured,
field.is_nullable(),
row_count,
)?,
DataType::Int16 => arbitrary_primitive_array_dyn::<arrow::datatypes::Int16Type>(
unstructured,
field.is_nullable(),
row_count,
)?,
DataType::Int32 => arbitrary_primitive_array_dyn::<arrow::datatypes::Int32Type>(
unstructured,
field.is_nullable(),
row_count,
)?,
DataType::Int64 => arbitrary_primitive_array_dyn::<arrow::datatypes::Int64Type>(
unstructured,
field.is_nullable(),
row_count,
)?,
DataType::UInt8 => arbitrary_primitive_array_dyn::<arrow::datatypes::UInt8Type>(
unstructured,
field.is_nullable(),
row_count,
)?,
DataType::UInt16 => arbitrary_primitive_array_dyn::<arrow::datatypes::UInt16Type>(
unstructured,
field.is_nullable(),
row_count,
)?,
DataType::UInt32 => arbitrary_primitive_array_dyn::<arrow::datatypes::UInt32Type>(
unstructured,
field.is_nullable(),
row_count,
)?,
DataType::UInt64 => arbitrary_primitive_array_dyn::<arrow::datatypes::UInt64Type>(
unstructured,
field.is_nullable(),
row_count,
)?,
DataType::Float16 => arbitrary_primitive_array_dyn::<arrow::datatypes::Float16Type>(
unstructured,
field.is_nullable(),
row_count,
)?,
DataType::Float32 => arbitrary_primitive_array_dyn::<arrow::datatypes::Float32Type>(
unstructured,
field.is_nullable(),
row_count,
)?,
DataType::Float64 => arbitrary_primitive_array_dyn::<arrow::datatypes::Float64Type>(
unstructured,
field.is_nullable(),
row_count,
)?,
DataType::Timestamp(TimeUnit::Second, time_zone) => {
arbitrary_timestamp_array_dyn::<arrow::datatypes::TimestampSecondType>(
unstructured,
field,
row_count,
time_zone.as_ref(),
)?
}
DataType::Timestamp(TimeUnit::Millisecond, time_zone) => {
arbitrary_timestamp_array_dyn::<arrow::datatypes::TimestampMillisecondType>(
unstructured,
field,
row_count,
time_zone.as_ref(),
)?
}
DataType::Timestamp(TimeUnit::Microsecond, time_zone) => {
arbitrary_timestamp_array_dyn::<arrow::datatypes::TimestampMicrosecondType>(
unstructured,
field,
row_count,
time_zone.as_ref(),
)?
}
DataType::Timestamp(TimeUnit::Nanosecond, time_zone) => {
arbitrary_timestamp_array_dyn::<arrow::datatypes::TimestampNanosecondType>(
unstructured,
field,
row_count,
time_zone.as_ref(),
)?
}
DataType::Date32 => arbitrary_primitive_array_dyn::<arrow::datatypes::Date32Type>(
unstructured,
field.is_nullable(),
row_count,
)?,
DataType::Date64 => arbitrary_primitive_array_dyn::<arrow::datatypes::Date64Type>(
unstructured,
field.is_nullable(),
row_count,
)?,
DataType::Time32(TimeUnit::Second) => arbitrary_primitive_array_dyn::<
arrow::datatypes::Time32SecondType,
>(
unstructured, field.is_nullable(), row_count
)?,
DataType::Time32(TimeUnit::Millisecond) => arbitrary_primitive_array_dyn::<
arrow::datatypes::Time32MillisecondType,
>(
unstructured, field.is_nullable(), row_count
)?,
DataType::Time32(TimeUnit::Nanosecond | TimeUnit::Microsecond)
| DataType::Time64(TimeUnit::Second | TimeUnit::Millisecond) => {
return Err(arbitrary::Error::IncorrectFormat);
}
DataType::Time64(TimeUnit::Microsecond) => arbitrary_primitive_array_dyn::<
arrow::datatypes::Time64MicrosecondType,
>(
unstructured, field.is_nullable(), row_count
)?,
DataType::Time64(TimeUnit::Nanosecond) => arbitrary_primitive_array_dyn::<
arrow::datatypes::Time64NanosecondType,
>(
unstructured, field.is_nullable(), row_count
)?,
DataType::Duration(TimeUnit::Second) => arbitrary_primitive_array_dyn::<
arrow::datatypes::DurationSecondType,
>(
unstructured, field.is_nullable(), row_count
)?,
DataType::Duration(TimeUnit::Millisecond) => {
arbitrary_primitive_array_dyn::<arrow::datatypes::DurationMillisecondType>(
unstructured,
field.is_nullable(),
row_count,
)?
}
DataType::Duration(TimeUnit::Microsecond) => {
arbitrary_primitive_array_dyn::<arrow::datatypes::DurationMicrosecondType>(
unstructured,
field.is_nullable(),
row_count,
)?
}
DataType::Duration(TimeUnit::Nanosecond) => arbitrary_primitive_array_dyn::<
arrow::datatypes::DurationNanosecondType,
>(
unstructured, field.is_nullable(), row_count
)?,
DataType::Interval(IntervalUnit::YearMonth) => {
arbitrary_primitive_array_dyn::<arrow::datatypes::IntervalYearMonthType>(
unstructured,
field.is_nullable(),
row_count,
)?
}
DataType::Interval(IntervalUnit::DayTime) => {
arbitrary_primitive_array_dyn::<arrow::datatypes::IntervalDayTimeType>(
unstructured,
field.is_nullable(),
row_count,
)?
}
DataType::Interval(IntervalUnit::MonthDayNano) => {
arbitrary_primitive_array_dyn::<arrow::datatypes::IntervalMonthDayNanoType>(
unstructured,
field.is_nullable(),
row_count,
)?
}
DataType::Binary => {
if field.is_nullable() {
arbitrary_byte_array::<GenericBinaryType<i32>, _>(row_count, || {
Option::<Box<[u8]>>::arbitrary(unstructured)
})?
} else {
arbitrary_byte_array::<GenericBinaryType<i32>, _>(row_count, || {
<Box<[u8]>>::arbitrary(unstructured).map(Some)
})?
}
}
DataType::FixedSizeBinary(size) => {
// annoyingly similar-but-different to `PrimitiveArray`
let values = {
let size =
usize::try_from(*size).map_err(|_error| arbitrary::Error::IncorrectFormat)?;
let capacity = size
.checked_mul(row_count)
.ok_or(arbitrary::Error::IncorrectFormat)?;
let mut buf = MutableBuffer::from_len_zeroed(capacity);
unstructured.fill_buffer(buf.as_slice_mut())?;
Buffer::from(buf)
};
let nulls = if field.is_nullable() {
let mut builder = BooleanBufferBuilder::new(row_count);
unstructured.fill_buffer(builder.as_slice_mut())?;
let buffer = NullBuffer::new(builder.finish());
Some(buffer)
} else {
None
};
let array = FixedSizeBinaryArray::new(*size, values, nulls);
Arc::new(array)
}
DataType::LargeBinary => {
if field.is_nullable() {
arbitrary_byte_array::<GenericBinaryType<i64>, _>(row_count, || {
Option::<Box<[u8]>>::arbitrary(unstructured)
})?
} else {
arbitrary_byte_array::<GenericBinaryType<i64>, _>(row_count, || {
<Box<[u8]>>::arbitrary(unstructured).map(Some)
})?
}
}
DataType::BinaryView => {
if field.is_nullable() {
arbitrary_byte_view_array::<BinaryViewType, _>(row_count, || {
Option::<Box<[u8]>>::arbitrary(unstructured)
})?
} else {
arbitrary_byte_view_array::<BinaryViewType, _>(row_count, || {
<Box<[u8]>>::arbitrary(unstructured).map(Some)
})?
}
}
// TODO i haven't found a good way to unify these branches #dev
//
// - string vs binary
// - large vs small
// - array vs view
DataType::Utf8 => {
if field.is_nullable() {
arbitrary_byte_array::<GenericStringType<i32>, _>(row_count, || {
Option::<String>::arbitrary(unstructured)
})?
} else {
arbitrary_byte_array::<GenericStringType<i32>, _>(row_count, || {
<String>::arbitrary(unstructured).map(Some)
})?
}
}
DataType::LargeUtf8 => {
if field.is_nullable() {
arbitrary_byte_array::<GenericStringType<i64>, _>(row_count, || {
Option::<String>::arbitrary(unstructured)
})?
} else {
arbitrary_byte_array::<GenericStringType<i64>, _>(row_count, || {
<String>::arbitrary(unstructured).map(Some)
})?
}
}
DataType::Utf8View => {
if field.is_nullable() {
arbitrary_byte_view_array::<StringViewType, _>(row_count, || {
Option::<String>::arbitrary(unstructured)
})?
} else {
arbitrary_byte_view_array::<StringViewType, _>(row_count, || {
<String>::arbitrary(unstructured).map(Some)
})?
}
}
DataType::Decimal128(precision, scale) => {
let array = arbitrary_primitive_array::<arrow::datatypes::Decimal128Type>(
unstructured,
field.is_nullable(),
row_count,
)?
.with_precision_and_scale(*precision, *scale)
.map_err(|_error| arbitrary::Error::IncorrectFormat)?;
Arc::new(array)
}
DataType::Decimal256(precision, scale) => {
let array = arbitrary_primitive_array::<arrow::datatypes::Decimal256Type>(
unstructured,
field.is_nullable(),
row_count,
)?
.with_precision_and_scale(*precision, *scale)
.map_err(|_error| arbitrary::Error::IncorrectFormat)?;
Arc::new(array)
}
// TODO generate arrays for more complex datatypes
DataType::List(_)
| DataType::ListView(_)
| DataType::FixedSizeList(_, _)
| DataType::LargeList(_)
| DataType::LargeListView(_)
| DataType::Struct(_)
| DataType::Union(_, _)
| DataType::Dictionary(_, _)
| DataType::Map(_, _)
| DataType::RunEndEncoded(_, _) => return Err(arbitrary::Error::IncorrectFormat),
};
Ok(array)
}

View file

@ -0,0 +1,164 @@
use std::sync::Arc;
use arbitrary::Arbitrary;
use arrow::datatypes::DataType;
use arrow::datatypes::TimeUnit;
use crate::FieldBuilder;
use crate::FieldsBuilder;
use crate::IntervalUnitBuilder;
use crate::TimeUnitBuilder;
// TODO generate more edge values
// Arrow has many places where it will panic or error in unexpected places if the input didn't match expectations.
// (Example unexpected error: `arrow::util::pretty::pretty_format_batches` errors when it doesn't recognize a timezone.)
// Making a fuzzer avoid problematic output is not exactly a good idea, so expect these to be relaxed later, maybe?
/// An `i32` that is guaranteed to be `> 0`.
///
/// Arrow will panic on divide by zero if a fixed size structure is size 0.
/// It will most likely go out of bounds for negative numbers.
//
// TODO this should really be enforced by whatever is taking the `i32`, get rid of this
pub struct PositiveNonZeroI32(i32);
impl PositiveNonZeroI32 {
#[expect(missing_docs)]
#[must_use]
pub fn new(input: i32) -> Option<PositiveNonZeroI32> {
(input > 0).then_some(PositiveNonZeroI32(input))
}
}
impl<'a> Arbitrary<'a> for PositiveNonZeroI32 {
fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
let num = u.int_in_range(1i32..=i32::MAX)?;
PositiveNonZeroI32::new(num).ok_or(arbitrary::Error::IncorrectFormat)
}
}
/// Make an [arbitrary] [`arrow::datatypes::DataType`].
///
/// What is actually constructed is a "builder".
/// Call the [`build`](DataTypeBuilder::build) method to get the final [`DataType`].
/// This is a technical limitation because we can't implement the trait [`arbitrary::Arbitrary`] for [`DataType`] in this unrelated crate.
/// TODO this can be cleaned up if merged into Arrow upstream #ecosystem/arrow
#[derive(Arbitrary)]
#[expect(missing_docs)]
pub enum DataTypeBuilder {
Null,
Boolean,
Int8,
Int16,
Int32,
Int64,
UInt8,
UInt16,
UInt32,
UInt64,
Float16,
Float32,
Float64,
Timestamp(
TimeUnitBuilder,
Option<
// TODO maybe generate proper time zone names and offsets, at least most of the time
// TODO maybe split into feature `chrono-tz` and not set; without it, only do offsets?
Arc<str>,
>,
),
Date32,
Date64,
// only generate valid combinations
Time32Seconds,
Time32Milliseconds,
Time64Microseconds,
Time64Nanoseconds,
Duration(TimeUnitBuilder),
Interval(IntervalUnitBuilder),
Binary,
FixedSizeBinary(PositiveNonZeroI32),
LargeBinary,
BinaryView,
Utf8,
LargeUtf8,
Utf8View,
List(Box<FieldBuilder>),
ListView(Box<FieldBuilder>),
FixedSizeList(Box<FieldBuilder>, PositiveNonZeroI32),
LargeList(Box<FieldBuilder>),
LargeListView(Box<FieldBuilder>),
Struct(Box<FieldsBuilder>),
// TODO Union
// TODO maybe enforce `DataType::is_dictionary_key_type`?
Dictionary(Box<DataTypeBuilder>, Box<DataTypeBuilder>),
Decimal128(u8, i8),
Decimal256(u8, i8),
// TODO enforce Map shape, key hashability
// TODO Map(Box<FieldBuilder>, bool),
// TODO run_ends must be an integer type
// TODO RunEndEncoded(Box<FieldBuilder>, Box<FieldBuilder>),
}
impl DataTypeBuilder {
#[must_use]
#[expect(missing_docs)]
pub fn build(self) -> DataType {
match self {
DataTypeBuilder::Null => DataType::Null,
DataTypeBuilder::Boolean => DataType::Boolean,
DataTypeBuilder::Int8 => DataType::Int8,
DataTypeBuilder::Int16 => DataType::Int16,
DataTypeBuilder::Int32 => DataType::Int32,
DataTypeBuilder::Int64 => DataType::Int64,
DataTypeBuilder::UInt8 => DataType::UInt8,
DataTypeBuilder::UInt16 => DataType::UInt16,
DataTypeBuilder::UInt32 => DataType::UInt32,
DataTypeBuilder::UInt64 => DataType::UInt64,
DataTypeBuilder::Float16 => DataType::Float16,
DataTypeBuilder::Float32 => DataType::Float32,
DataTypeBuilder::Float64 => DataType::Float64,
DataTypeBuilder::Timestamp(time_unit, time_zone) => {
DataType::Timestamp(time_unit.build(), time_zone)
}
DataTypeBuilder::Date32 => DataType::Date32,
DataTypeBuilder::Date64 => DataType::Date64,
DataTypeBuilder::Time32Seconds => DataType::Time32(TimeUnit::Second),
DataTypeBuilder::Time32Milliseconds => DataType::Time32(TimeUnit::Millisecond),
DataTypeBuilder::Time64Microseconds => DataType::Time64(TimeUnit::Microsecond),
DataTypeBuilder::Time64Nanoseconds => DataType::Time64(TimeUnit::Nanosecond),
DataTypeBuilder::Duration(builder) => DataType::Duration(builder.build()),
DataTypeBuilder::Interval(builder) => DataType::Interval(builder.build()),
DataTypeBuilder::Binary => DataType::Binary,
DataTypeBuilder::FixedSizeBinary(size) => DataType::FixedSizeBinary(size.0),
DataTypeBuilder::LargeBinary => DataType::LargeBinary,
DataTypeBuilder::BinaryView => DataType::BinaryView,
DataTypeBuilder::Utf8 => DataType::Utf8,
DataTypeBuilder::LargeUtf8 => DataType::LargeUtf8,
DataTypeBuilder::Utf8View => DataType::Utf8View,
DataTypeBuilder::List(field_builder) => DataType::List(Arc::new(field_builder.build())),
DataTypeBuilder::ListView(field_builder) => {
DataType::ListView(Arc::new(field_builder.build()))
}
DataTypeBuilder::FixedSizeList(field_builder, size) => {
DataType::FixedSizeList(Arc::new(field_builder.build()), size.0)
}
DataTypeBuilder::LargeList(field_builder) => {
DataType::LargeList(Arc::new(field_builder.build()))
}
DataTypeBuilder::LargeListView(field_builder) => {
DataType::LargeListView(Arc::new(field_builder.build()))
}
DataTypeBuilder::Struct(fields_builder) => DataType::Struct(fields_builder.build()),
// TODO Union
DataTypeBuilder::Dictionary(key_type, value_type) => {
DataType::Dictionary(Box::new(key_type.build()), Box::new(value_type.build()))
}
// TODO make only valid precision and scale values? now they're discarded in `arbitrary_array` <https://docs.rs/datafusion/latest/datafusion/common/arrow/datatypes/fn.validate_decimal_precision_and_scale.html>
DataTypeBuilder::Decimal128(precision, scale) => DataType::Decimal128(precision, scale),
DataTypeBuilder::Decimal256(precision, scale) => DataType::Decimal256(precision, scale),
// TODO generate arbitrary Map
// TODO generate arbitrary RunEndEncoded
}
}
}

View file

@ -0,0 +1,55 @@
use arbitrary::Arbitrary;
use arrow::datatypes::Field;
use arrow::datatypes::Fields;
use crate::DataTypeBuilder;
/// Make an [arbitrary] [`Field`] for an Arrow [`Schema`](arrow::datatypes::Schema).
#[derive(Arbitrary)]
#[expect(missing_docs)]
pub struct FieldBuilder {
pub name: String,
pub data_type: DataTypeBuilder,
/// Ignored (always set to `true`) when `data_type` is [[`arrow::datatypes::DataType::Null`]].
pub nullable: bool,
// TODO dict_id: i64,
// TODO dict_is_ordered: bool,
// TODO metadata: HashMap<String, String>,
}
impl FieldBuilder {
#[must_use]
#[expect(missing_docs)]
pub fn build(self) -> Field {
let Self {
name,
data_type,
nullable,
} = self;
let data_type = data_type.build();
let nullable = nullable || data_type == arrow::datatypes::DataType::Null;
Field::new(name, data_type, nullable)
}
}
/// Make a [arbitrary] [`Fields`] for an Arrow [`Schema`](arrow::datatypes::Schema).
///
/// This is just a vector of [`Field`] (see [`FieldBuilder`]), but Arrow uses a separate type for it.
#[derive(Arbitrary)]
#[expect(missing_docs)]
pub struct FieldsBuilder {
pub builders: Vec<FieldBuilder>,
}
impl FieldsBuilder {
#[must_use]
#[expect(missing_docs)]
pub fn build(self) -> Fields {
let Self { builders } = self;
let fields = builders
.into_iter()
.map(FieldBuilder::build)
.collect::<Vec<_>>();
Fields::from(fields)
}
}

View file

@ -0,0 +1,23 @@
use arbitrary::Arbitrary;
use arrow::datatypes::IntervalUnit;
/// Make an [arbitrary] [`IntervalUnit`].
#[derive(Arbitrary)]
#[expect(missing_docs)]
pub enum IntervalUnitBuilder {
YearMonth,
DayTime,
MonthDayNano,
}
impl IntervalUnitBuilder {
#[must_use]
#[expect(missing_docs)]
pub const fn build(self) -> IntervalUnit {
match self {
IntervalUnitBuilder::YearMonth => IntervalUnit::YearMonth,
IntervalUnitBuilder::DayTime => IntervalUnit::DayTime,
IntervalUnitBuilder::MonthDayNano => IntervalUnit::MonthDayNano,
}
}
}

View file

@ -0,0 +1,66 @@
//! Create [arbitrary] [Arrow](arrow) data structures, for fuzz testing.
//!
//! This could become part of the upstream [`arrow`] crate in the future.
#![allow(clippy::exhaustive_enums)]
#![allow(clippy::exhaustive_structs)]
pub use arbitrary;
pub use arrow;
mod array;
mod data_type;
mod field;
mod interval_unit;
mod record_batch;
mod schema;
mod time_unit;
pub use array::*;
pub use data_type::*;
pub use field::*;
pub use interval_unit::*;
pub use record_batch::*;
pub use schema::*;
pub use time_unit::*;
#[cfg(test)]
mod tests {
use arbitrary::Arbitrary as _;
use super::*;
#[maybe_tracing::test]
#[cfg_attr(miri, ignore = "too slow")]
fn random() {
let mut entropy = vec![0u8; 4096];
getrandom::fill(&mut entropy).unwrap();
let mut unstructured = arbitrary::Unstructured::new(&entropy);
match RecordBatchBuilder::arbitrary(&mut unstructured) {
Ok(builder) => {
let record_batch = builder.build();
println!("{record_batch:?}");
}
Err(arbitrary::Error::IncorrectFormat) => {
// ignore silently
}
Err(error) => panic!("unexpected error: {error}"),
};
}
#[maybe_tracing::test]
fn fixed() {
// A non-randomized version that runs faster under Miri.
let fake_entropy = vec![0u8; 4096];
let mut unstructured = arbitrary::Unstructured::new(&fake_entropy);
match RecordBatchBuilder::arbitrary(&mut unstructured) {
Ok(builder) => {
let record_batch = builder.build();
println!("{record_batch:?}");
}
Err(arbitrary::Error::IncorrectFormat) => {
panic!("unexpected error, fiddle with fake entropy to avoid")
}
Err(error) => panic!("unexpected error: {error}"),
};
}
}

View file

@ -0,0 +1,93 @@
use std::sync::Arc;
use arrow::array::RecordBatch;
use arrow::datatypes::SchemaRef;
use crate::arbitrary_array;
use crate::SchemaBuilder;
/// Make an [arbitrary] [`RecordBatch`] with the given Arrow [`Schema`](arrow::datatypes::Schema).
pub fn record_batch_with_schema(
u: &mut arbitrary::Unstructured<'_>,
schema: SchemaRef,
) -> arbitrary::Result<RecordBatch> {
// TODO correct type for the below
let row_count = u.arbitrary_len::<[u8; 8]>()?;
let arrays = schema
.fields()
.iter()
.map(|field| arbitrary_array(u, field, row_count))
.collect::<Result<Vec<_>, _>>()?;
debug_assert_eq!(schema.fields().len(), arrays.len());
debug_assert!(arrays.iter().all(|array| array.len() == row_count));
debug_assert!(
arrays
.iter()
.zip(schema.fields().iter())
.all(|(array, field)| array.data_type() == field.data_type()),
"wrong array data type: {arrays:?} != {fields:?}",
arrays = arrays.iter().map(|a| a.data_type()).collect::<Vec<_>>(),
fields = schema
.fields()
.iter()
.map(|f| f.data_type())
.collect::<Vec<_>>(),
);
let options = arrow::array::RecordBatchOptions::new().with_row_count(Some(row_count));
#[expect(clippy::expect_used, reason = "not production code")]
let record_batch =
RecordBatch::try_new_with_options(schema, arrays, &options).expect("internal error");
Ok(record_batch)
}
/// Make an [arbitrary] [`RecordBatch`].
///
/// What is actually constructed is a "builder".
/// Call the [`build`](RecordBatchBuilder::build) method to get the final [`RecordBatch`].
/// This is a technical limitation because we can't implement the trait [`arbitrary::Arbitrary`] for [`RecordBatch`] in this unrelated crate.
/// This can be cleaned up if merged into Arrow upstream.
#[derive(Debug)]
#[expect(missing_docs)]
pub struct RecordBatchBuilder {
pub record_batch: RecordBatch,
}
/// Implement [`arbitrary::Arbitrary`] manually.
/// [`RecordBatch`] has constraints that are hard to express via deriving [`Arbitrary`](arbitrary::Arbitrary):
///
/// - number of schema fields must match number of array
/// - each array must match schema field type
/// - all arrays must have the same row count
impl<'a> arbitrary::Arbitrary<'a> for RecordBatchBuilder {
fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
let schema = SchemaBuilder::arbitrary(u)?.build();
let schema = Arc::new(schema);
let record_batch = record_batch_with_schema(u, schema)?;
Ok(RecordBatchBuilder { record_batch })
}
fn size_hint(depth: usize) -> (usize, Option<usize>) {
Self::try_size_hint(depth).unwrap_or_default()
}
fn try_size_hint(
depth: usize,
) -> arbitrary::Result<(usize, Option<usize>), arbitrary::MaxRecursionReached> {
arbitrary::size_hint::try_recursion_guard(depth, |depth| {
let size_hint = SchemaBuilder::try_size_hint(depth)?;
// TODO match on data_type and count the amount of entropy needed; but we don't have a row count yet
Ok(size_hint)
})
}
}
impl RecordBatchBuilder {
#[must_use]
#[expect(missing_docs)]
pub fn build(self) -> RecordBatch {
let Self { record_batch } = self;
record_batch
}
}

View file

@ -0,0 +1,27 @@
use arbitrary::Arbitrary;
use arrow::datatypes::Schema;
use crate::FieldsBuilder;
/// Make an [arbitrary] [`Schema`].
///
/// What is actually constructed is a "builder".
/// Call the [`build`](SchemaBuilder::build) method to get the final [`Schema`].
/// This is a technical limitation because we can't implement the trait [`arbitrary::Arbitrary`] for [`Schema`] in this unrelated crate.
/// This can be cleaned up if merged into Arrow upstream.
#[derive(Arbitrary)]
#[expect(missing_docs)]
pub struct SchemaBuilder {
pub fields: FieldsBuilder,
// TODO metadata: HashMap<String, String>
}
impl SchemaBuilder {
#[must_use]
#[expect(missing_docs)]
pub fn build(self) -> Schema {
let Self { fields } = self;
let fields = fields.build();
Schema::new(fields)
}
}

View file

@ -0,0 +1,25 @@
use arbitrary::Arbitrary;
use arrow::datatypes::TimeUnit;
/// Make an [arbitrary] [`TimeUnit`].
#[derive(Arbitrary)]
#[expect(missing_docs)]
pub enum TimeUnitBuilder {
Second,
Millisecond,
Microsecond,
Nanosecond,
}
impl TimeUnitBuilder {
#[must_use]
#[expect(missing_docs)]
pub const fn build(self) -> TimeUnit {
match self {
TimeUnitBuilder::Second => TimeUnit::Second,
TimeUnitBuilder::Millisecond => TimeUnit::Millisecond,
TimeUnitBuilder::Microsecond => TimeUnit::Microsecond,
TimeUnitBuilder::Nanosecond => TimeUnit::Nanosecond,
}
}
}

View file

@ -0,0 +1,38 @@
[package]
name = "kanto-backend-rocksdb"
version = "0.1.0"
description = "RocksDB backend for the KantoDB SQL database"
homepage.workspace = true
repository.workspace = true
authors.workspace = true
license.workspace = true
publish = false # TODO publish crate #severity/high #urgency/medium
edition.workspace = true
rust-version.workspace = true
[dependencies]
async-trait = { workspace = true }
dashmap = { workspace = true }
datafusion = { workspace = true }
futures-lite = { workspace = true }
gat-lending-iterator = { workspace = true }
kanto = { workspace = true }
kanto-index-format-v1 = { workspace = true }
kanto-key-format-v1 = { workspace = true }
kanto-meta-format-v1 = { workspace = true }
kanto-record-format-v1 = { workspace = true }
kanto-tunables = { workspace = true }
maybe-tracing = { workspace = true }
rkyv = { workspace = true }
rkyv_util = { workspace = true }
rocky = { workspace = true }
tempfile = { workspace = true }
tokio = { workspace = true }
tracing = { workspace = true }
[dev-dependencies]
kanto-testutil = { workspace = true }
test-log = { workspace = true }
[lints]
workspace = true

View file

@ -0,0 +1,652 @@
use std::path::Path;
use std::sync::Arc;
use datafusion::sql::sqlparser;
use gat_lending_iterator::LendingIterator as _;
use kanto::parquet::data_type::AsBytes as _;
use kanto::KantoError;
use kanto_meta_format_v1::IndexId;
use kanto_meta_format_v1::SequenceId;
use kanto_meta_format_v1::TableId;
// TODO `concat_bytes!` <https://github.com/rust-lang/rust/issues/87555> #waiting #ecosystem/rust #dev
#[expect(
clippy::indexing_slicing,
clippy::expect_used,
reason = "TODO const functions are awkward"
)]
const fn concat_bytes_to_array<const N: usize>(sources: &[&[u8]]) -> [u8; N] {
let mut result = [0u8; N];
let mut dst = 0;
let mut src = 0;
while src < sources.len() {
let mut offset = 0;
while offset < sources[src].len() {
result[dst] = sources[src][offset];
assert!(dst < N, "overflowing destination arrow");
dst = dst
.checked_add(1)
.expect("constant calculation must be correct");
offset = offset
.checked_add(1)
.expect("constant calculation must be correct");
}
src = src
.checked_add(1)
.expect("constant calculation must be correct");
}
assert!(dst == N, "did not fill the whole array");
result
}
// We reserve a global keyspace in RocksDB, in the default column family, keys beginning with `1:u64_le`.
const fn make_system_key(ident: u64) -> [u8; 16] {
concat_bytes_to_array(&[&1u64.to_be_bytes(), &ident.to_be_bytes()])
}
const FORMAT_MESSAGE_KEY: [u8; 16] = make_system_key(1);
const FORMAT_MESSAGE_CONTENT: &[u8] =
b"KantoDB database file\nSee <https://kantodb.com/> for more.\n";
const FORMAT_COOKIE_KEY: [u8; 16] = make_system_key(2);
trait RocksDbBackend {
/// Random data used to identify format.
///
/// Use pure entropy, don't get cute.
const FORMAT_COOKIE: [u8; 32];
/// Caller promises `FORMAT_COOKIE` has been checked already.
fn open_from_rocksdb(rocksdb: rocky::Database) -> Result<impl kanto::Backend, KantoError>;
}
/// Create a new RocksDB database directory.
pub fn create<P: AsRef<Path>>(path: P) -> Result<impl kanto::Backend, KantoError> {
let path = path.as_ref();
create_inner(path)
}
fn create_inner(path: &Path) -> Result<Database, KantoError> {
let mut opts = rocky::Options::new();
opts.create_if_missing(true);
// TODO let caller pass in TransactionDbOptions
let txdb_opts = rocky::TransactionDbOptions::new();
let column_families = vec![];
let rocksdb = rocky::Database::open(path, &opts, &txdb_opts, column_families).map_err(
|rocksdb_error| KantoError::Io {
op: kanto::error::Op::Init,
code: "n19nx5m9yc5br",
error: std::io::Error::other(rocksdb_error),
},
)?;
// TODO there's a race condition between creating the RocksDB files and getting to brand it as "ours"; how to properly brand a RocksDB database as belonging to this application
// TODO check for existence first? what if create is run with existing rocksdb
{
let write_options = rocky::WriteOptions::new();
let tx_options = rocky::TransactionOptions::new();
let tx = rocksdb.transaction_begin(&write_options, &tx_options);
let cf_default = rocksdb
.get_column_family("default")
.ok_or(KantoError::Internal {
code: "y4i5oznqrpr4o",
error: "trouble using default column family".into(),
})?;
{
let key = &FORMAT_MESSAGE_KEY;
let value = FORMAT_MESSAGE_CONTENT;
tx.put(&cf_default, key, value)
.map_err(|rocksdb_error| KantoError::Io {
op: kanto::error::Op::Init,
code: "fgkhwk4rbf5io",
error: std::io::Error::other(rocksdb_error),
})?;
}
{
let key = &FORMAT_COOKIE_KEY;
let value = &Database::FORMAT_COOKIE;
tx.put(&cf_default, key, value)
.map_err(|rocksdb_error| KantoError::Io {
op: kanto::error::Op::Init,
code: "yepp5upj9o79k",
error: std::io::Error::other(rocksdb_error),
})?;
}
tx.commit().map_err(|rocksdb_error| KantoError::Io {
op: kanto::error::Op::Init,
code: "wgkft87xp4kn6",
error: std::io::Error::other(rocksdb_error),
})?;
}
Database::open_from_rocksdb_v1(rocksdb)
}
/// Create a new *temporary* RocksDB database.
///
/// This is mostly useful for tests.
pub fn create_temp() -> Result<(tempfile::TempDir, Database), KantoError> {
let dir = tempfile::tempdir().map_err(|io_error| KantoError::Io {
op: kanto::error::Op::Init,
code: "9syfcddzphoqs",
error: io_error,
})?;
let db = create_inner(dir.path())?;
Ok((dir, db))
}
/// Open an existing RocksDB database directory.
pub fn open(path: &Path) -> Result<impl kanto::Backend + use<>, KantoError> {
// RocksDB has a funny dance about opening the database.
// They want you to list the column families that exist in the database (in order to configure them), but no extra ones (yes, you can do create if missing, but then you don't get to configure them).
// Load the "options" file from the database directory to figure out what column families exist.
// <https://github.com/facebook/rocksdb/wiki/Column-Families>
let latest_options =
rocky::LatestOptions::load_latest_options(path).map_err(|rocksdb_error| {
KantoError::Io {
op: kanto::error::Op::Init,
code: "he3ddmfxzymiw",
error: std::io::Error::other(rocksdb_error),
}
})?;
let rocksdb = {
let opts = latest_options.clone_db_options();
// TODO let caller pass in TransactionDbOptions
let txdb_opts = rocky::TransactionDbOptions::new();
let column_families = latest_options
.iter_column_families()
.collect::<Result<Vec<_>, _>>()
.map_err(|rocksdb_error| KantoError::Io {
op: kanto::error::Op::Init,
code: "umshco6957zj6",
error: std::io::Error::other(rocksdb_error),
})?;
rocky::Database::open(path, &opts, &txdb_opts, column_families)
}
.map_err(|rocksdb_error| KantoError::Io {
op: kanto::error::Op::Init,
code: "gzubzoig3be86",
error: std::io::Error::other(rocksdb_error),
})?;
open_from_rocksdb(rocksdb)
}
fn open_from_rocksdb(rocksdb: rocky::Database) -> Result<impl kanto::Backend, KantoError> {
let tx = {
let write_options = rocky::WriteOptions::new();
let tx_options = rocky::TransactionOptions::new();
rocksdb.transaction_begin(&write_options, &tx_options)
};
let cf_default = rocksdb
.get_column_family("default")
.ok_or(KantoError::Internal {
code: "sit1685wkky9k",
error: "trouble using default column family".into(),
})?;
{
let key = FORMAT_MESSAGE_KEY;
let want = FORMAT_MESSAGE_CONTENT;
let read_opts = rocky::ReadOptions::new();
let got = tx
.get_pinned(&cf_default, key, &read_opts)
.map_err(|rocksdb_error| KantoError::Io {
op: kanto::error::Op::Init,
code: "wfa4s57mk83py",
error: std::io::Error::other(rocksdb_error),
})?
.ok_or(KantoError::Init {
code: "ypjsnkcony3kh",
error: kanto::error::InitError::UnrecognizedDatabaseFormat,
})?;
if got.as_bytes() != want {
let error = KantoError::Init {
code: "fty1mzfs1fw6q",
error: kanto::error::InitError::UnrecognizedDatabaseFormat,
};
return Err(error);
}
}
let format_cookie = {
let read_opts = rocky::ReadOptions::new();
let key = FORMAT_COOKIE_KEY;
tx.get_pinned(&cf_default, key, &read_opts)
.map_err(|rocksdb_error| KantoError::Io {
op: kanto::error::Op::Init,
code: "d87g8itsmm7ag",
error: std::io::Error::other(rocksdb_error),
})?
.ok_or(KantoError::Init {
code: "9ogxqz1u5jboq",
error: kanto::error::InitError::UnrecognizedDatabaseFormat,
})?
};
tx.rollback().map_err(|rocksdb_error| KantoError::Io {
op: kanto::error::Op::Init,
code: "sdrh4rkxxt8ta",
error: std::io::Error::other(rocksdb_error),
})?;
if format_cookie == Database::FORMAT_COOKIE.as_bytes() {
Database::open_from_rocksdb(rocksdb)
} else {
let error = KantoError::Init {
code: "amsgxu6h1yj9r",
error: kanto::error::InitError::UnrecognizedDatabaseFormat,
};
Err(error)
}
}
/// A [`kanto::Backend`] that stores data in RocksDB.
#[derive(Clone)]
pub struct Database {
pub(crate) inner: Arc<InnerDatabase>,
}
pub(crate) struct ColumnFamilies {
pub name_defs: Arc<rocky::ColumnFamily>,
pub table_defs: Arc<rocky::ColumnFamily>,
pub index_defs: Arc<rocky::ColumnFamily>,
pub sequence_defs: Arc<rocky::ColumnFamily>,
// accessed only by the library abstraction
pub sequences: Arc<rocky::ColumnFamily>,
pub records: Arc<rocky::ColumnFamily>,
pub unique_indexes: Arc<rocky::ColumnFamily>,
pub multi_indexes: Arc<rocky::ColumnFamily>,
}
pub(crate) const SEQUENCE_SEQ_ID: SequenceId = SequenceId::MIN;
pub(crate) struct SystemSequences {
pub(crate) table_seq_id: SequenceId,
pub(crate) index_seq_id: SequenceId,
}
pub(crate) struct SystemData {
pub(crate) sequences: SystemSequences,
}
pub(crate) struct InnerDatabase {
pub(crate) rocksdb: rocky::Database,
pub(crate) sequence_tracker: crate::sequence::SequenceTracker,
pub(crate) column_families: ColumnFamilies,
pub(crate) system: SystemData,
}
impl Database {
fn ensure_cf_exist(
rocksdb: &rocky::Database,
cf_name: &str,
) -> Result<Arc<rocky::ColumnFamily>, rocky::RocksDbError> {
if let Some(cf) = rocksdb.get_column_family(cf_name) {
Ok(cf)
} else {
let opts = rocky::Options::new();
rocksdb.create_column_family(cf_name, &opts)
}
}
#[must_use]
pub(crate) const fn make_rocksdb_sequence_def_key(sequence_id: SequenceId) -> [u8; 8] {
sequence_id.get_u64().to_be_bytes()
}
#[must_use]
pub(crate) const fn make_rocksdb_table_def_key(table_id: TableId) -> [u8; 8] {
table_id.get_u64().to_be_bytes()
}
#[must_use]
pub(crate) const fn make_rocksdb_index_def_key(index_id: IndexId) -> [u8; 8] {
index_id.get_u64().to_be_bytes()
}
#[must_use]
pub(crate) const fn make_rocksdb_sequence_key(sequence_id: SequenceId) -> [u8; 8] {
sequence_id.get_u64().to_be_bytes()
}
#[must_use]
pub(crate) const fn make_rocksdb_record_key_prefix(table_id: TableId) -> [u8; 8] {
table_id.get_u64().to_be_bytes()
}
#[must_use]
pub(crate) const fn make_rocksdb_record_key_stop_before(table_id: TableId) -> [u8; 8] {
let table_id_num = table_id.get_u64();
debug_assert!(table_id_num < u64::MAX);
let stop_before = table_id_num.saturating_add(1);
stop_before.to_be_bytes()
}
pub(crate) fn make_rocksdb_record_key(table_id: TableId, row_key: &[u8]) -> Box<[u8]> {
// TODO later: optimize allocations and minimize copying
// TODO use types to make sure we don't confuse row_key vs rocksdb_key
let rocksdb_key_prefix = Self::make_rocksdb_record_key_prefix(table_id);
let mut k = Vec::with_capacity(rocksdb_key_prefix.len().saturating_add(row_key.len()));
k.extend_from_slice(&rocksdb_key_prefix);
k.extend_from_slice(row_key);
k.into_boxed_slice()
}
/// Create a new session with this backend registered as the default catalog.
///
/// This is mostly useful for tests.
#[must_use]
pub fn test_session(&self) -> kanto::Session {
let backend: Box<dyn kanto::Backend> = Box::new(self.clone());
kanto::Session::test_session(backend)
}
// TODO ugly api, experiment switching into builder style so all the inner helpers can be methods on a builder?
fn init_database(
rocksdb: &rocky::Database,
column_families: &ColumnFamilies,
sequence_tracker: &crate::sequence::SequenceTracker,
) -> Result<SystemData, KantoError> {
let tx = {
let write_options = rocky::WriteOptions::new();
let tx_options = rocky::TransactionOptions::new();
rocksdb.transaction_begin(&write_options, &tx_options)
};
Self::init_sequence_of_sequences(&tx, &column_families.sequence_defs)?;
let table_seq_id = Self::init_sequence_def(
&tx,
&column_families.sequence_defs,
sequence_tracker,
"tables",
)?;
let index_seq_id = Self::init_sequence_def(
&tx,
&column_families.sequence_defs,
sequence_tracker,
"indexes",
)?;
let system_sequences = SystemSequences {
table_seq_id,
index_seq_id,
};
tx.commit().map_err(|rocksdb_error| KantoError::Io {
op: kanto::error::Op::Init,
code: "u375qjb7o3jyh",
error: std::io::Error::other(rocksdb_error),
})?;
let system_data = SystemData {
sequences: system_sequences,
};
Ok(system_data)
}
fn init_sequence_of_sequences(
tx: &rocky::Transaction,
sequence_defs_cf: &rocky::ColumnFamily,
) -> Result<(), KantoError> {
let sequence_def_key = Self::make_rocksdb_sequence_def_key(SEQUENCE_SEQ_ID);
let read_opts = rocky::ReadOptions::new();
let exists = tx
.get_for_update_pinned(
sequence_defs_cf,
sequence_def_key,
&read_opts,
rocky::Exclusivity::Exclusive,
)
.map_err(|rocksdb_error| KantoError::Io {
op: kanto::error::Op::Init,
code: "316dpxkcf8ruy",
error: std::io::Error::other(rocksdb_error),
})?
.is_some();
if exists {
// incremental updates would go here
Ok(())
} else {
const KANTO_INTERNAL_CATALOG_NAME: &str = "_kanto_internal";
const KANTO_INTERNAL_SCHEMA_NAME: &str = "kanto";
const KANTO_INTERNAL_SEQUENCE_OF_SEQUENCES_NAME: &str = "sequences";
let sequence_def = kanto_meta_format_v1::sequence_def::SequenceDef {
sequence_id: SEQUENCE_SEQ_ID,
catalog: KANTO_INTERNAL_CATALOG_NAME.to_owned(),
schema: KANTO_INTERNAL_SCHEMA_NAME.to_owned(),
sequence_name: KANTO_INTERNAL_SEQUENCE_OF_SEQUENCES_NAME.to_owned(),
};
let bytes = rkyv::to_bytes::<rkyv::rancor::BoxedError>(&sequence_def).map_err(
|rkyv_error| KantoError::Internal {
code: "5hwja993uzb86",
error: Box::new(kanto::error::Message {
message: "failed to serialize sequence def",
error: Box::new(rkyv_error),
}),
},
)?;
tx.put(sequence_defs_cf, sequence_def_key, bytes)
.map_err(|rocksdb_error| KantoError::Io {
op: kanto::error::Op::Init,
code: "67soy6858owxa",
error: std::io::Error::other(rocksdb_error),
})?;
Ok(())
}
}
fn init_sequence_def(
tx: &rocky::Transaction,
sequence_defs_cf: &rocky::ColumnFamily,
sequence_tracker: &crate::sequence::SequenceTracker,
sequence_name: &str,
) -> Result<SequenceId, KantoError> {
const KANTO_INTERNAL_CATALOG_NAME: &str = "_kanto_internal";
const KANTO_INTERNAL_SCHEMA_NAME: &str = "kanto";
// Look for existing sequence_def.
// TODO should do lookup via index
{
let read_opts = rocky::ReadOptions::new();
let mut iter = tx.iter(read_opts, sequence_defs_cf);
while let Some(entry) =
iter.next()
.transpose()
.map_err(|rocksdb_error| KantoError::Io {
op: kanto::error::Op::Init,
code: "9rw1dttbwqxnc",
error: std::io::Error::other(rocksdb_error),
})?
{
let bytes = entry.value();
let sequence_def = rkyv::access::<
rkyv::Archived<kanto_meta_format_v1::sequence_def::SequenceDef>,
rkyv::rancor::BoxedError,
>(bytes)
.map_err(|rkyv_error| KantoError::Corrupt {
op: kanto::error::Op::Init,
code: "k41yd9xmpqxba",
error: Box::new(kanto::error::CorruptError::Rkyv { error: rkyv_error }),
})?;
{
if sequence_def.catalog != KANTO_INTERNAL_CATALOG_NAME {
continue;
}
if sequence_def.schema != KANTO_INTERNAL_SCHEMA_NAME {
continue;
}
if sequence_def.sequence_name != sequence_name {
continue;
}
}
// Found an existing record!
// For now, we don't bother updating it.. maybe we should.
let sequence_id = sequence_def.sequence_id.to_native();
return Ok(sequence_id);
}
}
// Not found
let sequence_id = {
let seq_id_num = sequence_tracker.inc(SEQUENCE_SEQ_ID)?;
SequenceId::new(seq_id_num).ok_or_else(|| KantoError::Corrupt {
op: kanto::error::Op::Init,
code: "sz8cgjb6jpyur",
error: Box::new(kanto::error::CorruptError::InvalidSequenceId {
sequence_id: seq_id_num,
sequence_def_debug: String::new(),
}),
})?
};
let sequence_def = kanto_meta_format_v1::sequence_def::SequenceDef {
sequence_id,
catalog: KANTO_INTERNAL_CATALOG_NAME.to_owned(),
schema: KANTO_INTERNAL_SCHEMA_NAME.to_owned(),
sequence_name: sequence_name.to_owned(),
};
let bytes =
rkyv::to_bytes::<rkyv::rancor::BoxedError>(&sequence_def).map_err(|rkyv_error| {
KantoError::Internal {
code: "88sauprbx8d8e",
error: Box::new(kanto::error::Message {
message: "failed to serialize sequence def",
error: Box::new(rkyv_error),
}),
}
})?;
let sequence_def_key = Self::make_rocksdb_sequence_def_key(sequence_id);
tx.put(sequence_defs_cf, sequence_def_key, bytes)
.map_err(|rocksdb_error| KantoError::Io {
op: kanto::error::Op::Init,
code: "kw743n8j4f4a6",
error: std::io::Error::other(rocksdb_error),
})?;
Ok(sequence_id)
}
fn open_from_rocksdb_v1(rocksdb: rocky::Database) -> Result<Self, KantoError> {
let ensure_cf = |cf_name: &str| -> Result<Arc<rocky::ColumnFamily>, KantoError> {
Database::ensure_cf_exist(&rocksdb, cf_name).map_err(|rocksdb_error| KantoError::Io {
op: kanto::error::Op::Init,
code: "m5w7fccb54wr6",
error: std::io::Error::other(rocksdb_error),
})
};
let column_families = ColumnFamilies {
name_defs: ensure_cf("name_defs")?,
table_defs: ensure_cf("table_defs")?,
index_defs: ensure_cf("index_defs")?,
sequence_defs: ensure_cf("sequence_defs")?,
sequences: ensure_cf("sequences")?,
records: ensure_cf("records")?,
unique_indexes: ensure_cf("unique_indexes")?,
multi_indexes: ensure_cf("multi_indexes")?,
};
let sequence_tracker = crate::sequence::SequenceTracker::new(
rocksdb.clone(),
column_families.sequences.clone(),
);
let system_data = Self::init_database(&rocksdb, &column_families, &sequence_tracker)?;
let inner = Arc::new(InnerDatabase {
rocksdb,
sequence_tracker,
column_families,
system: system_data,
});
Ok(Self { inner })
}
}
impl RocksDbBackend for Database {
const FORMAT_COOKIE: [u8; 32] = [
0x82, 0x48, 0x1b, 0xa3, 0xa1, 0xa3, 0x78, 0xae, 0x0f, 0x04, 0x12, 0x81, 0xf2, 0x0b, 0x63,
0xf8, 0x7d, 0xc6, 0x47, 0xcc, 0x80, 0xb0, 0x70, 0xc3, 0xb2, 0xa6, 0xf3, 0x60, 0xcb, 0x66,
0xda, 0x5f,
];
fn open_from_rocksdb(rocksdb: rocky::Database) -> Result<impl kanto::Backend, KantoError> {
Self::open_from_rocksdb_v1(rocksdb)
}
}
#[async_trait::async_trait]
impl kanto::Backend for Database {
fn as_any(&self) -> &dyn std::any::Any {
self
}
async fn start_transaction(
&self,
access_mode: sqlparser::ast::TransactionAccessMode,
isolation_level: sqlparser::ast::TransactionIsolationLevel,
) -> Result<Box<dyn kanto::Transaction>, KantoError> {
match access_mode {
sqlparser::ast::TransactionAccessMode::ReadWrite => Ok(()),
sqlparser::ast::TransactionAccessMode::ReadOnly => {
let error = KantoError::UnimplementedSql {
code: "bhetj3emnc8n1",
sql_syntax: "START TRANSACTION READ ONLY",
};
Err(error)
}
}?;
#[expect(
clippy::match_same_arms,
reason = "TODO rustfmt fails to format with comments in OR patterns <https://github.com/rust-lang/rustfmt/issues/6491> #waiting #ecosystem/rust #severity/low #urgency/low"
)]
match isolation_level {
// TODO advocate to reorder the enum in order of increasing isolation; we normally write matches in order of the enum, but make an exception here #ecosystem/sqlparser-rs
// We can choose to upgrade this to something we do support.
sqlparser::ast::TransactionIsolationLevel::ReadUncommitted => Ok(()),
// We can choose to upgrade this to something we do support.
sqlparser::ast::TransactionIsolationLevel::ReadCommitted => Ok(()),
// We can choose to upgrade this to something we do support.
sqlparser::ast::TransactionIsolationLevel::RepeatableRead => Ok(()),
// This is what we actually do, with RocksDB snapshots.
sqlparser::ast::TransactionIsolationLevel::Snapshot => Ok(()),
sqlparser::ast::TransactionIsolationLevel::Serializable => {
let error = KantoError::UnimplementedSql {
code: "dy1ch15p7n3ze",
sql_syntax: "START TRANSACTION ISOLATION LEVEL SERIALIZABLE",
};
Err(error)
}
}?;
let transaction = crate::Transaction::new(self);
Ok(Box::new(transaction))
}
fn clone_box(&self) -> Box<dyn kanto::Backend> {
Box::new(self.clone())
}
}
impl PartialEq for Database {
fn eq(&self, other: &Self) -> bool {
// Use the pointer inside the `self.inner` `Arc` as our identity, as the outer value gets cloned.
std::ptr::eq(self.inner.as_ref(), other.inner.as_ref())
}
}
impl Eq for Database {}
impl std::hash::Hash for Database {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
// Use the pointer inside the `self.inner` `Arc` as our identity, as the outer value gets cloned.
let ptr = <*const _>::from(self.inner.as_ref());
#[expect(clippy::as_conversions)]
let num = ptr as usize;
state.write_usize(num);
}
}
impl std::fmt::Debug for Database {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Database").finish_non_exhaustive()
}
}

View file

@ -0,0 +1,19 @@
//! `kanto_backend_rocksdb` provides a [KantoDB](https://kantodb.com/) database [backend](kanto::Backend) using [RocksDB](https://rocksdb.org/), delegating snapshot isolation and transaction handling fully to it.
//!
//! For more on RocksDB, see <https://github.com/facebook/rocksdb/wiki>
mod database;
mod savepoint;
mod sequence;
mod table;
mod table_provider;
mod table_scan;
mod transaction;
pub use crate::database::create;
pub use crate::database::create_temp;
pub use crate::database::open;
pub use crate::database::Database;
pub(crate) use crate::savepoint::Savepoint;
pub(crate) use crate::savepoint::WeakSavepoint;
pub(crate) use crate::transaction::Transaction;

View file

@ -0,0 +1,30 @@
# Does not parse in generic dialect
statement error ^SQL error: ParserError\("Expected: equals sign or TO, found: IO at Line: 1, Column: \d+"\) \(sql/xatq3n1zk65go\)$
SET STATISTICS IO ON;
statement ok
SET sql_dialect='mssql';
statement error ^External error: not_supported_sql/h83zk8ron8bme: SET session parameter \(MS-SQL syntax\)$
SET STATISTICS IO ON;
statement error ^External error: not_supported_sql/h83zk8ron8bme: SET session parameter \(MS-SQL syntax\)$
SET STATISTICS IO OFF;
statement error ^External error: not_supported_sql/h83zk8ron8bme: SET session parameter \(MS-SQL syntax\)$
SET STATISTICS PROFILE ON;
statement error ^External error: not_supported_sql/h83zk8ron8bme: SET session parameter \(MS-SQL syntax\)$
SET STATISTICS PROFILE OFF;
statement error ^External error: not_supported_sql/h83zk8ron8bme: SET session parameter \(MS-SQL syntax\)$
SET STATISTICS TIME ON;
statement error ^External error: not_supported_sql/h83zk8ron8bme: SET session parameter \(MS-SQL syntax\)$
SET STATISTICS TIME OFF;
statement error ^External error: not_supported_sql/h83zk8ron8bme: SET session parameter \(MS-SQL syntax\)$
SET STATISTICS XML ON;
statement error ^External error: not_supported_sql/h83zk8ron8bme: SET session parameter \(MS-SQL syntax\)$
SET STATISTICS XML OFF;

View file

@ -0,0 +1,50 @@
use std::sync::Arc;
use kanto::KantoError;
use crate::Transaction;
pub(crate) struct Savepoint {
inner: Arc<InnerSavepoint>,
}
pub(crate) struct WeakSavepoint {
inner: std::sync::Weak<InnerSavepoint>,
}
struct InnerSavepoint {
transaction: Transaction,
}
impl Savepoint {
pub(crate) fn new(transaction: Transaction) -> Savepoint {
let inner = InnerSavepoint { transaction };
Savepoint {
inner: Arc::new(inner),
}
}
pub(crate) fn weak(&self) -> WeakSavepoint {
WeakSavepoint {
inner: Arc::downgrade(&self.inner),
}
}
}
#[async_trait::async_trait]
impl kanto::Savepoint for Savepoint {
fn as_any(&self) -> &dyn std::any::Any {
self
}
async fn rollback(self: Box<Self>) -> Result<(), KantoError> {
let transaction = self.inner.transaction.clone();
transaction.rollback_to_savepoint(*self).await
}
}
impl PartialEq<WeakSavepoint> for Savepoint {
fn eq(&self, other: &WeakSavepoint) -> bool {
std::sync::Weak::ptr_eq(&Arc::downgrade(&self.inner), &other.inner)
}
}

View file

@ -0,0 +1,208 @@
use std::sync::Arc;
use dashmap::DashMap;
use kanto::KantoError;
use kanto_meta_format_v1::SequenceId;
use crate::Database;
#[derive(Debug)]
struct SeqState {
// TODO optimize with atomics, put only the reserve-next-chunk part behind locking (here `DashMap`).
// TODO could be preemptively reserve next chunk in a separate thread / tokio task / etc?
cur: u64,
max: u64,
}
// For best results, there should be only one `SequenceTracker` per `kanto_backend_rocksdb::Database`.
pub(crate) struct SequenceTracker {
// Reserving sequence number chunks is explicitly done outside any transactions.
// You wouldn't want two concurrent transactions get the same ID and conflict later.
// Use RocksDB directly, not via `crate::Database`, to avoid a reference cycle.
rocksdb: rocky::Database,
column_family: Arc<rocky::ColumnFamily>,
cache: DashMap<SequenceId, SeqState>,
}
impl SequenceTracker {
pub(crate) fn new(rocksdb: rocky::Database, column_family: Arc<rocky::ColumnFamily>) -> Self {
SequenceTracker {
rocksdb,
column_family,
cache: DashMap::new(),
}
}
#[must_use]
const fn encode_value(seq: u64) -> [u8; 8] {
seq.to_le_bytes()
}
#[must_use]
fn decode_value(buf: &[u8]) -> Option<u64> {
let input: [u8; 8] = buf.try_into().ok()?;
let num = u64::from_le_bytes(input);
Some(num)
}
#[maybe_tracing::instrument(skip(self), ret, err)]
fn reserve_chunk(&self, seq_id: SequenceId) -> Result<SeqState, KantoError> {
let write_options = rocky::WriteOptions::new();
let tx_options = rocky::TransactionOptions::new();
let tx = self.rocksdb.transaction_begin(&write_options, &tx_options);
let key = Database::make_rocksdb_sequence_key(seq_id);
// TODO should we loop on transaction conflicts
let value = {
let read_opts = rocky::ReadOptions::new();
tx.get_for_update_pinned(
&self.column_family,
key,
&read_opts,
rocky::Exclusivity::Exclusive,
)
.map_err(|rocksdb_error| KantoError::Io {
op: kanto::error::Op::SequenceUpdate {
sequence_id: seq_id,
},
code: "p1w8mbq5acd6n",
error: std::io::Error::other(rocksdb_error),
})
}?;
let new_state = match value {
None => SeqState {
cur: 1,
max: kanto_tunables::SEQUENCE_RESERVATION_CHUNK_SIZE,
},
Some(bytes) => {
let bytes = bytes.as_ref();
tracing::trace!(?bytes, "loaded");
let cur = Self::decode_value(bytes).ok_or_else(|| KantoError::Corrupt {
// TODO repeating seq_id in two places in the error
// TODO unify seq_id -> sequence_id
op: kanto::error::Op::SequenceUpdate {
sequence_id: seq_id,
},
code: "cusj4jcx8qsih",
error: Box::new(kanto::error::CorruptError::SequenceData {
sequence_id: seq_id,
}),
})?;
let max = cur.saturating_add(kanto_tunables::SEQUENCE_RESERVATION_CHUNK_SIZE);
SeqState { cur, max }
}
};
tracing::trace!(?new_state);
{
let new_max = new_state.max.checked_add(1).ok_or(KantoError::Execution {
// TODO repeating seq_id in two places in the error
op: kanto::error::Op::SequenceUpdate {
sequence_id: seq_id,
},
code: "wwnmhoex9cx3q",
error: Box::new(kanto::error::ExecutionError::SequenceExhausted {
sequence_id: seq_id,
}),
})?;
let value = Self::encode_value(new_max);
tx.put(&self.column_family, key, value)
.map_err(|rocksdb_error| KantoError::Io {
op: kanto::error::Op::SequenceUpdate {
sequence_id: seq_id,
},
code: "iyiqedewzpopw",
error: std::io::Error::other(rocksdb_error),
})?;
tx.commit().map_err(|rocksdb_error| KantoError::Io {
op: kanto::error::Op::SequenceUpdate {
sequence_id: seq_id,
},
code: "5y5h6iy5nttws",
error: std::io::Error::other(rocksdb_error),
})?;
}
Ok(new_state)
}
#[maybe_tracing::instrument(skip(self), ret, err)]
pub(crate) fn inc(&self, seq_id: SequenceId) -> Result<u64, KantoError> {
match self.cache.entry(seq_id) {
dashmap::mapref::entry::Entry::Occupied(mut entry) => {
let state = entry.get_mut();
state.cur = state.cur.checked_add(1).ok_or(KantoError::Execution {
// TODO repeating seq_id in two places in the error
op: kanto::error::Op::SequenceUpdate {
sequence_id: seq_id,
},
code: "5jsjk93zjr1sq",
error: Box::new(kanto::error::ExecutionError::SequenceExhausted {
sequence_id: seq_id,
}),
})?;
if state.cur > state.max {
*state = self.reserve_chunk(seq_id)?;
}
Ok(state.cur)
}
dashmap::mapref::entry::Entry::Vacant(entry) => {
let state = self.reserve_chunk(seq_id)?;
let cur = state.cur;
let _state_mut = entry.insert(state);
Ok(cur)
}
}
}
#[maybe_tracing::instrument(skip(self), ret, err)]
pub(crate) fn inc_n(
&self,
seq_id: SequenceId,
n: usize,
) -> Result<kanto::arrow::array::UInt64Array, KantoError> {
let mut builder = kanto::arrow::array::UInt64Builder::with_capacity(n);
// TODO optimize this
for _row_idx in 0..n {
let seq = self.inc(seq_id)?;
builder.append_value(seq);
}
let array = builder.finish();
Ok(array)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[maybe_tracing::test]
fn seq_simple() {
let (_dir, backend) = crate::create_temp().unwrap();
let column_family = backend.inner.rocksdb.get_column_family("default").unwrap();
{
let tracker =
SequenceTracker::new(backend.inner.rocksdb.clone(), column_family.clone());
let seq_a = SequenceId::new(42u64).unwrap();
let seq_b = SequenceId::new(13u64).unwrap();
assert_eq!(tracker.inc(seq_a).unwrap(), 1);
assert_eq!(tracker.inc(seq_a).unwrap(), 2);
assert_eq!(tracker.inc(seq_b).unwrap(), 1);
assert_eq!(tracker.inc(seq_a).unwrap(), 3);
drop(tracker);
}
{
let tracker =
SequenceTracker::new(backend.inner.rocksdb.clone(), column_family.clone());
let seq_a = SequenceId::new(42u64).unwrap();
let seq_b = SequenceId::new(13u64).unwrap();
assert_eq!(
tracker.inc(seq_a).unwrap(),
kanto_tunables::SEQUENCE_RESERVATION_CHUNK_SIZE + 1
);
assert_eq!(
tracker.inc(seq_b).unwrap(),
kanto_tunables::SEQUENCE_RESERVATION_CHUNK_SIZE + 1
);
drop(tracker);
}
}
}

View file

@ -0,0 +1,70 @@
use std::sync::Arc;
use kanto_meta_format_v1::TableId;
pub(crate) struct RocksDbTable<TableDefSource>
where
TableDefSource: rkyv_util::owned::StableBytes + Clone,
{
pub(crate) inner: Arc<InnerRocksDbTable<TableDefSource>>,
}
impl<TableDefSource> Clone for RocksDbTable<TableDefSource>
where
TableDefSource: rkyv_util::owned::StableBytes + Clone,
{
fn clone(&self) -> Self {
RocksDbTable {
inner: self.inner.clone(),
}
}
}
pub(crate) struct InnerRocksDbTable<TableDefSource>
where
TableDefSource: rkyv_util::owned::StableBytes + Clone,
{
pub(crate) tx: crate::Transaction,
pub(crate) data_cf: Arc<rocky::ColumnFamily>,
pub(crate) table_id: TableId,
pub(crate) table_def:
rkyv_util::owned::OwnedArchive<kanto_meta_format_v1::table_def::TableDef, TableDefSource>,
}
impl<TableDefSource> RocksDbTable<TableDefSource>
where
TableDefSource: rkyv_util::owned::StableBytes + Clone,
{
pub(crate) fn new(
tx: crate::Transaction,
data_cf: Arc<rocky::ColumnFamily>,
table_id: TableId,
table_def: rkyv_util::owned::OwnedArchive<
kanto_meta_format_v1::table_def::TableDef,
TableDefSource,
>,
) -> Self {
let inner = Arc::new(InnerRocksDbTable {
tx,
data_cf,
table_id,
table_def,
});
RocksDbTable { inner }
}
pub(crate) fn open_table_provider(
&self,
) -> crate::table_provider::RocksDbTableProvider<TableDefSource> {
crate::table_provider::RocksDbTableProvider::new(self.clone())
}
}
impl<TableDefSource> std::fmt::Debug for RocksDbTable<TableDefSource>
where
TableDefSource: rkyv_util::owned::StableBytes + Clone,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RocksDbTable").finish()
}
}

View file

@ -0,0 +1,84 @@
use std::sync::Arc;
pub(crate) struct RocksDbTableProvider<TableDefSource>
where
TableDefSource: rkyv_util::owned::StableBytes + Clone,
{
table: crate::table::RocksDbTable<TableDefSource>,
}
impl<TableDefSource> RocksDbTableProvider<TableDefSource>
where
TableDefSource: rkyv_util::owned::StableBytes + Clone,
{
pub(crate) const fn new(
table: crate::table::RocksDbTable<TableDefSource>,
) -> RocksDbTableProvider<TableDefSource> {
RocksDbTableProvider { table }
}
pub(crate) fn name(&self) -> datafusion::sql::ResolvedTableReference {
datafusion::sql::ResolvedTableReference {
catalog: Arc::from(self.table.inner.table_def.catalog.as_str()),
schema: Arc::from(self.table.inner.table_def.schema.as_str()),
table: Arc::from(self.table.inner.table_def.table_name.as_str()),
}
}
}
#[async_trait::async_trait]
impl<TableDefSource> datafusion::catalog::TableProvider for RocksDbTableProvider<TableDefSource>
where
TableDefSource: rkyv_util::owned::StableBytes + Clone + Send + Sync + 'static,
{
// We choose not to implement `insert_into` here, since we'd need to implement `UPDATE` and `DELETE` independently anyway.
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn table_type(&self) -> datafusion::logical_expr::TableType {
datafusion::logical_expr::TableType::Base
}
#[maybe_tracing::instrument(ret)]
fn schema(&self) -> datafusion::arrow::datatypes::SchemaRef {
// TODO maybe cache on first use, or always precompute #performance
kanto_meta_format_v1::table_def::make_arrow_schema(
self.table
.inner
.table_def
.fields
.values()
.filter(|field| field.is_live()),
)
}
#[maybe_tracing::instrument(skip(_session), ret, err)]
async fn scan(
&self,
_session: &dyn datafusion::catalog::Session,
projection: Option<&Vec<usize>>,
filters: &[datafusion::prelude::Expr],
_limit: Option<usize>,
) -> Result<Arc<dyn datafusion::physical_plan::ExecutionPlan>, datafusion::error::DataFusionError>
{
assert!(
filters.is_empty(),
"filters should be empty until we declare we can support them"
);
let exec_plan = crate::table_scan::RocksDbTableScan::new(self.table.clone(), projection)?;
Ok(Arc::new(exec_plan))
}
}
impl<TableDefSource> std::fmt::Debug for RocksDbTableProvider<TableDefSource>
where
TableDefSource: rkyv_util::owned::StableBytes + Clone,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RocksDbTableProvider")
.field("table", &self.table)
.finish()
}
}

View file

@ -0,0 +1,399 @@
use std::collections::HashSet;
use std::sync::Arc;
use gat_lending_iterator::LendingIterator as _;
use kanto::KantoError;
use kanto_meta_format_v1::table_def::ArchivedTableDef;
use kanto_meta_format_v1::FieldId;
use crate::Database;
pub(crate) struct RocksDbTableScan<TableDefSource>
where
TableDefSource: rkyv_util::owned::StableBytes + Clone,
{
inner: Arc<InnerRocksDbTableScan<TableDefSource>>,
}
impl<TableDefSource> Clone for RocksDbTableScan<TableDefSource>
where
TableDefSource: rkyv_util::owned::StableBytes + Clone,
{
fn clone(&self) -> Self {
RocksDbTableScan {
inner: self.inner.clone(),
}
}
}
struct InnerRocksDbTableScan<TableDefSource>
where
TableDefSource: rkyv_util::owned::StableBytes + Clone,
{
table: crate::table::RocksDbTable<TableDefSource>,
wanted_field_ids: Arc<HashSet<FieldId>>,
df_plan_properties: datafusion::physical_plan::PlanProperties,
}
impl<TableDefSource> RocksDbTableScan<TableDefSource>
where
TableDefSource: rkyv_util::owned::StableBytes + Clone,
{
#[maybe_tracing::instrument(ret, err)]
fn make_wanted_field_ids(
table_def: &ArchivedTableDef,
projection: Option<&Vec<usize>>,
) -> Result<HashSet<FieldId>, KantoError> {
// It seems [`TableProvider::scan`] required handling of `projection` is just limiting what columns to produce (for optimization), while something higher up reorders and duplicates the output columns as required.
//
// Whether a `TableProvider` decides to implement that optimization or not , the produced `RecordBatch` and the return value of `ExecutionPlan::schema` must agree.
//
// We're doing projection by `HashSet` which makes the results always come out in table column order, which could break `SELECT col2, col1 FROM mytable`.
// We do have test coverage, but that might be unsatisfactory.
//
// TODO is projection guaranteed to be in-order? #ecosystem/datafusion
debug_assert!(projection.is_none_or(|p| p.is_sorted()));
// DataFusion only knows about live fields, so only look at those.
let live_fields = table_def
.fields
.values()
.filter(|field| field.is_live())
.collect::<Box<[_]>>();
let wanted_field_ids = if let Some(projection) = projection {
projection.iter()
.map(|col_idx| {
live_fields.get(*col_idx)
.map(|f| f.field_id.to_native())
.ok_or_else(|| KantoError::Internal {
code: "6x3tyc9c4iq6k",
// TODO why is rustfmt broken here #ecosystem/rust
error: format!("datafusion projection column index out of bounds: {col_idx} not in 0..{len}", len=live_fields.len()).into(),
})
})
.collect::<Result<HashSet<FieldId>, _>>()?
} else {
live_fields
.iter()
.map(|field| field.field_id.to_native())
.collect::<HashSet<FieldId>>()
};
Ok(wanted_field_ids)
}
pub(crate) fn new(
table: crate::table::RocksDbTable<TableDefSource>,
projection: Option<&Vec<usize>>,
) -> Result<Self, KantoError> {
let wanted_field_ids = Self::make_wanted_field_ids(&table.inner.table_def, projection)?;
let wanted_field_ids = Arc::new(wanted_field_ids);
// TODO cache somewhere explicitly; for now we're riding on `self.inner.eq_properties`
let schema = {
let fields = table
.inner
.table_def
.fields
.values()
// guaranteed to only contain live fields
.filter(|field| wanted_field_ids.contains(&field.field_id.to_native()));
#[cfg(debug_assertions)]
let fields = fields.inspect(|field| debug_assert!(field.is_live()));
kanto_meta_format_v1::table_def::make_arrow_schema(fields)
};
let df_plan_properties = {
let eq_properties = datafusion::physical_expr::EquivalenceProperties::new(schema);
let partitioning = datafusion::physical_plan::Partitioning::UnknownPartitioning(1);
let emission_type =
datafusion::physical_plan::execution_plan::EmissionType::Incremental;
let boundedness = datafusion::physical_plan::execution_plan::Boundedness::Bounded;
datafusion::physical_plan::PlanProperties::new(
eq_properties,
partitioning,
emission_type,
boundedness,
)
};
let inner = Arc::new(InnerRocksDbTableScan {
table,
wanted_field_ids,
df_plan_properties,
});
Ok(RocksDbTableScan { inner })
}
}
impl<TableDefSource> std::fmt::Debug for RocksDbTableScan<TableDefSource>
where
TableDefSource: rkyv_util::owned::StableBytes + Clone,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RocksDbTableScan")
.field("table", &self.inner.table)
.finish()
}
}
impl<TableDefSource> datafusion::physical_plan::display::DisplayAs
for RocksDbTableScan<TableDefSource>
where
TableDefSource: rkyv_util::owned::StableBytes + Clone,
{
fn fmt_as(
&self,
t: datafusion::physical_plan::DisplayFormatType,
f: &mut std::fmt::Formatter<'_>,
) -> std::fmt::Result {
match t {
datafusion::physical_plan::DisplayFormatType::Default
| datafusion::physical_plan::DisplayFormatType::Verbose => f
.debug_struct("RocksDbTableScan")
.field("table", &self.inner.table.inner.table_id)
.finish(),
}
}
}
impl<TableDefSource> datafusion::physical_plan::ExecutionPlan for RocksDbTableScan<TableDefSource>
where
TableDefSource: rkyv_util::owned::StableBytes + Clone + Send + Sync + 'static,
{
fn name(&self) -> &str {
Self::static_name()
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
#[maybe_tracing::instrument(ret)]
fn properties(&self) -> &datafusion::physical_plan::PlanProperties {
&self.inner.df_plan_properties
}
#[maybe_tracing::instrument(ret)]
fn schema(&self) -> datafusion::arrow::datatypes::SchemaRef {
self.inner.df_plan_properties.eq_properties.schema().clone()
}
#[maybe_tracing::instrument(ret)]
fn children(&self) -> Vec<&Arc<(dyn datafusion::physical_plan::ExecutionPlan + 'static)>> {
vec![]
}
#[maybe_tracing::instrument(ret, err)]
fn with_new_children(
self: Arc<Self>,
children: Vec<Arc<dyn datafusion::physical_plan::ExecutionPlan>>,
) -> Result<Arc<dyn datafusion::physical_plan::ExecutionPlan>, datafusion::error::DataFusionError>
{
// TODO i don't really understand why ExecutionPlan children (aka inputs) can be replaced at this stage (seems to be blurring query optimization passes vs execution time) but we're a table scan, we have no inputs! #ecosystem/datafusion
//
// DataFusion implementations are all over the place: silent ignore, return error, `unimplemented!`, `unreachable!`; some check len of new children, some assume.
// Then there's `with_new_children_if_necessary` helper that does check len.
if children.is_empty() {
// do nothing
Ok(self)
} else {
tracing::error!(?children, "unsupported: with_new_children");
let error = KantoError::Internal {
code: "4fsybg7tj9th1",
error: "unsupported: RocksDbTableScan with_new_children".into(),
};
Err(error.into())
}
}
#[maybe_tracing::instrument(skip(context), err)]
fn execute(
&self,
partition: usize,
context: Arc<datafusion::execution::context::TaskContext>,
) -> Result<
datafusion::physical_plan::SendableRecordBatchStream,
datafusion::error::DataFusionError,
> {
let schema = self.schema();
// we could support multiple partitions (find good keys to partition by). but we do not do that at this time
if partition > 0 {
let empty = datafusion::physical_plan::EmptyRecordBatchStream::new(schema);
return Ok(Box::pin(empty));
}
let batch_size = context.session_config().batch_size();
let table_id = self.inner.table.inner.table_id;
let table_key_prefix = Database::make_rocksdb_record_key_prefix(table_id);
// TODO once we support filters or partitioning, start and prefix might no longer be the same
let table_key_start = table_key_prefix;
// TODO once we support filters or partitioning, stop_before gets more complex
let table_key_stop_before = Database::make_rocksdb_record_key_stop_before(table_id);
tracing::trace!(
?table_key_start,
?table_key_stop_before,
"execute table scan"
);
let iter = {
let read_opts = self.inner.table.inner.tx.make_read_options();
self.inner.table.inner.tx.inner.db_tx.range(
read_opts,
&self.inner.table.inner.data_cf,
table_key_start..table_key_stop_before,
)
};
let record_reader = kanto_record_format_v1::RecordReader::new_from_table_def(
self.inner.table.inner.table_def.clone(),
self.inner.wanted_field_ids.clone(),
batch_size,
);
let stream = RocksDbTableScanStream {
table_id,
batch_size,
iter,
record_reader,
};
let stream =
datafusion::physical_plan::stream::RecordBatchStreamAdapter::new(schema, stream);
Ok(Box::pin(stream))
}
#[maybe_tracing::instrument(ret, err)]
fn statistics(
&self,
) -> Result<datafusion::physical_plan::Statistics, datafusion::error::DataFusionError> {
let schema = self.schema();
let column_statistics: Vec<_> = schema
.fields
.iter()
.map(|_field| datafusion::physical_plan::ColumnStatistics {
null_count: datafusion::common::stats::Precision::Absent,
max_value: datafusion::common::stats::Precision::Absent,
min_value: datafusion::common::stats::Precision::Absent,
sum_value: datafusion::common::stats::Precision::Absent,
distinct_count: datafusion::common::stats::Precision::Absent,
})
.collect();
let statistics = datafusion::physical_plan::Statistics {
num_rows: datafusion::common::stats::Precision::Absent,
total_byte_size: datafusion::common::stats::Precision::Absent,
column_statistics,
};
Ok(statistics)
}
}
struct RocksDbTableScanStream<TableDefSource>
where
TableDefSource: rkyv_util::owned::StableBytes + Clone,
{
table_id: kanto_meta_format_v1::TableId,
batch_size: usize,
iter: rocky::Iter,
record_reader: kanto_record_format_v1::RecordReader<
rkyv_util::owned::OwnedArchive<kanto_meta_format_v1::table_def::TableDef, TableDefSource>,
>,
}
impl<TableDefSource> RocksDbTableScanStream<TableDefSource>
where
TableDefSource: rkyv_util::owned::StableBytes + Clone,
{
fn fill_record_reader_sync(&mut self) -> Result<(), KantoError> {
while let Some(entry) =
self.iter
.next()
.transpose()
.map_err(|rocksdb_error| KantoError::Io {
op: kanto::error::Op::TableScan {
table_id: self.table_id,
},
code: "nfiih93dhser6",
error: std::io::Error::other(rocksdb_error),
})?
{
let key = entry.key();
let value = entry.value();
tracing::trace!(?key, ?value, "iterator saw");
self.record_reader
.push_record(value)
.map_err(|record_error| {
use kanto_record_format_v1::error::RecordError;
match record_error {
RecordError::Corrupt { error } => KantoError::Corrupt {
op: kanto::error::Op::TableScan {
table_id: self.table_id,
},
code: "sf9f9qxnszuxr",
error: Box::new(kanto::error::CorruptError::InvalidRecord { error }),
},
error @ (RecordError::InvalidSchema { .. }
| RecordError::InvalidData { .. }
| RecordError::Internal { .. }
| RecordError::InternalArrow { .. }) => KantoError::Internal {
code: "f6j8ckqoua6aa",
error: Box::new(error),
},
}
})?;
if self.record_reader.num_records() >= self.batch_size {
break;
}
}
Ok(())
}
fn build_record_batch(&mut self) -> Result<kanto::arrow::array::RecordBatch, KantoError> {
let record_batch = self.record_reader.build().map_err(|record_error| {
use kanto_record_format_v1::error::RecordError;
match record_error {
RecordError::Corrupt { error } => KantoError::Corrupt {
op: kanto::error::Op::TableScan {
table_id: self.table_id,
},
code: "wm5sejkimmi94",
error: Box::new(kanto::error::CorruptError::InvalidRecord { error }),
},
error @ (RecordError::InvalidSchema { .. }
| RecordError::InvalidData { .. }
| RecordError::Internal { .. }
| RecordError::InternalArrow { .. }) => KantoError::Internal {
code: "qbqhmms3ghm16",
error: Box::new(error),
},
}
})?;
Ok(record_batch)
}
}
impl<TableDefSource> Unpin for RocksDbTableScanStream<TableDefSource> where
TableDefSource: rkyv_util::owned::StableBytes + Clone
{
}
impl<TableDefSource> futures_lite::Stream for RocksDbTableScanStream<TableDefSource>
where
TableDefSource: rkyv_util::owned::StableBytes + Clone,
{
type Item = Result<kanto::arrow::array::RecordBatch, datafusion::error::DataFusionError>;
fn poll_next(
mut self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
if let Err(error) = self.fill_record_reader_sync() {
return std::task::Poll::Ready(Some(Err(error.into())));
}
if self.record_reader.num_records() == 0 {
return std::task::Poll::Ready(None);
}
match self.build_record_batch() {
Ok(record_batch) => std::task::Poll::Ready(Some(Ok(record_batch))),
Err(error) => std::task::Poll::Ready(Some(Err(error.into()))),
}
}
}

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,141 @@
#![expect(
missing_docs,
reason = "rustc lint bug <https://github.com/rust-lang/rust/issues/137561> #ecosystem/rust #waiting"
)]
use futures_lite::StreamExt as _;
use kanto_testutil::assert_batches_eq;
// TODO write helpers to make these tests as simple as sqllogictest.
// Using Rust is better for e.g. running two sessions concurrently, and the result verification is better here.
#[test_log::test(tokio::test(flavor = "multi_thread"))]
async fn create_table_simple() -> Result<(), String> {
let (_dir, backend) = kanto_backend_rocksdb::create_temp().unwrap();
{
let mut session = backend.test_session();
let stream = session
.sql("CREATE TABLE myschema.mytable (mycolumn BIGINT NOT NULL, PRIMARY KEY(mycolumn));")
.await
.unwrap();
let batches = stream.try_collect::<_, _, Vec<_>>().await.unwrap();
assert_batches_eq!(
r"
++
++
",
&batches,
);
}
{
let mut session = backend.test_session();
let stream = session
.sql(
"SELECT * FROM kanto.information_schema.tables WHERE table_schema<>'information_schema';",
)
.await
.unwrap();
let batches = stream.try_collect::<_, _, Vec<_>>().await.unwrap();
assert_batches_eq!(
r"
+---------------+--------------+------------+------------+
| table_catalog | table_schema | table_name | table_type |
+---------------+--------------+------------+------------+
| kanto | myschema | mytable | BASE TABLE |
+---------------+--------------+------------+------------+
",
&batches,
);
let stream = session
.sql("SELECT column_name, data_type FROM myschema.information_schema.columns WHERE table_catalog='kanto' AND table_schema='myschema' AND table_name='mytable';")
.await
.unwrap();
let batches = stream.try_collect::<_, _, Vec<_>>().await.unwrap();
assert_batches_eq!(
r"
+-------------+-----------+
| column_name | data_type |
+-------------+-----------+
| mycolumn | Int64 |
+-------------+-----------+
",
&batches,
);
drop(session);
}
Ok(())
}
#[test_log::test(tokio::test(flavor = "multi_thread"))]
async fn create_table_without_primary_key_uses_rowid() {
let (_dir, backend) = kanto_backend_rocksdb::create_temp().unwrap();
let mut session = backend.test_session();
{
let stream = session
.sql("CREATE TABLE myschema.mytable (mycolumn BIGINT);")
.await
.unwrap();
let batches = stream.try_collect::<_, _, Vec<_>>().await.unwrap();
assert_batches_eq!(
r"
++
++
",
&batches,
);
}
{
let stream = session
.sql("SELECT column_name, data_type FROM myschema.information_schema.columns WHERE table_catalog='kanto' AND table_schema='myschema' AND table_name='mytable';")
.await
.unwrap();
let batches = stream.try_collect::<_, _, Vec<_>>().await.unwrap();
assert_batches_eq!(
r"
+-------------+-----------+
| column_name | data_type |
+-------------+-----------+
| mycolumn | Int64 |
+-------------+-----------+
",
&batches,
);
}
drop(session);
}
#[test_log::test(tokio::test(flavor = "multi_thread"))]
async fn create_table_or_replace_not_implemented() {
let (_dir, backend) = kanto_backend_rocksdb::create_temp().unwrap();
let mut session = backend.test_session();
let result = session
.sql("CREATE OR REPLACE TABLE myschema.mytable (mycolumn BIGINT);")
.await;
match result {
Ok(_record_batch_stream) => panic!("unexpected success"),
Err(datafusion::error::DataFusionError::External(error)) => {
let error = error
.downcast::<kanto::KantoError>()
.expect("wrong kind of error");
assert!(
matches!(
*error,
kanto::KantoError::UnimplementedSql {
code: "65so99qr1hyck",
sql_syntax: "CREATE OR REPLACE TABLE"
}
),
"wrong error: {error:?}",
);
}
Err(error) => panic!("wrong error: {error:?}"),
}
}

View file

@ -0,0 +1,66 @@
#![expect(
missing_docs,
reason = "rustc lint bug <https://github.com/rust-lang/rust/issues/137561> #ecosystem/rust #waiting"
)]
use futures_lite::stream::StreamExt as _;
use kanto_testutil::assert_batches_eq;
#[test_log::test(tokio::test(flavor = "multi_thread"))]
async fn insert_select() -> Result<(), String> {
let (_dir, backend) = kanto_backend_rocksdb::create_temp().unwrap();
{
let mut session = backend.test_session();
let stream = session
.sql("CREATE TABLE kanto.myschema.mytable (a BIGINT NOT NULL, b BIGINT, PRIMARY KEY(a));")
.await
.unwrap();
let batches = stream.try_collect::<_, _, Vec<_>>().await.unwrap();
assert_batches_eq!(
r"
++
++
",
&batches,
);
}
{
let mut session = backend.test_session();
let stream = session
.sql("INSERT kanto.myschema.mytable (a, b) VALUES (42, 7), (13, 34);")
.await
.unwrap();
let batches = stream.try_collect::<_, _, Vec<_>>().await.unwrap();
assert_batches_eq!(
r"
++
++
",
&batches,
);
}
{
let mut session = backend.test_session();
let stream = session
.sql("SELECT * FROM kanto.myschema.mytable ORDER BY a;")
.await
.unwrap();
let batches = stream.try_collect::<_, _, Vec<_>>().await.unwrap();
assert_batches_eq!(
r"
+----+----+
| a | b |
+----+----+
| 13 | 34 |
| 42 | 7 |
+----+----+
",
&batches,
);
}
Ok(())
}

View file

@ -0,0 +1,50 @@
#![expect(
missing_docs,
reason = "rustc lint bug <https://github.com/rust-lang/rust/issues/137561> #ecosystem/rust #waiting"
)]
use futures_lite::StreamExt as _;
use kanto_testutil::assert_batches_eq;
// TODO write helpers to make these tests as simple as sqllogictest.
// Using Rust is better for e.g. running two sessions concurrently, and the result verification is better here.
#[test_log::test(tokio::test(flavor = "multi_thread"))]
async fn reopen_database() -> Result<(), String> {
let (dir, backend) = kanto_backend_rocksdb::create_temp().unwrap();
{
let mut session = backend.test_session();
let stream = session
.sql("CREATE TABLE myschema.mytable (mycolumn BIGINT NOT NULL, PRIMARY KEY(mycolumn));")
.await
.unwrap();
let batches = stream.try_collect::<_, _, Vec<_>>().await.unwrap();
assert_batches_eq!(
r"
++
++
",
&batches,
);
}
drop(backend);
let backend = kanto_backend_rocksdb::open(dir.path()).expect("database open");
{
let mut session = kanto::Session::test_session(Box::new(backend));
let stream = session
.sql("INSERT myschema.mytable (mycolumn) VALUES (42);")
.await
.unwrap();
let batches = stream.try_collect::<_, _, Vec<_>>().await.unwrap();
assert_batches_eq!(
r"
++
++
",
&batches,
);
}
Ok(())
}

View file

@ -0,0 +1,29 @@
#![expect(
missing_docs,
reason = "rustc lint bug <https://github.com/rust-lang/rust/issues/137561> #ecosystem/rust #waiting"
)]
use futures_lite::StreamExt as _;
use kanto_testutil::assert_batches_eq;
#[test_log::test(tokio::test(flavor = "multi_thread"))]
async fn select_constant() -> Result<(), String> {
let (_dir, backend) = kanto_backend_rocksdb::create_temp().unwrap();
let mut session = backend.test_session();
let stream = session
.sql("SELECT 'Hello, world' AS greeting;")
.await
.unwrap();
let batches = stream.try_collect::<_, _, Vec<_>>().await.unwrap();
assert_batches_eq!(
r"
+--------------+
| greeting |
+--------------+
| Hello, world |
+--------------+
",
&batches,
);
Ok(())
}

View file

@ -0,0 +1,184 @@
#![expect(
missing_docs,
reason = "rustc lint bug <https://github.com/rust-lang/rust/issues/137561> #ecosystem/rust #waiting"
)]
use futures_lite::stream::StreamExt as _;
use kanto_testutil::assert_batches_eq;
#[test_log::test(tokio::test(flavor = "multi_thread"))]
async fn transaction_implicit_rollback_create_table() {
let (_dir, backend) = kanto_backend_rocksdb::create_temp().unwrap();
let mut session = backend.test_session();
{
let stream = session
.sql("START TRANSACTION ISOLATION LEVEL REPEATABLE READ;")
.await
.unwrap();
let batches = stream.try_collect::<_, _, Vec<_>>().await.unwrap();
assert_batches_eq!(
r"
++
++
",
&batches,
);
}
{
let stream = session
.sql("CREATE TABLE myschema.mytable (mycolumn BIGINT NOT NULL, PRIMARY KEY(mycolumn));")
.await
.unwrap();
let batches = stream.try_collect::<_, _, Vec<_>>().await.unwrap();
assert_batches_eq!(
r"
++
++
",
&batches,
);
}
{
let stream = session.sql("ROLLBACK;").await.unwrap();
let batches = stream.try_collect::<_, _, Vec<_>>().await.unwrap();
assert_batches_eq!(
r"
++
++
",
&batches,
);
}
{
let stream = session
.sql(
"SELECT * FROM kanto.information_schema.tables WHERE table_schema<>'information_schema';",
)
.await
.unwrap();
let batches = stream.try_collect::<_, _, Vec<_>>().await.unwrap();
assert_batches_eq!(
r"
++
++
",
&batches,
);
}
}
#[test_log::test(tokio::test(flavor = "multi_thread"))]
async fn transaction_changes_not_concurrently_visible() {
let (_dir, backend) = kanto_backend_rocksdb::create_temp().unwrap();
let mut session_a = backend.test_session();
let mut session_b = backend.test_session();
{
let stream = session_a
.sql("CREATE TABLE myschema.mytable (mycolumn BIGINT NOT NULL, PRIMARY KEY(mycolumn));")
.await
.unwrap();
let batches = stream.try_collect::<_, _, Vec<_>>().await.unwrap();
assert_batches_eq!(
r"
++
++
",
&batches,
);
}
{
let stream = session_a
.sql("START TRANSACTION ISOLATION LEVEL REPEATABLE READ;")
.await
.unwrap();
let batches = stream.try_collect::<_, _, Vec<_>>().await.unwrap();
assert_batches_eq!(
r"
++
++
",
&batches,
);
}
{
let stream = session_a
.sql("INSERT INTO kanto.myschema.mytable (mycolumn) VALUES (42), (13);")
.await
.unwrap();
let batches = stream.try_collect::<_, _, Vec<_>>().await.unwrap();
assert_batches_eq!(
r"
++
++
",
&batches,
);
}
{
let stream = session_b
.sql("SELECT * FROM kanto.myschema.mytable;")
.await
.unwrap();
let batches = stream.try_collect::<_, _, Vec<_>>().await.unwrap();
assert_batches_eq!(
r"
++
++
",
&batches,
);
}
{
let stream = session_b
.sql("START TRANSACTION ISOLATION LEVEL REPEATABLE READ;")
.await
.unwrap();
let batches = stream.try_collect::<_, _, Vec<_>>().await.unwrap();
assert_batches_eq!(
r"
++
++
",
&batches,
);
}
{
let stream = session_a.sql("COMMIT;").await.unwrap();
let batches = stream.try_collect::<_, _, Vec<_>>().await.unwrap();
assert_batches_eq!(
r"
++
++
",
&batches,
);
}
{
// Snapshot isolation -> won't see values even if they are committed.
let stream = session_b
.sql("SELECT * FROM kanto.myschema.mytable;")
.await
.unwrap();
let batches = stream.try_collect::<_, _, Vec<_>>().await.unwrap();
assert_batches_eq!(
r"
++
++
",
&batches,
);
}
}

View file

@ -0,0 +1,21 @@
[package]
name = "kanto-index-format-v1"
version = "0.1.0"
description = "Low-level data format helper for KantoDB database"
homepage.workspace = true
repository.workspace = true
authors.workspace = true
license.workspace = true
publish = false # TODO publish crate #severity/high #urgency/medium
edition.workspace = true
rust-version.workspace = true
[dependencies]
datafusion = { workspace = true }
kanto-key-format-v1 = { workspace = true }
kanto-meta-format-v1 = { workspace = true }
rkyv = { workspace = true }
thiserror = { workspace = true }
[lints]
workspace = true

View file

@ -0,0 +1,529 @@
//! Internal storage formats for KantoDB.
//!
//! Unless you are implementing a `kanto::Backend`, you should **not** be using these directly.
//!
//! The entry point to this module is [`Index::new`].
use std::borrow::Cow;
use std::sync::Arc;
use datafusion::arrow::array::Array;
use datafusion::arrow::record_batch::RecordBatch;
use kanto_meta_format_v1::index_def::ArchivedExpr;
use kanto_meta_format_v1::index_def::ArchivedIndexDef;
use kanto_meta_format_v1::index_def::ArchivedOrderByExpr;
use kanto_meta_format_v1::table_def::ArchivedFieldDef;
use kanto_meta_format_v1::table_def::ArchivedTableDef;
use kanto_meta_format_v1::FieldId;
/// Initialization-time errors that are about the schema, not about the data.
#[derive(thiserror::Error, Debug)]
#[expect(
clippy::exhaustive_enums,
reason = "low-level package, callers will have to adapt to any change"
)]
pub enum IndexInitError {
#[expect(missing_docs)]
#[error(
"index expression field not found: {field_id}\nindex_def={index_def_debug:#?}\ntable_def={table_def_debug:#?}"
)]
IndexExprFieldIdNotFound {
field_id: FieldId,
index_def_debug: String,
table_def_debug: String,
},
}
/// Execution-time errors that are about the input data.
#[derive(thiserror::Error, Debug)]
#[expect(
clippy::exhaustive_enums,
reason = "low-level package, callers will have to adapt to any change"
)]
pub enum IndexExecutionError {
#[expect(missing_docs)]
#[error("indexed field missing from record batch: {field_name}")]
IndexedFieldNotInBatch { field_name: String },
}
/// Used for indexes where there can only be one value.
///
/// This happens with
///
/// - `UNIQUE` `NULLS NOT DISTINCT` index
/// - `UNIQUE` `NULLS DISTINCT` index where all indexed fields are `NOT NULL`
#[derive(rkyv::Archive, rkyv::Serialize)]
struct IndexValueUnique<'data> {
#[rkyv(with = rkyv::with::AsOwned)]
row_key: Cow<'data, [u8]>,
// TODO index included fields (#ref/unimplemented_sql/ip415b5s8sa6h)
}
/// Used for indexes where there may be multiple values.
/// In that case, the row key is stored in the key, allowing multiple values for the "same" index key.
///
/// This happens with
///
/// - `NOT UNIQUE` index
/// - `UNIQUE` `NULLS DISTINCT` index where indexed field can be `NULL`
#[derive(rkyv::Archive, rkyv::Serialize)]
struct IndexValueMulti<'data> {
/// Duplicated from key for ease of access.
// TODO remove?
#[rkyv(with = rkyv::with::AsOwned)]
row_key: Cow<'data, [u8]>,
// TODO index included fields (#ref/unimplemented_sql/ip415b5s8sa6h)
}
fn make_index_key_source_arrays(
field_defs: &[&ArchivedFieldDef],
record_batch: &RecordBatch,
) -> Result<Vec<Arc<dyn Array>>, IndexExecutionError> {
field_defs
.iter()
.map(|field| {
let want_name = &field.field_name;
let array = record_batch.column_by_name(want_name).ok_or_else(|| {
IndexExecutionError::IndexedFieldNotInBatch {
field_name: want_name.as_str().to_owned(),
}
})?;
Ok(array.clone())
})
.collect::<Result<Vec<_>, _>>()
}
/// Library for making index keys and values.
#[expect(clippy::exhaustive_enums)]
pub enum Index<'def> {
/// See [`UniqueIndex`].
Unique(UniqueIndex<'def>),
/// See [`MultiIndex`].
Multi(MultiIndex<'def>),
}
impl<'def> Index<'def> {
fn get_field<'table>(
table_def: &'table ArchivedTableDef,
index_def: &ArchivedIndexDef,
order_by_expr: &ArchivedOrderByExpr,
) -> Result<&'table ArchivedFieldDef, IndexInitError> {
match &order_by_expr.expr {
ArchivedExpr::Field { field_id } => {
let field = table_def.fields.get(field_id).ok_or_else(|| {
IndexInitError::IndexExprFieldIdNotFound {
field_id: field_id.to_native(),
index_def_debug: format!("{index_def:?}"),
table_def_debug: format!("{table_def:?}"),
}
})?;
Ok(field)
}
}
}
/// Make an index helper for the given `index_def` belonging to the given `table_def`.
///
/// # Examples
///
/// ```rust compile_fail
/// # // TODO `TableDef` is too complex to write out fully #dev #doc
/// let index = kanto_index_format_v1::Index::new(table_def, index_def);
/// match index {
/// kanto_index_format_v1::Index::Unique(unique_index) => {
/// let keys = unique_index.make_unique_index_keys(&record_batch);
/// }
/// kanto_index_format_v1::Index::Multi(multi_index) => { ... }
/// }
/// ```
pub fn new(
table_def: &'def ArchivedTableDef,
index_def: &'def ArchivedIndexDef,
) -> Result<Index<'def>, IndexInitError> {
debug_assert_eq!(index_def.table, table_def.table_id);
let field_defs = index_def
.columns
.iter()
.map(|order_by_expr| Self::get_field(table_def, index_def, order_by_expr))
.collect::<Result<Box<[_]>, _>>()?;
let has_nullable = field_defs.iter().any(|field| field.nullable);
let can_use_unique_index = match &index_def.index_kind {
kanto_meta_format_v1::index_def::ArchivedIndexKind::Unique(unique_index_options) => {
!unique_index_options.nulls_distinct || !has_nullable
}
kanto_meta_format_v1::index_def::ArchivedIndexKind::Multi => false,
};
if can_use_unique_index {
Ok(Index::Unique(UniqueIndex::new(
table_def, index_def, field_defs,
)))
} else {
Ok(Index::Multi(MultiIndex::new(
table_def, index_def, field_defs,
)))
}
}
}
/// A unique index (with no complications).
/// Each key stores exactly one value, which refers to a single row key.
pub struct UniqueIndex<'def> {
_table_def: &'def ArchivedTableDef,
index_def: &'def ArchivedIndexDef,
field_defs: Box<[&'def ArchivedFieldDef]>,
}
impl<'def> UniqueIndex<'def> {
#[must_use]
const fn new(
table_def: &'def ArchivedTableDef,
index_def: &'def ArchivedIndexDef,
field_defs: Box<[&'def ArchivedFieldDef]>,
) -> UniqueIndex<'def> {
UniqueIndex {
_table_def: table_def,
index_def,
field_defs,
}
}
/// Make index keys for the `record_batch`.
pub fn make_unique_index_keys(
&self,
record_batch: &RecordBatch,
) -> Result<datafusion::arrow::array::BinaryArray, IndexExecutionError> {
let index_key_source_arrays = make_index_key_source_arrays(&self.field_defs, record_batch)?;
let keys = kanto_key_format_v1::make_keys(&index_key_source_arrays);
Ok(keys)
}
/// Make the value to store in the index for this `row_key`.
#[must_use]
pub fn index_value(&self, row_key: &[u8]) -> Box<[u8]> {
assert!(
self.index_def.include.is_empty(),
"index included fields (#ref/unimplemented_sql/ip415b5s8sa6h)"
);
assert!(
self.index_def.predicate.is_none(),
"partial indexes (#ref/unimplemented_sql/9134bk3fe98x6)"
);
let value = IndexValueUnique {
row_key: Cow::from(row_key),
};
let bytes = rkyv::to_bytes::<rkyv::rancor::Panic>(&value).unwrap();
bytes.into_boxed_slice()
}
}
/// A multi-value index, where multiple records have the same key prefix.
/// To read record references from the index, a prefix scan is needed.
pub struct MultiIndex<'def> {
_table_def: &'def ArchivedTableDef,
index_def: &'def ArchivedIndexDef,
field_defs: Box<[&'def ArchivedFieldDef]>,
}
impl<'def> MultiIndex<'def> {
#[must_use]
const fn new(
table_def: &'def ArchivedTableDef,
index_def: &'def ArchivedIndexDef,
field_defs: Box<[&'def ArchivedFieldDef]>,
) -> MultiIndex<'def> {
MultiIndex {
_table_def: table_def,
index_def,
field_defs,
}
}
/// Make index keys for the `record_batch` with matching `row_keys`.
pub fn make_multi_index_keys(
&self,
record_batch: &RecordBatch,
row_keys: &datafusion::arrow::array::BinaryArray,
) -> Result<datafusion::arrow::array::BinaryArray, IndexExecutionError> {
let mut index_key_source_arrays =
make_index_key_source_arrays(&self.field_defs, record_batch)?;
index_key_source_arrays.push(Arc::new(row_keys.clone()));
let keys = kanto_key_format_v1::make_keys(&index_key_source_arrays);
Ok(keys)
}
/// Make the value to store in the index for this `row_key`.
#[must_use]
pub fn index_value(&self, row_key: &[u8]) -> Box<[u8]> {
assert!(
self.index_def.include.is_empty(),
"index included fields (#ref/unimplemented_sql/ip415b5s8sa6h)"
);
assert!(
self.index_def.predicate.is_none(),
"partial indexes (#ref/unimplemented_sql/9134bk3fe98x6)"
);
let value = IndexValueMulti {
row_key: Cow::from(row_key),
};
let bytes = rkyv::to_bytes::<rkyv::rancor::Panic>(&value).unwrap();
bytes.into_boxed_slice()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
#[expect(clippy::too_many_lines)]
fn unique_index() {
let fields = [
kanto_meta_format_v1::table_def::FieldDef {
field_id: FieldId::new(2).unwrap(),
added_in: kanto_meta_format_v1::table_def::TableDefGeneration::new(7).unwrap(),
deleted_in: None,
field_name: "a".to_owned(),
field_type: kanto_meta_format_v1::field_type::FieldType::U16,
nullable: false,
default_value_arrow_bytes: None,
},
kanto_meta_format_v1::table_def::FieldDef {
field_id: FieldId::new(3).unwrap(),
added_in: kanto_meta_format_v1::table_def::TableDefGeneration::new(8).unwrap(),
deleted_in: None,
field_name: "c".to_owned(),
field_type: kanto_meta_format_v1::field_type::FieldType::U32,
nullable: false,
default_value_arrow_bytes: None,
},
];
let table_def = kanto_meta_format_v1::table_def::TableDef {
table_id: kanto_meta_format_v1::TableId::new(42).unwrap(),
catalog: "mycatalog".to_owned(),
schema: "myschema".to_owned(),
table_name: "mytable".to_owned(),
generation: kanto_meta_format_v1::table_def::TableDefGeneration::new(13).unwrap(),
row_key_kind: kanto_meta_format_v1::table_def::RowKeyKind::PrimaryKeys(
kanto_meta_format_v1::table_def::PrimaryKeysDef {
field_ids: kanto_meta_format_v1::table_def::PrimaryKeys::try_from(vec![
fields[0].field_id,
])
.unwrap(),
},
),
fields: fields
.into_iter()
.map(|field| (field.field_id, field))
.collect(),
};
let owned_archived_table_def = kanto_meta_format_v1::owned_archived(&table_def).unwrap();
let index_def = kanto_meta_format_v1::index_def::IndexDef {
index_id: kanto_meta_format_v1::IndexId::new(42).unwrap(),
table: table_def.table_id,
index_name: "myindex".to_owned(),
columns: vec![
kanto_meta_format_v1::index_def::OrderByExpr {
expr: kanto_meta_format_v1::index_def::Expr::Field {
field_id: table_def.fields[1].field_id,
},
sort_order: kanto_meta_format_v1::index_def::SortOrder::Ascending,
null_order: kanto_meta_format_v1::index_def::NullOrder::NullsLast,
},
kanto_meta_format_v1::index_def::OrderByExpr {
expr: kanto_meta_format_v1::index_def::Expr::Field {
field_id: table_def.fields[0].field_id,
},
sort_order: kanto_meta_format_v1::index_def::SortOrder::Ascending,
null_order: kanto_meta_format_v1::index_def::NullOrder::NullsLast,
},
],
index_kind: kanto_meta_format_v1::index_def::IndexKind::Unique(
kanto_meta_format_v1::index_def::UniqueIndexOptions {
nulls_distinct: false,
},
),
include: vec![],
predicate: None,
};
let owned_archived_index_def = kanto_meta_format_v1::owned_archived(&index_def).unwrap();
let index = Index::new(&owned_archived_table_def, &owned_archived_index_def).unwrap();
// TODO `debug_assert_matches!` <https://doc.rust-lang.org/std/assert_matches/macro.debug_assert_matches.html>, <https://github.com/rust-lang/rust/issues/82775> #waiting #ecosystem/rust #severity/low #dev
assert!(matches!(index, Index::Unique(..)));
let unique_index = match index {
Index::Unique(unique_index) => unique_index,
Index::Multi(_) => panic!("index was supposed to be unique"),
};
{
let a: datafusion::arrow::array::ArrayRef =
Arc::new(datafusion::arrow::array::UInt16Array::from(vec![3, 4, 5]));
let c: datafusion::arrow::array::ArrayRef =
Arc::new(datafusion::arrow::array::UInt32Array::from(vec![
11, 10, 12,
]));
let record_batch = RecordBatch::try_from_iter([("a", a), ("c", c)]).unwrap();
let keys = unique_index.make_unique_index_keys(&record_batch).unwrap();
let got = keys
.into_iter()
// TODO `BinaryArray` non-nullable iteration: the `Option` here is ugly #ref/4b6891jiy4t5a #ecosystem/datafusion #severity/medium #urgency/low
.map(|option| option.unwrap())
.collect::<Vec<_>>();
assert_eq!(
&got,
&[
&[
1u8, 0, 0, 0, 11, // c
1u8, 0, 3, // a
],
&[
1u8, 0, 0, 0, 10, // c
1u8, 0, 4, // a
],
&[
1u8, 0, 0, 0, 12, // c
1u8, 0, 5, // a
],
],
);
}
{
const FAKE_ROW_KEY: &[u8] = b"fakerowkey";
let bytes = unique_index.index_value(FAKE_ROW_KEY);
let got_index_value =
rkyv::access::<rkyv::Archived<IndexValueUnique<'_>>, rkyv::rancor::Panic>(&bytes)
.unwrap();
assert_eq!(got_index_value.row_key.as_ref(), FAKE_ROW_KEY);
}
}
#[test]
#[expect(clippy::too_many_lines)]
fn multi_index() {
let fields = [
kanto_meta_format_v1::table_def::FieldDef {
field_id: FieldId::new(2).unwrap(),
added_in: kanto_meta_format_v1::table_def::TableDefGeneration::new(7).unwrap(),
deleted_in: None,
field_name: "a".to_owned(),
field_type: kanto_meta_format_v1::field_type::FieldType::U16,
nullable: false,
default_value_arrow_bytes: None,
},
kanto_meta_format_v1::table_def::FieldDef {
field_id: FieldId::new(3).unwrap(),
added_in: kanto_meta_format_v1::table_def::TableDefGeneration::new(8).unwrap(),
deleted_in: None,
field_name: "c".to_owned(),
field_type: kanto_meta_format_v1::field_type::FieldType::U32,
nullable: false,
default_value_arrow_bytes: None,
},
];
let table_def = kanto_meta_format_v1::table_def::TableDef {
table_id: kanto_meta_format_v1::TableId::new(42).unwrap(),
catalog: "mycatalog".to_owned(),
schema: "myschema".to_owned(),
table_name: "mytable".to_owned(),
generation: kanto_meta_format_v1::table_def::TableDefGeneration::new(13).unwrap(),
row_key_kind: kanto_meta_format_v1::table_def::RowKeyKind::PrimaryKeys(
kanto_meta_format_v1::table_def::PrimaryKeysDef {
field_ids: kanto_meta_format_v1::table_def::PrimaryKeys::try_from(vec![
fields[0].field_id,
])
.unwrap(),
},
),
fields: fields
.into_iter()
.map(|field| (field.field_id, field))
.collect(),
};
let owned_archived_table_def = kanto_meta_format_v1::owned_archived(&table_def).unwrap();
let index_def = kanto_meta_format_v1::index_def::IndexDef {
index_id: kanto_meta_format_v1::IndexId::new(42).unwrap(),
table: table_def.table_id,
index_name: "myindex".to_owned(),
columns: vec![
kanto_meta_format_v1::index_def::OrderByExpr {
expr: kanto_meta_format_v1::index_def::Expr::Field {
field_id: table_def.fields[1].field_id,
},
sort_order: kanto_meta_format_v1::index_def::SortOrder::Ascending,
null_order: kanto_meta_format_v1::index_def::NullOrder::NullsLast,
},
kanto_meta_format_v1::index_def::OrderByExpr {
expr: kanto_meta_format_v1::index_def::Expr::Field {
field_id: table_def.fields[0].field_id,
},
sort_order: kanto_meta_format_v1::index_def::SortOrder::Ascending,
null_order: kanto_meta_format_v1::index_def::NullOrder::NullsLast,
},
],
index_kind: kanto_meta_format_v1::index_def::IndexKind::Multi,
include: vec![],
predicate: None,
};
let owned_archived_index_def = kanto_meta_format_v1::owned_archived(&index_def).unwrap();
let index = Index::new(&owned_archived_table_def, &owned_archived_index_def).unwrap();
// TODO `debug_assert_matches!` <https://doc.rust-lang.org/std/assert_matches/macro.debug_assert_matches.html>, <https://github.com/rust-lang/rust/issues/82775> #waiting #ecosystem/rust #severity/low #dev
assert!(matches!(index, Index::Multi(..)));
let multi_index = match index {
Index::Unique(_) => panic!("index was supposed to be multi"),
Index::Multi(multi_index) => multi_index,
};
{
let a: datafusion::arrow::array::ArrayRef =
Arc::new(datafusion::arrow::array::UInt16Array::from(vec![3, 4, 5]));
let c: datafusion::arrow::array::ArrayRef =
Arc::new(datafusion::arrow::array::UInt32Array::from(vec![
11, 10, 12,
]));
let record_batch = RecordBatch::try_from_iter([("a", a.clone()), ("c", c)]).unwrap();
let row_keys =
datafusion::arrow::array::BinaryArray::from_vec(vec![b"foo", b"bar", b"quux"]);
let keys = multi_index
.make_multi_index_keys(&record_batch, &row_keys)
.unwrap();
let got = keys
.into_iter()
// TODO `BinaryArray` non-nullable iteration: the `Option` here is ugly #ref/4b6891jiy4t5a #ecosystem/datafusion #severity/medium #urgency/low
.map(|option| option.unwrap())
.collect::<Vec<_>>();
assert_eq!(
&got,
&[
&[
1u8, 0, 0, 0, 11, // c
1, 0, 3, // a
2, b'f', b'o', b'o', 0, 0, 0, 0, 0, 3 // row key
],
&[
1u8, 0, 0, 0, 10, // c
1, 0, 4, // a
2, b'b', b'a', b'r', 0, 0, 0, 0, 0, 3 // row key
],
&[
1u8, 0, 0, 0, 12, // c
1, 0, 5, // a
2, b'q', b'u', b'u', b'x', 0, 0, 0, 0, 4 // row key
],
],
);
}
{
const FAKE_ROW_KEY: &[u8] = b"fakerowkey";
let bytes = multi_index.index_value(FAKE_ROW_KEY);
let got_index_value =
rkyv::access::<rkyv::Archived<IndexValueMulti<'_>>, rkyv::rancor::Panic>(&bytes)
.unwrap();
assert_eq!(got_index_value.row_key.as_ref(), FAKE_ROW_KEY);
}
}
}

51
crates/kanto/Cargo.toml Normal file
View file

@ -0,0 +1,51 @@
[package]
name = "kanto"
version = "0.1.0"
description = "KantoDB SQL database as a library"
keywords = ["kantodb", "sql", "database"]
categories = ["database-implementations"]
homepage.workspace = true
repository.workspace = true
authors.workspace = true
license.workspace = true
publish = false # TODO publish crate #severity/high #urgency/medium
edition.workspace = true
rust-version.workspace = true
[dependencies]
async-trait = { workspace = true }
datafusion = { workspace = true }
futures-lite = { workspace = true }
kanto-index-format-v1 = { workspace = true }
kanto-meta-format-v1 = { workspace = true }
kanto-record-format-v1 = { workspace = true }
kanto-tunables = { workspace = true }
maybe-tracing = { workspace = true }
parking_lot = { workspace = true }
rkyv = { workspace = true }
smallvec = { workspace = true }
thiserror = { workspace = true }
tracing = { workspace = true }
[dev-dependencies]
hex = { workspace = true }
kanto-backend-rocksdb = { workspace = true }
libtest-mimic = { workspace = true }
regex = { workspace = true }
sqllogictest = { workspace = true }
tokio = { workspace = true }
tracing-subscriber = { workspace = true }
walkdir = { workspace = true }
[lints]
workspace = true
[[test]]
name = "sqllogictest"
path = "tests/sqllogictest.rs"
harness = false
[[test]]
name = "sqllogictest_sqlite"
path = "tests/sqllogictest_sqlite.rs"
harness = false

View file

@ -0,0 +1,80 @@
use datafusion::sql::sqlparser;
use crate::KantoError;
use crate::Transaction;
/// A database backend stores & manages data, implementing almost all things related to the data it stores, such as low-level data formats, transactions, indexes, and so on.
#[async_trait::async_trait]
pub trait Backend: Send + Sync + std::fmt::Debug + DynEq + DynHash + std::any::Any {
/// Return the `Backend` as a [`std::any::Any`].
fn as_any(&self) -> &dyn std::any::Any;
/// Check equality between self and other.
fn is_equal(&self, other: &dyn Backend) -> bool {
self.dyn_eq(other.as_any())
}
/// Start a new transaction.
/// All data manipulation and access happens inside a transaction.
///
/// Transactions do not nest, but savepoints can be used to same effect.
async fn start_transaction(
&self,
access_mode: sqlparser::ast::TransactionAccessMode,
isolation_level: sqlparser::ast::TransactionIsolationLevel,
) -> Result<Box<dyn Transaction>, KantoError>;
/// Kludge to make a `Box<dyn Backend>` cloneable.
fn clone_box(&self) -> Box<dyn Backend>;
}
impl Clone for Box<dyn Backend> {
fn clone(&self) -> Box<dyn Backend> {
self.clone_box()
}
}
/// To implement this helper trait, derive or implement [`Eq`] for your concrete type.
#[expect(missing_docs)]
pub trait DynEq {
fn dyn_eq(&self, other: &dyn std::any::Any) -> bool;
}
impl<T: PartialEq + Eq + 'static> DynEq for T {
fn dyn_eq(&self, other: &dyn std::any::Any) -> bool {
if let Some(other) = other.downcast_ref::<Self>() {
self == other
} else {
false
}
}
}
/// To implement this helper trait, derive or implement [`Hash`] for your concrete type.
#[expect(missing_docs)]
pub trait DynHash {
fn dyn_hash(&self, state: &mut dyn std::hash::Hasher);
}
impl<T: std::hash::Hash> DynHash for T {
fn dyn_hash(&self, mut state: &mut dyn std::hash::Hasher) {
self.hash(&mut state);
}
}
impl std::hash::Hash for dyn Backend {
fn hash<H>(&self, state: &mut H)
where
H: std::hash::Hasher,
{
self.dyn_hash(state);
}
}
impl PartialEq for dyn Backend {
fn eq(&self, other: &Self) -> bool {
self.is_equal(other)
}
}
impl Eq for dyn Backend {}

View file

@ -0,0 +1,7 @@
//! Defaults for things that can be changed in configuration or at runtime.
/// Default SQL catalog name when omitted.
pub const DEFAULT_CATALOG_NAME: &str = "kanto";
/// Default SQL schema name when omitted.
pub const DEFAULT_SCHEMA_NAME: &str = "public";

494
crates/kanto/src/error.rs Normal file
View file

@ -0,0 +1,494 @@
//! Error types for KantoDB.
use datafusion::sql::sqlparser;
use kanto_meta_format_v1::FieldId;
use kanto_meta_format_v1::IndexId;
use kanto_meta_format_v1::SequenceId;
use kanto_meta_format_v1::TableId;
#[derive(Debug)]
#[non_exhaustive]
#[expect(missing_docs)]
pub enum Op {
Init,
CreateTable,
CreateIndex,
IterTables,
LookupName {
name_ref: datafusion::sql::ResolvedTableReference,
},
SaveName {
name_ref: datafusion::sql::ResolvedTableReference,
},
LoadTableDef {
table_id: TableId,
},
IterIndexes {
table_id: TableId,
},
LoadIndexDef {
index_id: IndexId,
},
TableScan {
table_id: TableId,
},
LoadRecord {
table_id: TableId,
row_key: Box<[u8]>,
},
InsertRecords {
table_id: TableId,
},
InsertRecord {
table_id: TableId,
row_key: Box<[u8]>,
},
UpdateRecords {
table_id: TableId,
},
UpdateRecord {
table_id: TableId,
row_key: Box<[u8]>,
},
DeleteRecords {
table_id: TableId,
},
DeleteRecord {
table_id: TableId,
row_key: Box<[u8]>,
},
InsertIndex {
table_id: TableId,
index_id: IndexId,
},
DeleteIndex {
table_id: TableId,
index_id: IndexId,
},
SequenceUpdate {
sequence_id: SequenceId,
},
Rollback,
Commit,
}
impl std::fmt::Display for Op {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Op::Init => {
write!(f, "initializing")
}
Op::IterTables => {
write!(f, "iterate table definitions")
}
Op::LookupName { name_ref } => {
write!(f, "lookup name definition: {name_ref}")
}
Op::SaveName { name_ref } => {
write!(f, "save name definition: {name_ref}")
}
Op::LoadTableDef { table_id } => {
write!(f, "load table definition: {table_id}")
}
Op::IterIndexes { table_id } => {
write!(f, "iterate index definitions: table_id={table_id}")
}
Op::LoadIndexDef { index_id } => {
write!(f, "load index definition: {index_id}")
}
Op::TableScan { table_id } => {
write!(f, "table scan: table_id={table_id}")
}
Op::CreateTable => {
write!(f, "create table")
}
Op::CreateIndex => {
write!(f, "create index")
}
Op::LoadRecord {
table_id,
row_key: _,
} => {
// `row_key` is a lot of noise
write!(f, "load record: table_id={table_id}")
}
Op::InsertRecords { table_id } => {
write!(f, "insert records: table_id={table_id}")
}
Op::InsertRecord {
table_id,
row_key: _,
} => {
// `row_key` is a lot of noise
write!(f, "insert record: table_id={table_id}")
}
Op::UpdateRecords { table_id } => {
write!(f, "update record: table_id={table_id}")
}
Op::UpdateRecord {
table_id,
row_key: _,
} => {
// `row_key` is a lot of noise
write!(f, "update record: table_id={table_id}")
}
Op::DeleteRecords { table_id } => {
write!(f, "delete records: table_id={table_id}")
}
Op::DeleteRecord {
table_id,
row_key: _,
} => {
// `row_key` is a lot of noise
write!(f, "delete record: table_id={table_id}")
}
Op::InsertIndex { table_id, index_id } => {
write!(
f,
"insert into index: table_id={table_id}, index_id={index_id}"
)
}
Op::DeleteIndex { table_id, index_id } => {
write!(
f,
"delete from index: table_id={table_id}, index_id={index_id}"
)
}
Op::SequenceUpdate { sequence_id } => {
write!(f, "update sequence: sequence_id={sequence_id}")
}
Op::Rollback => {
write!(f, "rollback")
}
Op::Commit => {
write!(f, "commit")
}
}
}
}
#[derive(thiserror::Error, Debug)]
#[non_exhaustive]
#[expect(clippy::module_name_repetitions)]
#[expect(missing_docs)]
pub enum CorruptError {
#[error("invalid table id: {table_id}")]
InvalidTableId { table_id: u64 },
#[error("invalid field id: {field_id}\ntable_def={table_def_debug:#?}")]
InvalidFieldId {
field_id: u64,
table_def_debug: String,
},
#[error("primary key field not found: {field_id}\ntable_def={table_def_debug:#?}")]
PrimaryKeyFieldIdNotFound {
field_id: FieldId,
table_def_debug: String,
},
#[error(
"primary key field missing from record batch: {field_id} {field_name:?}\ntable_def={table_def_debug:#?}\nrecord_batch_schema={record_batch_schema_debug}"
)]
PrimaryKeyFieldMissingFromRecordBatch {
field_id: FieldId,
field_name: String,
table_def_debug: String,
record_batch_schema_debug: String,
},
#[error("invalid table definition: {table_id}\ntable_def={table_def_debug:#?}")]
InvalidTableDef {
table_id: TableId,
table_def_debug: String,
},
#[error("invalid index id: index_id={index_id}\nindex_def={index_def_debug:#?}")]
InvalidIndexId {
index_id: u64,
index_def_debug: String,
},
#[error(
"invalid sequence id: sequence_id={sequence_id}\nsequence_def={sequence_def_debug:#?}"
)]
InvalidSequenceId {
sequence_id: u64,
sequence_def_debug: String,
},
#[error("corrupt sequence: sequence_id={sequence_id}")]
SequenceData { sequence_id: SequenceId },
#[error("invalid UTF-8 in {access}: {error}")]
InvalidUtf8 {
access: &'static str,
#[source]
error: std::str::Utf8Error,
},
#[error("invalid record value: {error}")]
InvalidRecord {
#[source]
error: kanto_record_format_v1::error::RecordCorruptError,
},
#[error("index error: {error}")]
Index {
#[source]
error: kanto_index_format_v1::IndexInitError,
},
#[error("table definition validate: {error}")]
Validate {
#[source]
error: kanto_meta_format_v1::table_def::TableDefValidateError,
},
#[error("rkyv error: {error}")]
Rkyv {
#[source]
error: rkyv::rancor::BoxedError,
},
}
#[derive(thiserror::Error, Debug)]
#[non_exhaustive]
#[expect(clippy::module_name_repetitions)]
#[expect(missing_docs)]
pub enum InitError {
#[error("unrecognized database format")]
UnrecognizedDatabaseFormat,
}
#[derive(thiserror::Error, Debug)]
#[non_exhaustive]
#[expect(clippy::module_name_repetitions)]
#[expect(missing_docs)]
pub enum ExecutionError {
#[error("table exists already: {table_ref}")]
TableExists {
table_ref: datafusion::sql::ResolvedTableReference,
},
#[error("sequence exhausted: sequence_id={sequence_id}")]
SequenceExhausted { sequence_id: SequenceId },
#[error("table definition has too many fields")]
TooManyFields,
#[error("conflict on primary key")]
ConflictOnPrimaryKey,
#[error("conflict on unique index")]
ConflictOnUniqueIndex { index_id: IndexId },
#[error("not nullable: column_name={column_name}")]
NotNullable { column_name: String },
}
#[derive(thiserror::Error, Debug)]
#[non_exhaustive]
#[expect(clippy::module_name_repetitions)]
#[expect(missing_docs)]
pub enum PlanError {
#[error("transaction is aborted")]
TransactionAborted,
#[error("not in a transaction")]
NoTransaction,
#[error("savepoint not found: {savepoint_name}")]
SavepointNotFound {
savepoint_name: sqlparser::ast::Ident,
},
#[error("unrecognized sql dialect: {dialect_name}")]
UnrecognizedSqlDialect { dialect_name: String },
#[error("catalog not found: {catalog_name}")]
CatalogNotFound { catalog_name: String },
#[error("table not found: {table_ref}")]
TableNotFound {
table_ref: datafusion::sql::ResolvedTableReference,
},
#[error("column not found: {table_ref}.{column_ident}")]
ColumnNotFound {
table_ref: datafusion::sql::ResolvedTableReference,
column_ident: sqlparser::ast::Ident,
},
#[error("invalid table name: {table_name}")]
InvalidTableName {
table_name: sqlparser::ast::ObjectName,
},
#[error("primary key cannot be nullable: {column_name}")]
PrimaryKeyNullable { column_name: String },
#[error("primary key declared more than once")]
PrimaryKeyTooMany,
#[error("primary key referred to twice in table definition")]
PrimaryKeyDuplicate,
#[error("index must have name")]
IndexNameRequired,
#[error("invalid index name: {index_name}")]
InvalidIndexName {
index_name: sqlparser::ast::ObjectName,
},
#[error("index and table must be in same catalog and schema")]
IndexInDifferentSchema,
#[error("cannot set an unknown variable: {variable_name}")]
SetUnknownVariable {
variable_name: sqlparser::ast::ObjectName,
},
#[error("SET number of values does not match number of variables")]
SetWrongNumberOfValues,
#[error("SET value is wrong type: expected {expected} got {sql}", sql=format!("{}", .value).escape_default())]
SetValueIncorrectType {
expected: &'static str,
value: sqlparser::ast::Expr,
},
#[error("cannot combine ROLLBACK AND [NO] CHAIN and TO SAVEPOINT")]
RollbackAndChainToSavepoint,
#[error("cannot use CREATE TABLE .. ON COMMIT without TEMPORARY")]
CreateTableOnCommitWithoutTemporary,
#[error("cannot use CREATE {{ GLOBAL | LOCAL }} TABLE without TEMPORARY")]
CreateTableGlobalOrLocalWithoutTemporary,
#[error("cannot set nullable value to non-nullable column: column_name={column_name}")]
NotNullable { column_name: String },
}
#[derive(thiserror::Error, Debug)]
#[non_exhaustive]
#[expect(clippy::module_name_repetitions)]
#[expect(missing_docs)]
pub enum KantoError {
#[error("io_error/{code}: {op}: {error}")]
Io {
op: Op,
code: &'static str,
#[source]
error: std::io::Error,
},
#[error("data corruption/{code}: {op}: {error}")]
Corrupt {
op: Op,
code: &'static str,
#[source]
error: Box<CorruptError>,
},
#[error("initialization error/{code}: {error}")]
Init {
code: &'static str,
error: InitError,
},
#[error("sql/{code}: {error}")]
Sql {
code: &'static str,
#[source]
error: sqlparser::parser::ParserError,
},
#[error("plan/{code}: {error}")]
Plan {
code: &'static str,
#[source]
error: Box<PlanError>,
},
#[error("execution/{code}: {op}: {error}")]
Execution {
op: Op,
code: &'static str,
#[source]
error: Box<ExecutionError>,
},
#[error("unimplemented_sql/{code}: {sql_syntax}")]
UnimplementedSql {
code: &'static str,
sql_syntax: &'static str,
},
#[error("unimplemented/{code}: {message}")]
Unimplemented {
code: &'static str,
message: &'static str,
},
#[error("not_supported_sql/{code}: {sql_syntax}")]
NotSupportedSql {
code: &'static str,
sql_syntax: &'static str,
},
#[error("internal_error/{code}: {error}")]
Internal {
code: &'static str,
error: Box<dyn std::error::Error + Send + Sync>,
},
}
impl From<KantoError> for datafusion::error::DataFusionError {
fn from(error: KantoError) -> datafusion::error::DataFusionError {
match error {
KantoError::Io { .. } => {
datafusion::error::DataFusionError::IoError(std::io::Error::other(error))
}
KantoError::Sql { code, error } => {
// TODO ugly way to smuggle extra information
datafusion::error::DataFusionError::SQL(error, Some(format!(" (sql/{code})")))
}
KantoError::Plan { .. } => datafusion::error::DataFusionError::Plan(error.to_string()),
KantoError::Execution { .. } => {
datafusion::error::DataFusionError::Execution(error.to_string())
}
_ => datafusion::error::DataFusionError::External(Box::from(error)),
}
}
}
/// Error type to prefix a message to any error.
#[expect(clippy::exhaustive_structs)]
#[derive(thiserror::Error, Debug)]
#[error("{message}: {error}")]
pub struct Message {
/// Message to show before the wrapped error.
pub message: &'static str,
/// Underlying error.
#[source]
pub error: Box<dyn std::error::Error + Send + Sync>,
}
#[cfg(test)]
mod tests {
use super::*;
#[expect(
clippy::missing_const_for_fn,
reason = "TODO `const fn` to avoid clippy noise <https://github.com/rust-lang/rust-clippy/issues/13938> #waiting #ecosystem/rust"
)]
#[test]
fn impl_error_send_sync_static() {
const fn assert_error<E: std::error::Error + Send + Sync + 'static>() {}
assert_error::<KantoError>();
}
}

27
crates/kanto/src/lib.rs Normal file
View file

@ -0,0 +1,27 @@
//! Common functionality and data types for the Kanto database system.
mod backend;
pub mod defaults;
pub mod error;
mod session;
mod settings;
mod transaction;
mod transaction_context;
pub mod util;
// Re-exports of dependencies exposed by our public APIs.
pub use async_trait;
pub use datafusion;
pub use datafusion::arrow;
pub use datafusion::parquet;
pub use datafusion::sql::sqlparser;
pub use crate::backend::Backend;
pub use crate::backend::DynEq;
pub use crate::backend::DynHash;
pub use crate::error::KantoError;
pub use crate::session::Session;
pub(crate) use crate::settings::Settings;
pub use crate::transaction::Savepoint;
pub use crate::transaction::Transaction;
pub(crate) use crate::transaction_context::TransactionContext;

2250
crates/kanto/src/session.rs Normal file

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,8 @@
use std::sync::Arc;
use datafusion::sql::sqlparser;
#[derive(Debug, Default)]
pub(crate) struct Settings {
pub(crate) sql_dialect: Option<Arc<dyn sqlparser::dialect::Dialect + Send + Sync>>,
}

View file

@ -0,0 +1,102 @@
use std::sync::Arc;
use datafusion::sql::sqlparser;
use futures_lite::Stream;
use crate::KantoError;
/// Savepoint is a marker in a timeline of changes made within one transaction that the transaction can be rolled back to.
#[async_trait::async_trait]
pub trait Savepoint: Send + Sync {
/// Return the `Savepoint` as a [`std::any::Any`].
fn as_any(&self) -> &dyn std::any::Any;
/// Rollback to this savepoint, and release it.
///
/// This has to take a boxed self because dynamic traits cannot consume self in Rust (<https://stackoverflow.com/questions/46620790/how-to-call-a-method-that-consumes-self-on-a-boxed-trait-object>)
async fn rollback(self: Box<Self>) -> Result<(), KantoError>;
}
/// A database transaction is the primary way of manipulating and accessing data stores in a [`kanto::Backend`](crate::Backend).
#[async_trait::async_trait]
pub trait Transaction: Send + Sync {
/// Return the `Transaction` as a [`std::any::Any`].
fn as_any(&self) -> &dyn std::any::Any;
/// Create a [`Savepoint`] at this point in time.
async fn savepoint(&self) -> Result<Box<dyn Savepoint>, KantoError>;
/// Finish the transaction, either committing or abandoning the changes made in it.
async fn finish(
&self,
conclusion: &datafusion::logical_expr::TransactionConclusion,
) -> Result<(), KantoError>;
/// Find a table (that may have been altered in this same transaction).
async fn lookup_table(
&self,
table_ref: &datafusion::sql::ResolvedTableReference,
) -> Result<Option<Arc<dyn datafusion::catalog::TableProvider>>, KantoError>;
/// List all tables.
// TODO i wish Rust had good generators.
// This is really a `[(schema, [table])]`, but we'll keep it flat for simplicity.
#[expect(
clippy::type_complexity,
reason = "TODO would be nice to simplify the iterator item type somehow, but it is a fallible named-items iterator (#ref/i5rfyyqimjyxs) #severity/low #urgency/low"
)]
fn stream_tables(
&self,
) -> Box<
dyn Stream<
Item = Result<
(
datafusion::sql::ResolvedTableReference,
Box<dyn datafusion::catalog::TableProvider>,
),
KantoError,
>,
> + Send,
>;
/// Create a table.
async fn create_table(
&self,
table_ref: datafusion::sql::ResolvedTableReference,
create_table: sqlparser::ast::CreateTable,
) -> Result<(), KantoError>;
/// Create an index.
async fn create_index(
&self,
table_ref: datafusion::sql::ResolvedTableReference,
index_name: Arc<str>,
unique: bool,
nulls_distinct: bool,
columns: Vec<sqlparser::ast::OrderByExpr>,
) -> Result<(), KantoError>;
/// Insert records.
async fn insert(
&self,
session_state: datafusion::execution::SessionState,
table_ref: datafusion::sql::ResolvedTableReference,
input: Arc<datafusion::logical_expr::LogicalPlan>,
) -> Result<(), datafusion::error::DataFusionError>;
/// Delete records.
async fn delete(
&self,
session_state: datafusion::execution::SessionState,
table_ref: datafusion::sql::ResolvedTableReference,
input: Arc<datafusion::logical_expr::LogicalPlan>,
) -> Result<(), datafusion::error::DataFusionError>;
/// Update records.
async fn update(
&self,
session_state: datafusion::execution::SessionState,
table_ref: datafusion::sql::ResolvedTableReference,
input: Arc<datafusion::logical_expr::LogicalPlan>,
) -> Result<(), datafusion::error::DataFusionError>;
}

View file

@ -0,0 +1,158 @@
use std::collections::HashMap;
use std::sync::Arc;
use datafusion::sql::sqlparser;
use futures_lite::StreamExt as _;
use smallvec::SmallVec;
use crate::session::ImplicitOrExplicit;
use crate::util::trigger::Trigger;
use crate::Backend;
use crate::KantoError;
use crate::Settings;
use crate::Transaction;
enum SavepointKind {
NestedTx,
NamedSavepoint { name: sqlparser::ast::Ident },
}
pub(crate) struct SavepointState {
kind: SavepointKind,
savepoints: HashMap<Box<dyn Backend>, Box<dyn crate::Savepoint>>,
settings: Settings,
}
impl SavepointState {
pub(crate) async fn rollback(self) -> Result<(), KantoError> {
for backend_savepoint in self.savepoints.into_values() {
// On error, drop the remaining nested transaction savepoints.
// Some of the backend transactions may be in an unclear state; caller is expected to not use this transaction any more.
backend_savepoint.rollback().await?;
}
Ok(())
}
}
#[derive(Clone)]
pub(crate) struct TransactionContext {
inner: Arc<InnerTransactionContext>,
}
struct InnerTransactionContext {
implicit: ImplicitOrExplicit,
aborted: Trigger,
transactions: HashMap<Box<dyn Backend>, Box<dyn Transaction>>,
/// Blocking mutex, held only briefly for pure data manipulation with no I/O or such.
savepoints:
parking_lot::Mutex<SmallVec<[SavepointState; kanto_tunables::TYPICAL_MAX_TRANSACTIONS]>>,
}
impl TransactionContext {
pub(crate) fn new(
implicit: ImplicitOrExplicit,
transactions: HashMap<Box<dyn Backend>, Box<dyn Transaction>>,
) -> TransactionContext {
let inner = Arc::new(InnerTransactionContext {
implicit,
aborted: Trigger::new(),
transactions,
savepoints: parking_lot::Mutex::new(SmallVec::new()),
});
TransactionContext { inner }
}
pub(crate) fn set_aborted(&self) {
self.inner.aborted.trigger();
}
pub(crate) fn is_aborted(&self) -> bool {
self.inner.aborted.get()
}
pub(crate) fn is_implicit(&self) -> bool {
match self.inner.implicit {
ImplicitOrExplicit::Implicit => true,
ImplicitOrExplicit::Explicit => false,
}
}
pub(crate) fn get_transaction<'a>(
&'a self,
backend: &dyn Backend,
) -> Result<&'a dyn Transaction, KantoError> {
let transaction = self
.inner
.transactions
.get(backend)
.ok_or(KantoError::Internal {
code: "jnrmoeu8rpqow",
error: "backend had no transaction".into(),
})?;
Ok(transaction.as_ref())
}
pub(crate) fn transactions(&self) -> impl Iterator<Item = (&dyn Backend, &dyn Transaction)> {
self.inner
.transactions
.iter()
.map(|(backend, tx)| (backend.as_ref(), tx.as_ref()))
}
pub(crate) async fn savepoint(
&self,
savepoint_name: Option<sqlparser::ast::Ident>,
) -> Result<(), KantoError> {
let savepoints = futures_lite::stream::iter(&self.inner.transactions)
.then(async |(backend, tx)| {
let savepoint = tx.savepoint().await?;
Ok((backend.clone(), savepoint))
})
.try_collect()
.await?;
let kind = if let Some(name) = savepoint_name {
SavepointKind::NamedSavepoint { name }
} else {
SavepointKind::NestedTx
};
let savepoint_state = SavepointState {
kind,
savepoints,
settings: Settings::default(),
};
let mut guard = self.inner.savepoints.lock();
guard.push(savepoint_state);
Ok(())
}
pub(crate) fn pop_to_last_nested_tx(&self) -> Option<SavepointState> {
let mut guard = self.inner.savepoints.lock();
guard
.iter()
.rposition(|savepoint| matches!(savepoint.kind, SavepointKind::NestedTx))
.and_then(|idx| guard.drain(idx..).next())
}
pub(crate) fn pop_to_savepoint(
&self,
savepoint_name: &sqlparser::ast::Ident,
) -> Option<SavepointState> {
let mut guard = self.inner.savepoints.lock();
guard
.iter()
.rposition(|savepoint| {
matches!(&savepoint.kind, SavepointKind::NamedSavepoint { name }
if name == savepoint_name)
})
.and_then(|idx| guard.drain(idx..).next())
}
pub(crate) fn get_setting<T>(&self, getter: impl Fn(&Settings) -> Option<T>) -> Option<T> {
let savepoints = self.inner.savepoints.lock();
let local_var = savepoints
.iter()
.rev()
.find_map(|savepoint_state| getter(&savepoint_state.settings));
local_var
}
}

16
crates/kanto/src/util.rs Normal file
View file

@ -0,0 +1,16 @@
//! Utility functions for implementing [`kanto::Backend`](crate::Backend).
use std::sync::Arc;
mod sql_type_to_field_type;
pub(crate) mod trigger;
pub use self::sql_type_to_field_type::sql_type_to_field_type;
/// Return a stream that has no rows and an empty schema.
#[must_use]
pub fn no_rows_empty_schema() -> datafusion::physical_plan::SendableRecordBatchStream {
let schema = Arc::new(datafusion::arrow::datatypes::Schema::empty());
let data = datafusion::physical_plan::EmptyRecordBatchStream::new(schema);
Box::pin(data)
}

View file

@ -0,0 +1,265 @@
//! Utility functions for implementing [`kanto::Backend`](crate::Backend).
use datafusion::sql::sqlparser;
use kanto_meta_format_v1::field_type::FieldType;
use crate::KantoError;
/// Convert a SQL data type to a KantoDB field type.
#[expect(clippy::too_many_lines, reason = "large match")]
pub fn sql_type_to_field_type(
sql_data_type: &sqlparser::ast::DataType,
) -> Result<FieldType, KantoError> {
// TODO this duplicates unexported `convert_data_type` called by [`SqlToRel::build_schema`](https://docs.rs/datafusion/latest/datafusion/sql/planner/struct.SqlToRel.html#method.build_schema), that one does sql->dftype
use sqlparser::ast::DataType as SqlType;
match sql_data_type {
SqlType::Character(_) | SqlType::Char(_) | SqlType::FixedString(_) => {
Err(KantoError::UnimplementedSql {
code: "hdhf8ygp9zna4",
sql_syntax: "type { CHARACTER | CHAR | FIXEDSTRING }",
})
}
SqlType::CharacterVarying(_)
| SqlType::CharVarying(_)
| SqlType::Varchar(_)
| SqlType::Nvarchar(_)
| SqlType::Text
| SqlType::String(_)
| SqlType::TinyText
| SqlType::MediumText
| SqlType::LongText => Ok(FieldType::String),
SqlType::Uuid => Err(KantoError::UnimplementedSql {
code: "531ax7gb73ce1",
sql_syntax: "type UUID",
}),
SqlType::CharacterLargeObject(_length)
| SqlType::CharLargeObject(_length)
| SqlType::Clob(_length) => Err(KantoError::UnimplementedSql {
code: "ew7uhufhkzj9w",
sql_syntax: "type { CHARACTER LARGE OBJECT | CHAR LARGE OBJECT | CLOB }",
}),
SqlType::Binary(_length) => Err(KantoError::UnimplementedSql {
code: "i16wqwmie17eg",
sql_syntax: "type BINARY",
}),
SqlType::Varbinary(_)
| SqlType::Blob(_)
| SqlType::Bytes(_)
| SqlType::Bytea
| SqlType::TinyBlob
| SqlType::MediumBlob
| SqlType::LongBlob => Ok(FieldType::Binary),
SqlType::Numeric(_) | SqlType::Decimal(_) | SqlType::Dec(_) => {
Err(KantoError::UnimplementedSql {
code: "uoakmzwax448h",
sql_syntax: "type { NUMERIC | DECIMAL | DEC }",
})
}
SqlType::BigNumeric(_) | SqlType::BigDecimal(_) => Err(KantoError::UnimplementedSql {
code: "an57g4d4uzwew",
sql_syntax: "type { BIGNUMERIC | BIGDECIMAL } (BigQuery syntax)",
}),
SqlType::Float(Some(precision)) if (1..=24).contains(precision) => Ok(FieldType::F32),
SqlType::Float(None)
| SqlType::Double(sqlparser::ast::ExactNumberInfo::None)
| SqlType::DoublePrecision
| SqlType::Float64
| SqlType::Float8 => Ok(FieldType::F64),
SqlType::Float(Some(precision)) if (25..=53).contains(precision) => Ok(FieldType::F64),
SqlType::Float(Some(_precision)) => Err(KantoError::UnimplementedSql {
code: "xfwbr486jaw5q",
sql_syntax: "type FLOAT(precision)",
}),
SqlType::TinyInt(None) => Ok(FieldType::I8),
SqlType::TinyInt(Some(_)) => Err(KantoError::UnimplementedSql {
code: "hueii8mjgk8qq",
sql_syntax: "type TINYINT(precision)",
}),
SqlType::UnsignedTinyInt(None) | SqlType::UInt8 => Ok(FieldType::U8),
SqlType::UnsignedTinyInt(Some(_)) => Err(KantoError::UnimplementedSql {
code: "ewt6tneh7ecps",
sql_syntax: "type UNSIGNED TINYINT(precision)",
}),
SqlType::Int2(None) | SqlType::SmallInt(None) | SqlType::Int16 => Ok(FieldType::I16),
SqlType::UnsignedInt2(None) | SqlType::UnsignedSmallInt(None) | SqlType::UInt16 => {
Ok(FieldType::U16)
}
SqlType::Int2(Some(_))
| SqlType::UnsignedInt2(Some(_))
| SqlType::SmallInt(Some(_))
| SqlType::UnsignedSmallInt(Some(_)) => Err(KantoError::UnimplementedSql {
code: "g57gd8499t4hg",
sql_syntax: "type { (UNSIGNED) INT2(precision) | (UNSIGNED) SMALLINT(precision) }",
}),
SqlType::MediumInt(_) | SqlType::UnsignedMediumInt(_) => {
Err(KantoError::UnimplementedSql {
code: "m7snzsrqj68sw",
sql_syntax: "type { MEDIUMINT (UNSIGNED) }",
})
}
SqlType::Int(None) | SqlType::Integer(None) | SqlType::Int4(None) | SqlType::Int32 => {
Ok(FieldType::I32)
}
SqlType::Int(Some(_)) | SqlType::Integer(Some(_)) | SqlType::Int4(Some(_)) => {
Err(KantoError::UnimplementedSql {
code: "tiptbb5d8buxa",
sql_syntax: "type { INTEGER(precision) | INT(precision) | INT4(precision) }",
})
}
SqlType::Int8(None) | SqlType::Int64 | SqlType::BigInt(None) => Ok(FieldType::I64),
SqlType::Int128 => Err(KantoError::UnimplementedSql {
code: "4sg3mftgxu9j4",
sql_syntax: "type INT128 (ClickHouse syntax)",
}),
SqlType::Int256 => Err(KantoError::UnimplementedSql {
code: "ex9mnxxipbqqs",
sql_syntax: "type INT256 (ClickHouse syntax)",
}),
SqlType::UnsignedInt(None)
| SqlType::UnsignedInt4(None)
| SqlType::UnsignedInteger(None)
| SqlType::UInt32 => Ok(FieldType::U32),
SqlType::UnsignedInt(Some(_))
| SqlType::UnsignedInt4(Some(_))
| SqlType::UnsignedInteger(Some(_)) => Err(KantoError::UnimplementedSql {
code: "fzt1q8mfgkr6n",
sql_syntax: "type { UNSIGNED INTEGER(precision) | UNSIGNED INT(precision) | UNSIGNED INT4(precision) }",
}),
SqlType::UInt64 | SqlType::UnsignedBigInt(None) | SqlType::UnsignedInt8(None) => {
Ok(FieldType::U64)
}
SqlType::UInt128 => Err(KantoError::UnimplementedSql {
code: "riob7j6jahpte",
sql_syntax: "type UINT128 (ClickHouse syntax)",
}),
SqlType::UInt256 => Err(KantoError::UnimplementedSql {
code: "q6tzjzd7xf3ga",
sql_syntax: "type UINT256 (ClickHouse syntax)",
}),
SqlType::BigInt(Some(_)) | SqlType::Int8(Some(_)) => Err(KantoError::UnimplementedSql {
code: "88fzntq3mz14g",
sql_syntax: "type { BIGINT(precision) | INT8(precision) }",
}),
SqlType::UnsignedBigInt(Some(_)) | SqlType::UnsignedInt8(Some(_)) => {
Err(KantoError::UnimplementedSql {
code: "sze33j7qzt91s",
sql_syntax: "type { UNSIGNED BIGINT(precision) | UNSIGNED INT8(precision) }",
})
}
SqlType::Float4 | SqlType::Float32 | SqlType::Real => Ok(FieldType::F32),
SqlType::Double(sqlparser::ast::ExactNumberInfo::Precision(precision))
if (25..=53).contains(precision) =>
{
Ok(FieldType::F64)
}
SqlType::Double(sqlparser::ast::ExactNumberInfo::Precision(_precision)) => {
Err(KantoError::UnimplementedSql {
code: "wo43sei9ubpon",
sql_syntax: "type DOUBLE(precision)",
})
}
SqlType::Double(sqlparser::ast::ExactNumberInfo::PrecisionAndScale(_precision, _scale)) => {
Err(KantoError::NotSupportedSql {
code: "zgxdduxydzpoq",
sql_syntax: "type DOUBLE(precision, scale) (MySQL syntax, deprecated)",
})
}
SqlType::Bool | SqlType::Boolean => Ok(FieldType::Boolean),
SqlType::Date => Err(KantoError::UnimplementedSql {
code: "1r5ez1z8j7ryo",
sql_syntax: "type DATE",
}),
SqlType::Date32 => Err(KantoError::UnimplementedSql {
code: "781y5fxih7pnn",
sql_syntax: "type DATE32",
}),
SqlType::Time(_, _) => Err(KantoError::UnimplementedSql {
code: "bppzhg7ck7xhr",
sql_syntax: "type TIME",
}),
SqlType::Datetime(_) => Err(KantoError::UnimplementedSql {
code: "dojz85ngwo5wr",
sql_syntax: "type DATETIME (MySQL syntax)",
}),
SqlType::Datetime64(_, _) => Err(KantoError::UnimplementedSql {
code: "oak3s5g3rnutq",
sql_syntax: "type DATETIME (ClickHouse syntax)",
}),
SqlType::Timestamp(_, _) => Err(KantoError::UnimplementedSql {
code: "njad51sbxzwnn",
sql_syntax: "type TIMESTAMP",
}),
SqlType::Interval => Err(KantoError::UnimplementedSql {
code: "5wxynwyfua69s",
sql_syntax: "type INTERVAL",
}),
SqlType::JSON | SqlType::JSONB => Err(KantoError::UnimplementedSql {
code: "abcw878jr35qc",
sql_syntax: "type { JSON | JSONB }",
}),
SqlType::Regclass => Err(KantoError::UnimplementedSql {
code: "quzjzcsusttgc",
sql_syntax: "type REGCLASS (Postgres syntax)",
}),
SqlType::Bit(_) | SqlType::BitVarying(_) => Err(KantoError::UnimplementedSql {
code: "88uneott9owc1",
sql_syntax: "type BIT (VARYING)",
}),
SqlType::Custom(..) => Err(KantoError::UnimplementedSql {
code: "dt315zutjqpzc",
sql_syntax: "custom type",
}),
SqlType::Array(_) => Err(KantoError::UnimplementedSql {
code: "hnmqj3qbtutn4",
sql_syntax: "type ARRAY",
}),
SqlType::Map(_, _) => Err(KantoError::UnimplementedSql {
code: "7ts8zgnafnp7g",
sql_syntax: "type MAP",
}),
SqlType::Tuple(_) => Err(KantoError::UnimplementedSql {
code: "x7qwkwdhdznxe",
sql_syntax: "type TUPLE",
}),
SqlType::Nested(_) => Err(KantoError::UnimplementedSql {
code: "fa6nofmo1s49g",
sql_syntax: "type NESTED",
}),
SqlType::Enum(_, _) => Err(KantoError::UnimplementedSql {
code: "1geqh5qxdeoko",
sql_syntax: "type ENUM",
}),
SqlType::Set(_) => Err(KantoError::UnimplementedSql {
code: "hm53sq3gco3sw",
sql_syntax: "type SET",
}),
SqlType::Struct(_, _) => Err(KantoError::UnimplementedSql {
code: "nsapujxjjk9hg",
sql_syntax: "type STRUCT",
}),
SqlType::Union(_) => Err(KantoError::UnimplementedSql {
code: "oownfhnfj5b9r",
sql_syntax: "type UNION",
}),
SqlType::Nullable(_) => Err(KantoError::UnimplementedSql {
code: "ra5kspnstauu6",
sql_syntax: "type NULLABLE (ClickHouse syntax)",
}),
SqlType::LowCardinality(_) => Err(KantoError::UnimplementedSql {
code: "78exey8xjz3yk",
sql_syntax: "type LOWCARDINALITY",
}),
SqlType::Unspecified => Err(KantoError::UnimplementedSql {
code: "916n48obgiqh6",
sql_syntax: "type UNSPECIFIED",
}),
SqlType::Trigger => Err(KantoError::UnimplementedSql {
code: "zf5jx1ykoc5s1",
sql_syntax: "type TRIGGER",
}),
SqlType::AnyType => Err(KantoError::NotSupportedSql {
code: "bwjrzsbzg59gy",
sql_syntax: "type ANY TYPE (BigQuery syntax)",
}),
}
}

View file

@ -0,0 +1,29 @@
/// Trigger keeps track of whether something happened.
/// Once triggered, it will always remain triggered.
pub(crate) struct Trigger {
triggered: std::sync::atomic::AtomicBool,
}
impl Trigger {
#[must_use]
pub(crate) const fn new() -> Trigger {
Trigger {
triggered: std::sync::atomic::AtomicBool::new(false),
}
}
pub(crate) fn trigger(&self) {
self.triggered
.store(true, std::sync::atomic::Ordering::Relaxed);
}
pub(crate) fn get(&self) -> bool {
self.triggered.load(std::sync::atomic::Ordering::Relaxed)
}
}
impl Default for Trigger {
fn default() -> Self {
Trigger::new()
}
}

View file

@ -0,0 +1,626 @@
#![expect(
clippy::print_stderr,
clippy::unimplemented,
clippy::unwrap_used,
clippy::expect_used,
reason = "test code"
)]
use std::collections::HashSet;
use std::ffi::OsStr;
use std::path::Path;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Duration;
use futures_lite::stream::StreamExt as _;
use kanto::arrow::array::AsArray as _;
struct TestDatabase(Arc<tokio::sync::RwLock<kanto::Session>>);
impl sqllogictest::MakeConnection for TestDatabase {
type Conn = TestSession;
type MakeFuture =
std::future::Ready<Result<TestSession, <Self::Conn as sqllogictest::AsyncDB>::Error>>;
#[maybe_tracing::instrument(skip(self))]
fn make(&mut self) -> Self::MakeFuture {
std::future::ready(Ok(TestSession(self.0.clone())))
}
}
struct TestSession(Arc<tokio::sync::RwLock<kanto::Session>>);
fn convert_f64_to_sqllogic_string(
f: f64,
expectation_type: &sqllogictest::DefaultColumnType,
) -> String {
match expectation_type {
// Sqlite's tests use implicit conversion to integer by marking result columns as integers.
sqllogictest::DefaultColumnType::Integer => {
let s = format!("{f:.0}");
if s == "-0" {
"0".to_owned()
} else {
s
}
}
sqllogictest::DefaultColumnType::FloatingPoint | sqllogictest::DefaultColumnType::Any => {
// 3 digits of precision is what SQLite sqllogictest seems to expect.
let s = format!("{f:.3}");
if s == "-0.000" {
"0.000".to_owned()
} else {
s
}
}
sqllogictest::DefaultColumnType::Text => {
format!("<unexpected SLT column expectation for Float64: {expectation_type:?}>")
}
}
}
fn convert_binary_to_sqllogic_strings<'arr>(
iter: impl Iterator<Item = Option<&'arr [u8]>> + 'arr,
) -> Box<dyn Iterator<Item = String> + 'arr> {
let iter = iter.map(|opt_bytes| {
opt_bytes
.map(|b| {
let mut hex = String::new();
hex.push_str(r"X'");
hex.push_str(&hex::encode(b));
hex.push('\'');
hex
})
.unwrap_or_else(|| "NULL".to_owned())
});
let as_dyn: Box<dyn Iterator<Item = String>> = Box::new(iter);
as_dyn
}
#[expect(clippy::too_many_lines, reason = "large match")]
fn convert_array_to_sqllogic_strings<'arr>(
array: &'arr Arc<dyn datafusion::arrow::array::Array>,
expectation_type: sqllogictest::DefaultColumnType,
data_type: &datafusion::arrow::datatypes::DataType,
) -> Box<dyn Iterator<Item = String> + 'arr> {
match data_type {
datafusion::arrow::datatypes::DataType::Int8 => {
let array = array
.as_primitive_opt::<kanto::arrow::datatypes::Int8Type>()
.unwrap();
let iter = array.iter().map(|opt_num| {
opt_num
.map(|num| num.to_string())
.unwrap_or_else(|| "NULL".to_owned())
});
let as_dyn: Box<dyn Iterator<Item = String>> = Box::new(iter);
as_dyn
}
datafusion::arrow::datatypes::DataType::UInt8 => {
let array = array
.as_primitive_opt::<kanto::arrow::datatypes::UInt8Type>()
.unwrap();
let iter = array.iter().map(|opt_num| {
opt_num
.map(|num| num.to_string())
.unwrap_or_else(|| "NULL".to_owned())
});
let as_dyn: Box<dyn Iterator<Item = String>> = Box::new(iter);
as_dyn
}
datafusion::arrow::datatypes::DataType::Int32 => {
let array = array
.as_primitive_opt::<kanto::arrow::datatypes::Int32Type>()
.unwrap();
let iter = array.iter().map(|opt_num| {
opt_num
.map(|num| num.to_string())
.unwrap_or_else(|| "NULL".to_owned())
});
let as_dyn: Box<dyn Iterator<Item = String>> = Box::new(iter);
as_dyn
}
datafusion::arrow::datatypes::DataType::UInt32 => {
let array = array
.as_primitive_opt::<kanto::arrow::datatypes::UInt32Type>()
.unwrap();
let iter = array.iter().map(|opt_num| {
opt_num
.map(|num| num.to_string())
.unwrap_or_else(|| "NULL".to_owned())
});
let as_dyn: Box<dyn Iterator<Item = String>> = Box::new(iter);
as_dyn
}
datafusion::arrow::datatypes::DataType::Int64 => {
let array = array
.as_primitive_opt::<kanto::arrow::datatypes::Int64Type>()
.unwrap();
let iter = array.iter().map(|opt_num| {
opt_num
.map(|num| num.to_string())
.unwrap_or_else(|| "NULL".to_owned())
});
let as_dyn: Box<dyn Iterator<Item = String>> = Box::new(iter);
as_dyn
}
datafusion::arrow::datatypes::DataType::UInt64 => {
let array = array
.as_primitive_opt::<kanto::arrow::datatypes::UInt64Type>()
.unwrap();
let iter = array.iter().map(|opt_num| {
opt_num
.map(|num| num.to_string())
.unwrap_or_else(|| "NULL".to_owned())
});
let as_dyn: Box<dyn Iterator<Item = String>> = Box::new(iter);
as_dyn
}
datafusion::arrow::datatypes::DataType::Float32 => {
let array = array
.as_primitive_opt::<kanto::arrow::datatypes::Float32Type>()
.unwrap();
let iter = array.iter().map(move |opt_num| {
opt_num
.map(|f| convert_f64_to_sqllogic_string(f64::from(f), &expectation_type))
.unwrap_or_else(|| "NULL".to_owned())
});
let as_dyn: Box<dyn Iterator<Item = String>> = Box::new(iter);
as_dyn
}
datafusion::arrow::datatypes::DataType::Float64 => {
let array = array
.as_primitive_opt::<kanto::arrow::datatypes::Float64Type>()
.unwrap();
let iter = array.iter().map(move |opt_num| {
opt_num
.map(|f| convert_f64_to_sqllogic_string(f, &expectation_type))
.unwrap_or_else(|| "NULL".to_owned())
});
let as_dyn: Box<dyn Iterator<Item = String>> = Box::new(iter);
as_dyn
}
datafusion::arrow::datatypes::DataType::Utf8 => {
let array = array.as_string_opt::<i32>().unwrap();
let iter = array.iter().map(|opt_str| {
opt_str
.map(|s| {
if s.is_empty() {
"(empty)".to_owned()
} else {
let printable = b' '..b'~';
let t = s
.bytes()
.map(|c| {
if printable.contains(&c) {
c.into()
} else {
'@'
}
})
.collect::<String>();
t
}
})
.unwrap_or_else(|| "NULL".to_owned())
});
let as_dyn: Box<dyn Iterator<Item = String>> = Box::new(iter);
as_dyn
}
datafusion::arrow::datatypes::DataType::Utf8View => {
let array = array.as_string_view_opt().unwrap();
let iter = array.iter().map(|opt_str| {
opt_str
.map(|s| {
if s.is_empty() {
"(empty)".to_owned()
} else {
let printable = b' '..b'~';
let t = s
.bytes()
.map(|c| {
if printable.contains(&c) {
c.into()
} else {
'@'
}
})
.collect::<String>();
t
}
})
.unwrap_or_else(|| "NULL".to_owned())
});
let as_dyn: Box<dyn Iterator<Item = String>> = Box::new(iter);
as_dyn
}
datafusion::arrow::datatypes::DataType::Binary => {
let array = array.as_binary_opt::<i32>().unwrap();
convert_binary_to_sqllogic_strings(array.iter())
}
datafusion::arrow::datatypes::DataType::BinaryView => {
let array = array.as_binary_view_opt().unwrap();
convert_binary_to_sqllogic_strings(array.iter())
}
datafusion::arrow::datatypes::DataType::Boolean => {
let array = array.as_boolean_opt().unwrap();
let iter = array.iter().map(|opt_num| {
opt_num
.map(|num| num.to_string())
.unwrap_or_else(|| "NULL".to_owned())
});
let as_dyn: Box<dyn Iterator<Item = String>> = Box::new(iter);
as_dyn
}
_ => unimplemented!("sqllogictest: data type not implemented: {data_type:?}"),
}
}
fn is_hidden(entry: &walkdir::DirEntry) -> bool {
entry.file_name().to_string_lossy().starts_with('.')
}
#[async_trait::async_trait]
impl sqllogictest::AsyncDB for TestSession {
type Error = datafusion::error::DataFusionError;
type ColumnType = sqllogictest::DefaultColumnType;
#[maybe_tracing::instrument(span_name = "TestSession::run", skip(self), err)]
async fn run(
&mut self,
sql: &str,
) -> Result<sqllogictest::DBOutput<Self::ColumnType>, Self::Error> {
tracing::trace!("acquiring session lock");
let mut guard = self.0.write().await;
tracing::trace!("running sql");
let stream = guard.sql(sql).await?;
tracing::trace!("done with sql");
let types: Vec<Self::ColumnType> = stream
.schema()
.fields()
.iter()
.map(|field| {
// TODO this is wrong, we need the datatype formatting info from the `*.slt` <https://github.com/risinglightdb/sqllogictest-rs/issues/227> #ecosystem/sqllogictest-rs
match field.data_type() {
datafusion::arrow::datatypes::DataType::Int32
| datafusion::arrow::datatypes::DataType::Int64
| datafusion::arrow::datatypes::DataType::UInt32
| datafusion::arrow::datatypes::DataType::UInt64 => {
sqllogictest::column_type::DefaultColumnType::Integer
}
// TODO recognize more data types
_ => sqllogictest::column_type::DefaultColumnType::Any,
}
})
.collect();
let mut rows: Vec<Vec<String>> = Vec::new();
tracing::trace!("collecting");
let batches = stream.try_collect::<_, _, Vec<_>>().await?;
tracing::trace!("collect done");
for batch in batches {
let mut value_per_column = batch
.columns()
.iter()
.enumerate()
.map(|(column_idx, array)| {
debug_assert!(column_idx < types.len());
#[expect(clippy::indexing_slicing, reason = "test code")]
let type_ = types[column_idx].clone();
convert_array_to_sqllogic_strings(
array,
type_,
batch.schema().field(column_idx).data_type(),
)
})
.collect::<Vec<_>>();
loop {
let row: Vec<_> = value_per_column
.iter_mut()
.map_while(Iterator::next)
.collect();
if row.is_empty() {
break;
}
rows.push(row);
}
}
tracing::trace!("prepping output");
let output = sqllogictest::DBOutput::Rows { types, rows };
drop(guard);
Ok(output)
}
fn engine_name(&self) -> &'static str {
"kanto"
}
#[maybe_tracing::instrument()]
async fn sleep(dur: Duration) {
let _wat: () = tokio::time::sleep(dur).await;
}
#[maybe_tracing::instrument(skip(_command), ret, err(level=tracing::Level::WARN))]
async fn run_command(_command: std::process::Command) -> std::io::Result<std::process::Output> {
unimplemented!("security hazard");
}
async fn shutdown(&mut self) -> () {
// nothing
}
}
fn is_datafusion_limitation(error: &dyn std::error::Error) -> Option<&'static str> {
// TODO i'm having serious trouble unpacking this `dyn Error`.
// do ugly string matching instead.
//
// TODO note these as ecosystem limitations, report upstream
let datafusion_limitations = [
(
// TODO DataFusion non-unique column names <https://github.com/apache/arrow-datafusion/issues/6543>, <https://github.com/apache/arrow-datafusion/issues/8379> #ecosystem/datafusion #waiting #severity/high #urgency/medium
"duplicate column names",
regex::Regex::new(r"Projections require unique expression names but the expression .* the same name").unwrap()
),
(
// TODO DataFusion non-unique column names <https://github.com/apache/arrow-datafusion/issues/6543>, <https://github.com/apache/arrow-datafusion/issues/8379> #ecosystem/datafusion #waiting #severity/high #urgency/medium
"duplicate column names (qual vs unqual)",
regex::Regex::new(r"Schema error: Schema contains qualified field name .* and unqualified field name .* which would be ambiguous").unwrap(),
),
(
"convert to null",
regex::Regex::new(r"Error during planning: Cannot automatically convert .* to Null").unwrap(),
),
(
// TODO DataFusion does not support `avg(DISTINCT)`: <https://github.com/apache/arrow-datafusion/issues/2408> #ecosystem/datafusion #waiting #severity/medium #urgency/low
"AVG(DISTINCT)",
regex::Regex::new(r"Execution error: avg\(DISTINCT\) aggregations are not available").unwrap(),
),
(
"non-aggregate values in projection",
regex::Regex::new(r"Error during planning: Projection references non-aggregate values: Expression .* could not be resolved from available columns: .*").unwrap(),
),
(
"AVG(NULL)",
regex::Regex::new(r"Error during planning: No function matches the given name and argument types 'AVG\(Null\)'. You might need to add explicit type casts.").unwrap(),
),
(
"Correlated column is not allowed in predicate",
regex::Regex::new(r"Error during planning: Correlated column is not allowed in predicate:").unwrap(),
),
(
"HAVING references non-aggregate",
regex::Regex::new(r"Error during planning: HAVING clause references non-aggregate values: Expression .* could not be resolved from available columns: .*").unwrap(),
),
(
"scalar subquery",
regex::Regex::new(r"This feature is not implemented: Physical plan does not support logical expression ScalarSubquery\(<subquery>\)").unwrap(),
),
];
let error_message = error.to_string();
datafusion_limitations
.iter()
.find(|(_explanation, re)| re.is_match(&error_message))
.map(|(explanation, _re)| *explanation)
}
#[maybe_tracing::instrument(skip(tester), ret)]
fn execute_sqllogic_test<D: sqllogictest::AsyncDB, M: sqllogictest::MakeConnection<Conn = D>>(
tester: &mut sqllogictest::Runner<D, M>,
path: &Path,
kludge_datafusion_limitations: bool,
) -> Result<(), libtest_mimic::Failed> {
let script = std::fs::read_to_string(path)?;
let script = {
// TODO sqllogictest-rs fails to parse comment at end of line <https://github.com/risinglightdb/sqllogictest-rs/issues/122> #waiting #ecosystem/sqllogictest-rs
let trailing_comment = regex::Regex::new(r"(?mR)#.*$").unwrap();
trailing_comment.replace_all(&script, "")
};
let name = path.as_os_str().to_string_lossy();
tester
.run_script_with_name(&script, Arc::from(name))
.or_else(|error| match error.kind() {
sqllogictest::TestErrorKind::Fail {
sql: _,
err,
kind: sqllogictest::RecordKind::Query,
} => {
if kludge_datafusion_limitations {
if let Some(explanation) = is_datafusion_limitation(&err) {
eprintln!("Ignoring known DataFusion limitation: {explanation}");
return Ok(());
}
}
Err(error)
}
sqllogictest::TestErrorKind::QueryResultMismatch {
sql: _,
expected,
actual,
} => {
{
// TODO hash mismatch likely due to expected column type API is wrong way around: <https://github.com/risinglightdb/sqllogictest-rs/issues/227> #ecosystem/sqllogictest-rs #severity/high #urgency/medium
let re = regex::Regex::new(r"^\d+ values hashing to [0-9a-f]+$").unwrap();
if re.is_match(&expected) && re.is_match(&actual) {
return Ok(());
}
}
Err(error)
}
_ => Err(error),
})?;
Ok(())
}
// kludge around <https://github.com/risinglightdb/sqllogictest-rs/issues/108>
#[maybe_tracing::instrument(skip(_normalizer), ret)]
fn sql_result_validator(
_normalizer: sqllogictest::Normalizer,
actual: &[Vec<String>],
expected: &[String],
) -> bool {
// TODO expected column type API is wrong way around: <https://github.com/risinglightdb/sqllogictest-rs/issues/227>
fn kludge_compare<S: AsRef<str>>(got: &[S], want: &[String]) -> bool {
fn kludge_string_compare<S: AsRef<str>>(got: S, want: &str) -> bool {
let got = got.as_ref();
got == want || got == format!("{want}.000")
}
got.len() == want.len()
&& got
.iter()
.zip(want)
.all(|(got, want)| kludge_string_compare(got, want))
}
let column_wise = || {
let actual: Vec<_> = actual.iter().map(|v| v.join(" ")).collect();
tracing::trace!(?actual, "column-wise joined");
actual
};
let value_wise = || {
// SQLite "value-wise" compatibility.
let actual: Vec<_> = actual.iter().flatten().map(String::as_str).collect();
tracing::trace!(?actual, "value-wise flattened");
actual
};
kludge_compare(&value_wise(), expected) || kludge_compare(&column_wise()[..], expected)
}
fn uses_create_table_primary_key(path: &Path) -> bool {
// TODO `CREATE TABLE with column option { PRIMARY KEY | UNIQUE }` (#ref/unimplemented_sql/hwcjakao83bue) #severity/high #urgency/medium
let re_oneline =
regex::Regex::new(r"(?m)^CREATE TABLE tab\d+\(pk INTEGER PRIMARY KEY,").unwrap();
let re_wrapped =
regex::Regex::new(r"(?m)^CREATE TABLE t\d+\(\n\s+[a-z0-9]+ INTEGER PRIMARY KEY,").unwrap();
let script = std::fs::read_to_string(path).unwrap();
re_oneline.is_match(&script) || re_wrapped.is_match(&script)
}
fn uses_create_view(path: &Path) -> bool {
// TODO `CREATE VIEW` (#ref/unimplemented_sql/b9m5uhu9pnnsw)
let re = regex::Regex::new(r"(?m)^CREATE VIEW ").unwrap();
let script = std::fs::read_to_string(path).unwrap();
re.is_match(&script)
}
pub(crate) fn setup_tests(
slt_root: impl AsRef<Path>,
slt_ext: &str,
kludge_datafusion_limitations: bool,
kludge_sqlite_tests: bool,
) -> Vec<libtest_mimic::Trial> {
let known_broken = HashSet::from([
"rocksdb/evidence/slt_lang_update", // TODO cannot update row_id tables yet (#ref/unimplemented/i61jpodrazscs)
"rocksdb/evidence/slt_lang_droptable", // TODO `DROP TABLE` (#ref/unimplemented_sql/47hnz51gohsx6)
"rocksdb/evidence/slt_lang_dropindex", // TODO `DROP INDEX` (#ref/unimplemented_sql/4kgkg4jhqhfrw)
]);
find_slt_files(&slt_root, slt_ext)
.flat_map(|path| {
[("rocksdb", kanto_backend_rocksdb::create_temp)].map(|(backend_name, backend_fn)| {
let path_pretty = path
.strip_prefix(&slt_root)
.expect("found slt file is not under root")
.with_extension("")
.to_string_lossy()
.into_owned();
let name = format!("{backend_name}/{path_pretty}");
libtest_mimic::Trial::test(&name, {
let path = path.clone();
move || {
let (_dir, backend) = backend_fn().expect("cannot setup backend");
let backend = Box::new(backend);
test_sql_session(
path,
backend_name,
backend,
kludge_datafusion_limitations,
if kludge_sqlite_tests {
sqllogictest::ResultMode::ValueWise
} else {
sqllogictest::ResultMode::RowWise
},
)
}
})
.with_ignored_flag(
kludge_sqlite_tests
&& (known_broken.contains(name.as_str())
|| uses_create_table_primary_key(&path)
|| uses_create_view(&path)),
)
})
})
.collect::<Vec<_>>()
}
pub(crate) fn test_sql_session(
path: impl AsRef<Path>,
backend_name: &str,
backend: Box<dyn kanto::Backend>,
kludge_datafusion_limitations: bool,
result_mode: sqllogictest::ResultMode,
) -> Result<(), libtest_mimic::Failed> {
// `test_log::test` doesn't work because of the function argument, set up tracing manually
{
let subscriber = ::tracing_subscriber::FmtSubscriber::builder()
.with_env_filter(tracing_subscriber::EnvFilter::new(
std::env::var("RUST_LOG")
.unwrap_or_else(|_| format!("kanto=debug,kanto_backend_{backend_name}=debug")),
))
.compact()
.with_test_writer()
.finish();
let _ignore_error = tracing::subscriber::set_global_default(subscriber);
}
let reactor = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap();
// actual test logic
let session = kanto::Session::test_session(backend);
let test_database = TestDatabase(Arc::new(tokio::sync::RwLock::new(session)));
let mut tester = sqllogictest::Runner::new(test_database);
tester.add_label("datafusion");
tester.add_label("kantodb");
tester.with_hash_threshold(8);
tester.with_validator(sql_result_validator);
reactor.block_on(async {
{
let output = tester
.apply_record(sqllogictest::Record::Control(
sqllogictest::Control::ResultMode(result_mode),
))
.await;
debug_assert!(matches!(output, sqllogictest::RecordOutput::Nothing));
}
execute_sqllogic_test(&mut tester, path.as_ref(), kludge_datafusion_limitations)
})?;
drop(tester);
reactor.shutdown_background();
Ok(())
}
pub(crate) fn find_slt_files<P, E>(
input_dir: P,
input_extension: E,
) -> impl Iterator<Item = PathBuf>
where
P: AsRef<Path>,
E: AsRef<OsStr>,
{
let input_extension = input_extension.as_ref().to_owned();
walkdir::WalkDir::new(&input_dir)
.sort_by_file_name()
.into_iter()
.filter_entry(|entry| !is_hidden(entry))
.map(|result| result.expect("read dir"))
.filter(|entry| entry.file_type().is_file())
.map(walkdir::DirEntry::into_path)
.filter(move |path| path.extension().is_some_and(|ext| ext == input_extension))
}

View file

@ -0,0 +1,16 @@
#![expect(
missing_docs,
reason = "rustc lint bug <https://github.com/rust-lang/rust/issues/137561> #ecosystem/rust #waiting"
)]
use std::path::PathBuf;
mod slt_util;
fn main() {
let args = libtest_mimic::Arguments::from_args();
let input_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("../../testdata/sql");
let tests = slt_util::setup_tests(input_dir, "slt", false, false);
libtest_mimic::run(&args, tests).exit();
}

View file

@ -0,0 +1,33 @@
#![expect(
missing_docs,
reason = "rustc lint bug <https://github.com/rust-lang/rust/issues/137561> #ecosystem/rust #waiting"
)]
use std::path::Path;
mod slt_util;
fn main() {
const SQLLOGICTEST_PATH_ENV: &str = "SQLLOGICTEST_PATH";
let args = libtest_mimic::Arguments::from_args();
let tests = if let Some(sqlite_test_path) = std::env::var_os(SQLLOGICTEST_PATH_ENV) {
let input_dir = Path::new(&sqlite_test_path).join("test");
slt_util::setup_tests(input_dir, "test", true, true)
} else {
// We don't know what tests to generate without the env var.
// Generate a dummy test that reports what's missing.
Vec::from([
libtest_mimic::Trial::test("sqlite_tests_not_configured", || {
// TODO replace with our fork that has `skipif` additions
const URL: &str = "https://www.sqlite.org/sqllogictest/";
Err(libtest_mimic::Failed::from(format!(
"set {SQLLOGICTEST_PATH_ENV} to checkout of {URL}"
)))
})
.with_ignored_flag(true),
])
};
libtest_mimic::run(&args, tests).exit();
}

View file

@ -0,0 +1,17 @@
[package]
name = "kanto-key-format-v1"
version = "0.1.0"
description = "Low-level data format helper for KantoDB database"
homepage.workspace = true
repository.workspace = true
authors.workspace = true
license.workspace = true
publish = false # TODO publish crate #severity/high #urgency/medium
edition.workspace = true
rust-version.workspace = true
[dependencies]
datafusion = { workspace = true }
[lints]
workspace = true

View file

@ -0,0 +1,59 @@
//! Internal storage formats for KantoDB.
//!
//! Unless you are implementing a `kanto::Backend`, you should **not** be using these directly.
use std::sync::Arc;
/// Encode rows from `key_source_arrays` into suitable keys.
#[must_use]
pub fn make_keys(
key_source_arrays: &[Arc<dyn datafusion::arrow::array::Array>],
) -> datafusion::arrow::array::BinaryArray {
let sort_fields = key_source_arrays
.iter()
.map(|array| {
datafusion::arrow::row::SortField::new_with_options(
array.data_type().clone(),
datafusion::arrow::compute::SortOptions {
descending: false,
nulls_first: false,
},
)
})
.collect::<Vec<_>>();
let row_converter = datafusion::arrow::row::RowConverter::new(sort_fields)
.expect("internal error #ea7ht1xeexpgo: misuse of RowConverter::new");
let rows = row_converter
.convert_columns(key_source_arrays)
.expect("internal error #z9i4xb8me68gk: misuse of RowConverter::convert_columns");
// TODO this fallibility sucks; do better when we take over from `arrow-row`
rows.try_into_binary()
.expect("internal error #yq9xnqtzhbn1k: RowConverter must produce BinaryArray")
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_simple() {
let array: datafusion::arrow::array::ArrayRef =
Arc::new(datafusion::arrow::array::UInt64Array::from(vec![42, 13]));
let keys = make_keys(&[array]);
let got = format!(
"{}",
datafusion::arrow::util::pretty::pretty_format_columns("key", &[Arc::new(keys)])
.unwrap(),
);
assert_eq!(
got,
"\
+--------------------+\n\
| key |\n\
+--------------------+\n\
| 01000000000000002a |\n\
| 01000000000000000d |\n\
+--------------------+"
);
}
}

View file

@ -0,0 +1,17 @@
[package]
name = "maybe-tracing"
version = "0.1.0"
description = "Workarounds for `tracing` crate not working with Miri and llvm-cov"
homepage.workspace = true
repository.workspace = true
authors.workspace = true
license.workspace = true
publish = false # TODO publish crate #severity/high #urgency/medium
edition.workspace = true
rust-version.workspace = true
[lib]
proc-macro = true
[lints]
workspace = true

View file

@ -0,0 +1,102 @@
//! Workarounds that temporarily disable `tracing` or `tracing_subscriber`.
//!
//! - `tracing::instrument` breaks code coverage.
//! - `test_log::test` breaks Miri isolation mode.
use proc_macro::Delimiter;
use proc_macro::Group;
use proc_macro::Ident;
use proc_macro::Punct;
use proc_macro::Spacing;
use proc_macro::Span;
use proc_macro::TokenStream;
use proc_macro::TokenTree;
/// Instrument the call with tracing, except when recording code coverage.
///
/// This is a workaround for <https://github.com/tokio-rs/tracing/issues/2082>.
// TODO tracing breaks llvm-cov <https://github.com/tokio-rs/tracing/issues/2082> #waiting #ecosystem/tracing
#[proc_macro_attribute]
pub fn instrument(attr: TokenStream, item: TokenStream) -> TokenStream {
let mut tokens = TokenStream::new();
if cfg!(not(coverage)) {
tokens.extend(TokenStream::from_iter([
TokenTree::Punct(Punct::new('#', Spacing::Alone)),
TokenTree::Group(Group::new(
Delimiter::Bracket,
TokenStream::from_iter([
TokenTree::Ident(Ident::new("tracing", Span::call_site())),
TokenTree::Punct(Punct::new(':', Spacing::Joint)),
TokenTree::Punct(Punct::new(':', Spacing::Alone)),
TokenTree::Ident(Ident::new("instrument", Span::call_site())),
TokenTree::Group(Group::new(Delimiter::Parenthesis, attr)),
]),
)),
]));
}
tokens.extend(item);
tokens
}
/// Define a `#[test]` that has `tracing` configured with `env_logger`, except under Miri.
///
/// This is a wrapper for `test_log::test` that disables itself when run under [Miri](https://github.com/rust-lang/miri/).
/// Miri isolation mode does not support logging wall clock times (<https://github.com/rust-lang/miri/issues/3740>).
// TODO tracing breaks miri due to missing wall clock <https://github.com/rust-lang/miri/issues/3740> #waiting #ecosystem/tracing
#[proc_macro_attribute]
pub fn test(attr: TokenStream, item: TokenStream) -> TokenStream {
let mut tokens = TokenStream::new();
// can't `if cfg!(miri)` at proc macro time, it seems that part does not run under the miri interpreter
// output `#[cfg_attr(miri, test)]`
tokens.extend(TokenStream::from_iter([
TokenTree::Punct(Punct::new('#', Spacing::Alone)),
TokenTree::Group(Group::new(
Delimiter::Bracket,
TokenStream::from_iter([
TokenTree::Ident(Ident::new("cfg_attr", Span::call_site())),
TokenTree::Group(Group::new(
Delimiter::Parenthesis,
TokenStream::from_iter([
TokenTree::Ident(Ident::new("miri", Span::call_site())),
TokenTree::Punct(Punct::new(',', Spacing::Alone)),
TokenTree::Ident(Ident::new("test", Span::call_site())),
]),
)),
]),
)),
]));
// output `#[cfg_attr(not(miri), test_log::test)]`
tokens.extend(TokenStream::from_iter([
TokenTree::Punct(Punct::new('#', Spacing::Alone)),
TokenTree::Group(Group::new(
Delimiter::Bracket,
TokenStream::from_iter([
TokenTree::Ident(Ident::new("cfg_attr", Span::call_site())),
TokenTree::Group(Group::new(
Delimiter::Parenthesis,
TokenStream::from_iter([
TokenTree::Ident(Ident::new("not", Span::call_site())),
TokenTree::Group(Group::new(
Delimiter::Parenthesis,
TokenStream::from_iter([TokenTree::Ident(Ident::new(
"miri",
Span::call_site(),
))]),
)),
TokenTree::Punct(Punct::new(',', Spacing::Alone)),
TokenTree::Ident(Ident::new("test_log", Span::call_site())),
TokenTree::Punct(Punct::new(':', Spacing::Joint)),
TokenTree::Punct(Punct::new(':', Spacing::Alone)),
TokenTree::Ident(Ident::new("test", Span::call_site())),
TokenTree::Group(Group::new(Delimiter::Parenthesis, attr)),
]),
)),
]),
)),
]));
tokens.extend(item);
tokens
}

View file

@ -0,0 +1,24 @@
[package]
name = "kanto-meta-format-v1"
version = "0.1.0"
description = "Low-level data format helper for KantoDB database"
homepage.workspace = true
repository.workspace = true
authors.workspace = true
license.workspace = true
publish = false # TODO publish crate #severity/high #urgency/medium
edition.workspace = true
rust-version.workspace = true
[dependencies]
arbitrary = { workspace = true }
datafusion = { workspace = true }
duplicate = { workspace = true }
indexmap = { workspace = true }
paste = { workspace = true }
rkyv = { workspace = true }
rkyv_util = { workspace = true }
thiserror = { workspace = true }
[lints]
workspace = true

View file

@ -0,0 +1,58 @@
//! Data type of a field in a database record.
/// Data type of a field in a database record.
#[derive(rkyv::Archive, rkyv::Serialize, Debug, arbitrary::Arbitrary)]
#[rkyv(derive(Debug))]
#[expect(missing_docs)]
pub enum FieldType {
U64,
U32,
U16,
U8,
I64,
I32,
I16,
I8,
F64,
F32,
String,
Binary,
Boolean,
// TODO more types
// TODO also parameterized, e.g. `VARCHAR(30)`, `FLOAT(23)`
}
impl ArchivedFieldType {
/// If this field is a fixed width primitive Arrow type, what is the width.
#[must_use]
pub fn fixed_arrow_bytes_width(&self) -> Option<usize> {
let data_type = datafusion::arrow::datatypes::DataType::from(self);
data_type.primitive_width()
}
}
#[duplicate::duplicate_item(
T;
[ FieldType ];
[ ArchivedFieldType ];
)]
impl From<&T> for datafusion::arrow::datatypes::DataType {
fn from(field_type: &T) -> Self {
use datafusion::arrow::datatypes::DataType;
match field_type {
T::U64 => DataType::UInt64,
T::U32 => DataType::UInt32,
T::U16 => DataType::UInt16,
T::U8 => DataType::UInt8,
T::I64 => DataType::Int64,
T::I32 => DataType::Int32,
T::I16 => DataType::Int16,
T::I8 => DataType::Int8,
T::F64 => DataType::Float64,
T::F32 => DataType::Float32,
T::String => DataType::Utf8View,
T::Binary => DataType::BinaryView,
T::Boolean => DataType::Boolean,
}
}
}

View file

@ -0,0 +1,92 @@
//! Definition of an index for record lookup and consistency enforcement.
#![expect(
missing_docs,
reason = "waiting <https://github.com/rkyv/rkyv/issues/596> <https://github.com/rkyv/rkyv/issues/597>"
)]
use super::FieldId;
use super::IndexId;
use super::TableId;
/// An expression used in the index key.
#[derive(rkyv::Archive, rkyv::Serialize, Debug)]
#[rkyv(derive(Debug), attr(expect(missing_docs)))]
pub enum Expr {
/// Use a field from the record as index key.
Field {
/// ID of the field to use.
field_id: FieldId,
},
// TODO support more expressions
}
/// Sort order.
#[derive(rkyv::Archive, rkyv::Serialize, Debug)]
#[rkyv(derive(Debug))]
pub enum SortOrder {
/// Ascending order.
Ascending,
/// Descending order.
Descending,
}
/// Sort order for values containing `NULL`s.
#[derive(rkyv::Archive, rkyv::Serialize, Debug)]
#[rkyv(derive(Debug))]
pub enum NullOrder {
/// Nulls come before all other values.
NullsFirst,
/// Nulls come after all other values.
NullsLast,
}
/// Order the index by an expression.
#[derive(rkyv::Archive, rkyv::Serialize, Debug)]
#[rkyv(derive(Debug), attr(expect(missing_docs)))]
pub struct OrderByExpr {
/// Expression evaluating to the value stored.
pub expr: Expr,
/// Sort ascending or descending.
pub sort_order: SortOrder,
/// Do `NULL` values come before or after non-`NULL` values.
pub null_order: NullOrder,
}
/// Unique index specific options.
#[derive(rkyv::Archive, rkyv::Serialize, Debug)]
#[rkyv(derive(Debug), attr(expect(missing_docs)))]
pub struct UniqueIndexOptions {
/// Are `NULL` values considered equal or different.
pub nulls_distinct: bool,
}
/// What kind of an index is this.
#[derive(rkyv::Archive, rkyv::Serialize, Debug)]
#[rkyv(derive(Debug))]
pub enum IndexKind {
Unique(UniqueIndexOptions),
Multi,
}
/// Definition of an index for record lookup and consistency enforcement.
#[derive(rkyv::Archive, rkyv::Serialize, Debug)]
#[rkyv(derive(Debug), attr(expect(missing_docs)))]
pub struct IndexDef {
/// Unique ID of the index.
pub index_id: IndexId,
/// The table this index is for.
pub table: TableId,
/// Name of the index.
///
/// Indexes inherit their catalog and schema from the table they are for.
pub index_name: String,
/// Columns (or expressions) indexed.
pub columns: Vec<OrderByExpr>,
/// Is this a unique index, or can multiple records have the same values for the columns indexed.
pub index_kind: IndexKind,
/// What fields to store in the index record.
pub include: Vec<Expr>,
/// Only include records where this predicate evaluates to true.
pub predicate: Option<Expr>,
}

View file

@ -0,0 +1,62 @@
//! Internal storage formats for KantoDB.
//!
//! Unless you are implementing a `kanto::Backend`, you should **not** be using these directly.
#![allow(
clippy::exhaustive_structs,
clippy::exhaustive_enums,
reason = "exhaustive is ok as the wire format has to be stable anyway"
)]
pub mod field_type;
pub mod index_def;
pub mod name_def;
pub mod sequence_def;
pub mod table_def;
mod util;
pub use datafusion;
pub use datafusion::arrow;
pub use indexmap;
pub use rkyv;
pub use rkyv_util;
use crate::util::idgen::make_id_type;
make_id_type!(
/// ID of a [`FieldDef`](crate::table_def::FieldDef), unique within one [`TableDef`](crate::table_def::TableDef).
pub FieldId
);
make_id_type!(
/// Unique ID of an [`IndexDef`](crate::index_def::IndexDef).
pub IndexId
);
make_id_type!(
/// Unique ID of an [`SequenceDef`](crate::sequence_def::SequenceDef).
pub SequenceId
);
make_id_type!(
/// Unique ID of an [`TableDef`](crate::table_def::TableDef).
pub TableId
);
/// Return an serialized copy of `T` that allows accessing it as `ArchivedT`.
pub fn owned_archived<T>(
value: &T,
) -> Result<rkyv_util::owned::OwnedArchive<T, rkyv::util::AlignedVec>, rkyv::rancor::BoxedError>
where
T: for<'arena> rkyv::Serialize<
rkyv::api::high::HighSerializer<
rkyv::util::AlignedVec,
rkyv::ser::allocator::ArenaHandle<'arena>,
rkyv::rancor::BoxedError,
>,
>,
T::Archived: rkyv::Portable
+ for<'a> rkyv::bytecheck::CheckBytes<
rkyv::api::high::HighValidator<'a, rkyv::rancor::BoxedError>,
>,
{
let bytes = rkyv::to_bytes::<rkyv::rancor::BoxedError>(value)?;
rkyv_util::owned::OwnedArchive::<T, _>::new::<rkyv::rancor::BoxedError>(bytes)
}

View file

@ -0,0 +1,46 @@
//! Named entities.
#![expect(
missing_docs,
reason = "waiting <https://github.com/rkyv/rkyv/issues/596> <https://github.com/rkyv/rkyv/issues/597>"
)]
use super::IndexId;
use super::SequenceId;
use super::TableId;
/// What kind of an entity the name refers to.
#[derive(rkyv::Archive, rkyv::Serialize, Debug, arbitrary::Arbitrary)]
#[rkyv(derive(Debug), attr(expect(missing_docs)))]
pub enum NameKind {
/// Name is for a [`TableDef`]((crate::table_def::TableDef)).
Table {
/// ID of the table.
table_id: TableId,
},
/// Name is for a [`IndexDef`](crate::index_def::IndexDef).
Index {
/// ID of the index.
index_id: IndexId,
},
/// Name is for a [`SequenceDef`](crate::sequence_def::SequenceDef).
Sequence {
/// ID of the sequence.
sequence_id: SequenceId,
},
}
/// SQL defines tables, indexes, sequences etc to all be in one namespace.
/// This is the data type all such names resolve to, and refers to the ultimate destination object.
#[derive(rkyv::Archive, rkyv::Serialize, Debug, arbitrary::Arbitrary)]
#[rkyv(derive(Debug), attr(expect(missing_docs)))]
pub struct NameDef {
/// SQL catalog (top-level namespace) the name is in.
pub catalog: String,
/// SQL schema (second-level namespace) the name is in.
pub schema: String,
/// Name of the object.
pub name: String,
/// What kind of an object it is, and a reference to the object.
pub kind: NameKind,
}

View file

@ -0,0 +1,17 @@
//! Sequence definitions.
use super::SequenceId;
/// Definition of a sequence.
///
/// Not much interesting here.
#[derive(rkyv::Archive, rkyv::Serialize)]
#[expect(missing_docs)]
#[rkyv(attr(expect(missing_docs)))]
pub struct SequenceDef {
pub sequence_id: SequenceId,
pub catalog: String,
pub schema: String,
pub sequence_name: String,
// sequences don't seem to need anything interesting stored
}

View file

@ -0,0 +1,404 @@
//! Definition of a table of records.
use std::collections::HashMap;
use std::fmt::Debug;
use std::sync::Arc;
mod primary_keys;
pub use self::primary_keys::PrimaryKeys;
use crate::field_type::FieldType;
use crate::FieldId;
use crate::SequenceId;
use crate::TableId;
super::util::idgen::make_id_type!(
/// Generation counter of a [`TableDef`].
/// Every edit increments this.
pub TableDefGeneration
);
/// Definition of a *row ID* for the table.
#[derive(rkyv::Archive, rkyv::Serialize, Debug, arbitrary::Arbitrary)]
#[rkyv(derive(Debug), attr(expect(missing_docs)))]
pub struct RowIdDef {
/// Allocate IDs from this sequence.
pub sequence_id: SequenceId,
}
/// Definition of a *primary key* for the table.
#[derive(rkyv::Archive, rkyv::Serialize, Debug, arbitrary::Arbitrary)]
#[rkyv(derive(Debug), attr(expect(missing_docs)))]
pub struct PrimaryKeysDef {
/// Field IDs in `PrimaryKeys` are guaranteed to be present in `TableDef.fields`.
///
/// This is guaranteed to be non-empty and have no duplicates.
pub field_ids: PrimaryKeys,
}
/// What kind of a *row key* this table uses.
#[derive(rkyv::Archive, rkyv::Serialize, Debug, arbitrary::Arbitrary)]
#[rkyv(derive(Debug))]
pub enum RowKeyKind {
/// Row keys are assigned from a sequence.
RowId(RowIdDef),
/// Row keys are composed of field values or expressions using the field values.
PrimaryKeys(PrimaryKeysDef),
}
/// Definition of a field in a [`TableDef`].
#[derive(rkyv::Archive, rkyv::Serialize, Debug)]
#[rkyv(derive(Debug), attr(expect(missing_docs)))]
pub struct FieldDef {
/// ID of the field, unique within the [`TableDef`].
pub field_id: FieldId,
/// In what generation of [`TableDef`] this field was added.
///
/// If a field is deleted and later a field with the same name is added, that's a separate field with a separate [`FieldId`].
pub added_in: TableDefGeneration,
/// In what generation, if any, of [`TableDef`] this field was added.
pub deleted_in: Option<TableDefGeneration>,
/// Unique among live fields (see [`fields_live_in_latest`](ArchivedTableDef::fields_live_in_latest)).
pub field_name: String,
/// Data type of the database column.
pub field_type: FieldType,
/// Can a value of this field be `NULL`.
///
/// Fields that are in `PrimaryKeys` cannot be nullable.
pub nullable: bool,
/// Default value, as Arrow bytes suitable for the data type.
// TODO support setting default values
// TODO maybe only allow setting this field via a validating helper? probably better to do validation at whatever level would manage `generation` etc.
pub default_value_arrow_bytes: Option<Box<[u8]>>,
}
impl ArchivedFieldDef {
#[must_use]
/// Is the field live at the latest [generation](TableDef::generation).
pub fn is_live(&self) -> bool {
self.deleted_in.is_none()
}
}
impl<'a> arbitrary::Arbitrary<'a> for FieldDef {
fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
let field_id = u.arbitrary()?;
let added_in: TableDefGeneration = u.arbitrary()?;
let deleted_in: Option<TableDefGeneration> = if u.ratio(1u8, 10u8)? {
None
} else {
TableDefGeneration::arbitrary_greater_than(u, added_in)?
};
debug_assert!(deleted_in.is_none_or(|del| added_in < del));
let field_name = u.arbitrary()?;
let field_type = u.arbitrary()?;
let default_value_arrow_bytes: Option<_> = u.arbitrary()?;
let nullable = if deleted_in.is_some() && default_value_arrow_bytes.is_none() {
// must be nullable since it's deleted and has no default value
true
} else {
u.arbitrary()?
};
Ok(FieldDef {
field_id,
added_in,
deleted_in,
field_name,
field_type,
nullable,
default_value_arrow_bytes,
})
}
}
/// Definition of a table.
#[derive(rkyv::Archive, rkyv::Serialize, Debug)]
#[expect(missing_docs)]
#[rkyv(derive(Debug), attr(expect(missing_docs)))]
pub struct TableDef {
pub table_id: TableId,
pub catalog: String,
pub schema: String,
pub table_name: String,
/// Every update increases generation.
pub generation: TableDefGeneration,
/// What is the primary key of the table: a (possibly composite) key of field values, or an automatically assigned row ID.
pub row_key_kind: RowKeyKind,
/// The fields stored in this table.
///
/// Guaranteed to contain deleted [`FieldDef`]s for as long as they have records using a [`TableDefGeneration`] where the field was live.
//
// TODO `rkyv::collections::btree_map::ArchivedBTreeMap::get` is impossible to use <https://github.com/rkyv/rkyv/issues/585> #waiting #ecosystem/rkyv
// our inserts happen to in increasing `FieldId` order so insertion order and sorted order are the same; still want a sorted container here to guarantee correctness
pub fields: indexmap::IndexMap<FieldId, FieldDef>,
}
impl<'a> arbitrary::Arbitrary<'a> for TableDef {
#[expect(
clippy::unwrap_used,
clippy::unwrap_in_result,
reason = "not production code #severity/low #urgency/low"
)]
fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
let table_id = u.arbitrary()?;
let catalog = u.arbitrary()?;
let schema = u.arbitrary()?;
let table_name = u.arbitrary()?;
// TODO unique `field_id`
// TODO unique `field_name`
// TODO primary keys cannot be nullable
let fields: indexmap::IndexMap<_, _> = u
.arbitrary_iter::<FieldDef>()?
// awkward callback, don't want to collect to temporary just to map the `Ok` values
.map(|result| result.map(|field| (field.field_id, field)))
.collect::<Result<_, _>>()?;
let generation = {
// choose between generations mentioned in fields and max
let min = fields
.values()
.map(|field| field.added_in)
.chain(fields.values().filter_map(|field| field.deleted_in))
.max()
.map(|generation| generation.get_u64())
.unwrap_or(1);
#[expect(
clippy::range_minus_one,
reason = "Unstructured insists on InclusiveRange"
)]
let num = u.int_in_range(min..=u64::MAX - 1)?;
TableDefGeneration::new(num).unwrap()
};
let row_id = {
let num_primary_keys = u.choose_index(fields.len())?;
if num_primary_keys == 0 {
RowKeyKind::RowId(RowIdDef {
sequence_id: u.arbitrary()?,
})
} else {
// take a random subset in random order
let mut tmp: Vec<_> = fields.keys().collect();
let primary_key_ids = std::iter::from_fn(|| {
if tmp.is_empty() {
None
} else {
Some(u.choose_index(tmp.len()).map(|idx| *tmp.swap_remove(idx)))
}
})
.take(num_primary_keys)
.collect::<Result<Vec<_>, arbitrary::Error>>()?;
let primary_keys = PrimaryKeys::try_from(primary_key_ids).unwrap();
RowKeyKind::PrimaryKeys(PrimaryKeysDef {
field_ids: primary_keys,
})
}
};
let table_def = TableDef {
table_id,
catalog,
schema,
table_name,
generation,
row_key_kind: row_id,
fields,
};
Ok(table_def)
}
}
/// Errors from [`validate`](ArchivedTableDef::validate).
#[expect(clippy::module_name_repetitions)]
#[expect(missing_docs)]
#[derive(thiserror::Error, Debug)]
pub enum TableDefValidateError {
#[error("primary key field id not found: {field_id} not in {table_def_debug}")]
PrimaryKeyFieldIdNotFound {
field_id: FieldId,
table_def_debug: String,
},
#[error("primary key field cannot be deleted: {field_id} in {table_def_debug}")]
PrimaryKeyDeleted {
field_id: FieldId,
table_def_debug: String,
},
#[error("primary key field cannot be nullable: {field_id} in {table_def_debug}")]
PrimaryKeyNullable {
field_id: FieldId,
table_def_debug: String,
},
#[error("field id mismatch: {field_id} in {table_def_debug}")]
FieldIdMismatch {
field_id: FieldId,
table_def_debug: String,
},
#[error("field deleted before creation: {field_id} in {table_def_debug}")]
FieldDeletedBeforeCreation {
field_id: FieldId,
table_def_debug: String,
},
#[error("field name cannot be empty: {field_id} in {table_def_debug}")]
FieldNameEmpty {
field_id: FieldId,
table_def_debug: String,
},
#[error("duplicate field name: {field_id} vs {previous_field_id} in {table_def_debug}")]
FieldNameDuplicate {
field_id: FieldId,
previous_field_id: FieldId,
table_def_debug: String,
},
#[error("fixed size field default value has incorrect size: {field_id} in {table_def_debug}")]
FieldDefaultValueLength {
field_id: FieldId,
table_def_debug: String,
field_type_size: usize,
default_value_length: usize,
},
}
impl ArchivedTableDef {
/// Validate the internal consistency of the table definition.
///
/// Returns only the first error noticed.
pub fn validate(&self) -> Result<(), TableDefValidateError> {
// TODO maybe split into helper functions, one per check?
// TODO maybe field.validate()?;
match &self.row_key_kind {
ArchivedRowKeyKind::RowId(_seq_row_id_def) => {
// nothing
}
ArchivedRowKeyKind::PrimaryKeys(primary_keys_row_id_def) => {
for field_id in primary_keys_row_id_def.field_ids.iter() {
let field = self.fields.get(field_id).ok_or_else(|| {
TableDefValidateError::PrimaryKeyFieldIdNotFound {
field_id: field_id.to_native(),
table_def_debug: format!("{self:?}"),
}
})?;
if !field.is_live() {
let error = TableDefValidateError::PrimaryKeyDeleted {
field_id: field_id.to_native(),
table_def_debug: format!("{self:?}"),
};
return Err(error);
}
if field.nullable {
let error = TableDefValidateError::PrimaryKeyNullable {
field_id: field_id.to_native(),
table_def_debug: format!("{self:?}"),
};
return Err(error);
}
}
}
};
let mut live_field_names = HashMap::new();
for (field_id, field) in self.fields.iter() {
if &field.field_id != field_id {
let error = TableDefValidateError::FieldIdMismatch {
field_id: field_id.to_native(),
table_def_debug: format!("{self:?}"),
};
return Err(error);
}
if let Some(deleted_in) = field.deleted_in.as_ref() {
if deleted_in <= &field.added_in {
let error = TableDefValidateError::FieldDeletedBeforeCreation {
field_id: field_id.to_native(),
table_def_debug: format!("{self:?}"),
};
return Err(error);
}
}
if field.field_name.is_empty() {
let error = TableDefValidateError::FieldNameEmpty {
field_id: field_id.to_native(),
table_def_debug: format!("{self:?}"),
};
return Err(error);
}
if field.is_live() {
let field_name = field.field_name.as_str();
let entry = live_field_names.entry(field_name);
match entry {
std::collections::hash_map::Entry::Occupied(occupied_entry) => {
let error = TableDefValidateError::FieldNameDuplicate {
field_id: field_id.to_native(),
previous_field_id: *occupied_entry.get(),
table_def_debug: format!("{self:?}"),
};
return Err(error);
}
std::collections::hash_map::Entry::Vacant(vacant_entry) => {
let _mut_field_id = vacant_entry.insert(field.field_id.to_native());
}
}
}
if let Some(default_value) = field.default_value_arrow_bytes.as_deref() {
if let Some(field_type_size) = field.field_type.fixed_arrow_bytes_width() {
let default_value_length = default_value.len();
if default_value_length != field_type_size {
let error = TableDefValidateError::FieldDefaultValueLength {
field_id: field_id.to_native(),
table_def_debug: format!("{self:?}"),
field_type_size,
default_value_length,
};
return Err(error);
}
}
}
}
Ok(())
}
/// Which fields are live in the given generation.
pub fn fields_live_in_gen(
&self,
generation: TableDefGeneration,
) -> impl Iterator<Item = &ArchivedFieldDef> {
self.fields.values().filter(move |field| {
// This generation of row stores this field
let is_in_record = generation >= field.added_in
&& field
.deleted_in
.as_ref()
.is_none_or(|del| generation < *del);
is_in_record
})
}
/// Which fields are live in the latest generation.
pub fn fields_live_in_latest(&self) -> impl Iterator<Item = &ArchivedFieldDef> {
self.fields
.values()
.filter(|field| field.deleted_in.as_ref().is_none())
}
}
/// Make an Arrow [`Schema`](datafusion::arrow::datatypes::Schema) for some fields.
///
/// Using a subset of fields is useful since [`TableDef`] can include deleted fields.
/// See [`ArchivedTableDef::fields_live_in_latest`].
#[must_use]
pub fn make_arrow_schema<'defs>(
fields: impl Iterator<Item = &'defs ArchivedFieldDef>,
) -> datafusion::arrow::datatypes::SchemaRef {
let mut builder = {
let (_min, max) = fields.size_hint();
if let Some(capacity) = max {
datafusion::arrow::datatypes::SchemaBuilder::with_capacity(capacity)
} else {
datafusion::arrow::datatypes::SchemaBuilder::new()
}
};
for field in fields {
let arrow_field = datafusion::arrow::datatypes::Field::new(
field.field_name.to_owned(),
datafusion::arrow::datatypes::DataType::from(&field.field_type),
field.nullable,
);
builder.push(arrow_field);
}
Arc::new(builder.finish())
}

View file

@ -0,0 +1,44 @@
use crate::util::unique_nonempty::UniqueNonEmpty;
use crate::FieldId;
/// Slice of primary key [`FieldId`] that is guaranteed to be non-empty and have no duplicates.
#[derive(rkyv::Archive, rkyv::Serialize, arbitrary::Arbitrary)]
pub struct PrimaryKeys(UniqueNonEmpty<FieldId>);
impl TryFrom<Vec<FieldId>> for PrimaryKeys {
// TODO report why it failed #urgency/low #severity/low
type Error = ();
fn try_from(value: Vec<FieldId>) -> Result<Self, Self::Error> {
let une = UniqueNonEmpty::try_from(value)?;
Ok(Self(une))
}
}
impl std::ops::Deref for PrimaryKeys {
type Target = [FieldId];
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl std::ops::Deref for ArchivedPrimaryKeys {
type Target = [<FieldId as rkyv::Archive>::Archived];
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl std::fmt::Debug for PrimaryKeys {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_set().entries(self.0.iter()).finish()
}
}
impl std::fmt::Debug for ArchivedPrimaryKeys {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_set().entries(self.0.iter()).finish()
}
}

View file

@ -0,0 +1,2 @@
pub(crate) mod idgen;
pub(crate) mod unique_nonempty;

View file

@ -0,0 +1,194 @@
macro_rules! make_id_type {
(
$(#[$meta:meta])*
$vis:vis $name:ident
) => {
$(#[$meta])*
#[derive(Copy, Clone, Ord, PartialOrd, Eq, PartialEq, Hash, rkyv::Archive, rkyv::Serialize)]
#[rkyv(derive(Ord, PartialOrd, Eq, PartialEq, Hash, Debug))]
#[rkyv(compare(PartialOrd, PartialEq))]
#[non_exhaustive]
#[rkyv(attr(non_exhaustive))]
$vis struct $name(::std::num::NonZeroU64);
#[allow(clippy::allow_attributes, reason = "generated code")]
// Not all functionality is used by every ID type.
#[allow(dead_code)]
impl $name {
/// Minimum valid value.
$vis const MIN: $name = $name::new(1).expect("constant calculation must be correct");
/// Maximum valid value.
$vis const MAX: $name = $name::new(u64::MAX.saturating_sub(1)).expect("constant calculation must be correct");
/// Construct a new identifier.
/// The values `0` and `u64::MAX` are invalid.
#[must_use]
$vis const fn new(num: u64) -> Option<Self> {
if num == u64::MAX {
None
} else {
// limitation of const functions: no `Option::map`.
match ::std::num::NonZeroU64::new(num) {
Some(nz) => Some(Self(nz)),
None => None,
}
}
}
/// Get the value as a [`u64`].
#[must_use]
$vis const fn get_u64(&self) -> u64 {
self.0.get()
}
/// Get the value as a [`NonZeroU64`](::std::num::NonZeroU64).
#[must_use]
#[allow(dead_code)]
$vis const fn get_non_zero(&self) -> ::std::num::NonZeroU64 {
self.0
}
/// Return the next possible value, or [`None`] if that would be the invalid value [`u64::MAX`].
#[must_use]
$vis const fn next(&self) -> Option<Self> {
// `new` will refuse if we hit `u64::MAX`
Self::new(self.get_u64().saturating_add(1))
}
/// Iterate all valid identifier values.
$vis fn iter_all() -> impl Iterator<Item=$name> {
::std::iter::successors(Some(Self::MIN), |prev: &Self| prev.next())
}
fn arbitrary_greater_than(u: &mut ::arbitrary::Unstructured<'_>, greater_than: Self) -> ::arbitrary::Result<Option<Self>> {
if let Some(min) = greater_than.next() {
let num = u.int_in_range(min.get_u64()..=u64::MAX-1)?;
let id = Self::new(num).unwrap();
Ok(Some(id))
} else {
Ok(None)
}
}
}
impl ::paste::paste!([<Archived $name>]) {
#[doc = concat!(
"Convert the [`Archived",
stringify!($name),
"`] to [`",
stringify!($name),
"`].",
)]
#[must_use]
#[expect(clippy::allow_attributes)]
#[allow(dead_code)]
$vis const fn to_native(&self) -> $name {
$name(self.0.to_native())
}
/// Get the value as a [`u64`].
#[must_use]
#[expect(clippy::allow_attributes)]
#[allow(dead_code)]
$vis const fn get_u64(&self) -> u64 {
self.0.get()
}
/// Get the value as a [`NonZeroU64`](::std::num::NonZeroU64).
#[must_use]
#[expect(clippy::allow_attributes)]
#[allow(dead_code)]
$vis const fn get_non_zero(&self) -> ::std::num::NonZeroU64 {
self.0.to_native()
}
}
#[::duplicate::duplicate_item(
ident;
[$name];
[::paste::paste!([<Archived $name>])];
)]
impl ::std::fmt::Display for ident {
fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::result::Result<(), ::std::fmt::Error> {
::std::write!(f, "{}", self.get_u64())
}
}
impl ::std::fmt::Debug for $name {
fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::result::Result<(), ::std::fmt::Error> {
::std::write!(f, "{}({})", stringify!($name), self.get_u64())
}
}
impl ::std::cmp::PartialEq<u64> for $name {
fn eq(&self, other: &u64) -> bool {
self.0.get() == *other
}
}
impl ::std::cmp::PartialEq<$name> for u64 {
fn eq(&self, other: &$name) -> bool {
*self == other.get_u64()
}
}
impl<'a> ::arbitrary::Arbitrary<'a> for $name {
fn arbitrary(u: &mut ::arbitrary::Unstructured<'a>) -> ::arbitrary::Result<Self> {
let num = u.int_in_range(1..=u64::MAX-1)?;
Ok(Self::new(num).unwrap())
}
}
}
}
pub(crate) use make_id_type;
#[cfg(test)]
mod tests {
use std::num::NonZeroU64;
use super::*;
make_id_type!(FooId);
#[test]
fn roundtrip() {
let foo_id = FooId::new(42).unwrap();
assert_eq!(foo_id.get_u64(), 42u64);
assert_eq!(foo_id.get_non_zero(), NonZeroU64::new(42u64).unwrap());
}
#[test]
// TODO `const fn` to avoid clippy noise <https://github.com/rust-lang/rust-clippy/issues/13938> #waiting #ecosystem/rust
const fn copy() {
let foo_id = FooId::new(42).unwrap();
let one = foo_id;
let two = foo_id;
_ = one;
_ = two;
}
#[test]
fn debug() {
let foo_id = FooId::new(42).unwrap();
assert_eq!(format!("{foo_id:?}"), "FooId(42)");
}
#[test]
fn debug_alternate() {
let foo_id = FooId::new(42).unwrap();
assert_eq!(format!("{foo_id:#?}"), "FooId(42)");
}
#[test]
fn display() {
let foo_id = FooId::new(42).unwrap();
assert_eq!(format!("{foo_id}"), "42");
}
#[test]
fn partial_eq_u64() {
let foo_id = FooId::new(42).unwrap();
assert_eq!(foo_id, 42u64);
}
}

View file

@ -0,0 +1,113 @@
/// Slice that is guaranteed to be non-empty and have no duplicates.
#[derive(rkyv::Archive, rkyv::Serialize, arbitrary::Arbitrary)]
pub(crate) struct UniqueNonEmpty<T>(Box<[T]>);
fn is_unique<T>(slice: &[T]) -> bool
where
T: PartialEq,
{
// big-Oh performance is not relevant here, inputs should be smallish
for (idx, i) in slice.iter().enumerate() {
if let Some(tail) = idx
.checked_add(1)
.and_then(|tail_start| slice.get(tail_start..))
{
for j in tail {
if i == j {
return false;
}
}
}
}
true
}
impl<T> TryFrom<Vec<T>> for UniqueNonEmpty<T>
where
T: PartialEq,
{
// TODO report why it failed #urgency/low #severity/low
type Error = ();
fn try_from(value: Vec<T>) -> Result<Self, Self::Error> {
if value.is_empty() {
return Err(());
}
if !is_unique(&value) {
return Err(());
}
let une = Self(value.into_boxed_slice());
Ok(une)
}
}
impl<T> std::ops::Deref for UniqueNonEmpty<T> {
type Target = [T];
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<T> std::ops::Deref for ArchivedUniqueNonEmpty<T>
where
T: rkyv::Archive,
{
type Target = [T::Archived];
fn deref(&self) -> &Self::Target {
self.0.get()
}
}
impl<T, Rhs> PartialEq<Rhs> for UniqueNonEmpty<T>
where
T: std::fmt::Debug + PartialEq,
Rhs: AsRef<[T]>,
{
fn eq(&self, other: &Rhs) -> bool {
self.0.as_ref() == other.as_ref()
}
}
impl<T> std::fmt::Debug for UniqueNonEmpty<T>
where
T: std::fmt::Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_set().entries(self.0.iter()).finish()
}
}
impl<T> std::fmt::Debug for ArchivedUniqueNonEmpty<T>
where
T: rkyv::Archive,
T::Archived: std::fmt::Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_set().entries(self.0.iter()).finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_try_from_vec_dup() {
let got = UniqueNonEmpty::try_from(vec![3u8, 1, 2, 1]);
assert_eq!(got.err(), Some(()));
}
#[test]
fn test_try_from_vec_empty() {
let got = UniqueNonEmpty::try_from(Vec::<u8>::new());
assert_eq!(got.err(), Some(()));
}
#[test]
fn test_try_from_vec_ok() {
let got = UniqueNonEmpty::try_from(vec![3u8, 1, 2]).unwrap();
assert_eq!(got, vec![3, 1, 2]);
}
}

View file

@ -0,0 +1,25 @@
[package]
name = "kanto-protocol-postgres"
version = "0.1.0"
description = "Postgres wire protocol server for KantoDB"
homepage.workspace = true
repository.workspace = true
authors.workspace = true
license.workspace = true
publish = false # TODO publish crate #severity/high #urgency/medium
edition.workspace = true
rust-version.workspace = true
[dependencies]
async-trait = { workspace = true }
datafusion = { workspace = true }
futures = { workspace = true }
futures-lite = { workspace = true }
kanto = { workspace = true }
maybe-tracing = { workspace = true }
pgwire = { workspace = true }
tokio = { workspace = true }
tracing = { workspace = true }
[lints]
workspace = true

View file

@ -0,0 +1,383 @@
//! Postgres wire protocol support for Kanto.
use std::collections::BTreeMap;
use std::fmt::Debug;
use std::sync::Arc;
use async_trait::async_trait;
use futures::Sink;
use futures::SinkExt as _;
use futures::StreamExt as _;
use futures_lite::stream;
use kanto::arrow::array::AsArray as _;
use pgwire::api::results::DataRowEncoder;
use pgwire::api::results::FieldFormat;
use pgwire::api::results::FieldInfo;
use pgwire::api::results::QueryResponse;
use pgwire::api::results::Response;
use pgwire::api::ClientInfo;
use pgwire::api::Type;
use pgwire::error::PgWireError;
use pgwire::error::PgWireResult;
use pgwire::messages::response::NoticeResponse;
use pgwire::messages::PgWireBackendMessage;
use pgwire::tokio::process_socket;
pub use tokio;
/// Serve the Postgres wire protocol on TCP connections.
#[maybe_tracing::instrument(skip(backend), ret, err)]
pub async fn serve_pg(
// TODO support more than one backend at a time; multiplexing catalogs to backends is still undecided
backend: Box<dyn kanto::Backend>,
tcp_listener: tokio::net::TcpListener,
) -> Result<(), Box<dyn std::error::Error>> {
let mut session_config =
datafusion::execution::context::SessionConfig::new().with_information_schema(true);
let options = session_config.options_mut();
kanto::defaults::DEFAULT_CATALOG_NAME.clone_into(&mut options.catalog.default_catalog);
// TODO a shutdown that's not just `SIGTERM`, for tests etc
// TODO maybe look into Tower and whether it can handle stateful connections
#[expect(clippy::infinite_loop, reason = "for now, we serve TCP until shutdown")]
loop {
let incoming_socket = match tcp_listener.accept().await {
Ok(socket) => socket,
Err(error) => {
// TODO handle transient errors versus reasons for terminating; report transient network errors to caller somehow; see `accept(2)` discussion about network errors
tracing::error!(%error, "TCP accept");
continue;
}
};
let session = kanto::Session::new(
session_config.clone(),
BTreeMap::from([("kanto".to_owned(), backend.clone())]),
);
let session = Arc::new(tokio::sync::Mutex::new(session));
let handler = Arc::new(Handler { session });
// TODO something should collect returns and report errors; maybe use Tower
let _todo_join_handle = tokio::spawn(process_socket(incoming_socket.0, None, handler));
}
}
struct Handler {
session: Arc<tokio::sync::Mutex<kanto::Session>>,
}
impl pgwire::api::PgWireServerHandlers for Handler {
type StartupHandler = StartupHandler;
type SimpleQueryHandler = SimpleQueryHandler;
type ExtendedQueryHandler = pgwire::api::query::PlaceholderExtendedQueryHandler;
type CopyHandler = pgwire::api::copy::NoopCopyHandler;
type ErrorHandler = ErrorHandler;
fn simple_query_handler(&self) -> Arc<Self::SimpleQueryHandler> {
let h = SimpleQueryHandler {
session: self.session.clone(),
};
Arc::new(h)
}
fn extended_query_handler(&self) -> Arc<Self::ExtendedQueryHandler> {
// TODO implement Postgres protocol extended queries
Arc::new(pgwire::api::query::PlaceholderExtendedQueryHandler)
}
fn startup_handler(&self) -> Arc<Self::StartupHandler> {
Arc::new(StartupHandler {})
}
fn copy_handler(&self) -> Arc<Self::CopyHandler> {
// TODO `COPY` (#ref/unimplemented_sql/1fxo667s3wwes)
Arc::new(pgwire::api::copy::NoopCopyHandler)
}
fn error_handler(&self) -> Arc<Self::ErrorHandler> {
Arc::new(ErrorHandler {
session: self.session.clone(),
})
}
}
struct StartupHandler;
#[async_trait]
impl pgwire::api::auth::noop::NoopStartupHandler for StartupHandler {
#[maybe_tracing::instrument(skip(self, client), ret, err)]
async fn post_startup<C>(
&self,
client: &mut C,
_message: pgwire::messages::PgWireFrontendMessage,
) -> PgWireResult<()>
where
C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send,
C::Error: Debug,
PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
{
// TODO authentication etc, move away from `NoopStartupHandler`
client
.send(PgWireBackendMessage::NoticeResponse(NoticeResponse::from(
pgwire::error::ErrorInfo::new(
"NOTICE".to_owned(),
"01000".to_owned(),
"KantoDB ready".to_owned(),
),
)))
.await?;
Ok(())
}
}
struct SimpleQueryHandler {
session: Arc<tokio::sync::Mutex<kanto::Session>>,
}
impl SimpleQueryHandler {
#[expect(clippy::rc_buffer, reason = "pgwire API forces `Arc<Vec<_>>` on us")]
fn make_pg_schema(df_schema: &kanto::arrow::datatypes::Schema) -> Arc<Vec<FieldInfo>> {
// TODO maybe migrate this into a separate `arrow-postgres` crate?
let pg_schema = df_schema
.fields()
.iter()
.map(|field| {
FieldInfo::new(
field.name().clone(),
None,
None,
#[expect(clippy::match_same_arms)]
match field.data_type() {
datafusion::arrow::datatypes::DataType::Null => Type::UNKNOWN,
datafusion::arrow::datatypes::DataType::Boolean => Type::BOOL,
datafusion::arrow::datatypes::DataType::Int8 => Type::CHAR,
datafusion::arrow::datatypes::DataType::Int16 => Type::INT2,
datafusion::arrow::datatypes::DataType::Int32 => Type::INT4,
datafusion::arrow::datatypes::DataType::Int64 => Type::INT8,
datafusion::arrow::datatypes::DataType::UInt8 => Type::UNKNOWN,
datafusion::arrow::datatypes::DataType::UInt16 => Type::UNKNOWN,
datafusion::arrow::datatypes::DataType::UInt32 => Type::UNKNOWN,
datafusion::arrow::datatypes::DataType::UInt64 => Type::UNKNOWN,
datafusion::arrow::datatypes::DataType::Float16 => Type::UNKNOWN,
datafusion::arrow::datatypes::DataType::Float32 => Type::FLOAT4,
datafusion::arrow::datatypes::DataType::Float64 => Type::FLOAT8,
datafusion::arrow::datatypes::DataType::Timestamp(_, _) => Type::UNKNOWN,
datafusion::arrow::datatypes::DataType::Date32 => Type::UNKNOWN,
datafusion::arrow::datatypes::DataType::Date64 => Type::UNKNOWN,
datafusion::arrow::datatypes::DataType::Time32(_) => Type::UNKNOWN,
datafusion::arrow::datatypes::DataType::Time64(_) => Type::UNKNOWN,
datafusion::arrow::datatypes::DataType::Duration(_) => Type::UNKNOWN,
datafusion::arrow::datatypes::DataType::Interval(_) => Type::UNKNOWN,
datafusion::arrow::datatypes::DataType::Binary => Type::BYTEA,
datafusion::arrow::datatypes::DataType::FixedSizeBinary(_) => Type::UNKNOWN,
datafusion::arrow::datatypes::DataType::LargeBinary => Type::BYTEA,
datafusion::arrow::datatypes::DataType::BinaryView => Type::UNKNOWN,
datafusion::arrow::datatypes::DataType::Utf8 => Type::TEXT,
datafusion::arrow::datatypes::DataType::LargeUtf8 => Type::TEXT,
datafusion::arrow::datatypes::DataType::Utf8View => Type::TEXT,
datafusion::arrow::datatypes::DataType::List(_) => Type::UNKNOWN,
datafusion::arrow::datatypes::DataType::ListView(_) => Type::UNKNOWN,
datafusion::arrow::datatypes::DataType::FixedSizeList(_, _) => {
Type::UNKNOWN
}
datafusion::arrow::datatypes::DataType::LargeList(_) => Type::UNKNOWN,
datafusion::arrow::datatypes::DataType::LargeListView(_) => Type::UNKNOWN,
datafusion::arrow::datatypes::DataType::Struct(_) => Type::UNKNOWN,
datafusion::arrow::datatypes::DataType::Union(_, _) => Type::UNKNOWN,
datafusion::arrow::datatypes::DataType::Dictionary(_, _) => Type::UNKNOWN,
datafusion::arrow::datatypes::DataType::Decimal128(_, _) => Type::UNKNOWN,
datafusion::arrow::datatypes::DataType::Decimal256(_, _) => Type::UNKNOWN,
datafusion::arrow::datatypes::DataType::Map(_, _) => Type::UNKNOWN,
datafusion::arrow::datatypes::DataType::RunEndEncoded(_, _) => {
Type::UNKNOWN
}
},
FieldFormat::Text,
)
})
.collect::<Vec<_>>();
Arc::new(pg_schema)
}
#[expect(
clippy::unwrap_used,
clippy::unwrap_in_result,
reason = "downcasts inside match should never fail"
)]
fn encode_column(
encoder: &mut DataRowEncoder,
column: &dyn kanto::arrow::array::Array,
row_idx: usize,
) -> Result<(), PgWireError> {
// TODO maybe migrate this into a separate `arrow-postgres` crate?
match column.data_type() {
datafusion::arrow::datatypes::DataType::Boolean => {
let array = column.as_boolean_opt().unwrap();
let value = array.value(row_idx);
encoder.encode_field(&value)?;
}
datafusion::arrow::datatypes::DataType::Int8 => {
let array = column
.as_primitive_opt::<kanto::arrow::datatypes::Int8Type>()
.unwrap();
let value = array.value(row_idx);
encoder.encode_field(&value)?;
}
datafusion::arrow::datatypes::DataType::Int16 => {
let array = column
.as_primitive_opt::<kanto::arrow::datatypes::Int16Type>()
.unwrap();
let value = array.value(row_idx);
encoder.encode_field(&value)?;
}
datafusion::arrow::datatypes::DataType::Int32 => {
let array = column
.as_primitive_opt::<kanto::arrow::datatypes::Int32Type>()
.unwrap();
let value = array.value(row_idx);
encoder.encode_field(&value)?;
}
datafusion::arrow::datatypes::DataType::Int64 => {
let array = column
.as_primitive_opt::<kanto::arrow::datatypes::Int64Type>()
.unwrap();
let value = array.value(row_idx);
encoder.encode_field(&value)?;
}
datafusion::arrow::datatypes::DataType::Float32 => {
let array = column
.as_primitive_opt::<kanto::arrow::datatypes::Float32Type>()
.unwrap();
let value = array.value(row_idx);
encoder.encode_field(&value)?;
}
datafusion::arrow::datatypes::DataType::Float64 => {
let array = column
.as_primitive_opt::<kanto::arrow::datatypes::Float64Type>()
.unwrap();
let value = array.value(row_idx);
encoder.encode_field(&value)?;
}
datafusion::arrow::datatypes::DataType::Binary => {
let array = column.as_binary_opt::<i32>().unwrap();
let value = array.value(row_idx);
encoder.encode_field(&value)?;
}
datafusion::arrow::datatypes::DataType::LargeBinary => {
let array = column.as_binary_opt::<i64>().unwrap();
let value = array.value(row_idx);
encoder.encode_field(&value)?;
}
datafusion::arrow::datatypes::DataType::BinaryView => {
let array = column.as_binary_view_opt().unwrap();
let value = array.value(row_idx);
encoder.encode_field(&value)?;
}
datafusion::arrow::datatypes::DataType::Utf8 => {
let array = column.as_string_opt::<i32>().unwrap();
let value = array.value(row_idx);
encoder.encode_field(&value)?;
}
datafusion::arrow::datatypes::DataType::LargeUtf8 => {
let array = column.as_string_opt::<i64>().unwrap();
let value = array.value(row_idx);
encoder.encode_field(&value)?;
}
datafusion::arrow::datatypes::DataType::Utf8View => {
let array = column.as_string_view_opt().unwrap();
let value = array.value(row_idx);
encoder.encode_field(&value)?;
}
_ => {
let error = PgWireError::UserError(Box::new(pgwire::error::ErrorInfo::new(
"ERROR".to_owned(),
// <https://www.postgresql.org/docs/current/errcodes-appendix.html>
"42704".to_owned(),
"unimplemented data type in pg wire protocol".to_owned(),
)));
return Err(error);
}
}
Ok(())
}
}
#[async_trait]
impl pgwire::api::query::SimpleQueryHandler for SimpleQueryHandler {
#[maybe_tracing::instrument(skip(self, _client), err)]
async fn do_query<'a, C>(
&self,
_client: &mut C,
query: &'a str,
) -> PgWireResult<Vec<Response<'a>>>
where
C: ClientInfo + Sink<PgWireBackendMessage> + Unpin + Send + Sync,
C::Error: Debug,
PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
{
let mut session = self.session.lock().await;
let stream = session.sql(query).await.map_err(|error| {
PgWireError::UserError(Box::new(pgwire::error::ErrorInfo::new(
"error".to_owned(),
// TODO more fine-grained translation to error codes
"58000".to_owned(),
error.to_string(),
)))
})?;
// TODO differentiate `Response::Query` and `Response::Execution`
let pg_schema = Self::make_pg_schema(&stream.schema());
let data_row_stream = stream.flat_map({
let schema_ref = pg_schema.clone();
move |result| match result {
Err(error) => {
let error = PgWireError::UserError(Box::new(pgwire::error::ErrorInfo::new(
"error".to_owned(),
"foo".to_owned(),
error.to_string(),
)));
stream::once(Err(error)).boxed()
}
Ok(batch) => {
let num_rows = batch.num_rows();
let columns = Arc::new(Vec::from(batch.columns()));
let iter = (0..num_rows).map({
let schema_ref = schema_ref.clone();
let columns = columns.clone();
move |row_idx| {
let mut encoder = DataRowEncoder::new(schema_ref.clone());
for column in columns.iter() {
Self::encode_column(&mut encoder, column, row_idx)?;
}
encoder.finish()
}
});
stream::iter(iter).boxed()
}
}
});
Ok(vec![Response::Query(QueryResponse::new(
pg_schema,
data_row_stream,
))])
}
}
struct ErrorHandler {
// TODO log with some information about which session encountered the problem
#[expect(dead_code)]
session: Arc<tokio::sync::Mutex<kanto::Session>>,
}
impl pgwire::api::ErrorHandler for ErrorHandler {
fn on_error<C>(&self, _client: &C, error: &mut PgWireError)
where
C: ClientInfo,
{
tracing::error!(?error, "pgwire client error");
}
}

View file

@ -0,0 +1,29 @@
[package]
name = "kanto-record-format-v1"
version = "0.1.0"
description = "Low-level data format helper for KantoDB database"
homepage.workspace = true
repository.workspace = true
authors.workspace = true
license.workspace = true
publish = false # TODO publish crate #severity/high #urgency/medium
edition.workspace = true
rust-version.workspace = true
[package.metadata.cargo-shear]
ignored = ["tracing", "test-log"]
[dependencies]
datafusion = { workspace = true }
kanto-meta-format-v1 = { workspace = true }
kanto-tunables = { workspace = true }
maybe-tracing = { workspace = true }
thiserror = { workspace = true }
tracing = { workspace = true }
[dev-dependencies]
kanto-testutil = { workspace = true }
test-log = { workspace = true }
[lints]
workspace = true

View file

@ -0,0 +1,4 @@
/target/
/corpus/
/artifacts/
/coverage/

View file

@ -0,0 +1,40 @@
[package]
name = "kanto-record-format-v1-fuzz"
version = "0.0.0"
repository.workspace = true
authors.workspace = true
license.workspace = true
publish = false
edition.workspace = true
rust-version.workspace = true
[package.metadata]
cargo-fuzz = true
[dependencies]
arbitrary-arrow = { workspace = true }
datafusion = { workspace = true }
kanto-meta-format-v1 = { workspace = true }
kanto-record-format-v1 = { workspace = true }
libfuzzer-sys = { workspace = true }
[[bin]]
name = "build_slots"
path = "fuzz_targets/fuzz_build_slots.rs"
test = false
doc = false
bench = false
[[bin]]
name = "batch_to_rows"
path = "fuzz_targets/fuzz_batch_to_rows.rs"
test = false
doc = false
bench = false
[[bin]]
name = "roundtrip"
path = "fuzz_targets/fuzz_roundtrip.rs"
test = false
doc = false
bench = false

View file

@ -0,0 +1,48 @@
#![no_main]
use arbitrary_arrow::arbitrary;
use datafusion::arrow::array::RecordBatch;
use libfuzzer_sys::fuzz_target;
#[derive(Debug)]
struct TableAndRecordBatch {
table_def: kanto_meta_format_v1::table_def::TableDef,
record_batch: RecordBatch,
}
impl<'a> arbitrary::Arbitrary<'a> for TableAndRecordBatch {
fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
let table_def: kanto_meta_format_v1::table_def::TableDef = u.arbitrary()?;
let owned_archived_table_def = kanto_meta_format_v1::owned_archived(&table_def).unwrap();
let record_batch = {
let schema = kanto_meta_format_v1::table_def::make_arrow_schema(
owned_archived_table_def.fields_live_in_latest(),
);
arbitrary_arrow::record_batch_with_schema(u, schema)?
};
Ok(TableAndRecordBatch {
table_def,
record_batch,
})
}
}
#[derive(Debug, arbitrary::Arbitrary)]
struct Input {
table_and_record_batch: TableAndRecordBatch,
}
fuzz_target!(|input: Input| {
let Input {
table_and_record_batch:
TableAndRecordBatch {
table_def,
record_batch,
},
} = input;
let owned_archived_table_def = kanto_meta_format_v1::owned_archived(&table_def).unwrap();
let builder =
kanto_record_format_v1::RecordBuilder::new_from_table_def(&owned_archived_table_def)
.unwrap();
let _records = builder.build_records(&record_batch).unwrap();
});

View file

@ -0,0 +1,22 @@
#![no_main]
use arbitrary_arrow::arbitrary;
use libfuzzer_sys::fuzz_target;
#[derive(arbitrary::Arbitrary, Debug)]
struct Input {
field_defs: Box<[kanto_meta_format_v1::table_def::FieldDef]>,
}
fuzz_target!(|input: Input| {
let Input { field_defs } = input;
let owned_archived_fields = field_defs
.iter()
.map(|field_def| kanto_meta_format_v1::owned_archived(field_def).unwrap())
.collect::<Box<[_]>>();
let archived_fields = owned_archived_fields
.iter()
.map(std::ops::Deref::deref)
.collect::<Vec<_>>();
let _slots = kanto_record_format_v1::only_for_fuzzing::build_slots(&archived_fields).unwrap();
});

View file

@ -0,0 +1,74 @@
#![no_main]
use std::sync::Arc;
use arbitrary_arrow::arbitrary;
use datafusion::arrow::array::RecordBatch;
use libfuzzer_sys::fuzz_target;
#[derive(Debug)]
struct TableAndRecordBatch {
table_def: kanto_meta_format_v1::table_def::TableDef,
record_batch: RecordBatch,
}
impl<'a> arbitrary::Arbitrary<'a> for TableAndRecordBatch {
fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
let table_def: kanto_meta_format_v1::table_def::TableDef = u.arbitrary()?;
let owned_archived_table_def = kanto_meta_format_v1::owned_archived(&table_def).unwrap();
let record_batch = {
let schema = kanto_meta_format_v1::table_def::make_arrow_schema(
owned_archived_table_def.fields_live_in_latest(),
);
arbitrary_arrow::record_batch_with_schema(u, schema)?
};
Ok(TableAndRecordBatch {
table_def,
record_batch,
})
}
}
#[derive(Debug, arbitrary::Arbitrary)]
struct Input {
table_and_record_batch: TableAndRecordBatch,
}
fuzz_target!(|input: Input| {
let Input {
table_and_record_batch:
TableAndRecordBatch {
table_def,
record_batch,
},
} = input;
let owned_archived_table_def = kanto_meta_format_v1::owned_archived(&table_def).unwrap();
// batch_to_rows
let stored = {
let builder =
kanto_record_format_v1::RecordBuilder::new_from_table_def(&owned_archived_table_def)
.unwrap();
builder.build_records(&record_batch).unwrap()
};
// rows_to_batch
let got = {
let wanted_field_ids = owned_archived_table_def
.fields_live_in_latest()
.map(|field| field.field_id.to_native())
.collect();
let batch_size_hint = 1_000_000;
let mut reader = kanto_record_format_v1::RecordReader::new_from_table_def(
owned_archived_table_def,
Arc::new(wanted_field_ids),
batch_size_hint,
);
for row in stored {
reader.push_record(&row).unwrap();
}
reader.build().unwrap()
};
assert_eq!(record_batch, got);
});

View file

@ -0,0 +1,80 @@
//! Errors returned from this record parsing and building.
//! The main error type is [`RecordError`].
/// Corruption means system encountered invalid data.
/// Corrupted data may reside on disk, or have been corrupted as part of processing.
#[expect(
clippy::exhaustive_enums,
reason = "low-level package, callers will have to adapt to any change"
)]
#[expect(clippy::module_name_repetitions)]
#[derive(thiserror::Error, Debug)]
pub enum RecordCorruptError {
#[expect(missing_docs)]
#[error("corrupt record field: truncated")]
Truncated,
#[expect(missing_docs)]
#[error("corrupt record field: impossible table definition generation number")]
InvalidTableDefGeneration,
}
/// Errors related to the schema of the data being processed.
///
/// This is a separate type so that [`RecordBuilder::new_from_table_def`](crate::RecordBuilder::new_from_table_def) can return only this kind of error, guaranteeing it cannot suffer from e.g. corrupted data.
#[expect(
clippy::exhaustive_enums,
reason = "low-level package, callers will have to adapt to any change"
)]
#[expect(clippy::module_name_repetitions)]
#[derive(thiserror::Error, Debug)]
pub enum RecordSchemaError {
#[expect(missing_docs)]
#[error("record has too many columns")]
TooLarge,
}
/// Errors related to the input [`RecordBatch`](datafusion::arrow::array::RecordBatch).
#[expect(
clippy::exhaustive_enums,
reason = "low-level package, callers will have to adapt to any change"
)]
#[expect(clippy::module_name_repetitions)]
#[derive(thiserror::Error, Debug)]
pub enum RecordDataError {
#[expect(missing_docs)]
#[error("non-nullable input has null data: column_name={column_name}")]
NotNullable { column_name: String },
}
/// Main error type for this library.
#[expect(
clippy::exhaustive_enums,
reason = "low-level package, callers will have to adapt to any change"
)]
#[expect(clippy::module_name_repetitions)]
#[derive(thiserror::Error, Debug)]
pub enum RecordError {
#[expect(missing_docs)]
#[error("corrupt record: {error}")]
Corrupt { error: RecordCorruptError },
#[expect(missing_docs)]
#[error("invalid schema for record: {error}")]
InvalidSchema { error: RecordSchemaError },
#[expect(missing_docs)]
#[error("invalid data for record: {error}")]
InvalidData { error: RecordDataError },
#[expect(missing_docs)]
#[error("record: internal error: {message}")]
Internal { message: &'static str },
#[expect(missing_docs)]
#[error("record: internal error with arrow: {error}")]
InternalArrow {
#[source]
error: datafusion::arrow::error::ArrowError,
},
}

View file

@ -0,0 +1,22 @@
//! Internal storage formats for KantoDB.
//!
//! Unless you are implementing a `kanto::Backend`, you should **not** be using these directly.
pub mod error;
mod record_builder;
mod record_reader;
mod slot;
pub use record_builder::RecordBuilder;
pub use record_reader::RecordReader;
// Let fuzzing access internal APIs.
// DO NOT USE FOR OTHER PURPOSES.
//
// It seems impossible to modify the visibility of a module based on features, so create a dummy module for this.
//
// I would like to make this conditional, `#[cfg(fuzzing)]`, but that causes `rust-analyzer` to report compilation error.
#[doc(hidden)]
pub mod only_for_fuzzing {
pub use crate::slot::build_slots;
}

View file

@ -0,0 +1,429 @@
use datafusion::arrow::array::AsArray as _;
use datafusion::arrow::array::RecordBatch;
use crate::error::RecordError;
use crate::slot::Slot;
use crate::slot::Slots;
trait VariableLenData: datafusion::arrow::array::Array {
fn value(&self, row_idx: usize) -> &[u8];
}
impl VariableLenData for datafusion::arrow::array::StringArray {
fn value(&self, row_idx: usize) -> &[u8] {
self.value(row_idx).as_bytes()
}
}
impl VariableLenData for datafusion::arrow::array::LargeStringArray {
fn value(&self, row_idx: usize) -> &[u8] {
self.value(row_idx).as_bytes()
}
}
impl VariableLenData for datafusion::arrow::array::StringViewArray {
fn value(&self, row_idx: usize) -> &[u8] {
self.value(row_idx).as_bytes()
}
}
impl VariableLenData for datafusion::arrow::array::BinaryArray {
fn value(&self, row_idx: usize) -> &[u8] {
self.value(row_idx)
}
}
impl VariableLenData for datafusion::arrow::array::LargeBinaryArray {
fn value(&self, row_idx: usize) -> &[u8] {
self.value(row_idx)
}
}
impl VariableLenData for datafusion::arrow::array::BinaryViewArray {
fn value(&self, row_idx: usize) -> &[u8] {
self.value(row_idx)
}
}
/// Serialize database record values from [`RecordBatch`]es.
pub struct RecordBuilder {
latest_gen: kanto_meta_format_v1::table_def::TableDefGeneration,
min_record_size: usize,
slots: Box<[Slot]>,
}
impl RecordBuilder {
/// Start producing records from record batches for the given table.
pub fn new_from_table_def(
table_def: &kanto_meta_format_v1::table_def::ArchivedTableDef,
) -> Result<RecordBuilder, crate::error::RecordSchemaError> {
let latest_gen = table_def.generation.to_native();
let slots = {
let fields = table_def.fields_live_in_latest().collect::<Vec<_>>();
crate::slot::build_slots(&fields)?
};
let Slots {
min_record_size,
slots,
} = slots;
let builder = RecordBuilder {
latest_gen,
min_record_size,
slots,
};
Ok(builder)
}
fn copy_bitmap(
&self,
iter: impl Iterator<Item = usize>,
array: &datafusion::arrow::array::BooleanArray,
records: &mut [Vec<u8>],
slot: &Slot,
) -> Result<(), RecordError> {
for idx in iter {
debug_assert!(idx < records.len());
let arrow_bytes = if array.value(idx) { &[1u8] } else { &[0u8] };
let record = records.get_mut(idx).ok_or(RecordError::Internal {
message: "record vector out of bounds",
})?;
debug_assert!(record.len() >= self.min_record_size);
slot.set_valid(record)?;
slot.set_arrow_bytes(record, arrow_bytes)?;
}
Ok(())
}
fn copy_fixed_size_values(
&self,
iter: impl Iterator<Item = usize>,
buffer: &[u8],
records: &mut [Vec<u8>],
slot: &Slot,
) -> Result<(), RecordError> {
for idx in iter {
debug_assert!(idx < records.len());
let width = slot.width();
let source_start = idx.saturating_mul(width);
let source_end = source_start.saturating_add(width);
let arrow_bytes = buffer
.get(source_start..source_end)
.ok_or(RecordError::Corrupt {
error: crate::error::RecordCorruptError::Truncated,
})?;
let record = records.get_mut(idx).ok_or(RecordError::Internal {
message: "record vector out of bounds",
})?;
debug_assert!(record.len() >= self.min_record_size);
slot.set_valid(record)?;
slot.set_arrow_bytes(record, arrow_bytes)?;
}
Ok(())
}
fn copy_variable_size_values(
&self,
iter: impl Iterator<Item = usize>,
data: &impl VariableLenData,
records: &mut [Vec<u8>],
slot: &Slot,
) -> Result<(), RecordError> {
for idx in iter {
debug_assert!(idx < records.len());
let record = records.get_mut(idx).ok_or(RecordError::Internal {
message: "record vector out of bounds",
})?;
debug_assert!(record.len() >= self.min_record_size);
slot.set_valid(record)?;
let data = data.value(idx);
slot.set_indirect_data(record, data)?;
}
Ok(())
}
fn copy_variable_size(
&self,
array: &impl VariableLenData,
records: &mut [Vec<u8>],
slot: &Slot,
) -> Result<(), RecordError> {
// TODO beware physical nulls being in unintuitive formats
match array.nulls() {
Some(nulls) => {
self.copy_variable_size_values(nulls.valid_indices(), array, records, slot)
}
None => self.copy_variable_size_values(0..array.len(), array, records, slot),
}
}
/// Build database record values for rows in `record_batch`.
pub fn build_records(&self, record_batch: &RecordBatch) -> Result<Box<[Vec<u8>]>, RecordError> {
debug_assert_eq!(self.slots.len(), record_batch.num_columns());
let mut records = (0..record_batch.num_rows())
.map(|_| {
let mut v = Vec::with_capacity(kanto_tunables::TYPICAL_MAX_RECORD_SIZE);
v.extend_from_slice(&self.latest_gen.get_u64().to_le_bytes());
debug_assert!(
self.min_record_size
>= size_of::<kanto_meta_format_v1::table_def::TableDefGeneration>()
);
v.resize(self.min_record_size, 0u8);
v
})
.collect::<Box<[Vec<u8>]>>();
for (col_idx, (slot, column)) in self.slots.iter().zip(record_batch.columns()).enumerate() {
if !slot.is_nullable() && column.null_count() > 0 {
let df_field = record_batch.schema_ref().field(col_idx);
let error = RecordError::InvalidData {
error: crate::error::RecordDataError::NotNullable {
column_name: df_field.name().to_owned(),
},
};
return Err(error);
}
match column.data_type() {
datafusion::arrow::datatypes::DataType::BinaryView => {
debug_assert!(slot.is_indirect());
let array = column.as_binary_view_opt().ok_or(RecordError::Internal {
message: "could not convert array into binary view",
})?;
self.copy_variable_size(array, &mut records, slot)?;
}
datafusion::arrow::datatypes::DataType::Utf8View => {
debug_assert!(slot.is_indirect());
let array = column.as_string_view_opt().ok_or(RecordError::Internal {
message: "could not convert array into string view",
})?;
self.copy_variable_size(array, &mut records, slot)?;
}
datafusion::arrow::datatypes::DataType::Boolean => {
debug_assert!(!slot.is_indirect());
let boolean_array = column.as_boolean_opt().ok_or(RecordError::Internal {
message: "expected a boolean array",
})?;
// TODO ugly duplication
// TODO beware physical nulls being in unintuitive formats
if let Some(nulls) = column.nulls() {
self.copy_bitmap(nulls.valid_indices(), boolean_array, &mut records, slot)?;
} else {
self.copy_bitmap(0..column.len(), boolean_array, &mut records, slot)?;
}
}
_ => {
debug_assert!(!slot.is_indirect());
let array_data = column.to_data();
let [buffer]: &[_; 1] = array_data.buffers().try_into().map_err(|_error| {
RecordError::Internal {
message: "should have exactly one buffer",
}
})?;
debug_assert_eq!(buffer.len(), column.len().saturating_mul(slot.width()));
// TODO beware physical nulls being in unintuitive formats
if let Some(nulls) = column.nulls() {
self.copy_fixed_size_values(
nulls.valid_indices(),
buffer.as_slice(),
&mut records,
slot,
)?;
} else {
self.copy_fixed_size_values(
0..column.len(),
buffer.as_slice(),
&mut records,
slot,
)?;
}
}
}
// TODO maybe a trait GetArrowBytes?
// except RecordBatch munges everything into dyn Array so that doesn't really help.
// just a utility function then
//
// TODO can we make this vector-directional? one slot for each, then next slot for each; that means returning vec/slice etc not arrow BinaryArray, as they can't dense pack on the fly
// slot.set_arrow_bytes(todo_record);
}
debug_assert_eq!(records.len(), record_batch.num_rows());
Ok(records)
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use datafusion::arrow::array::ArrayRef;
use datafusion::arrow::array::UInt32Array;
use datafusion::arrow::datatypes::DataType;
use datafusion::arrow::datatypes::Field;
use datafusion::arrow::datatypes::Schema;
use kanto_meta_format_v1::table_def::TableDefGeneration;
use super::*;
#[maybe_tracing::test]
fn simple() {
let fields = [kanto_meta_format_v1::table_def::FieldDef {
field_id: kanto_meta_format_v1::FieldId::new(2).unwrap(),
added_in: TableDefGeneration::new(7).unwrap(),
deleted_in: None,
field_name: "a".to_owned(),
field_type: kanto_meta_format_v1::field_type::FieldType::U32,
nullable: false,
default_value_arrow_bytes: None,
}];
let table_def = kanto_meta_format_v1::table_def::TableDef {
table_id: kanto_meta_format_v1::TableId::new(42).unwrap(),
catalog: "mycatalog".to_owned(),
schema: "myschema".to_owned(),
table_name: "mytable".to_owned(),
generation: TableDefGeneration::new(13).unwrap(),
row_key_kind: kanto_meta_format_v1::table_def::RowKeyKind::PrimaryKeys(
kanto_meta_format_v1::table_def::PrimaryKeysDef {
field_ids: kanto_meta_format_v1::table_def::PrimaryKeys::try_from(vec![
fields[0].field_id,
])
.unwrap(),
},
),
fields: fields
.into_iter()
.map(|field| (field.field_id, field))
.collect(),
};
let owned_archived_table_def = kanto_meta_format_v1::owned_archived(&table_def).unwrap();
let builder = RecordBuilder::new_from_table_def(&owned_archived_table_def).unwrap();
let input = {
let schema = Schema::new(vec![Field::new("a", DataType::UInt32, false)]);
let array_a: ArrayRef = Arc::new(UInt32Array::from(vec![3, 7]));
RecordBatch::try_new(Arc::new(schema), vec![array_a]).unwrap()
};
let records = builder.build_records(&input).unwrap();
assert_eq!(records.len(), 2);
// TODO `clippy::missing_asserts_for_indexing` does not recognize `assert_eq` <https://github.com/rust-lang/rust-clippy/issues/14255> #waiting #ecosystem/rust/clippy
assert!(records.len() == 2);
assert_eq!(
records[0],
[
// generation
13,
0,
0,
0,
0,
0,
0,
0,
// validity
0b0000_0001,
// a
3,
0,
0,
0,
]
);
assert_eq!(
records[1],
[
// generation
13,
0,
0,
0,
0,
0,
0,
0,
// validity
0b0000_0001,
// a
7,
0,
0,
0,
]
);
}
#[maybe_tracing::test]
fn string() {
let fields = [kanto_meta_format_v1::table_def::FieldDef {
field_id: kanto_meta_format_v1::FieldId::new(2).unwrap(),
added_in: TableDefGeneration::new(7).unwrap(),
deleted_in: None,
field_name: "a".to_owned(),
field_type: kanto_meta_format_v1::field_type::FieldType::String,
nullable: false,
default_value_arrow_bytes: None,
}];
let table_def = kanto_meta_format_v1::table_def::TableDef {
table_id: kanto_meta_format_v1::TableId::new(42).unwrap(),
catalog: "mycatalog".to_owned(),
schema: "myschema".to_owned(),
table_name: "mytable".to_owned(),
generation: TableDefGeneration::new(13).unwrap(),
row_key_kind: kanto_meta_format_v1::table_def::RowKeyKind::PrimaryKeys(
kanto_meta_format_v1::table_def::PrimaryKeysDef {
field_ids: kanto_meta_format_v1::table_def::PrimaryKeys::try_from(vec![
fields[0].field_id,
])
.unwrap(),
},
),
fields: fields
.into_iter()
.map(|field| (field.field_id, field))
.collect(),
};
let owned_archived_table_def = kanto_meta_format_v1::owned_archived(&table_def).unwrap();
let builder = RecordBuilder::new_from_table_def(&owned_archived_table_def).unwrap();
let input = {
let schema = Schema::new(vec![Field::new("a", DataType::Utf8View, false)]);
let array_a: ArrayRef = Arc::new(
datafusion::arrow::array::StringViewArray::from_iter_values(["foo"]),
);
RecordBatch::try_new(Arc::new(schema), vec![array_a]).unwrap()
};
let records = builder.build_records(&input).unwrap();
assert_eq!(records.len(), 1);
assert_eq!(
records[0],
[
// generation
13,
0,
0,
0,
0,
0,
0,
0,
// validity
0b0000_0001,
// a: pointer to string
13,
0,
0,
0,
// string length
3,
0,
0,
0,
// string content
b'f',
b'o',
b'o',
]
);
}
}

View file

@ -0,0 +1,422 @@
use std::borrow::Cow;
use std::collections::HashMap;
use std::collections::HashSet;
use std::ops::Deref;
use std::sync::Arc;
use datafusion::arrow::array::RecordBatchOptions;
use datafusion::arrow::record_batch::RecordBatch;
mod array_builder;
mod column_value;
use self::array_builder::ArrayBuilder;
use self::column_value::ColumnValue;
use crate::error::RecordError;
use crate::slot::Slots;
/// Parse database record values into [`RecordBatch`].
pub struct RecordReader<D>
where
D: Deref<Target = kanto_meta_format_v1::table_def::ArchivedTableDef>,
{
table_def: D,
arrow_schema: datafusion::arrow::datatypes::SchemaRef,
/// Which fields to produce.
wanted_field_ids: Arc<HashSet<kanto_meta_format_v1::FieldId>>,
/// Record decoder per generation of table definition, populated as needed.
/// Values are same length as `wanted_fields`, in that order.
versioned: HashMap<kanto_meta_format_v1::table_def::TableDefGeneration, Box<[ColumnValue]>>,
/// `RecordBatch` column builders.
/// The version-specific arrays of `ColumnValue` feed to this shared array of `ArrayBuilder`.
/// Same length as `wanted_fields`, in that order.
array_builders: Box<[ArrayBuilder]>,
/// How many rows were seen.
/// Needed in case no columns were requested.
row_count: usize,
}
impl<D> RecordReader<D>
where
D: Deref<Target = kanto_meta_format_v1::table_def::ArchivedTableDef>,
{
/// Start producing record batches from records for the given table.
///
/// `wanted_field_ids` sets which fields to produce.
///
/// `batch_size_hint` is used to optimize allocations.
#[must_use]
pub fn new_from_table_def(
table_def: D,
wanted_field_ids: Arc<HashSet<kanto_meta_format_v1::FieldId>>,
batch_size_hint: usize,
) -> Self {
debug_assert!(
wanted_field_ids.iter().all(|wanted_field_id| table_def
.fields
.get_with(wanted_field_id, |a, b| a == &b.to_native())
.is_some()),
"`wanted_field_ids` must be in `table_def.fields`",
);
let arrow_schema = kanto_meta_format_v1::table_def::make_arrow_schema(
table_def
.fields
.values()
.filter(|field| wanted_field_ids.contains(&field.field_id.to_native())),
);
let array_builders = table_def
.fields
.values()
.filter(|field| wanted_field_ids.contains(&field.field_id.to_native()))
.map(|field| ArrayBuilder::new(&field.field_type, batch_size_hint))
.collect::<Box<[_]>>();
RecordReader {
table_def,
arrow_schema,
wanted_field_ids,
array_builders,
versioned: HashMap::new(),
row_count: 0,
}
}
/// How many records have been collected in the batch.
#[must_use]
pub const fn num_records(&self) -> usize {
self.row_count
}
#[maybe_tracing::instrument(skip(table_def), err)]
fn make_versioned_column_values(
generation: kanto_meta_format_v1::table_def::TableDefGeneration,
table_def: &kanto_meta_format_v1::table_def::ArchivedTableDef,
wanted_field_ids: &HashSet<kanto_meta_format_v1::FieldId>,
) -> Result<Box<[ColumnValue]>, RecordError> {
// TODO clean up this logic, i don't think we really need this many lookups for something that's fundamentally linear
let fields_in_record = table_def.fields_live_in_gen(generation).collect::<Vec<_>>();
let Slots {
min_record_size: _,
slots,
} = crate::slot::build_slots(&fields_in_record).map_err(|record_schema_error| {
RecordError::InvalidSchema {
error: record_schema_error,
}
})?;
let slot_by_field_id = fields_in_record
.iter()
.map(|field| field.field_id.to_native())
.zip(slots)
.collect::<HashMap<_, _>>();
let column_values: Box<[ColumnValue]> = table_def
.fields
.values()
.filter(|field| wanted_field_ids.contains(&field.field_id.to_native()))
.map(|field| {
let field_id = field.field_id.to_native();
if let Some(slot) = slot_by_field_id.get(&field_id) {
Ok(ColumnValue::Slot(*slot))
} else {
// Default, `NULL`, or error
if let Some(arrow_bytes) = field.default_value_arrow_bytes.as_ref() {
// TODO this copy sucks, can we make ColumnValue borrow from `'defs`
let arrow_bytes = Cow::from(arrow_bytes.as_ref().to_owned());
Ok(ColumnValue::DefaultConstant {
is_valid: true,
arrow_bytes,
})
} else if field.nullable {
let arrow_bytes =
if let Some(fixed_width) = field.field_type.fixed_arrow_bytes_width() {
let zero_bytes = std::iter::repeat(0u8).take(fixed_width).collect();
Cow::Owned(zero_bytes)
} else {
const EMPTY: &[u8] = &[];
Cow::Borrowed(EMPTY)
};
Ok(ColumnValue::DefaultConstant {
is_valid: false,
arrow_bytes,
})
} else {
let error = RecordError::Internal {
message: "field has become required without updating records",
};
Err(error)
}
}
})
.collect::<Result<_, _>>()?;
Ok(column_values)
}
/// Add a database record to the batch.
pub fn push_record(&mut self, record: &[u8]) -> Result<(), RecordError> {
const GEN_SIZE: usize = size_of::<u64>();
let (gen_bytes, _) =
record
.split_first_chunk::<GEN_SIZE>()
.ok_or(RecordError::Corrupt {
error: crate::error::RecordCorruptError::Truncated,
})?;
let gen_num = u64::from_le_bytes(*gen_bytes);
let generation = kanto_meta_format_v1::table_def::TableDefGeneration::new(gen_num).ok_or(
RecordError::Corrupt {
error: crate::error::RecordCorruptError::InvalidTableDefGeneration,
},
)?;
let column_values = {
// TODO awkward workaround for `.entry(gen).or_insert_with` not coping with fallibility
match self.versioned.entry(generation) {
std::collections::hash_map::Entry::Occupied(occupied_entry) => occupied_entry,
std::collections::hash_map::Entry::Vacant(vacant_entry) => vacant_entry
.insert_entry(Self::make_versioned_column_values(
generation,
&self.table_def,
&self.wanted_field_ids,
)?),
}
.into_mut()
};
debug_assert_eq!(column_values.len(), self.array_builders.len());
for (column_value, builder) in column_values.iter_mut().zip(self.array_builders.iter_mut())
{
if column_value.is_valid(record)? {
let arrow_bytes = column_value.get_arrow_bytes(record)?;
builder.append_from_arrow_bytes(arrow_bytes);
} else {
builder.append_null();
}
}
self.row_count = self.row_count.checked_add(1).ok_or(RecordError::Internal {
message: "row count over 4G",
})?;
Ok(())
}
/// Build a [`RecordBatch`] with the records collected so far.
///
/// After this, the [`RecordReader`] is ready for a new batch.
pub fn build(&mut self) -> Result<RecordBatch, RecordError> {
// This does not consume `self`, to enable reuse of `array_builders`.
let column_arrays = self
.array_builders
.iter_mut()
.map(ArrayBuilder::build)
.collect::<Result<Vec<datafusion::arrow::array::ArrayRef>, _>>()?;
let options = RecordBatchOptions::new().with_row_count(Some(self.row_count));
self.row_count = 0;
let record_batch =
RecordBatch::try_new_with_options(self.arrow_schema.clone(), column_arrays, &options)
.map_err(|arrow_error| RecordError::InternalArrow { error: arrow_error })?;
Ok(record_batch)
}
}
#[cfg(test)]
mod tests {
use kanto_meta_format_v1::table_def::TableDefGeneration;
use kanto_testutil::assert_batches_eq;
use super::*;
#[maybe_tracing::test]
fn simple() {
let fields = [
kanto_meta_format_v1::table_def::FieldDef {
field_id: kanto_meta_format_v1::FieldId::new(2).unwrap(),
added_in: TableDefGeneration::new(7).unwrap(),
deleted_in: None,
field_name: "a".to_owned(),
field_type: kanto_meta_format_v1::field_type::FieldType::U16,
nullable: false,
default_value_arrow_bytes: None,
},
kanto_meta_format_v1::table_def::FieldDef {
field_id: kanto_meta_format_v1::FieldId::new(3).unwrap(),
added_in: TableDefGeneration::new(8).unwrap(),
deleted_in: None,
field_name: "c".to_owned(),
field_type: kanto_meta_format_v1::field_type::FieldType::U32,
nullable: false,
default_value_arrow_bytes: None,
},
];
let wanted_field_ids = [fields[1].field_id, fields[0].field_id]
.into_iter()
.collect::<HashSet<_>>();
let table_def = kanto_meta_format_v1::table_def::TableDef {
table_id: kanto_meta_format_v1::TableId::new(42).unwrap(),
catalog: "mycatalog".to_owned(),
schema: "myschema".to_owned(),
table_name: "mytable".to_owned(),
generation: TableDefGeneration::new(13).unwrap(),
row_key_kind: kanto_meta_format_v1::table_def::RowKeyKind::PrimaryKeys(
kanto_meta_format_v1::table_def::PrimaryKeysDef {
field_ids: kanto_meta_format_v1::table_def::PrimaryKeys::try_from(vec![
fields[0].field_id,
])
.unwrap(),
},
),
fields: fields
.into_iter()
.map(|field| (field.field_id, field))
.collect(),
};
let owned_archived_table_def = kanto_meta_format_v1::owned_archived(&table_def).unwrap();
let batch_size_hint = 10usize;
let mut reader = RecordReader::new_from_table_def(
owned_archived_table_def,
Arc::new(wanted_field_ids),
batch_size_hint,
);
reader
.push_record(&[
// generation
9,
0,
0,
0,
0,
0,
0,
0,
// validity
0b0000_0011,
// u16
3,
0,
// u32
7,
0,
0,
0,
])
.unwrap();
let batch = reader.build().unwrap();
assert_batches_eq!(
r"
+---+---+
| a | c |
+---+---+
| 3 | 7 |
+---+---+
",
&[batch],
);
}
#[maybe_tracing::test]
fn string() {
let fields = [
kanto_meta_format_v1::table_def::FieldDef {
field_id: kanto_meta_format_v1::FieldId::new(2).unwrap(),
added_in: TableDefGeneration::new(7).unwrap(),
deleted_in: None,
field_name: "a".to_owned(),
field_type: kanto_meta_format_v1::field_type::FieldType::U16,
nullable: true,
default_value_arrow_bytes: None,
},
kanto_meta_format_v1::table_def::FieldDef {
field_id: kanto_meta_format_v1::FieldId::new(3).unwrap(),
added_in: TableDefGeneration::new(8).unwrap(),
deleted_in: None,
field_name: "b".to_owned(),
field_type: kanto_meta_format_v1::field_type::FieldType::String,
nullable: false,
default_value_arrow_bytes: None,
},
kanto_meta_format_v1::table_def::FieldDef {
field_id: kanto_meta_format_v1::FieldId::new(5).unwrap(),
added_in: TableDefGeneration::new(8).unwrap(),
deleted_in: None,
field_name: "c".to_owned(),
field_type: kanto_meta_format_v1::field_type::FieldType::U32,
nullable: false,
default_value_arrow_bytes: None,
},
];
let wanted_field_ids = [fields[1].field_id, fields[0].field_id, fields[2].field_id]
.into_iter()
.collect::<HashSet<_>>();
let wanted_field_ids = Arc::new(wanted_field_ids);
let table_def = kanto_meta_format_v1::table_def::TableDef {
table_id: kanto_meta_format_v1::TableId::new(42).unwrap(),
catalog: "mycatalog".to_owned(),
schema: "myschema".to_owned(),
table_name: "mytable".to_owned(),
generation: TableDefGeneration::new(13).unwrap(),
row_key_kind: kanto_meta_format_v1::table_def::RowKeyKind::PrimaryKeys(
kanto_meta_format_v1::table_def::PrimaryKeysDef {
field_ids: kanto_meta_format_v1::table_def::PrimaryKeys::try_from(vec![
fields[0].field_id,
])
.unwrap(),
},
),
fields: fields
.into_iter()
.map(|field| (field.field_id, field))
.collect(),
};
let owned_archived_table_def = kanto_meta_format_v1::owned_archived(&table_def).unwrap();
let batch_size_hint = 10usize;
let mut reader = RecordReader::new_from_table_def(
owned_archived_table_def,
wanted_field_ids,
batch_size_hint,
);
reader
.push_record(&[
// generation
9,
0,
0,
0,
0,
0,
0,
0,
// validity
0b0000_0111,
// u16
4,
0,
// pointer to string
19,
0,
0,
0,
// u32
13,
0,
0,
0,
// string length
3,
0,
0,
0,
// string content
b'f',
b'o',
b'o',
])
.unwrap();
let batch = reader.build().unwrap();
assert_batches_eq!(
r"
+---+-----+----+
| a | b | c |
+---+-----+----+
| 4 | foo | 13 |
+---+-----+----+
",
&[batch],
);
}
}

View file

@ -0,0 +1,216 @@
use std::sync::Arc;
use datafusion::arrow::array::Array;
use crate::error::RecordError;
pub(crate) enum Utf8OrBinary {
Binary,
Utf8,
}
pub(crate) enum ArrayBuilder {
FixedSize {
data_type: datafusion::arrow::datatypes::DataType,
width: usize,
data: datafusion::arrow::buffer::MutableBuffer,
validity: Option<datafusion::arrow::array::BooleanBufferBuilder>,
},
View {
// Arrow does not give us a nice way to access `BinaryArray`/`StringArray` regardless of which one it is.
// Attempting to access a `StringArray` as `BinaryArray` explicitly panics.
// Always construct the array as binary, and convert to string at the end if needed.
utf8_or_binary: Utf8OrBinary,
builder: datafusion::arrow::array::BinaryViewBuilder,
},
Boolean {
validity: datafusion::arrow::array::NullBufferBuilder,
values: datafusion::arrow::array::BooleanBufferBuilder,
},
}
impl ArrayBuilder {
pub(crate) fn new(
field_type: &kanto_meta_format_v1::field_type::ArchivedFieldType,
batch_size_hint: usize,
) -> ArrayBuilder {
use kanto_meta_format_v1::field_type::ArchivedFieldType as FT;
match field_type {
FT::U64 | FT::I64 | FT::F64 => Self::new_fixed_size(field_type, batch_size_hint, 8),
FT::U32 | FT::I32 | FT::F32 => Self::new_fixed_size(field_type, batch_size_hint, 4),
FT::U16 | FT::I16 => Self::new_fixed_size(field_type, batch_size_hint, 2),
FT::U8 | FT::I8 => Self::new_fixed_size(field_type, batch_size_hint, 1),
FT::String => Self::new_view(Utf8OrBinary::Utf8, batch_size_hint),
FT::Binary => Self::new_view(Utf8OrBinary::Binary, batch_size_hint),
FT::Boolean => {
let nulls = datafusion::arrow::array::NullBufferBuilder::new(batch_size_hint);
let values = datafusion::arrow::array::BooleanBufferBuilder::new(batch_size_hint);
ArrayBuilder::Boolean {
validity: nulls,
values,
}
}
}
}
fn new_fixed_size(
field_type: &kanto_meta_format_v1::field_type::ArchivedFieldType,
batch_size_hint: usize,
width: usize,
) -> ArrayBuilder {
let data_type = datafusion::arrow::datatypes::DataType::from(field_type);
let validity = datafusion::arrow::array::BooleanBufferBuilder::new(batch_size_hint);
let capacity = batch_size_hint.saturating_mul(width);
let data = datafusion::arrow::buffer::MutableBuffer::with_capacity(capacity);
ArrayBuilder::FixedSize {
data_type,
width,
data,
validity: Some(validity),
}
}
fn new_view(utf8_or_binary: Utf8OrBinary, batch_size_hint: usize) -> ArrayBuilder {
let builder = datafusion::arrow::array::BinaryViewBuilder::with_capacity(batch_size_hint);
ArrayBuilder::View {
utf8_or_binary,
builder,
}
}
pub(crate) fn append_from_arrow_bytes(&mut self, arrow_bytes: &[u8]) {
match self {
ArrayBuilder::FixedSize {
data_type: _,
width,
data,
validity,
} => {
debug_assert_eq!(arrow_bytes.len(), *width);
data.extend_from_slice(arrow_bytes);
if let Some(validity) = validity {
validity.append(true);
}
}
ArrayBuilder::View {
utf8_or_binary: _,
builder,
} => {
builder.append_value(arrow_bytes);
}
ArrayBuilder::Boolean {
validity: nulls,
values,
} => {
nulls.append_non_null();
debug_assert_eq!(arrow_bytes.len(), 1);
let byte = arrow_bytes
.first()
.expect("internal error #9pj5if5rf6k5c: set boolean got wrong number of bytes");
let value = *byte != 0;
values.append(value);
}
}
}
pub(crate) fn append_null(&mut self) {
match self {
ArrayBuilder::FixedSize {
data_type: _,
width,
data,
validity,
} => {
data.extend_zeros(*width);
if let Some(validity) = validity {
validity.append(false);
} else {
debug_assert!(validity.is_some());
}
}
ArrayBuilder::View {
utf8_or_binary: _,
builder,
} => builder.append_null(),
ArrayBuilder::Boolean {
validity: nulls,
values,
} => {
nulls.append_null();
values.append(false);
}
}
}
pub(crate) fn build(&mut self) -> Result<datafusion::arrow::array::ArrayRef, RecordError> {
// TODO if we do `&mut self` (which is a good idea), we need to be sure to clear out all mutable state here
let array: Arc<dyn Array> = match self {
ArrayBuilder::FixedSize {
data_type,
width,
data,
validity,
} => {
let mutable_buffer = std::mem::take(data);
let buffer = datafusion::arrow::buffer::Buffer::from(mutable_buffer);
let len = buffer.len().div_ceil(*width);
let nulls = validity
.as_mut()
.map(|v| datafusion::arrow::buffer::NullBuffer::new(v.finish()));
let builder = datafusion::arrow::array::ArrayDataBuilder::new(data_type.clone())
.add_buffer(buffer)
.nulls(nulls)
.len(len);
let array_data = builder
.build()
.map_err(|arrow_error| RecordError::InternalArrow { error: arrow_error })?;
#[cfg(debug_assertions)]
array_data
.validate_data()
.map_err(|arrow_error| RecordError::InternalArrow { error: arrow_error })?;
datafusion::arrow::array::make_array(array_data)
}
ArrayBuilder::View {
utf8_or_binary,
builder,
} => {
let array = builder.finish();
match utf8_or_binary {
Utf8OrBinary::Binary => Arc::new(array),
Utf8OrBinary::Utf8 => {
let array = {
if cfg!(debug_assertions) {
array.to_string_view().map_err(|arrow_error| {
RecordError::InternalArrow { error: arrow_error }
})?
} else {
#[expect(unsafe_code)]
unsafe {
array.to_string_view_unchecked()
}
}
};
Arc::new(array)
}
}
}
ArrayBuilder::Boolean {
validity: nulls,
values,
} => {
let array =
datafusion::arrow::array::BooleanArray::new(values.finish(), nulls.finish());
Arc::new(array)
}
};
// TODO `debug_assert_matches!` <https://doc.rust-lang.org/std/assert_matches/macro.debug_assert_matches.html>, <https://github.com/rust-lang/rust/issues/82775> #waiting #ecosystem/rust #severity/low #dev
#[cfg(debug_assertions)]
array
.to_data()
.validate()
.map_err(|arrow_error| RecordError::InternalArrow { error: arrow_error })?;
Ok(array)
}
}

View file

@ -0,0 +1,38 @@
use std::borrow::Cow;
use crate::error::RecordError;
use crate::slot::Slot;
#[derive(Debug)]
pub(crate) enum ColumnValue {
DefaultConstant {
is_valid: bool,
arrow_bytes: Cow<'static, [u8]>,
},
Slot(Slot),
}
impl ColumnValue {
pub(crate) fn is_valid<'data>(&'data self, data: &'data [u8]) -> Result<bool, RecordError> {
match self {
ColumnValue::DefaultConstant {
is_valid,
arrow_bytes: _,
} => Ok(*is_valid),
ColumnValue::Slot(slot) => slot.is_valid(data),
}
}
pub(crate) fn get_arrow_bytes<'data>(
&'data self,
data: &'data [u8],
) -> Result<&'data [u8], RecordError> {
match self {
ColumnValue::DefaultConstant {
is_valid: _,
arrow_bytes,
} => Ok(arrow_bytes),
ColumnValue::Slot(slot) => slot.get_arrow_bytes(data),
}
}
}

View file

@ -0,0 +1,225 @@
#![expect(missing_docs, reason = "not a public API, only exposed for fuzzing")]
use kanto_meta_format_v1::field_type::ArchivedFieldType;
use kanto_meta_format_v1::table_def::ArchivedFieldDef;
use crate::error::RecordError;
pub(crate) type Indirection = u32;
pub(crate) type Length = u32;
const fn needed_validity_bytes(num_fields: usize) -> usize {
num_fields.div_ceil(8)
}
const fn slot_width(field_type: &ArchivedFieldType) -> usize {
use kanto_meta_format_v1::field_type::ArchivedFieldType as FT;
match field_type {
FT::U64 | FT::I64 | FT::F64 => 8,
FT::U32 | FT::I32 | FT::F32 => 4,
FT::U16 | FT::I16 => 2,
FT::U8 | FT::I8 | FT::Boolean => 1,
FT::String | FT::Binary => size_of::<Indirection>(),
}
}
const fn field_is_indirect(field_type: &ArchivedFieldType) -> bool {
use kanto_meta_format_v1::field_type::ArchivedFieldType as FT;
match field_type {
FT::U64
| FT::U32
| FT::U16
| FT::U8
| FT::I64
| FT::I32
| FT::I16
| FT::I8
| FT::F64
| FT::F32
| FT::Boolean => false,
FT::String | FT::Binary => true,
}
}
#[derive(Clone, Copy, Debug)]
#[expect(unnameable_types, reason = "triggered by fuzzing-only export")]
pub struct Slot {
// TODO maybe switch to u32-size blocks for validity bits
valid_offset: usize,
valid_mask: u8,
start: usize,
stop: usize,
is_indirect: bool,
nullable: bool,
}
impl Slot {
#[must_use]
pub const fn is_indirect(&self) -> bool {
self.is_indirect
}
#[must_use]
pub const fn is_nullable(&self) -> bool {
self.nullable
}
#[must_use]
pub const fn width(&self) -> usize {
debug_assert!(self.stop >= self.start);
self.stop.saturating_sub(self.start)
}
pub fn is_valid(&self, record: &[u8]) -> Result<bool, RecordError> {
let byte = record.get(self.valid_offset).ok_or(RecordError::Corrupt {
error: crate::error::RecordCorruptError::Truncated,
})?;
let is_valid = (byte & self.valid_mask) != 0;
Ok(is_valid)
}
pub fn set_valid(&self, record: &mut [u8]) -> Result<(), RecordError> {
let byte = record
.get_mut(self.valid_offset)
.ok_or(RecordError::Corrupt {
error: crate::error::RecordCorruptError::Truncated,
})?;
*byte |= self.valid_mask;
Ok(())
}
#[maybe_tracing::instrument(skip(self), err)]
pub fn get_arrow_bytes<'record>(
&self,
record: &'record [u8],
) -> Result<&'record [u8], RecordError> {
let slot = record
.get(self.start..self.stop)
.ok_or(RecordError::Corrupt {
error: crate::error::RecordCorruptError::Truncated,
})?;
if self.is_indirect {
let pos: usize = {
// marking the assert const removes a lint
// <https://stackoverflow.com/questions/72419389/how-to-assert-size-of-usize-to-drop-support-for-incompatible-platforms>
const { assert!(Indirection::BITS <= usize::BITS) };
Indirection::from_le_bytes(slot.try_into().map_err(|_error| {
RecordError::Corrupt {
error: crate::error::RecordCorruptError::Truncated,
}
})?)
.try_into()
.map_err(|_error| RecordError::Internal {
message: "indirection must fit in usize",
})?
};
let data = record.get(pos..).ok_or(RecordError::Corrupt {
error: crate::error::RecordCorruptError::Truncated,
})?;
let (length_bytes, rest) = {
const LENGTH_SIZE: usize = size_of::<Length>();
data.split_first_chunk::<LENGTH_SIZE>()
.ok_or(RecordError::Corrupt {
error: crate::error::RecordCorruptError::Truncated,
})?
};
let length: usize = {
// marking the assert const removes a lint
// <https://stackoverflow.com/questions/72419389/how-to-assert-size-of-usize-to-drop-support-for-incompatible-platforms>
const { assert!(Length::BITS <= usize::BITS) };
let len_u32 = Length::from_le_bytes(*length_bytes);
usize::try_from(len_u32).map_err(|_error| RecordError::Internal {
message: "indirection must fit in usize",
})?
};
let indirect = rest.get(..length).ok_or(RecordError::Corrupt {
error: crate::error::RecordCorruptError::Truncated,
})?;
Ok(indirect)
} else {
Ok(slot)
}
}
pub fn set_arrow_bytes(&self, record: &mut [u8], bytes: &[u8]) -> Result<(), RecordError> {
debug_assert!(!self.is_indirect());
let slot = record
.get_mut(self.start..self.stop)
.ok_or(RecordError::Corrupt {
error: crate::error::RecordCorruptError::Truncated,
})?;
debug_assert_eq!(bytes.len(), slot.len());
slot.copy_from_slice(bytes);
Ok(())
}
pub fn set_indirect_data(&self, record: &mut Vec<u8>, data: &[u8]) -> Result<(), RecordError> {
debug_assert!(self.is_indirect());
let pos = Indirection::try_from(record.len()).map_err(|_error| RecordError::Internal {
message: "indirection must fit in given size",
})?;
let len = Length::try_from(data.len()).map_err(|_error| RecordError::Internal {
message: "length must fit in given size",
})?;
record.extend_from_slice(&len.to_le_bytes());
record.extend_from_slice(data);
let slot = record
.get_mut(self.start..self.stop)
.ok_or(RecordError::Corrupt {
error: crate::error::RecordCorruptError::Truncated,
})?;
slot.copy_from_slice(&pos.to_le_bytes());
Ok(())
}
}
#[expect(unnameable_types, reason = "triggered by fuzzing-only export")]
pub struct Slots {
pub min_record_size: usize,
pub slots: Box<[Slot]>,
}
pub fn build_slots(fields: &[&ArchivedFieldDef]) -> Result<Slots, crate::error::RecordSchemaError> {
const GEN_SIZE: usize = size_of::<kanto_meta_format_v1::table_def::TableDefGeneration>();
let validity_bits_size = needed_validity_bytes(fields.len());
let mut slot_offset = GEN_SIZE
.checked_add(validity_bits_size)
.ok_or(crate::error::RecordSchemaError::TooLarge)?;
let slots = fields
.iter()
.enumerate()
.map(|(slot_num, field)| {
#[expect(clippy::integer_division)]
let valid_byte = slot_num / 8;
let valid_offset = GEN_SIZE
.checked_add(valid_byte)
.ok_or(crate::error::RecordSchemaError::TooLarge)?;
let valid_mask = 1 << (slot_num % 8);
let start = slot_offset;
let slot_size = slot_width(&field.field_type);
let stop = start
.checked_add(slot_size)
.ok_or(crate::error::RecordSchemaError::TooLarge)?;
slot_offset = stop;
let is_indirect = field_is_indirect(&field.field_type);
let slot = Slot {
valid_offset,
valid_mask,
start,
stop,
is_indirect,
nullable: field.nullable,
};
Ok(slot)
})
.collect::<Result<_, _>>()?;
Ok(Slots {
min_record_size: slot_offset,
slots,
})
}

30
crates/rocky/Cargo.toml Normal file
View file

@ -0,0 +1,30 @@
[package]
name = "rocky"
version = "0.1.0"
description = "Multi-threaded binding to the RocksDB C++ key-value store library"
keywords = ["rocksdb", "key-value"]
categories = ["database", "api-bindings"]
homepage.workspace = true
repository.workspace = true
authors.workspace = true
license.workspace = true
publish = false # TODO publish crate #severity/high #urgency/medium
edition.workspace = true
rust-version.workspace = true
[dependencies]
dashmap = { workspace = true }
gat-lending-iterator = { workspace = true }
libc = { workspace = true }
librocksdb-sys = { workspace = true }
maybe-tracing = { workspace = true }
rkyv_util = { workspace = true }
thiserror = { workspace = true }
tracing = { workspace = true }
[dev-dependencies]
tempfile = { workspace = true }
test-log = { workspace = true }
[lints]
workspace = true

41
crates/rocky/src/cache.rs Normal file
View file

@ -0,0 +1,41 @@
/// Memory cache for RocksDB operation.
///
/// # Resources
///
/// - <https://github.com/facebook/rocksdb/wiki/Block-Cache>
pub struct Cache {
raw: RawCache,
}
struct RawCache(std::ptr::NonNull<librocksdb_sys::rocksdb_cache_t>);
unsafe impl Send for RawCache {}
unsafe impl Sync for RawCache {}
#[clippy::has_significant_drop]
impl Drop for RawCache {
fn drop(&mut self) {
unsafe {
librocksdb_sys::rocksdb_cache_destroy(self.0.as_ptr());
}
}
}
impl Cache {
/// Create a new LRU cache with the given `capacity` in bytes.
#[must_use]
pub fn new_lru(capacity: usize) -> Cache {
// TODO let caller control options
// TODO `rocksdb_cache_create_lru_with_strict_capacity_limit`?
let ptr = unsafe { librocksdb_sys::rocksdb_cache_create_lru(capacity) };
let raw = RawCache(
std::ptr::NonNull::new(ptr)
.expect("Cache::new_lru: rocksdb_cache_create_lru returned null pointer"),
);
Cache { raw }
}
pub(crate) const fn as_raw(&self) -> &std::ptr::NonNull<librocksdb_sys::rocksdb_cache_t> {
&self.raw.0
}
}

View file

@ -0,0 +1,57 @@
use std::sync::Arc;
use crate::KeepAlive;
/// Handle to a column family, which are namespaces for key-value storage.
///
/// # Resources
///
/// - <https://github.com/facebook/rocksdb/wiki/Column-Families>
pub struct ColumnFamily {
raw: RawColumnFamilyHandle,
}
struct RawColumnFamilyHandle {
ptr: std::ptr::NonNull<librocksdb_sys::rocksdb_column_family_handle_t>,
_keepalive_database: KeepAlive<Arc<crate::database::RawTransactionDb>>,
}
unsafe impl Send for RawColumnFamilyHandle {}
unsafe impl Sync for RawColumnFamilyHandle {}
#[clippy::has_significant_drop]
impl Drop for RawColumnFamilyHandle {
fn drop(&mut self) {
unsafe {
librocksdb_sys::rocksdb_column_family_handle_destroy(self.ptr.as_ptr());
}
}
}
impl ColumnFamily {
pub(crate) const fn from_owned_raw(
rocksdb_raw: Arc<crate::database::RawTransactionDb>,
ptr: std::ptr::NonNull<librocksdb_sys::rocksdb_column_family_handle_t>,
) -> ColumnFamily {
let raw = RawColumnFamilyHandle {
ptr,
_keepalive_database: KeepAlive::new(rocksdb_raw),
};
ColumnFamily { raw }
}
pub(crate) const fn as_raw(
&self,
) -> &std::ptr::NonNull<librocksdb_sys::rocksdb_column_family_handle_t> {
&self.raw.ptr
}
}
impl std::fmt::Debug for ColumnFamily {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ColumnFamily")
.field("raw", &self.raw.ptr)
.finish()
}
}

510
crates/rocky/src/cursor.rs Normal file
View file

@ -0,0 +1,510 @@
use crate::ColumnFamily;
use crate::KeepAlive;
use crate::ReadOptions;
use crate::RocksDbError;
use crate::Transaction;
/// Iterate RocksDB key-value pairs.
///
/// Note that this is not a Rust [`Iterator`](std::iter::Iterator), but more like a cursor inside a snapshot of the database.
/// Most uses are probably better off with with `Transaction::iter`
///
/// # Resources
///
/// - <https://github.com/facebook/rocksdb/wiki/Iterator>
#[must_use]
pub struct Cursor {
raw: RawCursor,
}
struct RawCursor {
ptr: std::ptr::NonNull<librocksdb_sys::rocksdb_iterator_t>,
_keepalive_transaction: KeepAlive<Transaction>,
_keepalive_read_opts: KeepAlive<ReadOptions>,
}
unsafe impl Send for RawCursor {}
#[clippy::has_significant_drop]
impl Drop for RawCursor {
#[maybe_tracing::instrument(skip(self))]
fn drop(&mut self) {
unsafe { librocksdb_sys::rocksdb_iter_destroy(self.ptr.as_ptr()) }
}
}
impl Cursor {
/// Create a cursor for scanning key-value pairs in `column_family` as seen by `transaction`.
///
/// Public API is [`Transaction::cursor`], [`Transaction::range`] or [`Transaction::iter`].
#[maybe_tracing::instrument(skip(read_opts))]
pub(crate) fn new(
transaction: Transaction,
read_opts: ReadOptions,
column_family: &ColumnFamily,
) -> Self {
let raw = unsafe {
librocksdb_sys::rocksdb_transaction_create_iterator_cf(
transaction.as_raw().as_ptr(),
read_opts.as_raw().as_ptr(),
column_family.as_raw().as_ptr(),
)
};
let raw = std::ptr::NonNull::new(raw)
.expect("Iter::new: rocksdb_transaction_create_iterator_cf returned null pointer");
let raw = RawCursor {
ptr: raw,
_keepalive_transaction: KeepAlive::new(transaction),
_keepalive_read_opts: KeepAlive::new(read_opts),
};
Cursor { raw }
}
/// Make a new [`Cursor`] and limit it to looking at keys in `range`.
///
/// Explicitly seeking outside of `range` is has undefined semantics but is guaranteed not to crash.
/// See <https://github.com/facebook/rocksdb/wiki/Iterator#iterating-upper-bound-and-lower-bound>.
#[maybe_tracing::instrument(skip(read_opts))]
pub(crate) fn with_range<K, R>(
transaction: Transaction,
mut read_opts: ReadOptions,
column_family: &ColumnFamily,
range: R,
) -> Self
where
K: AsRef<[u8]>,
R: std::ops::RangeBounds<K> + std::fmt::Debug,
{
match range.end_bound() {
std::ops::Bound::Included(end_key) => {
// RocksDB `ReadOptions::set_iterate_upper_bound` is exclusive.
// Compute first possible key after `end_key` and set that as upper bound.
let end_key = end_key.as_ref();
let mut one_past_end = Vec::with_capacity(end_key.len().saturating_add(1));
one_past_end.extend(end_key);
one_past_end.push(0u8);
read_opts.set_iterate_upper_bound(one_past_end);
}
std::ops::Bound::Excluded(end_key) => {
read_opts.set_iterate_upper_bound(end_key);
}
std::ops::Bound::Unbounded => {
// nothing
// TODO caller might have set upper bound already, that's just weird.
// can't really avoid as long as the api takes a ReadOptions in; rocksdb API is awkward shaped for trying to enforce stronger constraints on.
}
};
let mut cursor = Self::new(transaction, read_opts, column_family);
let start = range.start_bound().map(|k| Box::from(k.as_ref()));
match &start {
std::ops::Bound::Included(start) => {
cursor.seek(start);
}
std::ops::Bound::Excluded(prev_key) => {
cursor.seek(prev_key);
if let Some(k) = cursor.key() {
// still valid
if k == prev_key.as_ref() {
// and seek found the exact key; go to next one for exclusive
cursor.seek_next();
}
}
}
std::ops::Bound::Unbounded => cursor.seek_to_first(),
}
cursor
}
/// Is the current cursor position valid?
///
/// A cursor is invalid if a seek went outside of bounds of existing data, range or prefix.
///
/// # Developer note: public API always initializes cursors
///
/// This library should prevent high-level callers from seeing cursors that have been created but have not performed their first seek.
//
// TODO we don't expose prefixes at this time
#[maybe_tracing::instrument(skip(self), ret)]
pub fn is_valid(&self) -> bool {
let valid = unsafe { librocksdb_sys::rocksdb_iter_valid(self.raw.ptr.as_ptr()) };
valid != 0
}
fn get_error(&self) -> Result<(), RocksDbError> {
let mut err: *mut libc::c_char = ::std::ptr::null_mut();
unsafe {
librocksdb_sys::rocksdb_iter_get_error(self.raw.ptr.as_ptr(), &mut err);
}
if err.is_null() {
Ok(())
} else {
let error = unsafe { RocksDbError::from_raw(err) };
Err(error)
}
}
/// Seek backwards to the first possible value (if any).
#[maybe_tracing::instrument(skip(self), ret)]
pub fn seek_to_first(&mut self) {
unsafe { librocksdb_sys::rocksdb_iter_seek_to_first(self.raw.ptr.as_ptr()) };
}
/// Seek forwards to the last possible value (if any).
#[maybe_tracing::instrument(skip(self), ret)]
pub fn seek_to_last(&mut self) {
unsafe { librocksdb_sys::rocksdb_iter_seek_to_last(self.raw.ptr.as_ptr()) };
}
/// Seek to the given key, or the next existing key after it.
///
/// Seeking outside the set range is undocumented behavior but will *not* lead to a crash.
/// (See [`Transaction::range`], [`ReadOptions::set_iterate_lower_bound`], [`ReadOptions::set_iterate_upper_bound`].)
///
/// # Resources
///
/// - <https://github.com/facebook/rocksdb/wiki/Iterator#iterating-upper-bound-and-lower-bound>
#[maybe_tracing::instrument(skip(self, key), fields(key=?key.as_ref()), ret)]
pub fn seek<Key>(&mut self, key: Key)
where
Key: AsRef<[u8]>,
{
let key = key.as_ref();
let k = key.as_ptr().cast::<libc::c_char>();
let klen = libc::size_t::from(key.len());
unsafe {
librocksdb_sys::rocksdb_iter_seek(self.raw.ptr.as_ptr(), k, klen);
};
}
/// Seek to the given key, or the previous existing key before it.
///
/// See [`Cursor::seek`].
///
/// # Resources
///
/// - <https://github.com/facebook/rocksdb/wiki/SeekForPrev>
#[maybe_tracing::instrument(skip(self, key), fields(key=?key.as_ref()), ret)]
pub fn seek_for_prev<Key>(&mut self, key: Key)
where
Key: AsRef<[u8]>,
{
let iter = self.raw.ptr.as_ptr();
let key = key.as_ref();
let k = key.as_ptr().cast::<libc::c_char>();
let klen = libc::size_t::from(key.len());
unsafe {
librocksdb_sys::rocksdb_iter_seek_for_prev(iter, k, klen);
};
}
unsafe fn seek_next_unchecked(&mut self) {
debug_assert!(self.is_valid());
unsafe { librocksdb_sys::rocksdb_iter_next(self.raw.ptr.as_ptr()) }
}
unsafe fn seek_prev_unchecked(&mut self) {
debug_assert!(self.is_valid());
unsafe { librocksdb_sys::rocksdb_iter_prev(self.raw.ptr.as_ptr()) }
}
/// Move cursor to the next key.
#[maybe_tracing::instrument(skip(self))]
pub fn seek_next(&mut self) {
if self.is_valid() {
unsafe { self.seek_next_unchecked() }
}
}
/// Move cursor to the previous key.
#[maybe_tracing::instrument(skip(self))]
pub fn seek_prev(&mut self) {
if self.is_valid() {
unsafe { self.seek_prev_unchecked() }
}
}
/// Only call when Cursor is valid.
unsafe fn key_unchecked(&self) -> &[u8] {
debug_assert!(self.is_valid());
let mut out_len: libc::size_t = 0;
let key_ptr =
unsafe { librocksdb_sys::rocksdb_iter_key(self.raw.ptr.as_ptr(), &mut out_len) };
let key = unsafe { std::slice::from_raw_parts(key_ptr.cast::<u8>(), out_len) };
key
}
/// Only call when Cursor is valid.
unsafe fn value_unchecked(&self) -> &[u8] {
debug_assert!(self.is_valid());
let mut out_len: libc::size_t = 0;
let value_ptr =
unsafe { librocksdb_sys::rocksdb_iter_value(self.raw.ptr.as_ptr(), &mut out_len) };
let value = unsafe { std::slice::from_raw_parts(value_ptr.cast::<u8>(), out_len) };
value
}
/// Get the bytes of the key, if valid.
#[must_use]
pub fn key(&self) -> Option<&[u8]> {
self.is_valid().then(|| unsafe { self.key_unchecked() })
}
/// Get the bytes of the value, if valid.
#[must_use]
pub fn value(&self) -> Option<&[u8]> {
self.is_valid().then(|| unsafe { self.value_unchecked() })
}
/// Convert the [`Cursor`] into an [`gat_lending_iterator::LendingIterator`].
#[must_use]
pub const fn into_iter(self) -> Iter {
Iter::new(self)
}
}
/// Iterator for key-value pairs.
///
/// See [`Transaction::iter`].
pub struct Iter {
cursor: Cursor,
started: bool,
}
impl Iter {
pub(crate) const fn new(cursor: Cursor) -> Iter {
Iter {
cursor,
started: false,
}
}
}
impl gat_lending_iterator::LendingIterator for Iter {
type Item<'a>
= Result<Entry<'a>, RocksDbError>
where
Self: 'a;
fn next(&mut self) -> Option<Self::Item<'_>> {
if self.started {
self.cursor.seek_next();
} else {
self.started = true;
}
if self.cursor.is_valid() {
Some(Ok(Entry { iter: self }))
} else {
match self.cursor.get_error() {
Ok(()) => None,
Err(error) => Some(Err(error)),
}
}
}
}
/// A key-value pair at the current [`Cursor`] (or [`Iter`]) position.
///
/// Holding an [`Entry`] is a guarantee that a key-value pair was found ([`Cursor::is_valid`]) and that reading the key and value will succeed.[^mmap]
///
/// [^mmap]:
/// RocksDB supports an optional `mmap(2)` mode, in which reads are not guaranteed to succeed, but the program will terminate on I/O errors.
/// This library does not support using `mmap(2)`.
pub struct Entry<'cursor> {
iter: &'cursor Iter,
}
impl Entry<'_> {
#[must_use]
/// Get the key at the current location.
pub fn key(&self) -> &[u8] {
unsafe { self.iter.cursor.key_unchecked() }
}
/// Get the value at the current location.
#[must_use]
pub fn value(&self) -> &[u8] {
unsafe { self.iter.cursor.value_unchecked() }
}
}
#[cfg(test)]
mod tests {
use gat_lending_iterator::LendingIterator as _;
use crate::Database;
use crate::Options;
use crate::ReadOptions;
use crate::TransactionDbOptions;
use crate::TransactionOptions;
use crate::WriteOptions;
#[test]
fn cursor() {
let dir = tempfile::tempdir().expect("create temp directory");
let mut opts = Options::new();
opts.create_if_missing(true);
let txdb_opts = TransactionDbOptions::new();
let db = Database::open(&dir, &opts, &txdb_opts, vec![]).unwrap();
let cf_one = db.create_column_family("one", &Options::new()).unwrap();
let tx = {
let write_options = WriteOptions::new();
let tx_options = TransactionOptions::new();
db.transaction_begin(&write_options, &tx_options)
};
tx.put(&cf_one, "a", "one").unwrap();
tx.put(&cf_one, "b", "two").unwrap();
{
let read_options = ReadOptions::new();
let mut cursor = tx.cursor(read_options, &cf_one);
cursor.seek_to_first();
assert!(cursor.is_valid());
let key1 = cursor.key().unwrap();
let value1 = cursor.value().unwrap();
assert_eq!(key1, b"a");
assert_eq!(value1, b"one");
cursor.seek_next();
assert!(cursor.is_valid());
let key2 = cursor.key().unwrap();
let value2 = cursor.value().unwrap();
assert_eq!(key2, b"b");
assert_eq!(value2, b"two");
// this would not compile
// assert_eq!(key1, b"a");
}
drop(db);
}
#[test]
fn iter_explicit() {
let dir = tempfile::tempdir().expect("create temp directory");
let mut opts = Options::new();
opts.create_if_missing(true);
let txdb_opts = TransactionDbOptions::new();
let db = Database::open(&dir, &opts, &txdb_opts, vec![]).unwrap();
let cf_one = db.create_column_family("one", &Options::new()).unwrap();
let tx = {
let write_options = WriteOptions::new();
let tx_options = TransactionOptions::new();
db.transaction_begin(&write_options, &tx_options)
};
tx.put(&cf_one, "a", "one").unwrap();
tx.put(&cf_one, "b", "two").unwrap();
{
let read_options = ReadOptions::new();
let mut iter = tx.range::<&[u8], _>(read_options, &cf_one, ..);
let entry1 = iter
.next()
.unwrap()
.expect("normal database operations should work");
let key1 = entry1.key();
assert_eq!(key1, b"a");
assert_eq!(entry1.value(), b"one");
let entry2 = iter
.next()
.unwrap()
.expect("normal database operations should work");
assert_eq!(entry2.key(), b"b");
assert_eq!(entry2.value(), b"two");
// this would not compile
// assert_eq!(entry1.key(), b"a");
assert!(iter.next().is_none());
}
drop(db);
}
#[track_caller]
fn check_iter<K, R>(inputs: &[K], range: R, expected: &[K])
where
K: AsRef<[u8]> + std::fmt::Debug,
R: std::ops::RangeBounds<K> + std::fmt::Debug,
{
let dir = tempfile::tempdir().expect("create temp directory");
let mut opts = Options::new();
opts.create_if_missing(true);
let txdb_opts = TransactionDbOptions::new();
let db = Database::open(&dir, &opts, &txdb_opts, vec![]).unwrap();
let cf_one = db.create_column_family("one", &Options::new()).unwrap();
let tx = {
let write_options = WriteOptions::new();
let tx_options = TransactionOptions::new();
db.transaction_begin(&write_options, &tx_options)
};
for k in inputs {
tx.put(&cf_one, k, "dummy").unwrap();
}
let read_options = ReadOptions::new();
let iter = tx.range(read_options, &cf_one, range);
let got = iter
.map(|result: Result<crate::cursor::Entry<'_>, _>| {
result.map(|entry| entry.key().to_vec())
})
.into_iter()
.collect::<Result<Vec<_>, _>>()
.expect("normal database operations should work");
let want = expected
.iter()
.map(|k| k.as_ref().to_vec())
.collect::<Vec<_>>();
assert_eq!(got, want);
drop(db);
}
#[test]
fn iter_full() {
check_iter(&["a", "b"], .., &["a", "b"]);
}
#[test]
fn iter_range() {
check_iter(&["a", "b", "c"], "a".."b", &["a"]);
check_iter(&["b", "c"], "a".."b2", &["b"]);
check_iter(&["a", "b", "c"], "a2".."b2", &["b"]);
check_iter(&["a", "b", "c"], "c".."c2", &["c"]);
check_iter(&["a", "b", "c"], "c2".."x", &[]);
}
#[test]
fn iter_range_inclusive() {
check_iter(&["a", "b"], "a"..="b2", &["a", "b"]);
check_iter(&["a", "b"], "a"..="b", &["a", "b"]);
check_iter(&["a", "b"], "a"..="a2", &["a"]);
}
#[test]
fn iter_range_from() {
check_iter(&["a", "b", "c"], "a".., &["a", "b", "c"]);
check_iter(&["b", "c"], "a".., &["b", "c"]);
check_iter(&["a", "b", "c"], "a2".., &["b", "c"]);
check_iter(&["a", "b", "c"], "c".., &["c"]);
check_iter(&["a", "b", "c"], "c2".., &[]);
}
#[test]
fn iter_range_to() {
check_iter(&["a", "b"], .."b2", &["a", "b"]);
check_iter(&["a", "b"], .."b", &["a"]);
}
#[test]
fn iter_range_to_inclusive() {
check_iter(&["a", "b"], ..="b2", &["a", "b"]);
check_iter(&["a", "b"], ..="b", &["a", "b"]);
check_iter(&["a", "b"], ..="a2", &["a"]);
}
}

View file

@ -0,0 +1,364 @@
use std::os::unix::ffi::OsStrExt as _;
use std::path::Path;
use std::sync::Arc;
use dashmap::DashMap;
use crate::ColumnFamily;
use crate::KeepAlive;
use crate::Options;
use crate::RocksDbError;
use crate::Snapshot;
use crate::Transaction;
use crate::TransactionDbOptions;
use crate::TransactionOptions;
use crate::WriteOptions;
pub(crate) struct RawTransactionDb {
ptr: std::ptr::NonNull<librocksdb_sys::rocksdb_transactiondb_t>,
_keepalive_path: KeepAlive<std::ffi::CString>,
}
unsafe impl Send for RawTransactionDb {}
unsafe impl Sync for RawTransactionDb {}
#[clippy::has_significant_drop]
impl Drop for RawTransactionDb {
#[maybe_tracing::instrument(skip(self))]
fn drop(&mut self) {
unsafe {
librocksdb_sys::rocksdb_transactiondb_close(self.ptr.as_ptr());
}
}
}
/// An active RocksDB database.
#[derive(Clone)]
pub struct Database {
raw: Arc<RawTransactionDb>,
column_families: DashMap<String, Arc<ColumnFamily>>,
}
/// Information about a column family.
pub struct ColumnFamilyInfo<'parent> {
item: dashmap::mapref::multiple::RefMulti<'parent, String, Arc<ColumnFamily>>,
}
impl ColumnFamilyInfo<'_> {
/// Name of the column family.
#[must_use]
pub fn cf_name(&self) -> &str {
self.item.key()
}
/// Handle for the column family that can be used in RocksDB operations.
#[must_use]
pub fn cf_handle(&self) -> &ColumnFamily {
self.item.value()
}
}
impl Database {
/// Open an already-existing database.
///
/// RocksDB configuration is complex.
/// There's [`Options`], [`TransactionDbOptions`], and then also per-[`ColumnFamily`] [`Options`] (same as the top-level one, but not all fields matter).
/// See [`LatestOptions`](crate::LatestOptions) for how to load previously-used configuration.
//
// TODO maybe make it easier to pass `LatestOptions` here directly
#[expect(clippy::unwrap_in_result)]
#[maybe_tracing::instrument(skip(path, opts, txdb_opts, column_families), fields(path=%path.as_ref().display()), ret, err)]
pub fn open<P, I>(
path: P,
opts: &Options,
txdb_opts: &TransactionDbOptions,
column_families: I,
) -> Result<Database, RocksDbError>
where
P: AsRef<Path>,
I: IntoIterator<Item = (String, Options)>,
{
struct BuildColumnFamily {
cf_name: String,
cf_options: Options,
c_cf_name: std::ffi::CString,
}
let path = path.as_ref();
let c_path = {
let bytes = path.as_os_str().as_bytes();
std::ffi::CString::new(bytes).map_err(|error| RocksDbError::PathContainsNul {
path: path.to_owned(),
error,
})?
};
let column_families = {
let mut column_families = column_families.into_iter().collect::<Vec<_>>();
if column_families.is_empty() {
column_families.push(("default".to_owned(), Options::new()));
}
column_families
};
let column_families = column_families
.into_iter()
.map(
|(cf_name, cf_options)| match std::ffi::CString::new(cf_name.as_str()) {
Err(error) => Err(RocksDbError::ColumnFamilyNameContainsNul { cf_name, error }),
Ok(c_cf_name) => Ok(BuildColumnFamily {
cf_name,
cf_options,
c_cf_name,
}),
},
)
.collect::<Result<Vec<_>, _>>()?;
let column_family_names_ptrs: Vec<_> = column_families
.iter()
.map(|b| b.c_cf_name.as_ptr())
.collect();
debug_assert_eq!(column_families.len(), column_family_names_ptrs.len());
let column_family_options_ptrs: Vec<*const librocksdb_sys::rocksdb_options_t> =
column_families
.iter()
.map(|b| -> *const librocksdb_sys::rocksdb_options_t {
b.cf_options.as_raw().as_ptr()
})
.collect();
debug_assert_eq!(column_families.len(), column_family_options_ptrs.len());
let mut out_column_family_handles: Vec<
*mut librocksdb_sys::rocksdb_column_family_handle_t,
> = std::iter::repeat(std::ptr::null_mut())
.take(column_families.len())
.collect();
debug_assert_eq!(column_families.len(), out_column_family_handles.len());
let ptr = {
let mut err: *mut libc::c_char = ::std::ptr::null_mut();
let num_column_families = libc::c_int::try_from(column_families.len())
.map_err(|_error| RocksDbError::TooManyColumnFamilies)?;
let rocksdb = unsafe {
librocksdb_sys::rocksdb_transactiondb_open_column_families(
opts.as_raw().as_ptr(),
txdb_opts.as_raw().as_ptr(),
c_path.as_ptr(),
num_column_families,
column_family_names_ptrs.as_ptr(),
column_family_options_ptrs.as_ptr(),
out_column_family_handles.as_mut_ptr(),
&mut err,
)
};
if err.is_null() {
Ok(rocksdb)
} else {
let error = unsafe { RocksDbError::from_raw(err) };
Err(error)
}
}?;
let ptr = std::ptr::NonNull::new(ptr).expect(
"Database::open: rocksdb_transactiondb_open_column_families returned null pointer",
);
let raw = RawTransactionDb {
ptr,
_keepalive_path: KeepAlive::new(c_path),
};
debug_assert!(out_column_family_handles.iter().all(|ptr| !ptr.is_null()));
// We need `Database` first because `ColumnFamily` wants to refcount it, but that means we need to populate the `column_families` `DashMap` through individual inserts and not via `DashMap::from_iter` or `extend` or such.
let database = Database {
raw: Arc::new(raw),
column_families: DashMap::new(),
};
for (cf_name, cf_created) in std::iter::zip(column_families, out_column_family_handles).map(
|(cf_build, c_cf_handle)| {
let BuildColumnFamily { cf_name, .. } = cf_build;
let c_cf_handle = std::ptr::NonNull::new(c_cf_handle)
.expect("Database::open: rocksdb_transactiondb_open_column_families returned null pointer for column family handle");
let column_family = ColumnFamily::from_owned_raw(database.raw.clone(), c_cf_handle);
(cf_name, Arc::new(column_family))
},
) {
let old = database.column_families.insert(cf_name, cf_created);
debug_assert!(old.is_none());
}
Ok(database)
}
pub(crate) fn as_raw(&self) -> &std::ptr::NonNull<librocksdb_sys::rocksdb_transactiondb_t> {
&self.raw.ptr
}
/// List known column families.
///
/// Order of iteration is not guaranteed.
#[maybe_tracing::instrument(skip(self))]
pub fn column_families(&self) -> impl Iterator<Item = ColumnFamilyInfo<'_>> + '_ {
self.column_families
.iter()
.map(|item| ColumnFamilyInfo { item })
}
/// Get the named [`ColumnFamily`].
#[maybe_tracing::instrument(skip(self))]
pub fn get_column_family(&self, cf_name: &str) -> Option<Arc<ColumnFamily>> {
self.column_families
.get(cf_name)
.map(|entry| entry.value().clone())
}
/// Create a [`ColumnFamily`].
#[expect(clippy::unwrap_in_result)]
#[maybe_tracing::instrument(skip(self, cf_opts), err)]
pub fn create_column_family(
&self,
cf_name: &str,
cf_opts: &Options,
) -> Result<Arc<ColumnFamily>, RocksDbError> {
let c_cf_name = std::ffi::CString::new(cf_name).map_err(|error| {
RocksDbError::ColumnFamilyNameContainsNul {
cf_name: cf_name.to_owned(),
error,
}
})?;
let c_cf_opts = cf_opts.as_raw().as_ptr();
let c_cf_handle = {
let mut err: *mut libc::c_char = ::std::ptr::null_mut();
let maybe_handle = unsafe {
librocksdb_sys::rocksdb_transactiondb_create_column_family(
self.raw.ptr.as_ptr(),
c_cf_opts,
c_cf_name.as_ptr(),
&mut err,
)
};
if err.is_null() {
Ok(maybe_handle)
} else {
let error = unsafe { RocksDbError::from_raw(err) };
Err(error)
}
}?;
let c_cf_handle = std::ptr::NonNull::new(c_cf_handle)
.expect("Database::create_column_family: rocksdb_transactiondb_create_column_family returned null pointer");
let cf_created = Arc::new(ColumnFamily::from_owned_raw(self.raw.clone(), c_cf_handle));
let old = self
.column_families
.insert(cf_name.to_owned(), cf_created.clone());
debug_assert!(
old.is_none(),
"RocksDB was supposed to prevent duplicate column families"
);
Ok(cf_created)
}
/// Begin a [`Transaction`].
#[maybe_tracing::instrument(skip(self, write_options, tx_options))]
pub fn transaction_begin(
&self,
write_options: &WriteOptions,
tx_options: &TransactionOptions,
) -> Transaction {
Transaction::new(self.clone(), write_options, tx_options)
}
/// Take a [`Snapshot`] of the database state.
///
/// This can be used for repeatable read.
#[maybe_tracing::instrument(skip(self))]
pub fn snapshot(&self) -> Snapshot {
Snapshot::create(self.clone())
}
}
impl std::fmt::Debug for Database {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("rocky::Database")
.field("raw", &self.raw.ptr)
.finish_non_exhaustive()
}
}
#[cfg(test)]
mod tests {
use std::collections::HashSet;
use crate::Database;
use crate::Options;
use crate::RocksDbError;
use crate::TransactionDbOptions;
#[test]
fn database_create_and_drop() {
let dir = tempfile::tempdir().expect("create temp directory");
let mut opts = Options::new();
opts.create_if_missing(true);
let txdb_opts = TransactionDbOptions::new();
let db = Database::open(&dir, &opts, &txdb_opts, vec![]).unwrap();
drop(db);
}
#[test]
fn database_column_families() {
let dir = tempfile::tempdir().expect("create temp directory");
// Create the database and column families, so we can open with column families.
{
let mut opts = Options::new();
opts.create_if_missing(true);
let txdb_opts = TransactionDbOptions::new();
let db = Database::open(&dir, &opts, &txdb_opts, vec![]).unwrap();
let _column_family_one = db.create_column_family("one", &Options::new()).unwrap();
let _column_family_two = db.create_column_family("two", &Options::new()).unwrap();
drop(db);
}
// Now open it again, with column families.
let input_column_families = ["default", "one", "two"];
let opts = Options::new();
let txdb_opts = TransactionDbOptions::new();
let db = Database::open(
&dir,
&opts,
&txdb_opts,
input_column_families
.iter()
.map(|name| ((*name).to_owned(), Options::new())),
)
.unwrap();
let got_cf: HashSet<_> = db
.column_families()
.map(|cf| cf.cf_name().to_owned())
.collect();
let want_cf: HashSet<_> = input_column_families
.iter()
.map(ToString::to_string)
.collect();
assert_eq!(got_cf, want_cf);
drop(db);
}
#[test]
fn create_column_family_collision() {
let dir = tempfile::tempdir().expect("create temp directory");
let mut opts = Options::new();
opts.create_if_missing(true);
let txdb_opts = TransactionDbOptions::new();
let db = Database::open(&dir, &opts, &txdb_opts, vec![]).unwrap();
let _column_family = db.create_column_family("foo", &Options::new()).unwrap();
let error = db
.create_column_family("foo", &Options::new())
.err()
.unwrap();
assert!(matches!(error, RocksDbError::Other { message }
if message == "Invalid argument: Column family already exists"
));
drop(db);
}
}

40
crates/rocky/src/env.rs Normal file
View file

@ -0,0 +1,40 @@
/// Control how RocksDB interacts with the surrounding environment, for file I/O and such.
///
/// # Resources
///
/// - <https://github.com/facebook/rocksdb/wiki/Basic-Operations#environment>
pub struct Env {
raw: RawEnv,
}
struct RawEnv(std::ptr::NonNull<librocksdb_sys::rocksdb_env_t>);
unsafe impl Send for RawEnv {}
unsafe impl Sync for RawEnv {}
#[clippy::has_significant_drop]
impl Drop for RawEnv {
#[maybe_tracing::instrument(skip(self))]
fn drop(&mut self) {
unsafe {
librocksdb_sys::rocksdb_env_destroy(self.0.as_ptr());
}
}
}
impl Env {
/// Make a new default environment.
#[maybe_tracing::instrument]
pub fn new() -> Env {
let ptr = unsafe { librocksdb_sys::rocksdb_create_default_env() };
let raw = RawEnv(
std::ptr::NonNull::new(ptr)
.expect("Env::new: rocksdb_create_default_env returned null pointer"),
);
Env { raw }
}
pub(crate) const fn as_raw(&self) -> &std::ptr::NonNull<librocksdb_sys::rocksdb_env_t> {
&self.raw.0
}
}

41
crates/rocky/src/error.rs Normal file
View file

@ -0,0 +1,41 @@
use std::path::PathBuf;
/// Error returned from RocksDB.
#[derive(thiserror::Error, Debug)]
#[non_exhaustive]
#[expect(missing_docs)]
pub enum RocksDbError {
#[error("path contains NUL byte: {error}")]
PathContainsNul {
path: PathBuf,
#[source]
error: std::ffi::NulError,
},
#[error("column family name contains NUL byte: {error}")]
ColumnFamilyNameContainsNul {
cf_name: String,
#[source]
error: std::ffi::NulError,
},
#[error("too many column families")]
TooManyColumnFamilies,
#[error("rocksdb: {message}")]
Other { message: String },
}
impl RocksDbError {
#[must_use]
pub(crate) unsafe fn from_raw(ptr: *mut std::ffi::c_char) -> Self {
let message = {
let cstr = unsafe { std::ffi::CStr::from_ptr(ptr) };
String::from_utf8_lossy(cstr.to_bytes()).into_owned()
};
unsafe {
librocksdb_sys::rocksdb_free(ptr.cast::<libc::c_void>());
}
Self::Other { message }
}
}

View file

@ -0,0 +1,13 @@
/// `KeepAlive` is used to wrap C FFI pointers in a way that lets us keep them alive, but do *nothing else* with them.
///
/// Since we promise that we never actually use the pointer, `KeepAlive` can be marked as `Send` and `Sync`.
pub(crate) struct KeepAlive<T>(T);
unsafe impl<T> Send for KeepAlive<T> {}
unsafe impl<T> Sync for KeepAlive<T> {}
impl<T> KeepAlive<T> {
pub(crate) const fn new(t: T) -> Self {
Self(t)
}
}

View file

@ -0,0 +1,228 @@
use std::os::unix::ffi::OsStrExt as _;
use std::path::Path;
use crate::error::RocksDbError;
use crate::Cache;
use crate::Env;
use crate::KeepAlive;
use crate::Options;
/// Load RocksDB options from disk.
///
/// # Resources
///
/// - <https://github.com/facebook/rocksdb/wiki/RocksDB-Options-File>
pub struct LatestOptions {
db_options: std::ptr::NonNull<librocksdb_sys::rocksdb_options_t>,
list_column_family_names: *mut *mut std::ffi::c_char,
list_column_family_options: *mut *mut librocksdb_sys::rocksdb_options_t,
list_column_family_len: usize,
_keepalive_env: KeepAlive<Env>,
_keepalive_cache: KeepAlive<Cache>,
}
impl LatestOptions {
/// Load RocksDB options from disk.
#[expect(clippy::unwrap_in_result)]
pub fn load_latest_options<P: AsRef<Path>>(path: P) -> Result<LatestOptions, RocksDbError> {
let c_path =
std::ffi::CString::new(path.as_ref().as_os_str().as_bytes()).map_err(|error| {
RocksDbError::PathContainsNul {
path: path.as_ref().to_path_buf(),
error,
}
})?;
// TODO let caller pass in Env
let env = Env::new();
// TODO let caller pass in Cache
let cache = Cache::new_lru(32 * 1024 * 1024);
let mut db_options: *mut librocksdb_sys::rocksdb_options_t = std::ptr::null_mut();
let mut len: usize = 0;
let mut list_column_family_names: *mut *mut std::ffi::c_char = std::ptr::null_mut();
let mut list_column_family_options: *mut *mut librocksdb_sys::rocksdb_options_t =
std::ptr::null_mut();
let mut err: *mut libc::c_char = ::std::ptr::null_mut();
unsafe {
librocksdb_sys::rocksdb_load_latest_options(
c_path.as_ptr(),
env.as_raw().as_ptr(),
false,
cache.as_raw().as_ptr(),
&mut db_options,
&mut len,
&mut list_column_family_names,
&mut list_column_family_options,
&mut err,
);
};
if !err.is_null() {
let error = unsafe { RocksDbError::from_raw(err) };
return Err(error);
}
let db_options = std::ptr::NonNull::new(db_options)
.expect("LatestOptions::load_latest_options: rocksdb_load_latest_options returned null pointer for options");
let latest_options = LatestOptions {
_keepalive_env: KeepAlive::new(env),
_keepalive_cache: KeepAlive::new(cache),
db_options,
list_column_family_names,
list_column_family_options,
list_column_family_len: len,
};
Ok(latest_options)
}
/// Clone the [`Options`] that were used during last database use.
#[must_use]
pub fn clone_db_options(&self) -> Options {
Options::copy_from_raw(&self.db_options)
}
/// Iterate through known column family names and [`Options`].
#[must_use]
pub fn iter_column_families(&self) -> LatestOptionsColumnFamilyIter<'_> {
LatestOptionsColumnFamilyIter::new(self)
}
}
// This drop frees multiple raw pointers at once, and thus is on the container, not on an individual `RawFoo` pointer.
#[clippy::has_significant_drop]
impl Drop for LatestOptions {
fn drop(&mut self) {
unsafe {
librocksdb_sys::rocksdb_load_latest_options_destroy(
self.db_options.as_ptr(),
self.list_column_family_names,
self.list_column_family_options,
self.list_column_family_len,
);
};
}
}
/// Iterator for column families from [`LatestOptions`].
pub struct LatestOptionsColumnFamilyIter<'parent> {
_keepalive_parent: KeepAlive<&'parent LatestOptions>,
iter: std::iter::Zip<
std::slice::Iter<'parent, *mut i8>,
std::slice::Iter<'parent, *mut librocksdb_sys::rocksdb_options_t>,
>,
}
impl<'parent> LatestOptionsColumnFamilyIter<'parent> {
fn new(parent: &'parent LatestOptions) -> LatestOptionsColumnFamilyIter<'parent> {
let names = unsafe {
std::slice::from_raw_parts(
parent.list_column_family_names,
parent.list_column_family_len,
)
};
let options = unsafe {
std::slice::from_raw_parts(
parent.list_column_family_options,
parent.list_column_family_len,
)
};
assert_eq!(names.len(), options.len());
let iter = names.iter().zip(options);
LatestOptionsColumnFamilyIter {
_keepalive_parent: KeepAlive::new(parent),
iter,
}
}
}
/// Errors that can been iterating column families from [`LatestOptions`].
#[derive(thiserror::Error, Debug)]
#[non_exhaustive]
#[expect(missing_docs)]
pub enum LatestOptionsColumnFamilyIterError {
#[error("column family name is not UTF-8: {name_escaped}: {error}",
name_escaped = .error.as_bytes().escape_ascii().to_string(),
)]
ColumnFamilyNameNotUtf8 {
#[source]
error: std::string::FromUtf8Error,
},
}
impl Iterator for LatestOptionsColumnFamilyIter<'_> {
type Item = Result<(String, Options), LatestOptionsColumnFamilyIterError>;
#[expect(clippy::unwrap_in_result)]
fn next(&mut self) -> Option<Self::Item> {
self.iter.next().map(|(name_ptr, options_ptr)| {
let name = {
let cstr = unsafe { std::ffi::CStr::from_ptr(*name_ptr) };
let bytes = cstr.to_bytes().to_owned();
String::from_utf8(bytes).map_err(|utf8_error| {
LatestOptionsColumnFamilyIterError::ColumnFamilyNameNotUtf8 {
error: utf8_error,
}
})
}?;
let options_ptr = std::ptr::NonNull::new(*options_ptr)
.expect("LatestOptionsColumnFamilyIter::next: rocksdb_load_latest_options returned null pointer for options");
let options = Options::copy_from_raw(&options_ptr);
Ok((name, options))
})
}
}
#[cfg(test)]
mod tests {
use std::collections::HashSet;
use crate::Database;
use crate::LatestOptions;
use crate::Options;
use crate::TransactionDbOptions;
#[test]
fn latest_options_load_open_drop() {
let dir = tempfile::tempdir().expect("create temp directory");
// Create a DB so we have something to call `load_latest_options` on.
{
let mut opts = Options::new();
opts.create_if_missing(true);
let txdb_opts = TransactionDbOptions::new();
let db = Database::open(&dir, &opts, &txdb_opts, vec![]).unwrap();
let _column_family = db.create_column_family("one", &Options::new()).unwrap();
let _column_family = db.create_column_family("two", &Options::new()).unwrap();
drop(db);
}
let latest_options = LatestOptions::load_latest_options(&dir).unwrap();
// Open DB again, with latest options.
{
let opts = latest_options.clone_db_options();
let txdb_opts = TransactionDbOptions::new();
let column_families = latest_options
.iter_column_families()
.collect::<Result<Vec<_>, _>>()
.expect("list column families");
let db = Database::open(&dir, &opts, &txdb_opts, column_families).unwrap();
let got_cf: HashSet<_> = db
.column_families()
.map(|cf| cf.cf_name().to_owned())
.collect();
let want_cf: HashSet<_> = ["default", "one", "two"]
.iter()
.map(ToString::to_string)
.collect();
assert_eq!(got_cf, want_cf);
drop(db);
}
drop(latest_options);
}
}

68
crates/rocky/src/lib.rs Normal file
View file

@ -0,0 +1,68 @@
//! Rocky is a [RocksDB](https://rocksdb.org/) binding that works well with async and threads.
//!
//! In the [`rocksdb` crate](https://docs.rs/rocksdb/), at least as of v0.21.0 (2023-05), transactions are implemented as borrowing from the main database object.
//! This use of Rust lifetimes and ownership rules minimizes bugs, but also prevents multithreaded and `async` use of it.
//! Similar decisions are all over the codebase.
//! That makes the `rocksdb` crate API safe, but not useful for us.
//! This library was written to enable use cases not possible with their library.
//!
//! This library is a Rust-friendly API for multithreaded and `async` use on top of the low-level [`librocksdb_sys`](https://crates.io/crates/librocksdb-sys) crate.
//!
//! The underlying [RocksDB C++/C library](https://github.com/facebook/rocksdb) is thread-safe, and we expose that power to the caller, using atomic reference counts to ensure memory safety.
//!
//! Whenever the C++ library has an ordering dependency --- e.g. *transaction* must be freed before *database* --- we make one hold a reference to the other.
//! Unfortunately these dependencies are not documented anywhere, and we've had to discover them by experiments and `valgrind`.
//!
//! Note that for purposes of `async`, RocksDB functions are *blocking* and must be treated as such.
//!
//! At this time, this library is heavily biased to only contain features necessary for [KantoDB](https://kantodb.com/).
//! More can be added, but for now, we're e.g. only interested in transactional use, and minimize duplicate APIs such as `put_cf` vs `put`.
//!
//! For background, see
//!
//! - <https://github.com/rust-rocksdb/rust-rocksdb/issues/687>
//! - <https://github.com/rust-rocksdb/rust-rocksdb/issues/407>
//! - <https://github.com/rust-rocksdb/rust-rocksdb/issues/895>
//! - <https://github.com/rust-rocksdb/rust-rocksdb/issues/937>
#![expect(unsafe_code)]
mod cache;
mod column_family;
mod cursor;
mod database;
mod env;
mod error;
mod keepalive;
mod latest_options;
mod options;
mod read_options;
mod snapshot;
mod transaction;
mod transaction_options;
mod transactiondb_options;
mod view;
mod write_options;
pub use crate::cache::Cache;
pub use crate::column_family::ColumnFamily;
pub use crate::cursor::Cursor;
pub use crate::cursor::Entry;
pub use crate::cursor::Iter;
pub use crate::database::ColumnFamilyInfo;
pub use crate::database::Database;
pub use crate::env::Env;
pub use crate::error::RocksDbError;
pub(crate) use crate::keepalive::KeepAlive;
pub use crate::latest_options::LatestOptions;
pub use crate::latest_options::LatestOptionsColumnFamilyIter;
pub use crate::latest_options::LatestOptionsColumnFamilyIterError;
pub use crate::options::Options;
pub use crate::read_options::ReadOptions;
pub use crate::snapshot::Snapshot;
pub use crate::transaction::Exclusivity;
pub use crate::transaction::Transaction;
pub use crate::transaction_options::TransactionOptions;
pub use crate::transactiondb_options::TransactionDbOptions;
pub use crate::view::View;
pub use crate::write_options::WriteOptions;

View file

@ -0,0 +1,72 @@
/// RocksDB database-level configuration.
///
/// See also [`Database::open`](crate::Database::open), [`LatestOptions`](crate::LatestOptions).
///
/// # Resources
///
/// - <https://github.com/facebook/rocksdb/wiki/Basic-Operations#rocksdb-options>
/// - <https://github.com/facebook/rocksdb/wiki/Option-String-and-Option-Map>
/// - <https://github.com/facebook/rocksdb/wiki/RocksDB-Options-File>
pub struct Options {
raw: RawOptions,
}
struct RawOptions(std::ptr::NonNull<librocksdb_sys::rocksdb_options_t>);
unsafe impl Send for RawOptions {}
#[clippy::has_significant_drop]
impl Drop for RawOptions {
fn drop(&mut self) {
unsafe {
librocksdb_sys::rocksdb_options_destroy(self.0.as_ptr());
}
}
}
impl Options {
#[expect(missing_docs)]
#[must_use]
pub fn new() -> Options {
let ptr = unsafe { librocksdb_sys::rocksdb_options_create() };
let raw = RawOptions(
std::ptr::NonNull::new(ptr)
.expect("Options::new: rocksdb_options_create returned null pointer"),
);
Options { raw }
}
pub(crate) const fn from_owned_raw(
ptr: std::ptr::NonNull<librocksdb_sys::rocksdb_options_t>,
) -> Options {
let raw = RawOptions(ptr);
Options { raw }
}
pub(crate) fn copy_from_raw(
ptr: &std::ptr::NonNull<librocksdb_sys::rocksdb_options_t>,
) -> Options {
let ptr = unsafe { librocksdb_sys::rocksdb_options_create_copy(ptr.as_ptr()) };
let ptr = std::ptr::NonNull::new(ptr)
.expect("Options::copy_from_raw: rocksdb_options_create_copy returned null pointer");
Self::from_owned_raw(ptr)
}
pub(crate) const fn as_raw(&self) -> &std::ptr::NonNull<librocksdb_sys::rocksdb_options_t> {
&self.raw.0
}
/// Whether to create the database if it does not already exist.
pub fn create_if_missing(&mut self, create_if_missing: bool) {
let arg = libc::c_uchar::from(create_if_missing);
unsafe { librocksdb_sys::rocksdb_options_set_create_if_missing(self.raw.0.as_ptr(), arg) }
}
// TODO whatever options we want to set
}
impl Default for Options {
fn default() -> Self {
Self::new()
}
}

View file

@ -0,0 +1,173 @@
use std::pin::Pin;
use crate::KeepAlive;
use crate::Snapshot;
/// Configuration for read operations.
pub struct ReadOptions {
raw: RawReadOptions,
}
struct RawReadOptions {
ptr: std::ptr::NonNull<librocksdb_sys::rocksdb_readoptions_t>,
keepalive_snapshot: Option<KeepAlive<Snapshot>>,
// Keepalives for anything stored as a C++ `Slice` in RocksDB `ReadOptions`.
keepalive_lower_bound: Option<KeepAlive<Pin<Box<[u8]>>>>,
keepalive_upper_bound: Option<KeepAlive<Pin<Box<[u8]>>>>,
}
unsafe impl Send for RawReadOptions {}
#[clippy::has_significant_drop]
impl Drop for RawReadOptions {
fn drop(&mut self) {
unsafe {
librocksdb_sys::rocksdb_readoptions_destroy(self.ptr.as_ptr());
}
}
}
impl ReadOptions {
#[expect(missing_docs)]
#[must_use]
pub fn new() -> ReadOptions {
let ptr = unsafe { librocksdb_sys::rocksdb_readoptions_create() };
let raw = std::ptr::NonNull::new(ptr)
.expect("ReadOptions::new: rocksdb_readoptions_create returned null pointer");
let raw = RawReadOptions {
ptr: raw,
keepalive_snapshot: None,
keepalive_lower_bound: None,
keepalive_upper_bound: None,
};
ReadOptions { raw }
}
pub(crate) const fn as_raw(&self) -> &std::ptr::NonNull<librocksdb_sys::rocksdb_readoptions_t> {
&self.raw.ptr
}
/// Read from given [`Snapshot`].
pub fn set_snapshot(&mut self, snapshot: Snapshot) {
unsafe {
librocksdb_sys::rocksdb_readoptions_set_snapshot(
self.raw.ptr.as_ptr(),
snapshot.as_raw().as_ptr(),
);
}
self.raw.keepalive_snapshot = Some(KeepAlive::new(snapshot));
}
/// Do not look at keys less than this.
///
/// Most callers should prefer [`Transaction::range`](crate::Transaction::range).
pub fn set_iterate_lower_bound<Key>(&mut self, key: Key)
where
Key: AsRef<[u8]>,
{
let key = key.as_ref();
let key: Pin<Box<[u8]>> = Box::into_pin(Box::from(key));
unsafe {
librocksdb_sys::rocksdb_readoptions_set_iterate_lower_bound(
self.raw.ptr.as_ptr(),
key.as_ptr().cast::<libc::c_char>(),
libc::size_t::from(key.len()),
);
};
self.raw.keepalive_lower_bound = Some(KeepAlive::new(key));
}
/// Do not look at keys greater than this.
///
/// Most callers should prefer [`Transaction::range`](crate::Transaction::range).
pub fn set_iterate_upper_bound<Key>(&mut self, key: Key)
where
Key: AsRef<[u8]>,
{
let key = key.as_ref();
let key: Pin<Box<[u8]>> = Box::into_pin(Box::from(key));
unsafe {
librocksdb_sys::rocksdb_readoptions_set_iterate_upper_bound(
self.raw.ptr.as_ptr(),
key.as_ptr().cast::<libc::c_char>(),
libc::size_t::from(key.len()),
);
};
self.raw.keepalive_upper_bound = Some(KeepAlive::new(key));
}
}
impl Default for ReadOptions {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use crate::*;
#[test]
fn test_read_options_iterate_upper_bound_not_met() {
let dir = tempfile::tempdir().expect("create temp directory");
let db = {
let mut opts = Options::new();
opts.create_if_missing(true);
let txdb_opts = TransactionDbOptions::new();
Database::open(&dir, &opts, &txdb_opts, vec![]).unwrap()
};
let tx = {
let write_options = WriteOptions::new();
let tx_options = TransactionOptions::new();
db.transaction_begin(&write_options, &tx_options)
};
let column_family = db.get_column_family("default").unwrap();
tx.put(&column_family, "key", "dummy").unwrap();
let mut cursor = {
let mut read_opts = ReadOptions::new();
read_opts.set_iterate_upper_bound("l is after k");
tx.cursor(read_opts, &column_family)
};
cursor.seek_to_first();
// Upper bound should not affect this.
assert!(cursor.is_valid());
assert_eq!(cursor.key().unwrap(), "key".as_bytes());
assert_eq!(cursor.value().unwrap(), "dummy".as_bytes());
drop(db);
}
#[test]
fn test_read_options_iterate_upper_bound_is_met() {
let dir = tempfile::tempdir().expect("create temp directory");
let db = {
let mut opts = Options::new();
opts.create_if_missing(true);
let txdb_opts = TransactionDbOptions::new();
Database::open(&dir, &opts, &txdb_opts, vec![]).unwrap()
};
let tx = {
let write_options = WriteOptions::new();
let tx_options = TransactionOptions::new();
db.transaction_begin(&write_options, &tx_options)
};
let column_family = db.get_column_family("default").unwrap();
tx.put(&column_family, "key", "dummy").unwrap();
let mut iter = {
let mut read_opts = ReadOptions::new();
read_opts.set_iterate_upper_bound("a is before k");
tx.cursor(read_opts, &column_family)
};
iter.seek_to_first();
// Upper bound should prevent us from seeing `"key"`.
assert!(!iter.is_valid());
assert_eq!(iter.key(), None);
assert_eq!(iter.value(), None);
drop(db);
}
}

View file

@ -0,0 +1,161 @@
use crate::Database;
use crate::KeepAlive;
use crate::Transaction;
/// A snapshot in time of the database.
///
/// See [`Database::snapshot`], [`ReadOptions::set_snapshot`](crate::ReadOptions::set_snapshot), [`Transaction::get_snapshot`].
pub struct Snapshot {
raw: RawSnapshot,
}
enum SnapshotKind {
// Copy of a pointer from RocksDb transaction, originally created via [`TransactionOptions::set_snapshot`] called with `true`.
// Must not outlive its transaction.
// Drop by `rocksdb_free`.
TxView {
_keepalive_tx: KeepAlive<Transaction>,
},
// Pointer from e.g. [`librocksdb_sys::rocksdb_transactiondb_create_snapshot`].
// Must not outlive its database.
// Drop by [`librocksdb_sys::rocksdb_transactiondb_release_snapshot`], which also needs the database pointer.
Db {
db: Database,
},
}
struct RawSnapshot {
ptr: std::ptr::NonNull<librocksdb_sys::rocksdb_snapshot_t>,
kind: SnapshotKind,
}
unsafe impl Send for RawSnapshot {}
unsafe impl Sync for RawSnapshot {}
#[clippy::has_significant_drop]
impl Drop for RawSnapshot {
fn drop(&mut self) {
match &self.kind {
SnapshotKind::TxView { _keepalive_tx: _ } => unsafe {
librocksdb_sys::rocksdb_free(self.ptr.as_ptr().cast::<libc::c_void>());
},
SnapshotKind::Db { db } => unsafe {
librocksdb_sys::rocksdb_transactiondb_release_snapshot(
db.as_raw().as_ptr(),
self.ptr.as_ptr(),
);
},
}
}
}
impl Snapshot {
#[must_use]
pub(crate) fn create(db: Database) -> Snapshot {
let ptr =
unsafe { librocksdb_sys::rocksdb_transactiondb_create_snapshot(db.as_raw().as_ptr()) };
let ptr = std::ptr::NonNull::new(ptr.cast_mut()).expect(
"Snapshot::create: rocksdb_transactiondb_create_snapshot returned null pointer",
);
let kind = SnapshotKind::Db { db };
let raw = RawSnapshot { ptr, kind };
Snapshot { raw }
}
#[must_use]
pub(crate) fn from_tx(tx: Transaction) -> Option<Snapshot> {
let ptr = unsafe { librocksdb_sys::rocksdb_transaction_get_snapshot(tx.as_raw().as_ptr()) };
let ptr = std::ptr::NonNull::new(ptr.cast_mut())?;
let kind = SnapshotKind::TxView {
_keepalive_tx: KeepAlive::new(tx),
};
let raw = RawSnapshot { ptr, kind };
Some(Snapshot { raw })
}
pub(crate) const fn as_raw(&self) -> &std::ptr::NonNull<librocksdb_sys::rocksdb_snapshot_t> {
&self.raw.ptr
}
}
#[cfg(test)]
mod tests {
use crate::*;
#[test]
fn get_with_snapshot() {
let dir = tempfile::tempdir().expect("create temp directory");
let db = {
let mut opts = Options::new();
opts.create_if_missing(true);
let txdb_opts = TransactionDbOptions::new();
Database::open(&dir, &opts, &txdb_opts, vec![]).unwrap()
};
let column_family = db.get_column_family("default").unwrap();
let tx1 = {
let write_options = WriteOptions::new();
let tx_options = TransactionOptions::new();
db.transaction_begin(&write_options, &tx_options)
};
let tx2 = {
let write_options = WriteOptions::new();
let mut tx_options = TransactionOptions::new();
tx_options.set_snapshot(true);
db.transaction_begin(&write_options, &tx_options)
};
{
let mut read_opts = ReadOptions::new();
read_opts.set_snapshot(tx2.get_snapshot().unwrap());
let got = tx2.get_pinned(&column_family, "key", &read_opts).unwrap();
assert!(got.is_none());
}
tx1.put(&column_family, "key", "dummy").unwrap();
tx1.commit().unwrap();
{
let mut read_opts = ReadOptions::new();
read_opts.set_snapshot(tx2.get_snapshot().unwrap());
let got = tx2.get_pinned(&column_family, "key", &read_opts).unwrap();
assert!(got.is_none());
}
}
#[test]
fn get_without_snapshot() {
let dir = tempfile::tempdir().expect("create temp directory");
let db = {
let mut opts = Options::new();
opts.create_if_missing(true);
let txdb_opts = TransactionDbOptions::new();
Database::open(&dir, &opts, &txdb_opts, vec![]).unwrap()
};
let column_family = db.get_column_family("default").unwrap();
let tx1 = {
let write_options = WriteOptions::new();
let tx_options = TransactionOptions::new();
db.transaction_begin(&write_options, &tx_options)
};
let tx2 = {
let write_options = WriteOptions::new();
let tx_options = TransactionOptions::new();
db.transaction_begin(&write_options, &tx_options)
};
{
let read_opts = ReadOptions::new();
let got = tx2.get_pinned(&column_family, "key", &read_opts).unwrap();
assert!(got.is_none());
}
tx1.put(&column_family, "key", "dummy").unwrap();
tx1.commit().unwrap();
{
let read_opts = ReadOptions::new();
let got = tx2.get_pinned(&column_family, "key", &read_opts).unwrap();
assert!(got.is_some());
}
}
}

View file

@ -0,0 +1,508 @@
use std::sync::Arc;
use crate::error::RocksDbError;
use crate::ColumnFamily;
use crate::Cursor;
use crate::Database;
use crate::Iter;
use crate::KeepAlive;
use crate::ReadOptions;
use crate::Snapshot;
use crate::TransactionOptions;
use crate::View;
use crate::WriteOptions;
/// RocksDB database transaction.
///
/// This is the primary means to interact with the database contents (as far as this library is concerned).
///
/// # Resources
///
/// - <https://github.com/facebook/rocksdb/wiki/Transactions>
#[derive(Clone)]
pub struct Transaction {
raw: Arc<RawTransaction>,
}
struct RawTransaction {
ptr: std::ptr::NonNull<librocksdb_sys::rocksdb_transaction_t>,
_keepalive_database: KeepAlive<Database>,
}
unsafe impl Send for RawTransaction {}
unsafe impl Sync for RawTransaction {}
#[clippy::has_significant_drop]
impl Drop for RawTransaction {
#[maybe_tracing::instrument(skip(self), fields(ptr=?self.ptr))]
fn drop(&mut self) {
unsafe { librocksdb_sys::rocksdb_transaction_destroy(self.ptr.as_ptr()) }
}
}
/// When a "get for update" operations locks the fetched keys, is that lock shared between multiple readers or not.
///
/// See [`Transaction::get_for_update_pinned`].
//
// TODO i can't find docs for exact semantics (see also <https://github.com/rust-rocksdb/rust-rocksdb/issues/710>) #ecosystem/rocksdb #severity/high #urgency/medium
#[derive(Debug)]
#[expect(clippy::exhaustive_enums)]
#[expect(missing_docs)]
pub enum Exclusivity {
Exclusive,
Shared,
}
impl Exclusivity {
const fn as_raw_u8(&self) -> u8 {
match self {
Exclusivity::Exclusive => 1,
Exclusivity::Shared => 0,
}
}
}
impl Transaction {
#[maybe_tracing::instrument(skip(database, write_options, tx_options))]
pub(crate) fn new(
database: Database,
write_options: &WriteOptions,
tx_options: &TransactionOptions,
) -> Self {
// TODO RocksDB can reuse transactions (avoid free+alloc cost); I haven't found a good enough way to make that safe. #severity/low #urgency/medium #performance
let old_tx = std::ptr::null_mut();
let ptr = unsafe {
librocksdb_sys::rocksdb_transaction_begin(
database.as_raw().as_ptr(),
write_options.as_raw().as_ptr(),
tx_options.as_raw().as_ptr(),
old_tx,
)
};
let ptr = std::ptr::NonNull::new(ptr)
.expect("Transaction::new: rocksdb_transaction_begin returned null pointer");
let raw = RawTransaction {
ptr,
_keepalive_database: KeepAlive::new(database),
};
Self { raw: Arc::new(raw) }
}
pub(crate) fn as_raw(&self) -> &std::ptr::NonNull<librocksdb_sys::rocksdb_transaction_t> {
&self.raw.ptr
}
/// Create a save point that can be rolled back to.
///
/// RocksDB save points are an implicit stack, and the only operation for them is rolling back to the last-created savepoint.
/// This can be repeated to rollback as many savepoints as desired.
///
/// # Resources
///
/// - <https://github.com/facebook/rocksdb/wiki/Transactions#save-points>
pub fn savepoint(&self) {
unsafe {
librocksdb_sys::rocksdb_transaction_set_savepoint(self.raw.ptr.as_ptr());
}
}
/// Rollback to the last savepoint.
///
/// See [`Transaction::savepoint`].
pub fn rollback_to_last_savepoint(&self) -> Result<(), RocksDbError> {
let mut err: *mut libc::c_char = ::std::ptr::null_mut();
unsafe {
librocksdb_sys::rocksdb_transaction_rollback_to_savepoint(
self.raw.ptr.as_ptr(),
&mut err,
);
};
if err.is_null() {
Ok(())
} else {
let error = unsafe { RocksDbError::from_raw(err) };
Err(error)
}
}
/// Write a key-value pair.
#[maybe_tracing::instrument(err)]
pub fn put<Key, Value>(
&self,
column_family: &ColumnFamily,
key: Key,
value: Value,
) -> Result<(), RocksDbError>
where
Key: AsRef<[u8]> + std::fmt::Debug,
Value: AsRef<[u8]> + std::fmt::Debug,
{
let key = key.as_ref();
let value = value.as_ref();
let mut err: *mut libc::c_char = ::std::ptr::null_mut();
unsafe {
librocksdb_sys::rocksdb_transaction_put_cf(
self.raw.ptr.as_ptr(),
column_family.as_raw().as_ptr(),
key.as_ptr().cast::<libc::c_char>(),
libc::size_t::from(key.len()),
value.as_ptr().cast::<libc::c_char>(),
libc::size_t::from(value.len()),
&mut err,
);
};
tracing::trace!(?key, klen = key.len(), ?value, vlen = value.len(), "put");
if err.is_null() {
Ok(())
} else {
let error = unsafe { RocksDbError::from_raw(err) };
Err(error)
}
}
/// Read the value for given key.
#[maybe_tracing::instrument(skip(read_opts), err)]
pub fn get_pinned<Key>(
&self,
column_family: &ColumnFamily,
key: Key,
read_opts: &ReadOptions,
) -> Result<Option<View>, RocksDbError>
where
Key: AsRef<[u8]> + std::fmt::Debug,
{
let key = key.as_ref();
let key_ptr = key.as_ptr();
let mut err: *mut libc::c_char = ::std::ptr::null_mut();
let ptr = unsafe {
librocksdb_sys::rocksdb_transaction_get_pinned_cf(
self.raw.ptr.as_ptr(),
read_opts.as_raw().as_ptr(),
column_family.as_raw().as_ptr(),
key_ptr.cast::<libc::c_char>(),
key.len(),
&mut err,
)
};
if !err.is_null() {
let error = unsafe { RocksDbError::from_raw(err) };
Err(error)
} else if let Some(ptr) = std::ptr::NonNull::new(ptr) {
let view = View::from_owned_raw(ptr);
Ok(Some(view))
} else {
Ok(None)
}
}
/// Read the value for given key, locking it to protect against concurrent puts.
#[maybe_tracing::instrument(skip(read_opts), err)]
pub fn get_for_update_pinned<Key>(
&self,
column_family: &ColumnFamily,
key: Key,
read_opts: &ReadOptions,
exclusive: Exclusivity,
) -> Result<Option<View>, RocksDbError>
where
Key: AsRef<[u8]> + std::fmt::Debug,
{
let exclusive = exclusive.as_raw_u8();
let key = key.as_ref();
let key_ptr = key.as_ptr();
let mut err: *mut libc::c_char = ::std::ptr::null_mut();
let ptr = unsafe {
librocksdb_sys::rocksdb_transaction_get_pinned_for_update_cf(
self.raw.ptr.as_ptr(),
read_opts.as_raw().as_ptr(),
column_family.as_raw().as_ptr(),
key_ptr.cast::<libc::c_char>(),
key.len(),
exclusive,
&mut err,
)
};
if !err.is_null() {
let error = unsafe { RocksDbError::from_raw(err) };
Err(error)
} else if let Some(ptr) = std::ptr::NonNull::new(ptr) {
let view = View::from_owned_raw(ptr);
Ok(Some(view))
} else {
Ok(None)
}
}
/// Delete a key.
///
/// Deleting a non-existent key is a valid successful operation.
#[maybe_tracing::instrument(err)]
pub fn delete<Key>(&self, column_family: &ColumnFamily, key: Key) -> Result<(), RocksDbError>
where
Key: AsRef<[u8]> + std::fmt::Debug,
{
let key = key.as_ref();
let mut err: *mut libc::c_char = ::std::ptr::null_mut();
let tx = self.raw.ptr.as_ptr();
let column_family = column_family.as_raw().as_ptr();
let k = key.as_ptr().cast::<libc::c_char>();
let klen = libc::size_t::from(key.len());
unsafe {
librocksdb_sys::rocksdb_transaction_delete_cf(tx, column_family, k, klen, &mut err);
};
tracing::trace!(?key, klen = key.len(), "delete");
if err.is_null() {
Ok(())
} else {
let error = unsafe { RocksDbError::from_raw(err) };
Err(error)
}
}
/// Commit this transaction.
///
/// Calling `commit` on an already-finalized [`Transaction`] is an error.
#[maybe_tracing::instrument(ret, err)]
pub fn commit(&self) -> Result<(), RocksDbError> {
{
let mut err: *mut libc::c_char = ::std::ptr::null_mut();
unsafe {
librocksdb_sys::rocksdb_transaction_commit(self.raw.ptr.as_ptr(), &mut err);
}
if err.is_null() {
Ok(())
} else {
let error = unsafe { RocksDbError::from_raw(err) };
Err(error)
}
}?;
Ok(())
}
/// Rollback this transaction, discarding any changes made.
///
/// Calling `rollback` on an already-finalized [`Transaction`] is an error.
#[maybe_tracing::instrument(skip(self), err)]
pub fn rollback(&self) -> Result<(), RocksDbError> {
{
let mut err: *mut libc::c_char = ::std::ptr::null_mut();
unsafe {
librocksdb_sys::rocksdb_transaction_rollback(self.raw.ptr.as_ptr(), &mut err);
}
if err.is_null() {
Ok(())
} else {
let error = unsafe { RocksDbError::from_raw(err) };
Err(error)
}
}?;
Ok(())
}
/// Returns `Some` only when transaction was created with `TransactionOptions::set_snapshot(true)`.
#[maybe_tracing::instrument(skip(self))]
pub fn get_snapshot(&self) -> Option<Snapshot> {
Snapshot::from_tx(self.clone())
}
/// Create a cursor for scanning key-value pairs in `column_family` as seen by this transaction.
///
/// Simple uses are easier with [`Transaction::range`] or [`Transaction::iter`].
#[maybe_tracing::instrument(skip(read_opts, column_family))]
pub fn cursor(&self, read_opts: ReadOptions, column_family: &ColumnFamily) -> Cursor {
Cursor::new(self.clone(), read_opts, column_family)
}
/// Iterate through all key-value pairs in a column family.
///
/// RocksDB cursors do not match Rust [`Iterator`](std::iter::Iterator) semantics, as the underlying key and value bytes can only be borrowed.
/// The Cursor cannot be moved until current borrows have ended.
/// Instead, we use [`gat_lending_iterator::LendingIterator`], which matches this behavior.
///
/// Also note that the iterator items are fallible, as it may be doing I/O; see associated type `Item` in [`Iter`].
/// The example uses [`Option::transpose`] to go from `Option<Result>` to `Result<Option>`, but alternatively you could choose to say something like `let entry = entry?;` inside the loop.
///
/// ```rust
/// # let dir = tempfile::tempdir().unwrap();
/// # let mut options = rocky::Options::new();
/// # options.create_if_missing(true);
/// # let txdb_opts = rocky::TransactionDbOptions::new();
/// # let db = rocky::Database::open(dir, &options, &txdb_opts, []).unwrap();
/// # let write_options = rocky::WriteOptions::new();
/// # let tx_options = rocky::TransactionOptions::new();
/// # let tx = db.transaction_begin(&write_options, &tx_options);
/// # let column_family = db.create_column_family("dummy_cf", &options).unwrap();
/// use gat_lending_iterator::LendingIterator as _;
///
/// let read_options = rocky::ReadOptions::new();
/// let mut iter = tx.iter(read_options, &column_family);
/// while let Some(entry) = iter.next().transpose()? {
/// let key = entry.key();
/// let value = entry.value();
/// println!("{key:?}={value:?}");
/// }
/// # Ok::<(), rocky::RocksDbError>(())
/// ```
//
// TODO rustdoc cannot link to `<Iter as LendingIterator>::Item` <https://github.com/rust-lang/rust/issues/74563> #ecosystem/rust/rustdoc #doc #waiting
#[expect(clippy::iter_not_returning_iterator)]
#[maybe_tracing::instrument(skip(read_opts, column_family))]
pub fn iter(&self, read_opts: ReadOptions, column_family: &ColumnFamily) -> Iter {
self.range::<&[u8], _>(read_opts, column_family, ..)
}
/// Iterate through a range of key-value pairs in a column family.
///
/// Note: Iterating the full range (`..`, that is [`std::ops::RangeFull`]) requires type annotations.
/// Most uses of full ranges are better off calling [`Transaction::iter`].
///
/// ```rust
/// # let dir = tempfile::tempdir().unwrap();
/// # let mut options = rocky::Options::new();
/// # options.create_if_missing(true);
/// # let txdb_opts = rocky::TransactionDbOptions::new();
/// # let db = rocky::Database::open(dir, &options, &txdb_opts, []).unwrap();
/// # let write_options = rocky::WriteOptions::new();
/// # let tx_options = rocky::TransactionOptions::new();
/// # let tx = db.transaction_begin(&write_options, &tx_options);
/// # let read_options = rocky::ReadOptions::new();
/// # let column_family = db.create_column_family("dummy_cf", &options).unwrap();
/// let mut iter = tx.range::<&[u8], _>(read_options, &column_family, ..);
/// ```
#[maybe_tracing::instrument(skip(read_opts, column_family, range))]
pub fn range<K, R>(
&self,
read_opts: ReadOptions,
column_family: &ColumnFamily,
range: R,
) -> Iter
where
K: AsRef<[u8]>,
R: std::ops::RangeBounds<K> + std::fmt::Debug,
{
let cursor = Cursor::with_range(self.clone(), read_opts, column_family, range);
cursor.into_iter()
}
}
impl std::fmt::Debug for Transaction {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("rocky::Transaction")
.field("raw", &self.raw.ptr)
.finish()
}
}
#[cfg(test)]
mod tests {
use crate::Database;
use crate::Options;
use crate::ReadOptions;
use crate::RocksDbError;
use crate::TransactionDbOptions;
use crate::TransactionOptions;
use crate::WriteOptions;
#[test]
fn transaction_simple() {
let dir = tempfile::tempdir().expect("create temp directory");
let mut opts = Options::new();
opts.create_if_missing(true);
let txdb_opts = TransactionDbOptions::new();
let db = Database::open(&dir, &opts, &txdb_opts, vec![]).unwrap();
let cf_one = db.create_column_family("one", &Options::new()).unwrap();
let tx = {
let write_options = WriteOptions::new();
let tx_options = TransactionOptions::new();
db.transaction_begin(&write_options, &tx_options)
};
tx.put(&cf_one, "testkey", "testvalue").unwrap();
tx.commit().unwrap();
// Intentionally don't drop `tx` first here, in an effort to trigger more memory corruption bugs for valgrind to catch.
// drop(tx);
drop(db);
}
/// Deleting a non-existent key is not an error.
#[test]
fn delete_not_exist() {
let dir = tempfile::tempdir().expect("create temp directory");
let mut opts = Options::new();
opts.create_if_missing(true);
let txdb_opts = TransactionDbOptions::new();
let db = Database::open(&dir, &opts, &txdb_opts, vec![]).unwrap();
let cf_one = db.create_column_family("one", &Options::new()).unwrap();
let tx = {
let write_options = WriteOptions::new();
let tx_options = TransactionOptions::new();
db.transaction_begin(&write_options, &tx_options)
};
tx.delete(&cf_one, "testkey").unwrap();
tx.delete(&cf_one, "testkey").unwrap();
tx.delete(&cf_one, "testkey").unwrap();
tx.commit().unwrap();
// Intentionally don't drop `tx` first here, in an effort to trigger more memory corruption bugs for valgrind to catch.
// drop(tx);
drop(db);
}
#[test]
fn transaction_commit_and_keep_using() {
let dir = tempfile::tempdir().expect("create temp directory");
let mut opts = Options::new();
opts.create_if_missing(true);
let txdb_opts = TransactionDbOptions::new();
let db = Database::open(&dir, &opts, &txdb_opts, vec![]).unwrap();
let cf_one = db.create_column_family("one", &Options::new()).unwrap();
let tx = {
let write_options = WriteOptions::new();
let tx_options = TransactionOptions::new();
db.transaction_begin(&write_options, &tx_options)
};
tx.commit().unwrap();
// You can add writes to an already-closed transaction.
tx.put(&cf_one, "testkey", "testvalue").unwrap();
// They do not become visible to other transactions.
{
let tx2 = {
let write_options = WriteOptions::new();
let tx_options = TransactionOptions::new();
db.transaction_begin(&write_options, &tx_options)
};
let read_opts = ReadOptions::new();
let got = tx2.get_pinned(&cf_one, "testkey", &read_opts).unwrap();
assert!(got.is_none());
drop(tx2);
}
// Trying to commit again gets an error.
let error = tx.commit().err().unwrap();
assert!(matches!(error, RocksDbError::Other { message }
if message == "Invalid argument: Transaction has already been committed."
));
// Trying to rollback gets an error.
let error = tx.rollback().err().unwrap();
// lol, stringy errors that don't even match from one instance to another
assert!(matches!(error, RocksDbError::Other { message }
if message == "Invalid argument: This transaction has already been committed."
));
// Intentionally don't drop `tx` first here, in an effort to trigger more memory corruption bugs for valgrind to catch.
// drop(tx);
drop(db);
}
}

View file

@ -0,0 +1,55 @@
/// Configure a [`Transaction`](crate::Transaction).
///
/// See [`Database::transaction_begin`](crate::Database::transaction_begin).
pub struct TransactionOptions {
raw: RawTransactionOptions,
}
struct RawTransactionOptions(std::ptr::NonNull<librocksdb_sys::rocksdb_transaction_options_t>);
#[clippy::has_significant_drop]
impl Drop for RawTransactionOptions {
fn drop(&mut self) {
unsafe {
librocksdb_sys::rocksdb_transaction_options_destroy(self.0.as_ptr());
}
}
}
impl TransactionOptions {
#[expect(missing_docs)]
#[must_use]
pub fn new() -> TransactionOptions {
let ptr = unsafe { librocksdb_sys::rocksdb_transaction_options_create() };
let raw = RawTransactionOptions(std::ptr::NonNull::new(ptr).expect(
"TransactionOptions::new: rocksdb_transaction_options_create returned null pointer",
));
TransactionOptions { raw }
}
pub(crate) const fn as_raw(
&self,
) -> &std::ptr::NonNull<librocksdb_sys::rocksdb_transaction_options_t> {
&self.raw.0
}
/// Whether to use a snapshot from transaction start for conflict detection, or the time when a key was written.
///
/// For repeatable reads, see [`ReadOptions::set_snapshot`](crate::ReadOptions::set_snapshot).
///
/// # Resources
///
/// - <https://github.com/facebook/rocksdb/wiki/Transactions#setting-a-snapshot>
pub fn set_snapshot(&mut self, snapshot: bool) {
let arg = libc::c_uchar::from(snapshot);
unsafe {
librocksdb_sys::rocksdb_transaction_options_set_set_snapshot(self.raw.0.as_ptr(), arg);
};
}
}
impl Default for TransactionOptions {
fn default() -> Self {
Self::new()
}
}

View file

@ -0,0 +1,55 @@
/// Configure the transactional [`Database`](crate::Database).
pub struct TransactionDbOptions {
raw: RawTransactionDbOptions,
}
struct RawTransactionDbOptions(std::ptr::NonNull<librocksdb_sys::rocksdb_transactiondb_options_t>);
#[clippy::has_significant_drop]
impl Drop for RawTransactionDbOptions {
fn drop(&mut self) {
unsafe {
librocksdb_sys::rocksdb_transactiondb_options_destroy(self.0.as_ptr());
}
}
}
impl TransactionDbOptions {
#[expect(missing_docs)]
#[must_use]
pub fn new() -> TransactionDbOptions {
let ptr = unsafe { librocksdb_sys::rocksdb_transactiondb_options_create() };
let raw = RawTransactionDbOptions(std::ptr::NonNull::new(ptr).expect(
"TransactionDbOptions::new: rocksdb_transactiondb_options_create returned null pointer",
));
TransactionDbOptions { raw }
}
pub(crate) const fn as_raw(
&self,
) -> &std::ptr::NonNull<librocksdb_sys::rocksdb_transactiondb_options_t> {
&self.raw.0
}
fn duration_to_millis(duration: std::time::Duration) -> i64 {
let large = duration.as_millis();
i64::try_from(large).unwrap_or(i64::MAX)
}
/// Set a lock timeout.
pub fn set_transaction_lock_timeout(&mut self, timeout: Option<std::time::Duration>) {
let timeout_millisec = timeout.map(Self::duration_to_millis).unwrap_or(-1i64);
unsafe {
librocksdb_sys::rocksdb_transactiondb_options_set_transaction_lock_timeout(
self.raw.0.as_ptr(),
timeout_millisec,
);
}
}
}
impl Default for TransactionDbOptions {
fn default() -> Self {
Self::new()
}
}

79
crates/rocky/src/view.rs Normal file
View file

@ -0,0 +1,79 @@
use std::fmt::Debug;
use std::ops::Deref;
use std::sync::Arc;
/// View of bytes held alive by a RocksDB snapshot.
#[derive(Clone)]
pub struct View {
// TODO advocate for `rkyv_util` to implement `StableBytes` for `Arc<T: StableBytes>`, remove the Arc & Clone from here #ecosystem/rkyv #severity/low #urgency/low #dev
raw: Arc<RawPinnableSlice>,
}
struct RawPinnableSlice(std::ptr::NonNull<librocksdb_sys::rocksdb_pinnableslice_t>);
unsafe impl Send for RawPinnableSlice {}
unsafe impl Sync for RawPinnableSlice {}
#[clippy::has_significant_drop]
impl Drop for RawPinnableSlice {
fn drop(&mut self) {
unsafe { librocksdb_sys::rocksdb_pinnableslice_destroy(self.0.as_ptr()) };
}
}
impl View {
#[must_use]
pub(crate) fn from_owned_raw(
ptr: std::ptr::NonNull<librocksdb_sys::rocksdb_pinnableslice_t>,
) -> View {
let raw = RawPinnableSlice(ptr);
Self { raw: Arc::new(raw) }
}
/// Get the contents as bytes.
#[must_use]
pub fn as_bytes(&self) -> &[u8] {
let mut out_len: libc::size_t = 0;
let value = unsafe {
librocksdb_sys::rocksdb_pinnableslice_value(self.raw.0.as_ptr(), &mut out_len)
}
.cast::<u8>();
let slice = unsafe { std::slice::from_raw_parts(value, out_len) };
slice
}
}
impl Deref for View {
type Target = [u8];
fn deref(&self) -> &[u8] {
self.as_bytes()
}
}
impl PartialEq for View {
fn eq(&self, other: &View) -> bool {
self.as_bytes() == other.as_bytes()
}
}
impl Eq for View {}
impl PartialEq<&[u8]> for View {
fn eq(&self, other: &&[u8]) -> bool {
self.as_bytes() == *other
}
}
impl Debug for View {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "View({0:?})", &**self)
}
}
unsafe impl rkyv_util::owned::StableBytes for View {
fn bytes(&self) -> &[u8] {
self
}
}

View file

@ -0,0 +1,40 @@
/// Configuration for write operations.
pub struct WriteOptions {
raw: RawWriteOptions,
}
struct RawWriteOptions(std::ptr::NonNull<librocksdb_sys::rocksdb_writeoptions_t>);
#[clippy::has_significant_drop]
impl Drop for RawWriteOptions {
fn drop(&mut self) {
unsafe {
librocksdb_sys::rocksdb_writeoptions_destroy(self.0.as_ptr());
}
}
}
impl WriteOptions {
#[expect(missing_docs)]
#[must_use]
pub fn new() -> WriteOptions {
let ptr = unsafe { librocksdb_sys::rocksdb_writeoptions_create() };
let raw = RawWriteOptions(
std::ptr::NonNull::new(ptr)
.expect("WriteOptions::new: rocksdb_writeoptions_create returned null pointer"),
);
WriteOptions { raw }
}
pub(crate) const fn as_raw(
&self,
) -> &std::ptr::NonNull<librocksdb_sys::rocksdb_writeoptions_t> {
&self.raw.0
}
}
impl Default for WriteOptions {
fn default() -> Self {
Self::new()
}
}

Some files were not shown because too many files have changed in this diff Show more