1
Fork 0

First public release

This commit is contained in:
Tommi Virtanen 2025-03-27 13:45:43 -06:00 committed by Tommi Virtanen
parent da7aba533f
commit 5e1ec711c4
457 changed files with 27082 additions and 0 deletions

6
.cargo/mutants.toml Normal file
View file

@ -0,0 +1,6 @@
exclude_re = [
# test tool, not part of system under test
"<impl arbitrary::Arbitrary<",
# we do not test debug output, at least for now
"<impl std::fmt::Debug ",
]

20
.config/nextest.toml Normal file
View file

@ -0,0 +1,20 @@
[profile.default]
slow-timeout = { period = "5s", terminate-after = 2 }
# We're likely primarily I/O bound, even in tests.
# This might change if we add durability cheating features for tests, or move to primarily testing against storage abstractions.
test-threads = 20
[[profile.default.overrides]]
# The SQLite sqllogictest test suite has tens of thousands of queries in a single "test", give them more time to complete.
# Expected completion time is <30 seconds per test, but leave a generous margin.
filter = "package(=kanto) and kind(=test) and binary(=sqllogictest_sqlite)"
threads-required = 4
slow-timeout = { period = "1m", terminate-after = 5 }
[profile.valgrind]
slow-timeout = { period = "5m", terminate-after = 2 }
[profile.default-miri]
slow-timeout = { period = "5m", terminate-after = 2 }
# Miri cannot handle arbitrary C/C++ code.
default-filter = "not rdeps(rocky)"

5
.gitignore vendored Normal file
View file

@ -0,0 +1,5 @@
/target/
/result
/dev-db/
# valgrind artifacts
vgcore.*

4588
Cargo.lock generated Normal file

File diff suppressed because it is too large Load diff

227
Cargo.toml Normal file
View file

@ -0,0 +1,227 @@
[workspace]
resolver = "3"
members = ["crates/*", "crates/*/fuzz", "task/*"]
default-members = ["crates/*"]
[workspace.package]
# explicitly set a dummy version so nix build tasks using the workspace don't need to each set it
version = "0.0.0"
authors = ["Tommi Virtanen <tv@eagain.net>"]
license = "Apache-2.0"
homepage = "https://kantodb.com"
repository = "https://git.kantodb.com/kantodb/kantodb"
edition = "2024"
rust-version = "1.85.0"
[workspace.dependencies]
anyhow = "1.0.89"
arbitrary = { version = "1.4.1", features = [
"derive",
] } # sync major version against [`libfuzzer-sys`] re-export or fuzzing will break.
arbitrary-arrow = { version = "0.1.0", path = "crates/arbitrary-arrow" }
arrow = "54.2.0"
async-trait = "0.1.83"
clap = { version = "4.5.19", features = ["derive", "env"] }
dashmap = "6.1.0"
datafusion = "46.0.1"
duplicate = "2.0.0"
futures = "0.3.30"
futures-lite = "2.3.0"
gat-lending-iterator = "0.1.6"
hex = "0.4.3"
ignore = "0.4.23"
indexmap = "2.7.0" # sync major version against [`rkyv`] feature `indexmap-2`
kanto = { version = "0.1.0", path = "crates/kanto" }
kanto-backend-rocksdb = { version = "0.1.0", path = "crates/backend-rocksdb" }
kanto-index-format-v1 = { version = "0.1.0", path = "crates/index-format-v1" }
kanto-key-format-v1 = { version = "0.1.0", path = "crates/key-format-v1" }
kanto-meta-format-v1 = { version = "0.1.0", path = "crates/meta-format-v1" }
kanto-protocol-postgres = { version = "0.1.0", path = "crates/protocol-postgres" }
kanto-record-format-v1 = { version = "0.1.0", path = "crates/record-format-v1" }
kanto-testutil = { version = "0.1.0", path = "crates/testutil" }
kanto-tunables = { version = "0.1.0", path = "crates/tunables" }
libc = "0.2.153"
libfuzzer-sys = "0.4.9"
librocksdb-sys = { version = "0.17.1", features = ["lz4", "static", "zstd"] }
libtest-mimic = "0.8.1"
maybe-tracing = { version = "0.1.0", path = "crates/maybe-tracing" }
parking_lot = "0.12.3"
paste = "1.0.15"
pgwire = "0.28.0"
regex = "1.10.6"
rkyv = { version = "0.8.10", features = ["unaligned", "indexmap-2"] }
rkyv_util = { version = "0.1.0-alpha.1" }
rocky = { version = "0.1.0", path = "crates/rocky" }
smallvec = { version = "1.13.2", features = ["union"] }
sqllogictest = "0.28.0"
tempfile = "3.10.1"
test-log = { version = "0.2.16", default-features = false, features = [
"trace",
] }
thiserror = "2.0.6"
tokio = { version = "1.40.0", features = ["macros", "rt-multi-thread"] }
tracing = "0.1.40"
tracing-subscriber = { version = "0.3.18", default-features = false, features = [
"env-filter",
"fmt",
"json",
] }
tree-sitter = "0.23.0"
tree-sitter-md = { version = "0.3.2", features = ["parser"] }
tree-sitter-rust = "0.23.0"
unindent = "0.2.3"
walkdir = "2.5.0"
[workspace.lints.rust]
absolute_paths_not_starting_with_crate = "warn"
ambiguous_negative_literals = "warn"
closure_returning_async_block = "warn"
dead_code = "warn"
deprecated_safe_2024 = "deny"
elided_lifetimes_in_paths = "warn"
explicit_outlives_requirements = "warn"
future_incompatible = { level = "warn", priority = -1 }
if_let_rescope = "warn"
impl_trait_redundant_captures = "warn"
keyword_idents_2024 = "warn"
let_underscore_drop = "warn"
macro_use_extern_crate = "deny"
meta_variable_misuse = "warn"
missing_abi = "deny"
missing_copy_implementations = "allow" # TODO switch this to warn #urgency/low
missing_debug_implementations = "allow" # TODO switch this to warn #urgency/low
missing_docs = "warn"
missing_unsafe_on_extern = "deny"
non_ascii_idents = "forbid"
redundant_imports = "warn"
redundant_lifetimes = "warn"
rust_2018_idioms = { level = "warn", priority = -1 }
trivial_casts = "warn"
trivial_numeric_casts = "warn"
unexpected_cfgs = { level = "allow", check-cfg = ["cfg(kani)"] }
unit_bindings = "warn"
unnameable_types = "warn"
unreachable_pub = "warn"
unsafe_attr_outside_unsafe = "deny"
unsafe_code = "deny"
unsafe_op_in_unsafe_fn = "deny"
unused_crate_dependencies = "allow" # TODO false positives for packages with bin/lib/integration tests; desirable because this catches things `cargo-shear` does not <https://github.com/rust-lang/rust/issues/95513> #waiting #ecosystem/rust #severity/low #urgency/low
unused_extern_crates = "warn"
unused_import_braces = "warn"
unused_lifetimes = "warn"
unused_macro_rules = "warn"
unused_qualifications = "warn"
unused_results = "warn"
variant_size_differences = "warn"
[workspace.lints.clippy]
all = { level = "warn", priority = -2 }
allow_attributes = "warn"
arithmetic_side_effects = "warn"
as_conversions = "warn"
as_underscore = "warn"
assertions_on_result_states = "deny"
cargo = { level = "warn", priority = -1 }
cast_lossless = "warn"
cast_possible_truncation = "warn"
cfg_not_test = "warn"
complexity = { level = "warn", priority = -1 }
dbg_macro = "warn"
default_numeric_fallback = "warn"
disallowed_script_idents = "deny"
else_if_without_else = "warn"
empty_drop = "deny"
empty_enum_variants_with_brackets = "warn"
empty_structs_with_brackets = "warn"
error_impl_error = "warn"
exhaustive_enums = "warn"
exhaustive_structs = "warn"
exit = "deny"
expect_used = "allow" # TODO switch this to warn
field_scoped_visibility_modifiers = "allow" # TODO switch this to warn #urgency/medium
float_cmp_const = "warn"
fn_to_numeric_cast_any = "deny"
format_push_string = "warn"
host_endian_bytes = "deny"
if_not_else = "allow"
if_then_some_else_none = "warn"
indexing_slicing = "warn"
infinite_loop = "warn"
integer_division = "warn"
iter_over_hash_type = "warn"
large_include_file = "warn"
let_underscore_must_use = "warn"
let_underscore_untyped = "warn"
lossy_float_literal = "deny"
manual_is_power_of_two = "warn"
map_err_ignore = "warn"
map_unwrap_or = "allow"
match_like_matches_macro = "allow"
mem_forget = "deny"
missing_asserts_for_indexing = "warn"
missing_const_for_fn = "warn"
missing_errors_doc = "allow" # TODO write docstrings, remove this #urgency/medium
missing_inline_in_public_items = "allow" # TODO consider warn #urgency/low #performance
missing_panics_doc = "allow" # TODO write docstrings, remove this #urgency/medium
mixed_read_write_in_expression = "deny"
module_name_repetitions = "warn"
modulo_arithmetic = "warn"
multiple_crate_versions = "allow" # TODO i wish there was a way to fix this <https://github.com/rust-lang/rust-clippy/issues/9756> #urgency/low
multiple_unsafe_ops_per_block = "deny"
mutex_atomic = "warn"
non_ascii_literal = "deny"
non_zero_suggestions = "warn"
panic = "warn"
panic_in_result_fn = "warn" # note, this allows `debug_assert` but not `assert`
partial_pub_fields = "warn"
pathbuf_init_then_push = "warn"
pedantic = { level = "warn", priority = -1 }
perf = { level = "warn", priority = -1 }
print_stderr = "warn"
print_stdout = "warn"
pub_without_shorthand = "warn"
rc_buffer = "warn"
rc_mutex = "deny"
redundant_type_annotations = "warn"
ref_patterns = "warn"
renamed_function_params = "warn"
rest_pat_in_fully_bound_structs = "warn"
same_name_method = "deny"
semicolon_inside_block = "warn"
separated_literal_suffix = "warn"
set_contains_or_insert = "warn"
significant_drop_in_scrutinee = "warn"
similar_names = "allow" # too eager
str_to_string = "warn"
string_add = "warn"
string_lit_chars_any = "warn"
string_slice = "warn"
string_to_string = "warn"
style = { level = "warn", priority = -1 }
suspicious = { level = "warn", priority = -1 }
suspicious_xor_used_as_pow = "warn"
tests_outside_test_module = "allow" # TODO to switch this to warn <https://github.com/rust-lang/rust-clippy/issues/11024> #waiting #urgency/low #severity/low
todo = "warn"
try_err = "warn"
undocumented_unsafe_blocks = "allow" # TODO write safety docs, switch this to warn #urgency/low
unimplemented = "warn"
unnecessary_safety_comment = "warn"
unnecessary_safety_doc = "warn"
unnecessary_self_imports = "deny"
unreachable = "deny"
unused_result_ok = "warn"
unused_trait_names = "warn"
unwrap_in_result = "warn"
unwrap_used = "warn"
use_debug = "warn"
verbose_file_reads = "warn"
wildcard_imports = "warn"
[workspace.metadata.spellcheck]
config = "cargo-spellcheck.toml"
[workspace.metadata.crane]
name = "kantodb"
[profile.release]
lto = true

202
LICENSE Normal file
View file

@ -0,0 +1,202 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

24
cargo-spellcheck.toml Normal file
View file

@ -0,0 +1,24 @@
# TODO this is too painful with commented-out code, but commented-out code is a code smell of its own
#dev_comments = true
[Hunspell]
# TODO setting config overrides defaults <https://github.com/drahnr/cargo-spellcheck/issues/293#issuecomment-2397728638> #ecosystem/cargo-spellcheck #dev #severity/low
use_builtin = true
# Use built-in dictionary only, for reproducibility.
skip_os_lookups = true
search_dirs = ["."]
extra_dictionaries = ["hunspell.dict"]
# TODO "C++" is tokenized wrong <https://github.com/drahnr/cargo-spellcheck/issues/272> #ecosystem/cargo-spellcheck #dev #severity/low
tokenization_splitchars = "\",;:!?#(){}[]|/_-+='`&@§¶…"
[Hunspell.quirks]
transform_regex = [
# 2025Q1 style references to dates.
"^(20\\d\\d)Q([1234])$",
# version numbers; tokenization splits off everything at the dot so be conservative
"^v[012]$",
]

25
clippy.toml Normal file
View file

@ -0,0 +1,25 @@
allow-dbg-in-tests = true
allow-expect-in-tests = true
allow-indexing-slicing-in-tests = true
allow-panic-in-tests = true
allow-print-in-tests = true
allow-unwrap-in-tests = true
avoid-breaking-exported-api = false # TODO let clippy suggest more radical changes; change when approaching stable #urgency/low #severity/medium
disallowed-methods = [
{ path = "std::iter::Iterator::for_each", reason = "prefer `for` for side-effects" },
{ path = "std::iter::Iterator::try_for_each", reason = "prefer `for` for side-effects" },
{ path = "std::option::Option::map_or", reason = "prefer `map(..).unwrap_or(..)` for legibility" },
{ path = "std::option::Option::map_or_else", reason = "prefer `map(..).unwrap_or_else(..)` for legibility" },
{ path = "std::result::Result::map_or", reason = "prefer `map(..).unwrap_or(..)` for legibility" },
{ path = "std::result::Result::map_or_else", reason = "prefer `map(..).unwrap_or_else(..)` for legibility" },
]
doc-valid-idents = [
"..",
"CapnProto",
"DataFusion",
"KantoDB",
"RocksDB",
"SQLite",
]
semicolon-inside-block-ignore-singleline = true
warn-on-all-wildcard-imports = true

View file

@ -0,0 +1,28 @@
[package]
name = "arbitrary-arrow"
version = "0.1.0"
description = "Fuzzing support for Apache Arrow"
keywords = ["arrow", "fuzzing", "testing"]
categories = ["development-tools::testing"]
authors.workspace = true
license.workspace = true
repository.workspace = true
publish = false # TODO merge upstream or publish crate #severity/high #urgency/medium
edition.workspace = true
rust-version.workspace = true
[package.metadata.cargo-shear]
ignored = ["test-log"]
[dependencies]
arbitrary = { workspace = true }
arrow = { workspace = true }
tracing = { workspace = true }
[dev-dependencies]
getrandom = "0.3.1"
maybe-tracing = { workspace = true }
test-log = { workspace = true, features = ["trace"] }
[lints]
workspace = true

View file

@ -0,0 +1,4 @@
/target/
/corpus/
/artifacts/
/coverage/

View file

@ -0,0 +1,24 @@
[package]
name = "arbitrary-arrow-fuzz"
version = "0.0.0"
repository.workspace = true
authors.workspace = true
license.workspace = true
publish = false
edition.workspace = true
rust-version.workspace = true
[package.metadata]
cargo-fuzz = true
[dependencies]
arbitrary-arrow = { workspace = true }
arrow = { workspace = true, features = ["prettyprint"] }
libfuzzer-sys = { workspace = true }
[[bin]]
name = "record_batch_pretty"
path = "fuzz_targets/fuzz_record_batch_pretty.rs"
test = false
doc = false
bench = false

View file

@ -0,0 +1,33 @@
#![no_main]
use arbitrary_arrow::arbitrary;
use libfuzzer_sys::fuzz_target;
#[derive(Debug, arbitrary::Arbitrary)]
struct Input {
record_batches: Box<[arbitrary_arrow::RecordBatchBuilder]>,
}
fn is_all_printable(s: &str) -> bool {
const PRINTABLE: std::ops::RangeInclusive<u8> = 0x20..=0x7e;
s.bytes().all(|b| b == b'\n' || PRINTABLE.contains(&b))
}
fuzz_target!(|input: Input| {
let Input { record_batches } = input;
let record_batches = record_batches
.into_iter()
.map(|builder| builder.build())
.collect::<Box<_>>();
let display = arrow::util::pretty::pretty_format_batches(&record_batches).unwrap();
let output = format!("{display}");
if !is_all_printable(&output) {
for line in output.lines() {
println!("{line:#?}");
}
// TODO this would be desirable, but it's not true yet
// panic!("non-printable output: {b:#?}");
}
});

View file

@ -0,0 +1,498 @@
use std::sync::Arc;
use arbitrary::Arbitrary as _;
use arbitrary::Unstructured;
use arrow::array::Array;
use arrow::array::BooleanBufferBuilder;
use arrow::array::FixedSizeBinaryArray;
use arrow::array::GenericByteBuilder;
use arrow::array::GenericByteViewBuilder;
use arrow::array::PrimitiveArray;
use arrow::buffer::Buffer;
use arrow::buffer::MutableBuffer;
use arrow::buffer::NullBuffer;
use arrow::datatypes::ArrowNativeType as _;
use arrow::datatypes::ArrowTimestampType;
use arrow::datatypes::BinaryViewType;
use arrow::datatypes::ByteArrayType;
use arrow::datatypes::ByteViewType;
use arrow::datatypes::DataType;
use arrow::datatypes::Field;
use arrow::datatypes::GenericBinaryType;
use arrow::datatypes::GenericStringType;
use arrow::datatypes::IntervalUnit;
use arrow::datatypes::StringViewType;
use arrow::datatypes::TimeUnit;
/// Make an arbitrary [`PrimitiveArray<T>`].
#[tracing::instrument(skip(unstructured), ret)]
pub fn arbitrary_primitive_array<'a, T>(
unstructured: &mut Unstructured<'a>,
nullable: bool,
row_count: usize,
) -> Result<PrimitiveArray<T>, arbitrary::Error>
where
T: arrow::datatypes::ArrowPrimitiveType,
{
// Arrow primitive arrays are by definition safe to fill with random bytes.
let width = T::Native::get_byte_width();
let values = {
let capacity = width
.checked_mul(row_count)
.ok_or(arbitrary::Error::IncorrectFormat)?;
let mut buf = MutableBuffer::from_len_zeroed(capacity);
debug_assert_eq!(buf.as_slice().len(), capacity);
unstructured.fill_buffer(buf.as_slice_mut())?;
buf.into()
};
let nulls = if nullable {
let mut builder = BooleanBufferBuilder::new(row_count);
builder.resize(row_count);
debug_assert_eq!(builder.as_slice().len(), row_count.div_ceil(8));
unstructured.fill_buffer(builder.as_slice_mut())?;
let buffer = NullBuffer::new(builder.finish());
Some(buffer)
} else {
None
};
let builder = PrimitiveArray::<T>::new(values, nulls);
Ok(builder)
}
#[tracing::instrument(skip(unstructured), ret)]
fn arbitrary_primitive_array_dyn<'a, T>(
unstructured: &mut Unstructured<'a>,
nullable: bool,
row_count: usize,
) -> Result<Arc<dyn Array>, arbitrary::Error>
where
T: arrow::datatypes::ArrowPrimitiveType,
{
let array: PrimitiveArray<T> = arbitrary_primitive_array(unstructured, nullable, row_count)?;
Ok(Arc::new(array))
}
/// Make an arbitrary [`PrimitiveArray`] with any [`ArrowTimestampType`], in the given time zone.
#[tracing::instrument(skip(unstructured), ret)]
pub fn arbitrary_timestamp_array<'a, T>(
unstructured: &mut Unstructured<'a>,
nullable: bool,
row_count: usize,
time_zone: Option<&Arc<str>>,
) -> Result<PrimitiveArray<T>, arbitrary::Error>
where
T: ArrowTimestampType,
{
let array = arbitrary_primitive_array::<T>(unstructured, nullable, row_count)?
.with_timezone_opt(time_zone.cloned());
Ok(array)
}
#[tracing::instrument(skip(unstructured), ret)]
fn arbitrary_timestamp_array_dyn<'a, T>(
unstructured: &mut Unstructured<'a>,
field: &Field,
row_count: usize,
time_zone: Option<&Arc<str>>,
) -> Result<Arc<dyn Array>, arbitrary::Error>
where
T: ArrowTimestampType,
{
let array =
arbitrary_timestamp_array::<T>(unstructured, field.is_nullable(), row_count, time_zone)?;
Ok(Arc::new(array))
}
/// Make an arbitrary [`GenericByteArray`](arrow::array::GenericByteArray), such as [`StringArray`](arrow::array::StringArray) or [`BinaryArray`](arrow::array::BinaryArray).
///
/// # Examples
///
/// ```rust
/// use arbitrary_arrow::arbitrary::Arbitrary as _;
/// use arbitrary_arrow::arbitrary_byte_array;
/// use arrow::datatypes::GenericStringType;
/// # let unstructured = &mut arbitrary_arrow::arbitrary::Unstructured::new(b"fake entropy");
/// let row_count = 10;
/// let array = arbitrary_byte_array::<GenericStringType<i32>, _>(row_count, || {
/// Option::<String>::arbitrary(unstructured)
/// })?;
/// # Ok::<(), arbitrary_arrow::arbitrary::Error>(())
/// ```
#[tracing::instrument(skip(generate), ret)]
pub fn arbitrary_byte_array<'a, T, V>(
row_count: usize,
generate: impl FnMut() -> Result<Option<V>, arbitrary::Error>,
) -> Result<Arc<dyn Array>, arbitrary::Error>
where
T: ByteArrayType,
V: AsRef<T::Native>,
{
let mut builder = GenericByteBuilder::<T>::with_capacity(row_count, 0);
for result in std::iter::repeat_with(generate).take(row_count) {
let value = result?;
builder.append_option(value);
}
let array = builder.finish();
Ok(Arc::new(array))
}
/// Make an arbitrary [`GenericByteViewArray`](arrow::array::GenericByteViewArray), that is either a [`StringViewArray`](arrow::array::StringViewArray) or a [`BinaryViewArray`](arrow::array::BinaryViewArray).
///
/// # Examples
///
/// ```rust
/// use arbitrary_arrow::arbitrary_byte_view_array;
/// use arrow::datatypes::StringViewType;
///
/// use crate::arbitrary_arrow::arbitrary::Arbitrary as _;
/// # let unstructured = &mut arbitrary_arrow::arbitrary::Unstructured::new(b"fake entropy");
/// let row_count = 10;
/// let array = arbitrary_byte_view_array::<StringViewType, _>(row_count, || {
/// String::arbitrary(unstructured).map(|s| Some(s))
/// })?;
/// # Ok::<(), arbitrary_arrow::arbitrary::Error>(())
/// ```
#[tracing::instrument(skip(generate), ret)]
pub fn arbitrary_byte_view_array<'a, T, V>(
row_count: usize,
generate: impl FnMut() -> Result<Option<V>, arbitrary::Error>,
) -> Result<Arc<dyn Array>, arbitrary::Error>
where
T: ByteViewType,
V: AsRef<T::Native>,
{
let mut builder = GenericByteViewBuilder::<T>::with_capacity(row_count);
for result in std::iter::repeat_with(generate).take(row_count) {
let value = result?;
builder.append_option(value);
}
let array = builder.finish();
Ok(Arc::new(array))
}
/// Make an [arbitrary] array of the type described by a [`Field`] of an Arrow [`Schema`](arrow::datatypes::Schema).
///
/// Implemented as a helper function, not via the [`Arbitrary`](arbitrary::Arbitrary) trait, because we need to construct arrays with specific data type, nullability, and row count.
#[tracing::instrument(skip(unstructured), ret)]
pub fn arbitrary_array<'a>(
unstructured: &mut Unstructured<'a>,
field: &Field,
row_count: usize,
) -> Result<Arc<dyn Array>, arbitrary::Error> {
let array: Arc<dyn Array> = match field.data_type() {
DataType::Null => Arc::new(arrow::array::NullArray::new(row_count)),
DataType::Boolean => {
if field.is_nullable() {
let vec = std::iter::repeat_with(|| unstructured.arbitrary::<Option<bool>>())
.take(row_count)
.collect::<Result<Vec<_>, _>>()?;
Arc::new(arrow::array::BooleanArray::from(vec))
} else {
let vec = std::iter::repeat_with(|| unstructured.arbitrary::<bool>())
.take(row_count)
.collect::<Result<Vec<_>, _>>()?;
Arc::new(arrow::array::BooleanArray::from(vec))
}
}
DataType::Int8 => arbitrary_primitive_array_dyn::<arrow::datatypes::Int8Type>(
unstructured,
field.is_nullable(),
row_count,
)?,
DataType::Int16 => arbitrary_primitive_array_dyn::<arrow::datatypes::Int16Type>(
unstructured,
field.is_nullable(),
row_count,
)?,
DataType::Int32 => arbitrary_primitive_array_dyn::<arrow::datatypes::Int32Type>(
unstructured,
field.is_nullable(),
row_count,
)?,
DataType::Int64 => arbitrary_primitive_array_dyn::<arrow::datatypes::Int64Type>(
unstructured,
field.is_nullable(),
row_count,
)?,
DataType::UInt8 => arbitrary_primitive_array_dyn::<arrow::datatypes::UInt8Type>(
unstructured,
field.is_nullable(),
row_count,
)?,
DataType::UInt16 => arbitrary_primitive_array_dyn::<arrow::datatypes::UInt16Type>(
unstructured,
field.is_nullable(),
row_count,
)?,
DataType::UInt32 => arbitrary_primitive_array_dyn::<arrow::datatypes::UInt32Type>(
unstructured,
field.is_nullable(),
row_count,
)?,
DataType::UInt64 => arbitrary_primitive_array_dyn::<arrow::datatypes::UInt64Type>(
unstructured,
field.is_nullable(),
row_count,
)?,
DataType::Float16 => arbitrary_primitive_array_dyn::<arrow::datatypes::Float16Type>(
unstructured,
field.is_nullable(),
row_count,
)?,
DataType::Float32 => arbitrary_primitive_array_dyn::<arrow::datatypes::Float32Type>(
unstructured,
field.is_nullable(),
row_count,
)?,
DataType::Float64 => arbitrary_primitive_array_dyn::<arrow::datatypes::Float64Type>(
unstructured,
field.is_nullable(),
row_count,
)?,
DataType::Timestamp(TimeUnit::Second, time_zone) => {
arbitrary_timestamp_array_dyn::<arrow::datatypes::TimestampSecondType>(
unstructured,
field,
row_count,
time_zone.as_ref(),
)?
}
DataType::Timestamp(TimeUnit::Millisecond, time_zone) => {
arbitrary_timestamp_array_dyn::<arrow::datatypes::TimestampMillisecondType>(
unstructured,
field,
row_count,
time_zone.as_ref(),
)?
}
DataType::Timestamp(TimeUnit::Microsecond, time_zone) => {
arbitrary_timestamp_array_dyn::<arrow::datatypes::TimestampMicrosecondType>(
unstructured,
field,
row_count,
time_zone.as_ref(),
)?
}
DataType::Timestamp(TimeUnit::Nanosecond, time_zone) => {
arbitrary_timestamp_array_dyn::<arrow::datatypes::TimestampNanosecondType>(
unstructured,
field,
row_count,
time_zone.as_ref(),
)?
}
DataType::Date32 => arbitrary_primitive_array_dyn::<arrow::datatypes::Date32Type>(
unstructured,
field.is_nullable(),
row_count,
)?,
DataType::Date64 => arbitrary_primitive_array_dyn::<arrow::datatypes::Date64Type>(
unstructured,
field.is_nullable(),
row_count,
)?,
DataType::Time32(TimeUnit::Second) => arbitrary_primitive_array_dyn::<
arrow::datatypes::Time32SecondType,
>(
unstructured, field.is_nullable(), row_count
)?,
DataType::Time32(TimeUnit::Millisecond) => arbitrary_primitive_array_dyn::<
arrow::datatypes::Time32MillisecondType,
>(
unstructured, field.is_nullable(), row_count
)?,
DataType::Time32(TimeUnit::Nanosecond | TimeUnit::Microsecond)
| DataType::Time64(TimeUnit::Second | TimeUnit::Millisecond) => {
return Err(arbitrary::Error::IncorrectFormat);
}
DataType::Time64(TimeUnit::Microsecond) => arbitrary_primitive_array_dyn::<
arrow::datatypes::Time64MicrosecondType,
>(
unstructured, field.is_nullable(), row_count
)?,
DataType::Time64(TimeUnit::Nanosecond) => arbitrary_primitive_array_dyn::<
arrow::datatypes::Time64NanosecondType,
>(
unstructured, field.is_nullable(), row_count
)?,
DataType::Duration(TimeUnit::Second) => arbitrary_primitive_array_dyn::<
arrow::datatypes::DurationSecondType,
>(
unstructured, field.is_nullable(), row_count
)?,
DataType::Duration(TimeUnit::Millisecond) => {
arbitrary_primitive_array_dyn::<arrow::datatypes::DurationMillisecondType>(
unstructured,
field.is_nullable(),
row_count,
)?
}
DataType::Duration(TimeUnit::Microsecond) => {
arbitrary_primitive_array_dyn::<arrow::datatypes::DurationMicrosecondType>(
unstructured,
field.is_nullable(),
row_count,
)?
}
DataType::Duration(TimeUnit::Nanosecond) => arbitrary_primitive_array_dyn::<
arrow::datatypes::DurationNanosecondType,
>(
unstructured, field.is_nullable(), row_count
)?,
DataType::Interval(IntervalUnit::YearMonth) => {
arbitrary_primitive_array_dyn::<arrow::datatypes::IntervalYearMonthType>(
unstructured,
field.is_nullable(),
row_count,
)?
}
DataType::Interval(IntervalUnit::DayTime) => {
arbitrary_primitive_array_dyn::<arrow::datatypes::IntervalDayTimeType>(
unstructured,
field.is_nullable(),
row_count,
)?
}
DataType::Interval(IntervalUnit::MonthDayNano) => {
arbitrary_primitive_array_dyn::<arrow::datatypes::IntervalMonthDayNanoType>(
unstructured,
field.is_nullable(),
row_count,
)?
}
DataType::Binary => {
if field.is_nullable() {
arbitrary_byte_array::<GenericBinaryType<i32>, _>(row_count, || {
Option::<Box<[u8]>>::arbitrary(unstructured)
})?
} else {
arbitrary_byte_array::<GenericBinaryType<i32>, _>(row_count, || {
<Box<[u8]>>::arbitrary(unstructured).map(Some)
})?
}
}
DataType::FixedSizeBinary(size) => {
// annoyingly similar-but-different to `PrimitiveArray`
let values = {
let size =
usize::try_from(*size).map_err(|_error| arbitrary::Error::IncorrectFormat)?;
let capacity = size
.checked_mul(row_count)
.ok_or(arbitrary::Error::IncorrectFormat)?;
let mut buf = MutableBuffer::from_len_zeroed(capacity);
unstructured.fill_buffer(buf.as_slice_mut())?;
Buffer::from(buf)
};
let nulls = if field.is_nullable() {
let mut builder = BooleanBufferBuilder::new(row_count);
unstructured.fill_buffer(builder.as_slice_mut())?;
let buffer = NullBuffer::new(builder.finish());
Some(buffer)
} else {
None
};
let array = FixedSizeBinaryArray::new(*size, values, nulls);
Arc::new(array)
}
DataType::LargeBinary => {
if field.is_nullable() {
arbitrary_byte_array::<GenericBinaryType<i64>, _>(row_count, || {
Option::<Box<[u8]>>::arbitrary(unstructured)
})?
} else {
arbitrary_byte_array::<GenericBinaryType<i64>, _>(row_count, || {
<Box<[u8]>>::arbitrary(unstructured).map(Some)
})?
}
}
DataType::BinaryView => {
if field.is_nullable() {
arbitrary_byte_view_array::<BinaryViewType, _>(row_count, || {
Option::<Box<[u8]>>::arbitrary(unstructured)
})?
} else {
arbitrary_byte_view_array::<BinaryViewType, _>(row_count, || {
<Box<[u8]>>::arbitrary(unstructured).map(Some)
})?
}
}
// TODO i haven't found a good way to unify these branches #dev
//
// - string vs binary
// - large vs small
// - array vs view
DataType::Utf8 => {
if field.is_nullable() {
arbitrary_byte_array::<GenericStringType<i32>, _>(row_count, || {
Option::<String>::arbitrary(unstructured)
})?
} else {
arbitrary_byte_array::<GenericStringType<i32>, _>(row_count, || {
<String>::arbitrary(unstructured).map(Some)
})?
}
}
DataType::LargeUtf8 => {
if field.is_nullable() {
arbitrary_byte_array::<GenericStringType<i64>, _>(row_count, || {
Option::<String>::arbitrary(unstructured)
})?
} else {
arbitrary_byte_array::<GenericStringType<i64>, _>(row_count, || {
<String>::arbitrary(unstructured).map(Some)
})?
}
}
DataType::Utf8View => {
if field.is_nullable() {
arbitrary_byte_view_array::<StringViewType, _>(row_count, || {
Option::<String>::arbitrary(unstructured)
})?
} else {
arbitrary_byte_view_array::<StringViewType, _>(row_count, || {
<String>::arbitrary(unstructured).map(Some)
})?
}
}
DataType::Decimal128(precision, scale) => {
let array = arbitrary_primitive_array::<arrow::datatypes::Decimal128Type>(
unstructured,
field.is_nullable(),
row_count,
)?
.with_precision_and_scale(*precision, *scale)
.map_err(|_error| arbitrary::Error::IncorrectFormat)?;
Arc::new(array)
}
DataType::Decimal256(precision, scale) => {
let array = arbitrary_primitive_array::<arrow::datatypes::Decimal256Type>(
unstructured,
field.is_nullable(),
row_count,
)?
.with_precision_and_scale(*precision, *scale)
.map_err(|_error| arbitrary::Error::IncorrectFormat)?;
Arc::new(array)
}
// TODO generate arrays for more complex datatypes
DataType::List(_)
| DataType::ListView(_)
| DataType::FixedSizeList(_, _)
| DataType::LargeList(_)
| DataType::LargeListView(_)
| DataType::Struct(_)
| DataType::Union(_, _)
| DataType::Dictionary(_, _)
| DataType::Map(_, _)
| DataType::RunEndEncoded(_, _) => return Err(arbitrary::Error::IncorrectFormat),
};
Ok(array)
}

View file

@ -0,0 +1,164 @@
use std::sync::Arc;
use arbitrary::Arbitrary;
use arrow::datatypes::DataType;
use arrow::datatypes::TimeUnit;
use crate::FieldBuilder;
use crate::FieldsBuilder;
use crate::IntervalUnitBuilder;
use crate::TimeUnitBuilder;
// TODO generate more edge values
// Arrow has many places where it will panic or error in unexpected places if the input didn't match expectations.
// (Example unexpected error: `arrow::util::pretty::pretty_format_batches` errors when it doesn't recognize a timezone.)
// Making a fuzzer avoid problematic output is not exactly a good idea, so expect these to be relaxed later, maybe?
/// An `i32` that is guaranteed to be `> 0`.
///
/// Arrow will panic on divide by zero if a fixed size structure is size 0.
/// It will most likely go out of bounds for negative numbers.
//
// TODO this should really be enforced by whatever is taking the `i32`, get rid of this
pub struct PositiveNonZeroI32(i32);
impl PositiveNonZeroI32 {
#[expect(missing_docs)]
#[must_use]
pub fn new(input: i32) -> Option<PositiveNonZeroI32> {
(input > 0).then_some(PositiveNonZeroI32(input))
}
}
impl<'a> Arbitrary<'a> for PositiveNonZeroI32 {
fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
let num = u.int_in_range(1i32..=i32::MAX)?;
PositiveNonZeroI32::new(num).ok_or(arbitrary::Error::IncorrectFormat)
}
}
/// Make an [arbitrary] [`arrow::datatypes::DataType`].
///
/// What is actually constructed is a "builder".
/// Call the [`build`](DataTypeBuilder::build) method to get the final [`DataType`].
/// This is a technical limitation because we can't implement the trait [`arbitrary::Arbitrary`] for [`DataType`] in this unrelated crate.
/// TODO this can be cleaned up if merged into Arrow upstream #ecosystem/arrow
#[derive(Arbitrary)]
#[expect(missing_docs)]
pub enum DataTypeBuilder {
Null,
Boolean,
Int8,
Int16,
Int32,
Int64,
UInt8,
UInt16,
UInt32,
UInt64,
Float16,
Float32,
Float64,
Timestamp(
TimeUnitBuilder,
Option<
// TODO maybe generate proper time zone names and offsets, at least most of the time
// TODO maybe split into feature `chrono-tz` and not set; without it, only do offsets?
Arc<str>,
>,
),
Date32,
Date64,
// only generate valid combinations
Time32Seconds,
Time32Milliseconds,
Time64Microseconds,
Time64Nanoseconds,
Duration(TimeUnitBuilder),
Interval(IntervalUnitBuilder),
Binary,
FixedSizeBinary(PositiveNonZeroI32),
LargeBinary,
BinaryView,
Utf8,
LargeUtf8,
Utf8View,
List(Box<FieldBuilder>),
ListView(Box<FieldBuilder>),
FixedSizeList(Box<FieldBuilder>, PositiveNonZeroI32),
LargeList(Box<FieldBuilder>),
LargeListView(Box<FieldBuilder>),
Struct(Box<FieldsBuilder>),
// TODO Union
// TODO maybe enforce `DataType::is_dictionary_key_type`?
Dictionary(Box<DataTypeBuilder>, Box<DataTypeBuilder>),
Decimal128(u8, i8),
Decimal256(u8, i8),
// TODO enforce Map shape, key hashability
// TODO Map(Box<FieldBuilder>, bool),
// TODO run_ends must be an integer type
// TODO RunEndEncoded(Box<FieldBuilder>, Box<FieldBuilder>),
}
impl DataTypeBuilder {
#[must_use]
#[expect(missing_docs)]
pub fn build(self) -> DataType {
match self {
DataTypeBuilder::Null => DataType::Null,
DataTypeBuilder::Boolean => DataType::Boolean,
DataTypeBuilder::Int8 => DataType::Int8,
DataTypeBuilder::Int16 => DataType::Int16,
DataTypeBuilder::Int32 => DataType::Int32,
DataTypeBuilder::Int64 => DataType::Int64,
DataTypeBuilder::UInt8 => DataType::UInt8,
DataTypeBuilder::UInt16 => DataType::UInt16,
DataTypeBuilder::UInt32 => DataType::UInt32,
DataTypeBuilder::UInt64 => DataType::UInt64,
DataTypeBuilder::Float16 => DataType::Float16,
DataTypeBuilder::Float32 => DataType::Float32,
DataTypeBuilder::Float64 => DataType::Float64,
DataTypeBuilder::Timestamp(time_unit, time_zone) => {
DataType::Timestamp(time_unit.build(), time_zone)
}
DataTypeBuilder::Date32 => DataType::Date32,
DataTypeBuilder::Date64 => DataType::Date64,
DataTypeBuilder::Time32Seconds => DataType::Time32(TimeUnit::Second),
DataTypeBuilder::Time32Milliseconds => DataType::Time32(TimeUnit::Millisecond),
DataTypeBuilder::Time64Microseconds => DataType::Time64(TimeUnit::Microsecond),
DataTypeBuilder::Time64Nanoseconds => DataType::Time64(TimeUnit::Nanosecond),
DataTypeBuilder::Duration(builder) => DataType::Duration(builder.build()),
DataTypeBuilder::Interval(builder) => DataType::Interval(builder.build()),
DataTypeBuilder::Binary => DataType::Binary,
DataTypeBuilder::FixedSizeBinary(size) => DataType::FixedSizeBinary(size.0),
DataTypeBuilder::LargeBinary => DataType::LargeBinary,
DataTypeBuilder::BinaryView => DataType::BinaryView,
DataTypeBuilder::Utf8 => DataType::Utf8,
DataTypeBuilder::LargeUtf8 => DataType::LargeUtf8,
DataTypeBuilder::Utf8View => DataType::Utf8View,
DataTypeBuilder::List(field_builder) => DataType::List(Arc::new(field_builder.build())),
DataTypeBuilder::ListView(field_builder) => {
DataType::ListView(Arc::new(field_builder.build()))
}
DataTypeBuilder::FixedSizeList(field_builder, size) => {
DataType::FixedSizeList(Arc::new(field_builder.build()), size.0)
}
DataTypeBuilder::LargeList(field_builder) => {
DataType::LargeList(Arc::new(field_builder.build()))
}
DataTypeBuilder::LargeListView(field_builder) => {
DataType::LargeListView(Arc::new(field_builder.build()))
}
DataTypeBuilder::Struct(fields_builder) => DataType::Struct(fields_builder.build()),
// TODO Union
DataTypeBuilder::Dictionary(key_type, value_type) => {
DataType::Dictionary(Box::new(key_type.build()), Box::new(value_type.build()))
}
// TODO make only valid precision and scale values? now they're discarded in `arbitrary_array` <https://docs.rs/datafusion/latest/datafusion/common/arrow/datatypes/fn.validate_decimal_precision_and_scale.html>
DataTypeBuilder::Decimal128(precision, scale) => DataType::Decimal128(precision, scale),
DataTypeBuilder::Decimal256(precision, scale) => DataType::Decimal256(precision, scale),
// TODO generate arbitrary Map
// TODO generate arbitrary RunEndEncoded
}
}
}

View file

@ -0,0 +1,55 @@
use arbitrary::Arbitrary;
use arrow::datatypes::Field;
use arrow::datatypes::Fields;
use crate::DataTypeBuilder;
/// Make an [arbitrary] [`Field`] for an Arrow [`Schema`](arrow::datatypes::Schema).
#[derive(Arbitrary)]
#[expect(missing_docs)]
pub struct FieldBuilder {
pub name: String,
pub data_type: DataTypeBuilder,
/// Ignored (always set to `true`) when `data_type` is [[`arrow::datatypes::DataType::Null`]].
pub nullable: bool,
// TODO dict_id: i64,
// TODO dict_is_ordered: bool,
// TODO metadata: HashMap<String, String>,
}
impl FieldBuilder {
#[must_use]
#[expect(missing_docs)]
pub fn build(self) -> Field {
let Self {
name,
data_type,
nullable,
} = self;
let data_type = data_type.build();
let nullable = nullable || data_type == arrow::datatypes::DataType::Null;
Field::new(name, data_type, nullable)
}
}
/// Make a [arbitrary] [`Fields`] for an Arrow [`Schema`](arrow::datatypes::Schema).
///
/// This is just a vector of [`Field`] (see [`FieldBuilder`]), but Arrow uses a separate type for it.
#[derive(Arbitrary)]
#[expect(missing_docs)]
pub struct FieldsBuilder {
pub builders: Vec<FieldBuilder>,
}
impl FieldsBuilder {
#[must_use]
#[expect(missing_docs)]
pub fn build(self) -> Fields {
let Self { builders } = self;
let fields = builders
.into_iter()
.map(FieldBuilder::build)
.collect::<Vec<_>>();
Fields::from(fields)
}
}

View file

@ -0,0 +1,23 @@
use arbitrary::Arbitrary;
use arrow::datatypes::IntervalUnit;
/// Make an [arbitrary] [`IntervalUnit`].
#[derive(Arbitrary)]
#[expect(missing_docs)]
pub enum IntervalUnitBuilder {
YearMonth,
DayTime,
MonthDayNano,
}
impl IntervalUnitBuilder {
#[must_use]
#[expect(missing_docs)]
pub const fn build(self) -> IntervalUnit {
match self {
IntervalUnitBuilder::YearMonth => IntervalUnit::YearMonth,
IntervalUnitBuilder::DayTime => IntervalUnit::DayTime,
IntervalUnitBuilder::MonthDayNano => IntervalUnit::MonthDayNano,
}
}
}

View file

@ -0,0 +1,66 @@
//! Create [arbitrary] [Arrow](arrow) data structures, for fuzz testing.
//!
//! This could become part of the upstream [`arrow`] crate in the future.
#![allow(clippy::exhaustive_enums)]
#![allow(clippy::exhaustive_structs)]
pub use arbitrary;
pub use arrow;
mod array;
mod data_type;
mod field;
mod interval_unit;
mod record_batch;
mod schema;
mod time_unit;
pub use array::*;
pub use data_type::*;
pub use field::*;
pub use interval_unit::*;
pub use record_batch::*;
pub use schema::*;
pub use time_unit::*;
#[cfg(test)]
mod tests {
use arbitrary::Arbitrary as _;
use super::*;
#[maybe_tracing::test]
#[cfg_attr(miri, ignore = "too slow")]
fn random() {
let mut entropy = vec![0u8; 4096];
getrandom::fill(&mut entropy).unwrap();
let mut unstructured = arbitrary::Unstructured::new(&entropy);
match RecordBatchBuilder::arbitrary(&mut unstructured) {
Ok(builder) => {
let record_batch = builder.build();
println!("{record_batch:?}");
}
Err(arbitrary::Error::IncorrectFormat) => {
// ignore silently
}
Err(error) => panic!("unexpected error: {error}"),
};
}
#[maybe_tracing::test]
fn fixed() {
// A non-randomized version that runs faster under Miri.
let fake_entropy = vec![0u8; 4096];
let mut unstructured = arbitrary::Unstructured::new(&fake_entropy);
match RecordBatchBuilder::arbitrary(&mut unstructured) {
Ok(builder) => {
let record_batch = builder.build();
println!("{record_batch:?}");
}
Err(arbitrary::Error::IncorrectFormat) => {
panic!("unexpected error, fiddle with fake entropy to avoid")
}
Err(error) => panic!("unexpected error: {error}"),
};
}
}

View file

@ -0,0 +1,93 @@
use std::sync::Arc;
use arrow::array::RecordBatch;
use arrow::datatypes::SchemaRef;
use crate::arbitrary_array;
use crate::SchemaBuilder;
/// Make an [arbitrary] [`RecordBatch`] with the given Arrow [`Schema`](arrow::datatypes::Schema).
pub fn record_batch_with_schema(
u: &mut arbitrary::Unstructured<'_>,
schema: SchemaRef,
) -> arbitrary::Result<RecordBatch> {
// TODO correct type for the below
let row_count = u.arbitrary_len::<[u8; 8]>()?;
let arrays = schema
.fields()
.iter()
.map(|field| arbitrary_array(u, field, row_count))
.collect::<Result<Vec<_>, _>>()?;
debug_assert_eq!(schema.fields().len(), arrays.len());
debug_assert!(arrays.iter().all(|array| array.len() == row_count));
debug_assert!(
arrays
.iter()
.zip(schema.fields().iter())
.all(|(array, field)| array.data_type() == field.data_type()),
"wrong array data type: {arrays:?} != {fields:?}",
arrays = arrays.iter().map(|a| a.data_type()).collect::<Vec<_>>(),
fields = schema
.fields()
.iter()
.map(|f| f.data_type())
.collect::<Vec<_>>(),
);
let options = arrow::array::RecordBatchOptions::new().with_row_count(Some(row_count));
#[expect(clippy::expect_used, reason = "not production code")]
let record_batch =
RecordBatch::try_new_with_options(schema, arrays, &options).expect("internal error");
Ok(record_batch)
}
/// Make an [arbitrary] [`RecordBatch`].
///
/// What is actually constructed is a "builder".
/// Call the [`build`](RecordBatchBuilder::build) method to get the final [`RecordBatch`].
/// This is a technical limitation because we can't implement the trait [`arbitrary::Arbitrary`] for [`RecordBatch`] in this unrelated crate.
/// This can be cleaned up if merged into Arrow upstream.
#[derive(Debug)]
#[expect(missing_docs)]
pub struct RecordBatchBuilder {
pub record_batch: RecordBatch,
}
/// Implement [`arbitrary::Arbitrary`] manually.
/// [`RecordBatch`] has constraints that are hard to express via deriving [`Arbitrary`](arbitrary::Arbitrary):
///
/// - number of schema fields must match number of array
/// - each array must match schema field type
/// - all arrays must have the same row count
impl<'a> arbitrary::Arbitrary<'a> for RecordBatchBuilder {
fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
let schema = SchemaBuilder::arbitrary(u)?.build();
let schema = Arc::new(schema);
let record_batch = record_batch_with_schema(u, schema)?;
Ok(RecordBatchBuilder { record_batch })
}
fn size_hint(depth: usize) -> (usize, Option<usize>) {
Self::try_size_hint(depth).unwrap_or_default()
}
fn try_size_hint(
depth: usize,
) -> arbitrary::Result<(usize, Option<usize>), arbitrary::MaxRecursionReached> {
arbitrary::size_hint::try_recursion_guard(depth, |depth| {
let size_hint = SchemaBuilder::try_size_hint(depth)?;
// TODO match on data_type and count the amount of entropy needed; but we don't have a row count yet
Ok(size_hint)
})
}
}
impl RecordBatchBuilder {
#[must_use]
#[expect(missing_docs)]
pub fn build(self) -> RecordBatch {
let Self { record_batch } = self;
record_batch
}
}

View file

@ -0,0 +1,27 @@
use arbitrary::Arbitrary;
use arrow::datatypes::Schema;
use crate::FieldsBuilder;
/// Make an [arbitrary] [`Schema`].
///
/// What is actually constructed is a "builder".
/// Call the [`build`](SchemaBuilder::build) method to get the final [`Schema`].
/// This is a technical limitation because we can't implement the trait [`arbitrary::Arbitrary`] for [`Schema`] in this unrelated crate.
/// This can be cleaned up if merged into Arrow upstream.
#[derive(Arbitrary)]
#[expect(missing_docs)]
pub struct SchemaBuilder {
pub fields: FieldsBuilder,
// TODO metadata: HashMap<String, String>
}
impl SchemaBuilder {
#[must_use]
#[expect(missing_docs)]
pub fn build(self) -> Schema {
let Self { fields } = self;
let fields = fields.build();
Schema::new(fields)
}
}

View file

@ -0,0 +1,25 @@
use arbitrary::Arbitrary;
use arrow::datatypes::TimeUnit;
/// Make an [arbitrary] [`TimeUnit`].
#[derive(Arbitrary)]
#[expect(missing_docs)]
pub enum TimeUnitBuilder {
Second,
Millisecond,
Microsecond,
Nanosecond,
}
impl TimeUnitBuilder {
#[must_use]
#[expect(missing_docs)]
pub const fn build(self) -> TimeUnit {
match self {
TimeUnitBuilder::Second => TimeUnit::Second,
TimeUnitBuilder::Millisecond => TimeUnit::Millisecond,
TimeUnitBuilder::Microsecond => TimeUnit::Microsecond,
TimeUnitBuilder::Nanosecond => TimeUnit::Nanosecond,
}
}
}

View file

@ -0,0 +1,38 @@
[package]
name = "kanto-backend-rocksdb"
version = "0.1.0"
description = "RocksDB backend for the KantoDB SQL database"
homepage.workspace = true
repository.workspace = true
authors.workspace = true
license.workspace = true
publish = false # TODO publish crate #severity/high #urgency/medium
edition.workspace = true
rust-version.workspace = true
[dependencies]
async-trait = { workspace = true }
dashmap = { workspace = true }
datafusion = { workspace = true }
futures-lite = { workspace = true }
gat-lending-iterator = { workspace = true }
kanto = { workspace = true }
kanto-index-format-v1 = { workspace = true }
kanto-key-format-v1 = { workspace = true }
kanto-meta-format-v1 = { workspace = true }
kanto-record-format-v1 = { workspace = true }
kanto-tunables = { workspace = true }
maybe-tracing = { workspace = true }
rkyv = { workspace = true }
rkyv_util = { workspace = true }
rocky = { workspace = true }
tempfile = { workspace = true }
tokio = { workspace = true }
tracing = { workspace = true }
[dev-dependencies]
kanto-testutil = { workspace = true }
test-log = { workspace = true }
[lints]
workspace = true

View file

@ -0,0 +1,652 @@
use std::path::Path;
use std::sync::Arc;
use datafusion::sql::sqlparser;
use gat_lending_iterator::LendingIterator as _;
use kanto::parquet::data_type::AsBytes as _;
use kanto::KantoError;
use kanto_meta_format_v1::IndexId;
use kanto_meta_format_v1::SequenceId;
use kanto_meta_format_v1::TableId;
// TODO `concat_bytes!` <https://github.com/rust-lang/rust/issues/87555> #waiting #ecosystem/rust #dev
#[expect(
clippy::indexing_slicing,
clippy::expect_used,
reason = "TODO const functions are awkward"
)]
const fn concat_bytes_to_array<const N: usize>(sources: &[&[u8]]) -> [u8; N] {
let mut result = [0u8; N];
let mut dst = 0;
let mut src = 0;
while src < sources.len() {
let mut offset = 0;
while offset < sources[src].len() {
result[dst] = sources[src][offset];
assert!(dst < N, "overflowing destination arrow");
dst = dst
.checked_add(1)
.expect("constant calculation must be correct");
offset = offset
.checked_add(1)
.expect("constant calculation must be correct");
}
src = src
.checked_add(1)
.expect("constant calculation must be correct");
}
assert!(dst == N, "did not fill the whole array");
result
}
// We reserve a global keyspace in RocksDB, in the default column family, keys beginning with `1:u64_le`.
const fn make_system_key(ident: u64) -> [u8; 16] {
concat_bytes_to_array(&[&1u64.to_be_bytes(), &ident.to_be_bytes()])
}
const FORMAT_MESSAGE_KEY: [u8; 16] = make_system_key(1);
const FORMAT_MESSAGE_CONTENT: &[u8] =
b"KantoDB database file\nSee <https://kantodb.com/> for more.\n";
const FORMAT_COOKIE_KEY: [u8; 16] = make_system_key(2);
trait RocksDbBackend {
/// Random data used to identify format.
///
/// Use pure entropy, don't get cute.
const FORMAT_COOKIE: [u8; 32];
/// Caller promises `FORMAT_COOKIE` has been checked already.
fn open_from_rocksdb(rocksdb: rocky::Database) -> Result<impl kanto::Backend, KantoError>;
}
/// Create a new RocksDB database directory.
pub fn create<P: AsRef<Path>>(path: P) -> Result<impl kanto::Backend, KantoError> {
let path = path.as_ref();
create_inner(path)
}
fn create_inner(path: &Path) -> Result<Database, KantoError> {
let mut opts = rocky::Options::new();
opts.create_if_missing(true);
// TODO let caller pass in TransactionDbOptions
let txdb_opts = rocky::TransactionDbOptions::new();
let column_families = vec![];
let rocksdb = rocky::Database::open(path, &opts, &txdb_opts, column_families).map_err(
|rocksdb_error| KantoError::Io {
op: kanto::error::Op::Init,
code: "n19nx5m9yc5br",
error: std::io::Error::other(rocksdb_error),
},
)?;
// TODO there's a race condition between creating the RocksDB files and getting to brand it as "ours"; how to properly brand a RocksDB database as belonging to this application
// TODO check for existence first? what if create is run with existing rocksdb
{
let write_options = rocky::WriteOptions::new();
let tx_options = rocky::TransactionOptions::new();
let tx = rocksdb.transaction_begin(&write_options, &tx_options);
let cf_default = rocksdb
.get_column_family("default")
.ok_or(KantoError::Internal {
code: "y4i5oznqrpr4o",
error: "trouble using default column family".into(),
})?;
{
let key = &FORMAT_MESSAGE_KEY;
let value = FORMAT_MESSAGE_CONTENT;
tx.put(&cf_default, key, value)
.map_err(|rocksdb_error| KantoError::Io {
op: kanto::error::Op::Init,
code: "fgkhwk4rbf5io",
error: std::io::Error::other(rocksdb_error),
})?;
}
{
let key = &FORMAT_COOKIE_KEY;
let value = &Database::FORMAT_COOKIE;
tx.put(&cf_default, key, value)
.map_err(|rocksdb_error| KantoError::Io {
op: kanto::error::Op::Init,
code: "yepp5upj9o79k",
error: std::io::Error::other(rocksdb_error),
})?;
}
tx.commit().map_err(|rocksdb_error| KantoError::Io {
op: kanto::error::Op::Init,
code: "wgkft87xp4kn6",
error: std::io::Error::other(rocksdb_error),
})?;
}
Database::open_from_rocksdb_v1(rocksdb)
}
/// Create a new *temporary* RocksDB database.
///
/// This is mostly useful for tests.
pub fn create_temp() -> Result<(tempfile::TempDir, Database), KantoError> {
let dir = tempfile::tempdir().map_err(|io_error| KantoError::Io {
op: kanto::error::Op::Init,
code: "9syfcddzphoqs",
error: io_error,
})?;
let db = create_inner(dir.path())?;
Ok((dir, db))
}
/// Open an existing RocksDB database directory.
pub fn open(path: &Path) -> Result<impl kanto::Backend + use<>, KantoError> {
// RocksDB has a funny dance about opening the database.
// They want you to list the column families that exist in the database (in order to configure them), but no extra ones (yes, you can do create if missing, but then you don't get to configure them).
// Load the "options" file from the database directory to figure out what column families exist.
// <https://github.com/facebook/rocksdb/wiki/Column-Families>
let latest_options =
rocky::LatestOptions::load_latest_options(path).map_err(|rocksdb_error| {
KantoError::Io {
op: kanto::error::Op::Init,
code: "he3ddmfxzymiw",
error: std::io::Error::other(rocksdb_error),
}
})?;
let rocksdb = {
let opts = latest_options.clone_db_options();
// TODO let caller pass in TransactionDbOptions
let txdb_opts = rocky::TransactionDbOptions::new();
let column_families = latest_options
.iter_column_families()
.collect::<Result<Vec<_>, _>>()
.map_err(|rocksdb_error| KantoError::Io {
op: kanto::error::Op::Init,
code: "umshco6957zj6",
error: std::io::Error::other(rocksdb_error),
})?;
rocky::Database::open(path, &opts, &txdb_opts, column_families)
}
.map_err(|rocksdb_error| KantoError::Io {
op: kanto::error::Op::Init,
code: "gzubzoig3be86",
error: std::io::Error::other(rocksdb_error),
})?;
open_from_rocksdb(rocksdb)
}
fn open_from_rocksdb(rocksdb: rocky::Database) -> Result<impl kanto::Backend, KantoError> {
let tx = {
let write_options = rocky::WriteOptions::new();
let tx_options = rocky::TransactionOptions::new();
rocksdb.transaction_begin(&write_options, &tx_options)
};
let cf_default = rocksdb
.get_column_family("default")
.ok_or(KantoError::Internal {
code: "sit1685wkky9k",
error: "trouble using default column family".into(),
})?;
{
let key = FORMAT_MESSAGE_KEY;
let want = FORMAT_MESSAGE_CONTENT;
let read_opts = rocky::ReadOptions::new();
let got = tx
.get_pinned(&cf_default, key, &read_opts)
.map_err(|rocksdb_error| KantoError::Io {
op: kanto::error::Op::Init,
code: "wfa4s57mk83py",
error: std::io::Error::other(rocksdb_error),
})?
.ok_or(KantoError::Init {
code: "ypjsnkcony3kh",
error: kanto::error::InitError::UnrecognizedDatabaseFormat,
})?;
if got.as_bytes() != want {
let error = KantoError::Init {
code: "fty1mzfs1fw6q",
error: kanto::error::InitError::UnrecognizedDatabaseFormat,
};
return Err(error);
}
}
let format_cookie = {
let read_opts = rocky::ReadOptions::new();
let key = FORMAT_COOKIE_KEY;
tx.get_pinned(&cf_default, key, &read_opts)
.map_err(|rocksdb_error| KantoError::Io {
op: kanto::error::Op::Init,
code: "d87g8itsmm7ag",
error: std::io::Error::other(rocksdb_error),
})?
.ok_or(KantoError::Init {
code: "9ogxqz1u5jboq",
error: kanto::error::InitError::UnrecognizedDatabaseFormat,
})?
};
tx.rollback().map_err(|rocksdb_error| KantoError::Io {
op: kanto::error::Op::Init,
code: "sdrh4rkxxt8ta",
error: std::io::Error::other(rocksdb_error),
})?;
if format_cookie == Database::FORMAT_COOKIE.as_bytes() {
Database::open_from_rocksdb(rocksdb)
} else {
let error = KantoError::Init {
code: "amsgxu6h1yj9r",
error: kanto::error::InitError::UnrecognizedDatabaseFormat,
};
Err(error)
}
}
/// A [`kanto::Backend`] that stores data in RocksDB.
#[derive(Clone)]
pub struct Database {
pub(crate) inner: Arc<InnerDatabase>,
}
pub(crate) struct ColumnFamilies {
pub name_defs: Arc<rocky::ColumnFamily>,
pub table_defs: Arc<rocky::ColumnFamily>,
pub index_defs: Arc<rocky::ColumnFamily>,
pub sequence_defs: Arc<rocky::ColumnFamily>,
// accessed only by the library abstraction
pub sequences: Arc<rocky::ColumnFamily>,
pub records: Arc<rocky::ColumnFamily>,
pub unique_indexes: Arc<rocky::ColumnFamily>,
pub multi_indexes: Arc<rocky::ColumnFamily>,
}
pub(crate) const SEQUENCE_SEQ_ID: SequenceId = SequenceId::MIN;
pub(crate) struct SystemSequences {
pub(crate) table_seq_id: SequenceId,
pub(crate) index_seq_id: SequenceId,
}
pub(crate) struct SystemData {
pub(crate) sequences: SystemSequences,
}
pub(crate) struct InnerDatabase {
pub(crate) rocksdb: rocky::Database,
pub(crate) sequence_tracker: crate::sequence::SequenceTracker,
pub(crate) column_families: ColumnFamilies,
pub(crate) system: SystemData,
}
impl Database {
fn ensure_cf_exist(
rocksdb: &rocky::Database,
cf_name: &str,
) -> Result<Arc<rocky::ColumnFamily>, rocky::RocksDbError> {
if let Some(cf) = rocksdb.get_column_family(cf_name) {
Ok(cf)
} else {
let opts = rocky::Options::new();
rocksdb.create_column_family(cf_name, &opts)
}
}
#[must_use]
pub(crate) const fn make_rocksdb_sequence_def_key(sequence_id: SequenceId) -> [u8; 8] {
sequence_id.get_u64().to_be_bytes()
}
#[must_use]
pub(crate) const fn make_rocksdb_table_def_key(table_id: TableId) -> [u8; 8] {
table_id.get_u64().to_be_bytes()
}
#[must_use]
pub(crate) const fn make_rocksdb_index_def_key(index_id: IndexId) -> [u8; 8] {
index_id.get_u64().to_be_bytes()
}
#[must_use]
pub(crate) const fn make_rocksdb_sequence_key(sequence_id: SequenceId) -> [u8; 8] {
sequence_id.get_u64().to_be_bytes()
}
#[must_use]
pub(crate) const fn make_rocksdb_record_key_prefix(table_id: TableId) -> [u8; 8] {
table_id.get_u64().to_be_bytes()
}
#[must_use]
pub(crate) const fn make_rocksdb_record_key_stop_before(table_id: TableId) -> [u8; 8] {
let table_id_num = table_id.get_u64();
debug_assert!(table_id_num < u64::MAX);
let stop_before = table_id_num.saturating_add(1);
stop_before.to_be_bytes()
}
pub(crate) fn make_rocksdb_record_key(table_id: TableId, row_key: &[u8]) -> Box<[u8]> {
// TODO later: optimize allocations and minimize copying
// TODO use types to make sure we don't confuse row_key vs rocksdb_key
let rocksdb_key_prefix = Self::make_rocksdb_record_key_prefix(table_id);
let mut k = Vec::with_capacity(rocksdb_key_prefix.len().saturating_add(row_key.len()));
k.extend_from_slice(&rocksdb_key_prefix);
k.extend_from_slice(row_key);
k.into_boxed_slice()
}
/// Create a new session with this backend registered as the default catalog.
///
/// This is mostly useful for tests.
#[must_use]
pub fn test_session(&self) -> kanto::Session {
let backend: Box<dyn kanto::Backend> = Box::new(self.clone());
kanto::Session::test_session(backend)
}
// TODO ugly api, experiment switching into builder style so all the inner helpers can be methods on a builder?
fn init_database(
rocksdb: &rocky::Database,
column_families: &ColumnFamilies,
sequence_tracker: &crate::sequence::SequenceTracker,
) -> Result<SystemData, KantoError> {
let tx = {
let write_options = rocky::WriteOptions::new();
let tx_options = rocky::TransactionOptions::new();
rocksdb.transaction_begin(&write_options, &tx_options)
};
Self::init_sequence_of_sequences(&tx, &column_families.sequence_defs)?;
let table_seq_id = Self::init_sequence_def(
&tx,
&column_families.sequence_defs,
sequence_tracker,
"tables",
)?;
let index_seq_id = Self::init_sequence_def(
&tx,
&column_families.sequence_defs,
sequence_tracker,
"indexes",
)?;
let system_sequences = SystemSequences {
table_seq_id,
index_seq_id,
};
tx.commit().map_err(|rocksdb_error| KantoError::Io {
op: kanto::error::Op::Init,
code: "u375qjb7o3jyh",
error: std::io::Error::other(rocksdb_error),
})?;
let system_data = SystemData {
sequences: system_sequences,
};
Ok(system_data)
}
fn init_sequence_of_sequences(
tx: &rocky::Transaction,
sequence_defs_cf: &rocky::ColumnFamily,
) -> Result<(), KantoError> {
let sequence_def_key = Self::make_rocksdb_sequence_def_key(SEQUENCE_SEQ_ID);
let read_opts = rocky::ReadOptions::new();
let exists = tx
.get_for_update_pinned(
sequence_defs_cf,
sequence_def_key,
&read_opts,
rocky::Exclusivity::Exclusive,
)
.map_err(|rocksdb_error| KantoError::Io {
op: kanto::error::Op::Init,
code: "316dpxkcf8ruy",
error: std::io::Error::other(rocksdb_error),
})?
.is_some();
if exists {
// incremental updates would go here
Ok(())
} else {
const KANTO_INTERNAL_CATALOG_NAME: &str = "_kanto_internal";
const KANTO_INTERNAL_SCHEMA_NAME: &str = "kanto";
const KANTO_INTERNAL_SEQUENCE_OF_SEQUENCES_NAME: &str = "sequences";
let sequence_def = kanto_meta_format_v1::sequence_def::SequenceDef {
sequence_id: SEQUENCE_SEQ_ID,
catalog: KANTO_INTERNAL_CATALOG_NAME.to_owned(),
schema: KANTO_INTERNAL_SCHEMA_NAME.to_owned(),
sequence_name: KANTO_INTERNAL_SEQUENCE_OF_SEQUENCES_NAME.to_owned(),
};
let bytes = rkyv::to_bytes::<rkyv::rancor::BoxedError>(&sequence_def).map_err(
|rkyv_error| KantoError::Internal {
code: "5hwja993uzb86",
error: Box::new(kanto::error::Message {
message: "failed to serialize sequence def",
error: Box::new(rkyv_error),
}),
},
)?;
tx.put(sequence_defs_cf, sequence_def_key, bytes)
.map_err(|rocksdb_error| KantoError::Io {
op: kanto::error::Op::Init,
code: "67soy6858owxa",
error: std::io::Error::other(rocksdb_error),
})?;
Ok(())
}
}
fn init_sequence_def(
tx: &rocky::Transaction,
sequence_defs_cf: &rocky::ColumnFamily,
sequence_tracker: &crate::sequence::SequenceTracker,
sequence_name: &str,
) -> Result<SequenceId, KantoError> {
const KANTO_INTERNAL_CATALOG_NAME: &str = "_kanto_internal";
const KANTO_INTERNAL_SCHEMA_NAME: &str = "kanto";
// Look for existing sequence_def.
// TODO should do lookup via index
{
let read_opts = rocky::ReadOptions::new();
let mut iter = tx.iter(read_opts, sequence_defs_cf);
while let Some(entry) =
iter.next()
.transpose()
.map_err(|rocksdb_error| KantoError::Io {
op: kanto::error::Op::Init,
code: "9rw1dttbwqxnc",
error: std::io::Error::other(rocksdb_error),
})?
{
let bytes = entry.value();
let sequence_def = rkyv::access::<
rkyv::Archived<kanto_meta_format_v1::sequence_def::SequenceDef>,
rkyv::rancor::BoxedError,
>(bytes)
.map_err(|rkyv_error| KantoError::Corrupt {
op: kanto::error::Op::Init,
code: "k41yd9xmpqxba",
error: Box::new(kanto::error::CorruptError::Rkyv { error: rkyv_error }),
})?;
{
if sequence_def.catalog != KANTO_INTERNAL_CATALOG_NAME {
continue;
}
if sequence_def.schema != KANTO_INTERNAL_SCHEMA_NAME {
continue;
}
if sequence_def.sequence_name != sequence_name {
continue;
}
}
// Found an existing record!
// For now, we don't bother updating it.. maybe we should.
let sequence_id = sequence_def.sequence_id.to_native();
return Ok(sequence_id);
}
}
// Not found
let sequence_id = {
let seq_id_num = sequence_tracker.inc(SEQUENCE_SEQ_ID)?;
SequenceId::new(seq_id_num).ok_or_else(|| KantoError::Corrupt {
op: kanto::error::Op::Init,
code: "sz8cgjb6jpyur",
error: Box::new(kanto::error::CorruptError::InvalidSequenceId {
sequence_id: seq_id_num,
sequence_def_debug: String::new(),
}),
})?
};
let sequence_def = kanto_meta_format_v1::sequence_def::SequenceDef {
sequence_id,
catalog: KANTO_INTERNAL_CATALOG_NAME.to_owned(),
schema: KANTO_INTERNAL_SCHEMA_NAME.to_owned(),
sequence_name: sequence_name.to_owned(),
};
let bytes =
rkyv::to_bytes::<rkyv::rancor::BoxedError>(&sequence_def).map_err(|rkyv_error| {
KantoError::Internal {
code: "88sauprbx8d8e",
error: Box::new(kanto::error::Message {
message: "failed to serialize sequence def",
error: Box::new(rkyv_error),
}),
}
})?;
let sequence_def_key = Self::make_rocksdb_sequence_def_key(sequence_id);
tx.put(sequence_defs_cf, sequence_def_key, bytes)
.map_err(|rocksdb_error| KantoError::Io {
op: kanto::error::Op::Init,
code: "kw743n8j4f4a6",
error: std::io::Error::other(rocksdb_error),
})?;
Ok(sequence_id)
}
fn open_from_rocksdb_v1(rocksdb: rocky::Database) -> Result<Self, KantoError> {
let ensure_cf = |cf_name: &str| -> Result<Arc<rocky::ColumnFamily>, KantoError> {
Database::ensure_cf_exist(&rocksdb, cf_name).map_err(|rocksdb_error| KantoError::Io {
op: kanto::error::Op::Init,
code: "m5w7fccb54wr6",
error: std::io::Error::other(rocksdb_error),
})
};
let column_families = ColumnFamilies {
name_defs: ensure_cf("name_defs")?,
table_defs: ensure_cf("table_defs")?,
index_defs: ensure_cf("index_defs")?,
sequence_defs: ensure_cf("sequence_defs")?,
sequences: ensure_cf("sequences")?,
records: ensure_cf("records")?,
unique_indexes: ensure_cf("unique_indexes")?,
multi_indexes: ensure_cf("multi_indexes")?,
};
let sequence_tracker = crate::sequence::SequenceTracker::new(
rocksdb.clone(),
column_families.sequences.clone(),
);
let system_data = Self::init_database(&rocksdb, &column_families, &sequence_tracker)?;
let inner = Arc::new(InnerDatabase {
rocksdb,
sequence_tracker,
column_families,
system: system_data,
});
Ok(Self { inner })
}
}
impl RocksDbBackend for Database {
const FORMAT_COOKIE: [u8; 32] = [
0x82, 0x48, 0x1b, 0xa3, 0xa1, 0xa3, 0x78, 0xae, 0x0f, 0x04, 0x12, 0x81, 0xf2, 0x0b, 0x63,
0xf8, 0x7d, 0xc6, 0x47, 0xcc, 0x80, 0xb0, 0x70, 0xc3, 0xb2, 0xa6, 0xf3, 0x60, 0xcb, 0x66,
0xda, 0x5f,
];
fn open_from_rocksdb(rocksdb: rocky::Database) -> Result<impl kanto::Backend, KantoError> {
Self::open_from_rocksdb_v1(rocksdb)
}
}
#[async_trait::async_trait]
impl kanto::Backend for Database {
fn as_any(&self) -> &dyn std::any::Any {
self
}
async fn start_transaction(
&self,
access_mode: sqlparser::ast::TransactionAccessMode,
isolation_level: sqlparser::ast::TransactionIsolationLevel,
) -> Result<Box<dyn kanto::Transaction>, KantoError> {
match access_mode {
sqlparser::ast::TransactionAccessMode::ReadWrite => Ok(()),
sqlparser::ast::TransactionAccessMode::ReadOnly => {
let error = KantoError::UnimplementedSql {
code: "bhetj3emnc8n1",
sql_syntax: "START TRANSACTION READ ONLY",
};
Err(error)
}
}?;
#[expect(
clippy::match_same_arms,
reason = "TODO rustfmt fails to format with comments in OR patterns <https://github.com/rust-lang/rustfmt/issues/6491> #waiting #ecosystem/rust #severity/low #urgency/low"
)]
match isolation_level {
// TODO advocate to reorder the enum in order of increasing isolation; we normally write matches in order of the enum, but make an exception here #ecosystem/sqlparser-rs
// We can choose to upgrade this to something we do support.
sqlparser::ast::TransactionIsolationLevel::ReadUncommitted => Ok(()),
// We can choose to upgrade this to something we do support.
sqlparser::ast::TransactionIsolationLevel::ReadCommitted => Ok(()),
// We can choose to upgrade this to something we do support.
sqlparser::ast::TransactionIsolationLevel::RepeatableRead => Ok(()),
// This is what we actually do, with RocksDB snapshots.
sqlparser::ast::TransactionIsolationLevel::Snapshot => Ok(()),
sqlparser::ast::TransactionIsolationLevel::Serializable => {
let error = KantoError::UnimplementedSql {
code: "dy1ch15p7n3ze",
sql_syntax: "START TRANSACTION ISOLATION LEVEL SERIALIZABLE",
};
Err(error)
}
}?;
let transaction = crate::Transaction::new(self);
Ok(Box::new(transaction))
}
fn clone_box(&self) -> Box<dyn kanto::Backend> {
Box::new(self.clone())
}
}
impl PartialEq for Database {
fn eq(&self, other: &Self) -> bool {
// Use the pointer inside the `self.inner` `Arc` as our identity, as the outer value gets cloned.
std::ptr::eq(self.inner.as_ref(), other.inner.as_ref())
}
}
impl Eq for Database {}
impl std::hash::Hash for Database {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
// Use the pointer inside the `self.inner` `Arc` as our identity, as the outer value gets cloned.
let ptr = <*const _>::from(self.inner.as_ref());
#[expect(clippy::as_conversions)]
let num = ptr as usize;
state.write_usize(num);
}
}
impl std::fmt::Debug for Database {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Database").finish_non_exhaustive()
}
}

View file

@ -0,0 +1,19 @@
//! `kanto_backend_rocksdb` provides a [KantoDB](https://kantodb.com/) database [backend](kanto::Backend) using [RocksDB](https://rocksdb.org/), delegating snapshot isolation and transaction handling fully to it.
//!
//! For more on RocksDB, see <https://github.com/facebook/rocksdb/wiki>
mod database;
mod savepoint;
mod sequence;
mod table;
mod table_provider;
mod table_scan;
mod transaction;
pub use crate::database::create;
pub use crate::database::create_temp;
pub use crate::database::open;
pub use crate::database::Database;
pub(crate) use crate::savepoint::Savepoint;
pub(crate) use crate::savepoint::WeakSavepoint;
pub(crate) use crate::transaction::Transaction;

View file

@ -0,0 +1,30 @@
# Does not parse in generic dialect
statement error ^SQL error: ParserError\("Expected: equals sign or TO, found: IO at Line: 1, Column: \d+"\) \(sql/xatq3n1zk65go\)$
SET STATISTICS IO ON;
statement ok
SET sql_dialect='mssql';
statement error ^External error: not_supported_sql/h83zk8ron8bme: SET session parameter \(MS-SQL syntax\)$
SET STATISTICS IO ON;
statement error ^External error: not_supported_sql/h83zk8ron8bme: SET session parameter \(MS-SQL syntax\)$
SET STATISTICS IO OFF;
statement error ^External error: not_supported_sql/h83zk8ron8bme: SET session parameter \(MS-SQL syntax\)$
SET STATISTICS PROFILE ON;
statement error ^External error: not_supported_sql/h83zk8ron8bme: SET session parameter \(MS-SQL syntax\)$
SET STATISTICS PROFILE OFF;
statement error ^External error: not_supported_sql/h83zk8ron8bme: SET session parameter \(MS-SQL syntax\)$
SET STATISTICS TIME ON;
statement error ^External error: not_supported_sql/h83zk8ron8bme: SET session parameter \(MS-SQL syntax\)$
SET STATISTICS TIME OFF;
statement error ^External error: not_supported_sql/h83zk8ron8bme: SET session parameter \(MS-SQL syntax\)$
SET STATISTICS XML ON;
statement error ^External error: not_supported_sql/h83zk8ron8bme: SET session parameter \(MS-SQL syntax\)$
SET STATISTICS XML OFF;

View file

@ -0,0 +1,50 @@
use std::sync::Arc;
use kanto::KantoError;
use crate::Transaction;
pub(crate) struct Savepoint {
inner: Arc<InnerSavepoint>,
}
pub(crate) struct WeakSavepoint {
inner: std::sync::Weak<InnerSavepoint>,
}
struct InnerSavepoint {
transaction: Transaction,
}
impl Savepoint {
pub(crate) fn new(transaction: Transaction) -> Savepoint {
let inner = InnerSavepoint { transaction };
Savepoint {
inner: Arc::new(inner),
}
}
pub(crate) fn weak(&self) -> WeakSavepoint {
WeakSavepoint {
inner: Arc::downgrade(&self.inner),
}
}
}
#[async_trait::async_trait]
impl kanto::Savepoint for Savepoint {
fn as_any(&self) -> &dyn std::any::Any {
self
}
async fn rollback(self: Box<Self>) -> Result<(), KantoError> {
let transaction = self.inner.transaction.clone();
transaction.rollback_to_savepoint(*self).await
}
}
impl PartialEq<WeakSavepoint> for Savepoint {
fn eq(&self, other: &WeakSavepoint) -> bool {
std::sync::Weak::ptr_eq(&Arc::downgrade(&self.inner), &other.inner)
}
}

View file

@ -0,0 +1,208 @@
use std::sync::Arc;
use dashmap::DashMap;
use kanto::KantoError;
use kanto_meta_format_v1::SequenceId;
use crate::Database;
#[derive(Debug)]
struct SeqState {
// TODO optimize with atomics, put only the reserve-next-chunk part behind locking (here `DashMap`).
// TODO could be preemptively reserve next chunk in a separate thread / tokio task / etc?
cur: u64,
max: u64,
}
// For best results, there should be only one `SequenceTracker` per `kanto_backend_rocksdb::Database`.
pub(crate) struct SequenceTracker {
// Reserving sequence number chunks is explicitly done outside any transactions.
// You wouldn't want two concurrent transactions get the same ID and conflict later.
// Use RocksDB directly, not via `crate::Database`, to avoid a reference cycle.
rocksdb: rocky::Database,
column_family: Arc<rocky::ColumnFamily>,
cache: DashMap<SequenceId, SeqState>,
}
impl SequenceTracker {
pub(crate) fn new(rocksdb: rocky::Database, column_family: Arc<rocky::ColumnFamily>) -> Self {
SequenceTracker {
rocksdb,
column_family,
cache: DashMap::new(),
}
}
#[must_use]
const fn encode_value(seq: u64) -> [u8; 8] {
seq.to_le_bytes()
}
#[must_use]
fn decode_value(buf: &[u8]) -> Option<u64> {
let input: [u8; 8] = buf.try_into().ok()?;
let num = u64::from_le_bytes(input);
Some(num)
}
#[maybe_tracing::instrument(skip(self), ret, err)]
fn reserve_chunk(&self, seq_id: SequenceId) -> Result<SeqState, KantoError> {
let write_options = rocky::WriteOptions::new();
let tx_options = rocky::TransactionOptions::new();
let tx = self.rocksdb.transaction_begin(&write_options, &tx_options);
let key = Database::make_rocksdb_sequence_key(seq_id);
// TODO should we loop on transaction conflicts
let value = {
let read_opts = rocky::ReadOptions::new();
tx.get_for_update_pinned(
&self.column_family,
key,
&read_opts,
rocky::Exclusivity::Exclusive,
)
.map_err(|rocksdb_error| KantoError::Io {
op: kanto::error::Op::SequenceUpdate {
sequence_id: seq_id,
},
code: "p1w8mbq5acd6n",
error: std::io::Error::other(rocksdb_error),
})
}?;
let new_state = match value {
None => SeqState {
cur: 1,
max: kanto_tunables::SEQUENCE_RESERVATION_CHUNK_SIZE,
},
Some(bytes) => {
let bytes = bytes.as_ref();
tracing::trace!(?bytes, "loaded");
let cur = Self::decode_value(bytes).ok_or_else(|| KantoError::Corrupt {
// TODO repeating seq_id in two places in the error
// TODO unify seq_id -> sequence_id
op: kanto::error::Op::SequenceUpdate {
sequence_id: seq_id,
},
code: "cusj4jcx8qsih",
error: Box::new(kanto::error::CorruptError::SequenceData {
sequence_id: seq_id,
}),
})?;
let max = cur.saturating_add(kanto_tunables::SEQUENCE_RESERVATION_CHUNK_SIZE);
SeqState { cur, max }
}
};
tracing::trace!(?new_state);
{
let new_max = new_state.max.checked_add(1).ok_or(KantoError::Execution {
// TODO repeating seq_id in two places in the error
op: kanto::error::Op::SequenceUpdate {
sequence_id: seq_id,
},
code: "wwnmhoex9cx3q",
error: Box::new(kanto::error::ExecutionError::SequenceExhausted {
sequence_id: seq_id,
}),
})?;
let value = Self::encode_value(new_max);
tx.put(&self.column_family, key, value)
.map_err(|rocksdb_error| KantoError::Io {
op: kanto::error::Op::SequenceUpdate {
sequence_id: seq_id,
},
code: "iyiqedewzpopw",
error: std::io::Error::other(rocksdb_error),
})?;
tx.commit().map_err(|rocksdb_error| KantoError::Io {
op: kanto::error::Op::SequenceUpdate {
sequence_id: seq_id,
},
code: "5y5h6iy5nttws",
error: std::io::Error::other(rocksdb_error),
})?;
}
Ok(new_state)
}
#[maybe_tracing::instrument(skip(self), ret, err)]
pub(crate) fn inc(&self, seq_id: SequenceId) -> Result<u64, KantoError> {
match self.cache.entry(seq_id) {
dashmap::mapref::entry::Entry::Occupied(mut entry) => {
let state = entry.get_mut();
state.cur = state.cur.checked_add(1).ok_or(KantoError::Execution {
// TODO repeating seq_id in two places in the error
op: kanto::error::Op::SequenceUpdate {
sequence_id: seq_id,
},
code: "5jsjk93zjr1sq",
error: Box::new(kanto::error::ExecutionError::SequenceExhausted {
sequence_id: seq_id,
}),
})?;
if state.cur > state.max {
*state = self.reserve_chunk(seq_id)?;
}
Ok(state.cur)
}
dashmap::mapref::entry::Entry::Vacant(entry) => {
let state = self.reserve_chunk(seq_id)?;
let cur = state.cur;
let _state_mut = entry.insert(state);
Ok(cur)
}
}
}
#[maybe_tracing::instrument(skip(self), ret, err)]
pub(crate) fn inc_n(
&self,
seq_id: SequenceId,
n: usize,
) -> Result<kanto::arrow::array::UInt64Array, KantoError> {
let mut builder = kanto::arrow::array::UInt64Builder::with_capacity(n);
// TODO optimize this
for _row_idx in 0..n {
let seq = self.inc(seq_id)?;
builder.append_value(seq);
}
let array = builder.finish();
Ok(array)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[maybe_tracing::test]
fn seq_simple() {
let (_dir, backend) = crate::create_temp().unwrap();
let column_family = backend.inner.rocksdb.get_column_family("default").unwrap();
{
let tracker =
SequenceTracker::new(backend.inner.rocksdb.clone(), column_family.clone());
let seq_a = SequenceId::new(42u64).unwrap();
let seq_b = SequenceId::new(13u64).unwrap();
assert_eq!(tracker.inc(seq_a).unwrap(), 1);
assert_eq!(tracker.inc(seq_a).unwrap(), 2);
assert_eq!(tracker.inc(seq_b).unwrap(), 1);
assert_eq!(tracker.inc(seq_a).unwrap(), 3);
drop(tracker);
}
{
let tracker =
SequenceTracker::new(backend.inner.rocksdb.clone(), column_family.clone());
let seq_a = SequenceId::new(42u64).unwrap();
let seq_b = SequenceId::new(13u64).unwrap();
assert_eq!(
tracker.inc(seq_a).unwrap(),
kanto_tunables::SEQUENCE_RESERVATION_CHUNK_SIZE + 1
);
assert_eq!(
tracker.inc(seq_b).unwrap(),
kanto_tunables::SEQUENCE_RESERVATION_CHUNK_SIZE + 1
);
drop(tracker);
}
}
}

View file

@ -0,0 +1,70 @@
use std::sync::Arc;
use kanto_meta_format_v1::TableId;
pub(crate) struct RocksDbTable<TableDefSource>
where
TableDefSource: rkyv_util::owned::StableBytes + Clone,
{
pub(crate) inner: Arc<InnerRocksDbTable<TableDefSource>>,
}
impl<TableDefSource> Clone for RocksDbTable<TableDefSource>
where
TableDefSource: rkyv_util::owned::StableBytes + Clone,
{
fn clone(&self) -> Self {
RocksDbTable {
inner: self.inner.clone(),
}
}
}
pub(crate) struct InnerRocksDbTable<TableDefSource>
where
TableDefSource: rkyv_util::owned::StableBytes + Clone,
{
pub(crate) tx: crate::Transaction,
pub(crate) data_cf: Arc<rocky::ColumnFamily>,
pub(crate) table_id: TableId,
pub(crate) table_def:
rkyv_util::owned::OwnedArchive<kanto_meta_format_v1::table_def::TableDef, TableDefSource>,
}
impl<TableDefSource> RocksDbTable<TableDefSource>
where
TableDefSource: rkyv_util::owned::StableBytes + Clone,
{
pub(crate) fn new(
tx: crate::Transaction,
data_cf: Arc<rocky::ColumnFamily>,
table_id: TableId,
table_def: rkyv_util::owned::OwnedArchive<
kanto_meta_format_v1::table_def::TableDef,
TableDefSource,
>,
) -> Self {
let inner = Arc::new(InnerRocksDbTable {
tx,
data_cf,
table_id,
table_def,
});
RocksDbTable { inner }
}
pub(crate) fn open_table_provider(
&self,
) -> crate::table_provider::RocksDbTableProvider<TableDefSource> {
crate::table_provider::RocksDbTableProvider::new(self.clone())
}
}
impl<TableDefSource> std::fmt::Debug for RocksDbTable<TableDefSource>
where
TableDefSource: rkyv_util::owned::StableBytes + Clone,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RocksDbTable").finish()
}
}

View file

@ -0,0 +1,84 @@
use std::sync::Arc;
pub(crate) struct RocksDbTableProvider<TableDefSource>
where
TableDefSource: rkyv_util::owned::StableBytes + Clone,
{
table: crate::table::RocksDbTable<TableDefSource>,
}
impl<TableDefSource> RocksDbTableProvider<TableDefSource>
where
TableDefSource: rkyv_util::owned::StableBytes + Clone,
{
pub(crate) const fn new(
table: crate::table::RocksDbTable<TableDefSource>,
) -> RocksDbTableProvider<TableDefSource> {
RocksDbTableProvider { table }
}
pub(crate) fn name(&self) -> datafusion::sql::ResolvedTableReference {
datafusion::sql::ResolvedTableReference {
catalog: Arc::from(self.table.inner.table_def.catalog.as_str()),
schema: Arc::from(self.table.inner.table_def.schema.as_str()),
table: Arc::from(self.table.inner.table_def.table_name.as_str()),
}
}
}
#[async_trait::async_trait]
impl<TableDefSource> datafusion::catalog::TableProvider for RocksDbTableProvider<TableDefSource>
where
TableDefSource: rkyv_util::owned::StableBytes + Clone + Send + Sync + 'static,
{
// We choose not to implement `insert_into` here, since we'd need to implement `UPDATE` and `DELETE` independently anyway.
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn table_type(&self) -> datafusion::logical_expr::TableType {
datafusion::logical_expr::TableType::Base
}
#[maybe_tracing::instrument(ret)]
fn schema(&self) -> datafusion::arrow::datatypes::SchemaRef {
// TODO maybe cache on first use, or always precompute #performance
kanto_meta_format_v1::table_def::make_arrow_schema(
self.table
.inner
.table_def
.fields
.values()
.filter(|field| field.is_live()),
)
}
#[maybe_tracing::instrument(skip(_session), ret, err)]
async fn scan(
&self,
_session: &dyn datafusion::catalog::Session,
projection: Option<&Vec<usize>>,
filters: &[datafusion::prelude::Expr],
_limit: Option<usize>,
) -> Result<Arc<dyn datafusion::physical_plan::ExecutionPlan>, datafusion::error::DataFusionError>
{
assert!(
filters.is_empty(),
"filters should be empty until we declare we can support them"
);
let exec_plan = crate::table_scan::RocksDbTableScan::new(self.table.clone(), projection)?;
Ok(Arc::new(exec_plan))
}
}
impl<TableDefSource> std::fmt::Debug for RocksDbTableProvider<TableDefSource>
where
TableDefSource: rkyv_util::owned::StableBytes + Clone,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RocksDbTableProvider")
.field("table", &self.table)
.finish()
}
}

View file

@ -0,0 +1,399 @@
use std::collections::HashSet;
use std::sync::Arc;
use gat_lending_iterator::LendingIterator as _;
use kanto::KantoError;
use kanto_meta_format_v1::table_def::ArchivedTableDef;
use kanto_meta_format_v1::FieldId;
use crate::Database;
pub(crate) struct RocksDbTableScan<TableDefSource>
where
TableDefSource: rkyv_util::owned::StableBytes + Clone,
{
inner: Arc<InnerRocksDbTableScan<TableDefSource>>,
}
impl<TableDefSource> Clone for RocksDbTableScan<TableDefSource>
where
TableDefSource: rkyv_util::owned::StableBytes + Clone,
{
fn clone(&self) -> Self {
RocksDbTableScan {
inner: self.inner.clone(),
}
}
}
struct InnerRocksDbTableScan<TableDefSource>
where
TableDefSource: rkyv_util::owned::StableBytes + Clone,
{
table: crate::table::RocksDbTable<TableDefSource>,
wanted_field_ids: Arc<HashSet<FieldId>>,
df_plan_properties: datafusion::physical_plan::PlanProperties,
}
impl<TableDefSource> RocksDbTableScan<TableDefSource>
where
TableDefSource: rkyv_util::owned::StableBytes + Clone,
{
#[maybe_tracing::instrument(ret, err)]
fn make_wanted_field_ids(
table_def: &ArchivedTableDef,
projection: Option<&Vec<usize>>,
) -> Result<HashSet<FieldId>, KantoError> {
// It seems [`TableProvider::scan`] required handling of `projection` is just limiting what columns to produce (for optimization), while something higher up reorders and duplicates the output columns as required.
//
// Whether a `TableProvider` decides to implement that optimization or not , the produced `RecordBatch` and the return value of `ExecutionPlan::schema` must agree.
//
// We're doing projection by `HashSet` which makes the results always come out in table column order, which could break `SELECT col2, col1 FROM mytable`.
// We do have test coverage, but that might be unsatisfactory.
//
// TODO is projection guaranteed to be in-order? #ecosystem/datafusion
debug_assert!(projection.is_none_or(|p| p.is_sorted()));
// DataFusion only knows about live fields, so only look at those.
let live_fields = table_def
.fields
.values()
.filter(|field| field.is_live())
.collect::<Box<[_]>>();
let wanted_field_ids = if let Some(projection) = projection {
projection.iter()
.map(|col_idx| {
live_fields.get(*col_idx)
.map(|f| f.field_id.to_native())
.ok_or_else(|| KantoError::Internal {
code: "6x3tyc9c4iq6k",
// TODO why is rustfmt broken here #ecosystem/rust
error: format!("datafusion projection column index out of bounds: {col_idx} not in 0..{len}", len=live_fields.len()).into(),
})
})
.collect::<Result<HashSet<FieldId>, _>>()?
} else {
live_fields
.iter()
.map(|field| field.field_id.to_native())
.collect::<HashSet<FieldId>>()
};
Ok(wanted_field_ids)
}
pub(crate) fn new(
table: crate::table::RocksDbTable<TableDefSource>,
projection: Option<&Vec<usize>>,
) -> Result<Self, KantoError> {
let wanted_field_ids = Self::make_wanted_field_ids(&table.inner.table_def, projection)?;
let wanted_field_ids = Arc::new(wanted_field_ids);
// TODO cache somewhere explicitly; for now we're riding on `self.inner.eq_properties`
let schema = {
let fields = table
.inner
.table_def
.fields
.values()
// guaranteed to only contain live fields
.filter(|field| wanted_field_ids.contains(&field.field_id.to_native()));
#[cfg(debug_assertions)]
let fields = fields.inspect(|field| debug_assert!(field.is_live()));
kanto_meta_format_v1::table_def::make_arrow_schema(fields)
};
let df_plan_properties = {
let eq_properties = datafusion::physical_expr::EquivalenceProperties::new(schema);
let partitioning = datafusion::physical_plan::Partitioning::UnknownPartitioning(1);
let emission_type =
datafusion::physical_plan::execution_plan::EmissionType::Incremental;
let boundedness = datafusion::physical_plan::execution_plan::Boundedness::Bounded;
datafusion::physical_plan::PlanProperties::new(
eq_properties,
partitioning,
emission_type,
boundedness,
)
};
let inner = Arc::new(InnerRocksDbTableScan {
table,
wanted_field_ids,
df_plan_properties,
});
Ok(RocksDbTableScan { inner })
}
}
impl<TableDefSource> std::fmt::Debug for RocksDbTableScan<TableDefSource>
where
TableDefSource: rkyv_util::owned::StableBytes + Clone,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RocksDbTableScan")
.field("table", &self.inner.table)
.finish()
}
}
impl<TableDefSource> datafusion::physical_plan::display::DisplayAs
for RocksDbTableScan<TableDefSource>
where
TableDefSource: rkyv_util::owned::StableBytes + Clone,
{
fn fmt_as(
&self,
t: datafusion::physical_plan::DisplayFormatType,
f: &mut std::fmt::Formatter<'_>,
) -> std::fmt::Result {
match t {
datafusion::physical_plan::DisplayFormatType::Default
| datafusion::physical_plan::DisplayFormatType::Verbose => f
.debug_struct("RocksDbTableScan")
.field("table", &self.inner.table.inner.table_id)
.finish(),
}
}
}
impl<TableDefSource> datafusion::physical_plan::ExecutionPlan for RocksDbTableScan<TableDefSource>
where
TableDefSource: rkyv_util::owned::StableBytes + Clone + Send + Sync + 'static,
{
fn name(&self) -> &str {
Self::static_name()
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
#[maybe_tracing::instrument(ret)]
fn properties(&self) -> &datafusion::physical_plan::PlanProperties {
&self.inner.df_plan_properties
}
#[maybe_tracing::instrument(ret)]
fn schema(&self) -> datafusion::arrow::datatypes::SchemaRef {
self.inner.df_plan_properties.eq_properties.schema().clone()
}
#[maybe_tracing::instrument(ret)]
fn children(&self) -> Vec<&Arc<(dyn datafusion::physical_plan::ExecutionPlan + 'static)>> {
vec![]
}
#[maybe_tracing::instrument(ret, err)]
fn with_new_children(
self: Arc<Self>,
children: Vec<Arc<dyn datafusion::physical_plan::ExecutionPlan>>,
) -> Result<Arc<dyn datafusion::physical_plan::ExecutionPlan>, datafusion::error::DataFusionError>
{
// TODO i don't really understand why ExecutionPlan children (aka inputs) can be replaced at this stage (seems to be blurring query optimization passes vs execution time) but we're a table scan, we have no inputs! #ecosystem/datafusion
//
// DataFusion implementations are all over the place: silent ignore, return error, `unimplemented!`, `unreachable!`; some check len of new children, some assume.
// Then there's `with_new_children_if_necessary` helper that does check len.
if children.is_empty() {
// do nothing
Ok(self)
} else {
tracing::error!(?children, "unsupported: with_new_children");
let error = KantoError::Internal {
code: "4fsybg7tj9th1",
error: "unsupported: RocksDbTableScan with_new_children".into(),
};
Err(error.into())
}
}
#[maybe_tracing::instrument(skip(context), err)]
fn execute(
&self,
partition: usize,
context: Arc<datafusion::execution::context::TaskContext>,
) -> Result<
datafusion::physical_plan::SendableRecordBatchStream,
datafusion::error::DataFusionError,
> {
let schema = self.schema();
// we could support multiple partitions (find good keys to partition by). but we do not do that at this time
if partition > 0 {
let empty = datafusion::physical_plan::EmptyRecordBatchStream::new(schema);
return Ok(Box::pin(empty));
}
let batch_size = context.session_config().batch_size();
let table_id = self.inner.table.inner.table_id;
let table_key_prefix = Database::make_rocksdb_record_key_prefix(table_id);
// TODO once we support filters or partitioning, start and prefix might no longer be the same
let table_key_start = table_key_prefix;
// TODO once we support filters or partitioning, stop_before gets more complex
let table_key_stop_before = Database::make_rocksdb_record_key_stop_before(table_id);
tracing::trace!(
?table_key_start,
?table_key_stop_before,
"execute table scan"
);
let iter = {
let read_opts = self.inner.table.inner.tx.make_read_options();
self.inner.table.inner.tx.inner.db_tx.range(
read_opts,
&self.inner.table.inner.data_cf,
table_key_start..table_key_stop_before,
)
};
let record_reader = kanto_record_format_v1::RecordReader::new_from_table_def(
self.inner.table.inner.table_def.clone(),
self.inner.wanted_field_ids.clone(),
batch_size,
);
let stream = RocksDbTableScanStream {
table_id,
batch_size,
iter,
record_reader,
};
let stream =
datafusion::physical_plan::stream::RecordBatchStreamAdapter::new(schema, stream);
Ok(Box::pin(stream))
}
#[maybe_tracing::instrument(ret, err)]
fn statistics(
&self,
) -> Result<datafusion::physical_plan::Statistics, datafusion::error::DataFusionError> {
let schema = self.schema();
let column_statistics: Vec<_> = schema
.fields
.iter()
.map(|_field| datafusion::physical_plan::ColumnStatistics {
null_count: datafusion::common::stats::Precision::Absent,
max_value: datafusion::common::stats::Precision::Absent,
min_value: datafusion::common::stats::Precision::Absent,
sum_value: datafusion::common::stats::Precision::Absent,
distinct_count: datafusion::common::stats::Precision::Absent,
})
.collect();
let statistics = datafusion::physical_plan::Statistics {
num_rows: datafusion::common::stats::Precision::Absent,
total_byte_size: datafusion::common::stats::Precision::Absent,
column_statistics,
};
Ok(statistics)
}
}
struct RocksDbTableScanStream<TableDefSource>
where
TableDefSource: rkyv_util::owned::StableBytes + Clone,
{
table_id: kanto_meta_format_v1::TableId,
batch_size: usize,
iter: rocky::Iter,
record_reader: kanto_record_format_v1::RecordReader<
rkyv_util::owned::OwnedArchive<kanto_meta_format_v1::table_def::TableDef, TableDefSource>,
>,
}
impl<TableDefSource> RocksDbTableScanStream<TableDefSource>
where
TableDefSource: rkyv_util::owned::StableBytes + Clone,
{
fn fill_record_reader_sync(&mut self) -> Result<(), KantoError> {
while let Some(entry) =
self.iter
.next()
.transpose()
.map_err(|rocksdb_error| KantoError::Io {
op: kanto::error::Op::TableScan {
table_id: self.table_id,
},
code: "nfiih93dhser6",
error: std::io::Error::other(rocksdb_error),
})?
{
let key = entry.key();
let value = entry.value();
tracing::trace!(?key, ?value, "iterator saw");
self.record_reader
.push_record(value)
.map_err(|record_error| {
use kanto_record_format_v1::error::RecordError;
match record_error {
RecordError::Corrupt { error } => KantoError::Corrupt {
op: kanto::error::Op::TableScan {
table_id: self.table_id,
},
code: "sf9f9qxnszuxr",
error: Box::new(kanto::error::CorruptError::InvalidRecord { error }),
},
error @ (RecordError::InvalidSchema { .. }
| RecordError::InvalidData { .. }
| RecordError::Internal { .. }
| RecordError::InternalArrow { .. }) => KantoError::Internal {
code: "f6j8ckqoua6aa",
error: Box::new(error),
},
}
})?;
if self.record_reader.num_records() >= self.batch_size {
break;
}
}
Ok(())
}
fn build_record_batch(&mut self) -> Result<kanto::arrow::array::RecordBatch, KantoError> {
let record_batch = self.record_reader.build().map_err(|record_error| {
use kanto_record_format_v1::error::RecordError;
match record_error {
RecordError::Corrupt { error } => KantoError::Corrupt {
op: kanto::error::Op::TableScan {
table_id: self.table_id,
},
code: "wm5sejkimmi94",
error: Box::new(kanto::error::CorruptError::InvalidRecord { error }),
},
error @ (RecordError::InvalidSchema { .. }
| RecordError::InvalidData { .. }
| RecordError::Internal { .. }
| RecordError::InternalArrow { .. }) => KantoError::Internal {
code: "qbqhmms3ghm16",
error: Box::new(error),
},
}
})?;
Ok(record_batch)
}
}
impl<TableDefSource> Unpin for RocksDbTableScanStream<TableDefSource> where
TableDefSource: rkyv_util::owned::StableBytes + Clone
{
}
impl<TableDefSource> futures_lite::Stream for RocksDbTableScanStream<TableDefSource>
where
TableDefSource: rkyv_util::owned::StableBytes + Clone,
{
type Item = Result<kanto::arrow::array::RecordBatch, datafusion::error::DataFusionError>;
fn poll_next(
mut self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
if let Err(error) = self.fill_record_reader_sync() {
return std::task::Poll::Ready(Some(Err(error.into())));
}
if self.record_reader.num_records() == 0 {
return std::task::Poll::Ready(None);
}
match self.build_record_batch() {
Ok(record_batch) => std::task::Poll::Ready(Some(Ok(record_batch))),
Err(error) => std::task::Poll::Ready(Some(Err(error.into()))),
}
}
}

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,141 @@
#![expect(
missing_docs,
reason = "rustc lint bug <https://github.com/rust-lang/rust/issues/137561> #ecosystem/rust #waiting"
)]
use futures_lite::StreamExt as _;
use kanto_testutil::assert_batches_eq;
// TODO write helpers to make these tests as simple as sqllogictest.
// Using Rust is better for e.g. running two sessions concurrently, and the result verification is better here.
#[test_log::test(tokio::test(flavor = "multi_thread"))]
async fn create_table_simple() -> Result<(), String> {
let (_dir, backend) = kanto_backend_rocksdb::create_temp().unwrap();
{
let mut session = backend.test_session();
let stream = session
.sql("CREATE TABLE myschema.mytable (mycolumn BIGINT NOT NULL, PRIMARY KEY(mycolumn));")
.await
.unwrap();
let batches = stream.try_collect::<_, _, Vec<_>>().await.unwrap();
assert_batches_eq!(
r"
++
++
",
&batches,
);
}
{
let mut session = backend.test_session();
let stream = session
.sql(
"SELECT * FROM kanto.information_schema.tables WHERE table_schema<>'information_schema';",
)
.await
.unwrap();
let batches = stream.try_collect::<_, _, Vec<_>>().await.unwrap();
assert_batches_eq!(
r"
+---------------+--------------+------------+------------+
| table_catalog | table_schema | table_name | table_type |
+---------------+--------------+------------+------------+
| kanto | myschema | mytable | BASE TABLE |
+---------------+--------------+------------+------------+
",
&batches,
);
let stream = session
.sql("SELECT column_name, data_type FROM myschema.information_schema.columns WHERE table_catalog='kanto' AND table_schema='myschema' AND table_name='mytable';")
.await
.unwrap();
let batches = stream.try_collect::<_, _, Vec<_>>().await.unwrap();
assert_batches_eq!(
r"
+-------------+-----------+
| column_name | data_type |
+-------------+-----------+
| mycolumn | Int64 |
+-------------+-----------+
",
&batches,
);
drop(session);
}
Ok(())
}
#[test_log::test(tokio::test(flavor = "multi_thread"))]
async fn create_table_without_primary_key_uses_rowid() {
let (_dir, backend) = kanto_backend_rocksdb::create_temp().unwrap();
let mut session = backend.test_session();
{
let stream = session
.sql("CREATE TABLE myschema.mytable (mycolumn BIGINT);")
.await
.unwrap();
let batches = stream.try_collect::<_, _, Vec<_>>().await.unwrap();
assert_batches_eq!(
r"
++
++
",
&batches,
);
}
{
let stream = session
.sql("SELECT column_name, data_type FROM myschema.information_schema.columns WHERE table_catalog='kanto' AND table_schema='myschema' AND table_name='mytable';")
.await
.unwrap();
let batches = stream.try_collect::<_, _, Vec<_>>().await.unwrap();
assert_batches_eq!(
r"
+-------------+-----------+
| column_name | data_type |
+-------------+-----------+
| mycolumn | Int64 |
+-------------+-----------+
",
&batches,
);
}
drop(session);
}
#[test_log::test(tokio::test(flavor = "multi_thread"))]
async fn create_table_or_replace_not_implemented() {
let (_dir, backend) = kanto_backend_rocksdb::create_temp().unwrap();
let mut session = backend.test_session();
let result = session
.sql("CREATE OR REPLACE TABLE myschema.mytable (mycolumn BIGINT);")
.await;
match result {
Ok(_record_batch_stream) => panic!("unexpected success"),
Err(datafusion::error::DataFusionError::External(error)) => {
let error = error
.downcast::<kanto::KantoError>()
.expect("wrong kind of error");
assert!(
matches!(
*error,
kanto::KantoError::UnimplementedSql {
code: "65so99qr1hyck",
sql_syntax: "CREATE OR REPLACE TABLE"
}
),
"wrong error: {error:?}",
);
}
Err(error) => panic!("wrong error: {error:?}"),
}
}

View file

@ -0,0 +1,66 @@
#![expect(
missing_docs,
reason = "rustc lint bug <https://github.com/rust-lang/rust/issues/137561> #ecosystem/rust #waiting"
)]
use futures_lite::stream::StreamExt as _;
use kanto_testutil::assert_batches_eq;
#[test_log::test(tokio::test(flavor = "multi_thread"))]
async fn insert_select() -> Result<(), String> {
let (_dir, backend) = kanto_backend_rocksdb::create_temp().unwrap();
{
let mut session = backend.test_session();
let stream = session
.sql("CREATE TABLE kanto.myschema.mytable (a BIGINT NOT NULL, b BIGINT, PRIMARY KEY(a));")
.await
.unwrap();
let batches = stream.try_collect::<_, _, Vec<_>>().await.unwrap();
assert_batches_eq!(
r"
++
++
",
&batches,
);
}
{
let mut session = backend.test_session();
let stream = session
.sql("INSERT kanto.myschema.mytable (a, b) VALUES (42, 7), (13, 34);")
.await
.unwrap();
let batches = stream.try_collect::<_, _, Vec<_>>().await.unwrap();
assert_batches_eq!(
r"
++
++
",
&batches,
);
}
{
let mut session = backend.test_session();
let stream = session
.sql("SELECT * FROM kanto.myschema.mytable ORDER BY a;")
.await
.unwrap();
let batches = stream.try_collect::<_, _, Vec<_>>().await.unwrap();
assert_batches_eq!(
r"
+----+----+
| a | b |
+----+----+
| 13 | 34 |
| 42 | 7 |
+----+----+
",
&batches,
);
}
Ok(())
}

View file

@ -0,0 +1,50 @@
#![expect(
missing_docs,
reason = "rustc lint bug <https://github.com/rust-lang/rust/issues/137561> #ecosystem/rust #waiting"
)]
use futures_lite::StreamExt as _;
use kanto_testutil::assert_batches_eq;
// TODO write helpers to make these tests as simple as sqllogictest.
// Using Rust is better for e.g. running two sessions concurrently, and the result verification is better here.
#[test_log::test(tokio::test(flavor = "multi_thread"))]
async fn reopen_database() -> Result<(), String> {
let (dir, backend) = kanto_backend_rocksdb::create_temp().unwrap();
{
let mut session = backend.test_session();
let stream = session
.sql("CREATE TABLE myschema.mytable (mycolumn BIGINT NOT NULL, PRIMARY KEY(mycolumn));")
.await
.unwrap();
let batches = stream.try_collect::<_, _, Vec<_>>().await.unwrap();
assert_batches_eq!(
r"
++
++
",
&batches,
);
}
drop(backend);
let backend = kanto_backend_rocksdb::open(dir.path()).expect("database open");
{
let mut session = kanto::Session::test_session(Box::new(backend));
let stream = session
.sql("INSERT myschema.mytable (mycolumn) VALUES (42);")
.await
.unwrap();
let batches = stream.try_collect::<_, _, Vec<_>>().await.unwrap();
assert_batches_eq!(
r"
++
++
",
&batches,
);
}
Ok(())
}

View file

@ -0,0 +1,29 @@
#![expect(
missing_docs,
reason = "rustc lint bug <https://github.com/rust-lang/rust/issues/137561> #ecosystem/rust #waiting"
)]
use futures_lite::StreamExt as _;
use kanto_testutil::assert_batches_eq;
#[test_log::test(tokio::test(flavor = "multi_thread"))]
async fn select_constant() -> Result<(), String> {
let (_dir, backend) = kanto_backend_rocksdb::create_temp().unwrap();
let mut session = backend.test_session();
let stream = session
.sql("SELECT 'Hello, world' AS greeting;")
.await
.unwrap();
let batches = stream.try_collect::<_, _, Vec<_>>().await.unwrap();
assert_batches_eq!(
r"
+--------------+
| greeting |
+--------------+
| Hello, world |
+--------------+
",
&batches,
);
Ok(())
}

View file

@ -0,0 +1,184 @@
#![expect(
missing_docs,
reason = "rustc lint bug <https://github.com/rust-lang/rust/issues/137561> #ecosystem/rust #waiting"
)]
use futures_lite::stream::StreamExt as _;
use kanto_testutil::assert_batches_eq;
#[test_log::test(tokio::test(flavor = "multi_thread"))]
async fn transaction_implicit_rollback_create_table() {
let (_dir, backend) = kanto_backend_rocksdb::create_temp().unwrap();
let mut session = backend.test_session();
{
let stream = session
.sql("START TRANSACTION ISOLATION LEVEL REPEATABLE READ;")
.await
.unwrap();
let batches = stream.try_collect::<_, _, Vec<_>>().await.unwrap();
assert_batches_eq!(
r"
++
++
",
&batches,
);
}
{
let stream = session
.sql("CREATE TABLE myschema.mytable (mycolumn BIGINT NOT NULL, PRIMARY KEY(mycolumn));")
.await
.unwrap();
let batches = stream.try_collect::<_, _, Vec<_>>().await.unwrap();
assert_batches_eq!(
r"
++
++
",
&batches,
);
}
{
let stream = session.sql("ROLLBACK;").await.unwrap();
let batches = stream.try_collect::<_, _, Vec<_>>().await.unwrap();
assert_batches_eq!(
r"
++
++
",
&batches,
);
}
{
let stream = session
.sql(
"SELECT * FROM kanto.information_schema.tables WHERE table_schema<>'information_schema';",
)
.await
.unwrap();
let batches = stream.try_collect::<_, _, Vec<_>>().await.unwrap();
assert_batches_eq!(
r"
++
++
",
&batches,
);
}
}
#[test_log::test(tokio::test(flavor = "multi_thread"))]
async fn transaction_changes_not_concurrently_visible() {
let (_dir, backend) = kanto_backend_rocksdb::create_temp().unwrap();
let mut session_a = backend.test_session();
let mut session_b = backend.test_session();
{
let stream = session_a
.sql("CREATE TABLE myschema.mytable (mycolumn BIGINT NOT NULL, PRIMARY KEY(mycolumn));")
.await
.unwrap();
let batches = stream.try_collect::<_, _, Vec<_>>().await.unwrap();
assert_batches_eq!(
r"
++
++
",
&batches,
);
}
{
let stream = session_a
.sql("START TRANSACTION ISOLATION LEVEL REPEATABLE READ;")
.await
.unwrap();
let batches = stream.try_collect::<_, _, Vec<_>>().await.unwrap();
assert_batches_eq!(
r"
++
++
",
&batches,
);
}
{
let stream = session_a
.sql("INSERT INTO kanto.myschema.mytable (mycolumn) VALUES (42), (13);")
.await
.unwrap();
let batches = stream.try_collect::<_, _, Vec<_>>().await.unwrap();
assert_batches_eq!(
r"
++
++
",
&batches,
);
}
{
let stream = session_b
.sql("SELECT * FROM kanto.myschema.mytable;")
.await
.unwrap();
let batches = stream.try_collect::<_, _, Vec<_>>().await.unwrap();
assert_batches_eq!(
r"
++
++
",
&batches,
);
}
{
let stream = session_b
.sql("START TRANSACTION ISOLATION LEVEL REPEATABLE READ;")
.await
.unwrap();
let batches = stream.try_collect::<_, _, Vec<_>>().await.unwrap();
assert_batches_eq!(
r"
++
++
",
&batches,
);
}
{
let stream = session_a.sql("COMMIT;").await.unwrap();
let batches = stream.try_collect::<_, _, Vec<_>>().await.unwrap();
assert_batches_eq!(
r"
++
++
",
&batches,
);
}
{
// Snapshot isolation -> won't see values even if they are committed.
let stream = session_b
.sql("SELECT * FROM kanto.myschema.mytable;")
.await
.unwrap();
let batches = stream.try_collect::<_, _, Vec<_>>().await.unwrap();
assert_batches_eq!(
r"
++
++
",
&batches,
);
}
}

View file

@ -0,0 +1,21 @@
[package]
name = "kanto-index-format-v1"
version = "0.1.0"
description = "Low-level data format helper for KantoDB database"
homepage.workspace = true
repository.workspace = true
authors.workspace = true
license.workspace = true
publish = false # TODO publish crate #severity/high #urgency/medium
edition.workspace = true
rust-version.workspace = true
[dependencies]
datafusion = { workspace = true }
kanto-key-format-v1 = { workspace = true }
kanto-meta-format-v1 = { workspace = true }
rkyv = { workspace = true }
thiserror = { workspace = true }
[lints]
workspace = true

View file

@ -0,0 +1,529 @@
//! Internal storage formats for KantoDB.
//!
//! Unless you are implementing a `kanto::Backend`, you should **not** be using these directly.
//!
//! The entry point to this module is [`Index::new`].
use std::borrow::Cow;
use std::sync::Arc;
use datafusion::arrow::array::Array;
use datafusion::arrow::record_batch::RecordBatch;
use kanto_meta_format_v1::index_def::ArchivedExpr;
use kanto_meta_format_v1::index_def::ArchivedIndexDef;
use kanto_meta_format_v1::index_def::ArchivedOrderByExpr;
use kanto_meta_format_v1::table_def::ArchivedFieldDef;
use kanto_meta_format_v1::table_def::ArchivedTableDef;
use kanto_meta_format_v1::FieldId;
/// Initialization-time errors that are about the schema, not about the data.
#[derive(thiserror::Error, Debug)]
#[expect(
clippy::exhaustive_enums,
reason = "low-level package, callers will have to adapt to any change"
)]
pub enum IndexInitError {
#[expect(missing_docs)]
#[error(
"index expression field not found: {field_id}\nindex_def={index_def_debug:#?}\ntable_def={table_def_debug:#?}"
)]
IndexExprFieldIdNotFound {
field_id: FieldId,
index_def_debug: String,
table_def_debug: String,
},
}
/// Execution-time errors that are about the input data.
#[derive(thiserror::Error, Debug)]
#[expect(
clippy::exhaustive_enums,
reason = "low-level package, callers will have to adapt to any change"
)]
pub enum IndexExecutionError {
#[expect(missing_docs)]
#[error("indexed field missing from record batch: {field_name}")]
IndexedFieldNotInBatch { field_name: String },
}
/// Used for indexes where there can only be one value.
///
/// This happens with
///
/// - `UNIQUE` `NULLS NOT DISTINCT` index
/// - `UNIQUE` `NULLS DISTINCT` index where all indexed fields are `NOT NULL`
#[derive(rkyv::Archive, rkyv::Serialize)]
struct IndexValueUnique<'data> {
#[rkyv(with = rkyv::with::AsOwned)]
row_key: Cow<'data, [u8]>,
// TODO index included fields (#ref/unimplemented_sql/ip415b5s8sa6h)
}
/// Used for indexes where there may be multiple values.
/// In that case, the row key is stored in the key, allowing multiple values for the "same" index key.
///
/// This happens with
///
/// - `NOT UNIQUE` index
/// - `UNIQUE` `NULLS DISTINCT` index where indexed field can be `NULL`
#[derive(rkyv::Archive, rkyv::Serialize)]
struct IndexValueMulti<'data> {
/// Duplicated from key for ease of access.
// TODO remove?
#[rkyv(with = rkyv::with::AsOwned)]
row_key: Cow<'data, [u8]>,
// TODO index included fields (#ref/unimplemented_sql/ip415b5s8sa6h)
}
fn make_index_key_source_arrays(
field_defs: &[&ArchivedFieldDef],
record_batch: &RecordBatch,
) -> Result<Vec<Arc<dyn Array>>, IndexExecutionError> {
field_defs
.iter()
.map(|field| {
let want_name = &field.field_name;
let array = record_batch.column_by_name(want_name).ok_or_else(|| {
IndexExecutionError::IndexedFieldNotInBatch {
field_name: want_name.as_str().to_owned(),
}
})?;
Ok(array.clone())
})
.collect::<Result<Vec<_>, _>>()
}
/// Library for making index keys and values.
#[expect(clippy::exhaustive_enums)]
pub enum Index<'def> {
/// See [`UniqueIndex`].
Unique(UniqueIndex<'def>),
/// See [`MultiIndex`].
Multi(MultiIndex<'def>),
}
impl<'def> Index<'def> {
fn get_field<'table>(
table_def: &'table ArchivedTableDef,
index_def: &ArchivedIndexDef,
order_by_expr: &ArchivedOrderByExpr,
) -> Result<&'table ArchivedFieldDef, IndexInitError> {
match &order_by_expr.expr {
ArchivedExpr::Field { field_id } => {
let field = table_def.fields.get(field_id).ok_or_else(|| {
IndexInitError::IndexExprFieldIdNotFound {
field_id: field_id.to_native(),
index_def_debug: format!("{index_def:?}"),
table_def_debug: format!("{table_def:?}"),
}
})?;
Ok(field)
}
}
}
/// Make an index helper for the given `index_def` belonging to the given `table_def`.
///
/// # Examples
///
/// ```rust compile_fail
/// # // TODO `TableDef` is too complex to write out fully #dev #doc
/// let index = kanto_index_format_v1::Index::new(table_def, index_def);
/// match index {
/// kanto_index_format_v1::Index::Unique(unique_index) => {
/// let keys = unique_index.make_unique_index_keys(&record_batch);
/// }
/// kanto_index_format_v1::Index::Multi(multi_index) => { ... }
/// }
/// ```
pub fn new(
table_def: &'def ArchivedTableDef,
index_def: &'def ArchivedIndexDef,
) -> Result<Index<'def>, IndexInitError> {
debug_assert_eq!(index_def.table, table_def.table_id);
let field_defs = index_def
.columns
.iter()
.map(|order_by_expr| Self::get_field(table_def, index_def, order_by_expr))
.collect::<Result<Box<[_]>, _>>()?;
let has_nullable = field_defs.iter().any(|field| field.nullable);
let can_use_unique_index = match &index_def.index_kind {
kanto_meta_format_v1::index_def::ArchivedIndexKind::Unique(unique_index_options) => {
!unique_index_options.nulls_distinct || !has_nullable
}
kanto_meta_format_v1::index_def::ArchivedIndexKind::Multi => false,
};
if can_use_unique_index {
Ok(Index::Unique(UniqueIndex::new(
table_def, index_def, field_defs,
)))
} else {
Ok(Index::Multi(MultiIndex::new(
table_def, index_def, field_defs,
)))
}
}
}
/// A unique index (with no complications).
/// Each key stores exactly one value, which refers to a single row key.
pub struct UniqueIndex<'def> {
_table_def: &'def ArchivedTableDef,
index_def: &'def ArchivedIndexDef,
field_defs: Box<[&'def ArchivedFieldDef]>,
}
impl<'def> UniqueIndex<'def> {
#[must_use]
const fn new(
table_def: &'def ArchivedTableDef,
index_def: &'def ArchivedIndexDef,
field_defs: Box<[&'def ArchivedFieldDef]>,
) -> UniqueIndex<'def> {
UniqueIndex {
_table_def: table_def,
index_def,
field_defs,
}
}
/// Make index keys for the `record_batch`.
pub fn make_unique_index_keys(
&self,
record_batch: &RecordBatch,
) -> Result<datafusion::arrow::array::BinaryArray, IndexExecutionError> {
let index_key_source_arrays = make_index_key_source_arrays(&self.field_defs, record_batch)?;
let keys = kanto_key_format_v1::make_keys(&index_key_source_arrays);
Ok(keys)
}
/// Make the value to store in the index for this `row_key`.
#[must_use]
pub fn index_value(&self, row_key: &[u8]) -> Box<[u8]> {
assert!(
self.index_def.include.is_empty(),
"index included fields (#ref/unimplemented_sql/ip415b5s8sa6h)"
);
assert!(
self.index_def.predicate.is_none(),
"partial indexes (#ref/unimplemented_sql/9134bk3fe98x6)"
);
let value = IndexValueUnique {
row_key: Cow::from(row_key),
};
let bytes = rkyv::to_bytes::<rkyv::rancor::Panic>(&value).unwrap();
bytes.into_boxed_slice()
}
}
/// A multi-value index, where multiple records have the same key prefix.
/// To read record references from the index, a prefix scan is needed.
pub struct MultiIndex<'def> {
_table_def: &'def ArchivedTableDef,
index_def: &'def ArchivedIndexDef,
field_defs: Box<[&'def ArchivedFieldDef]>,
}
impl<'def> MultiIndex<'def> {
#[must_use]
const fn new(
table_def: &'def ArchivedTableDef,
index_def: &'def ArchivedIndexDef,
field_defs: Box<[&'def ArchivedFieldDef]>,
) -> MultiIndex<'def> {
MultiIndex {
_table_def: table_def,
index_def,
field_defs,
}
}
/// Make index keys for the `record_batch` with matching `row_keys`.
pub fn make_multi_index_keys(
&self,
record_batch: &RecordBatch,
row_keys: &datafusion::arrow::array::BinaryArray,
) -> Result<datafusion::arrow::array::BinaryArray, IndexExecutionError> {
let mut index_key_source_arrays =
make_index_key_source_arrays(&self.field_defs, record_batch)?;
index_key_source_arrays.push(Arc::new(row_keys.clone()));
let keys = kanto_key_format_v1::make_keys(&index_key_source_arrays);
Ok(keys)
}
/// Make the value to store in the index for this `row_key`.
#[must_use]
pub fn index_value(&self, row_key: &[u8]) -> Box<[u8]> {
assert!(
self.index_def.include.is_empty(),
"index included fields (#ref/unimplemented_sql/ip415b5s8sa6h)"
);
assert!(
self.index_def.predicate.is_none(),
"partial indexes (#ref/unimplemented_sql/9134bk3fe98x6)"
);
let value = IndexValueMulti {
row_key: Cow::from(row_key),
};
let bytes = rkyv::to_bytes::<rkyv::rancor::Panic>(&value).unwrap();
bytes.into_boxed_slice()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
#[expect(clippy::too_many_lines)]
fn unique_index() {
let fields = [
kanto_meta_format_v1::table_def::FieldDef {
field_id: FieldId::new(2).unwrap(),
added_in: kanto_meta_format_v1::table_def::TableDefGeneration::new(7).unwrap(),
deleted_in: None,
field_name: "a".to_owned(),
field_type: kanto_meta_format_v1::field_type::FieldType::U16,
nullable: false,
default_value_arrow_bytes: None,
},
kanto_meta_format_v1::table_def::FieldDef {
field_id: FieldId::new(3).unwrap(),
added_in: kanto_meta_format_v1::table_def::TableDefGeneration::new(8).unwrap(),
deleted_in: None,
field_name: "c".to_owned(),
field_type: kanto_meta_format_v1::field_type::FieldType::U32,
nullable: false,
default_value_arrow_bytes: None,
},
];
let table_def = kanto_meta_format_v1::table_def::TableDef {
table_id: kanto_meta_format_v1::TableId::new(42).unwrap(),
catalog: "mycatalog".to_owned(),
schema: "myschema".to_owned(),
table_name: "mytable".to_owned(),
generation: kanto_meta_format_v1::table_def::TableDefGeneration::new(13).unwrap(),
row_key_kind: kanto_meta_format_v1::table_def::RowKeyKind::PrimaryKeys(
kanto_meta_format_v1::table_def::PrimaryKeysDef {
field_ids: kanto_meta_format_v1::table_def::PrimaryKeys::try_from(vec![
fields[0].field_id,
])
.unwrap(),
},
),
fields: fields
.into_iter()
.map(|field| (field.field_id, field))
.collect(),
};
let owned_archived_table_def = kanto_meta_format_v1::owned_archived(&table_def).unwrap();
let index_def = kanto_meta_format_v1::index_def::IndexDef {
index_id: kanto_meta_format_v1::IndexId::new(42).unwrap(),
table: table_def.table_id,
index_name: "myindex".to_owned(),
columns: vec![
kanto_meta_format_v1::index_def::OrderByExpr {
expr: kanto_meta_format_v1::index_def::Expr::Field {
field_id: table_def.fields[1].field_id,
},
sort_order: kanto_meta_format_v1::index_def::SortOrder::Ascending,
null_order: kanto_meta_format_v1::index_def::NullOrder::NullsLast,
},
kanto_meta_format_v1::index_def::OrderByExpr {
expr: kanto_meta_format_v1::index_def::Expr::Field {
field_id: table_def.fields[0].field_id,
},
sort_order: kanto_meta_format_v1::index_def::SortOrder::Ascending,
null_order: kanto_meta_format_v1::index_def::NullOrder::NullsLast,
},
],
index_kind: kanto_meta_format_v1::index_def::IndexKind::Unique(
kanto_meta_format_v1::index_def::UniqueIndexOptions {
nulls_distinct: false,
},
),
include: vec![],
predicate: None,
};
let owned_archived_index_def = kanto_meta_format_v1::owned_archived(&index_def).unwrap();
let index = Index::new(&owned_archived_table_def, &owned_archived_index_def).unwrap();
// TODO `debug_assert_matches!` <https://doc.rust-lang.org/std/assert_matches/macro.debug_assert_matches.html>, <https://github.com/rust-lang/rust/issues/82775> #waiting #ecosystem/rust #severity/low #dev
assert!(matches!(index, Index::Unique(..)));
let unique_index = match index {
Index::Unique(unique_index) => unique_index,
Index::Multi(_) => panic!("index was supposed to be unique"),
};
{
let a: datafusion::arrow::array::ArrayRef =
Arc::new(datafusion::arrow::array::UInt16Array::from(vec![3, 4, 5]));
let c: datafusion::arrow::array::ArrayRef =
Arc::new(datafusion::arrow::array::UInt32Array::from(vec![
11, 10, 12,
]));
let record_batch = RecordBatch::try_from_iter([("a", a), ("c", c)]).unwrap();
let keys = unique_index.make_unique_index_keys(&record_batch).unwrap();
let got = keys
.into_iter()
// TODO `BinaryArray` non-nullable iteration: the `Option` here is ugly #ref/4b6891jiy4t5a #ecosystem/datafusion #severity/medium #urgency/low
.map(|option| option.unwrap())
.collect::<Vec<_>>();
assert_eq!(
&got,
&[
&[
1u8, 0, 0, 0, 11, // c
1u8, 0, 3, // a
],
&[
1u8, 0, 0, 0, 10, // c
1u8, 0, 4, // a
],
&[
1u8, 0, 0, 0, 12, // c
1u8, 0, 5, // a
],
],
);
}
{
const FAKE_ROW_KEY: &[u8] = b"fakerowkey";
let bytes = unique_index.index_value(FAKE_ROW_KEY);
let got_index_value =
rkyv::access::<rkyv::Archived<IndexValueUnique<'_>>, rkyv::rancor::Panic>(&bytes)
.unwrap();
assert_eq!(got_index_value.row_key.as_ref(), FAKE_ROW_KEY);
}
}
#[test]
#[expect(clippy::too_many_lines)]
fn multi_index() {
let fields = [
kanto_meta_format_v1::table_def::FieldDef {
field_id: FieldId::new(2).unwrap(),
added_in: kanto_meta_format_v1::table_def::TableDefGeneration::new(7).unwrap(),
deleted_in: None,
field_name: "a".to_owned(),
field_type: kanto_meta_format_v1::field_type::FieldType::U16,
nullable: false,
default_value_arrow_bytes: None,
},
kanto_meta_format_v1::table_def::FieldDef {
field_id: FieldId::new(3).unwrap(),
added_in: kanto_meta_format_v1::table_def::TableDefGeneration::new(8).unwrap(),
deleted_in: None,
field_name: "c".to_owned(),
field_type: kanto_meta_format_v1::field_type::FieldType::U32,
nullable: false,
default_value_arrow_bytes: None,
},
];
let table_def = kanto_meta_format_v1::table_def::TableDef {
table_id: kanto_meta_format_v1::TableId::new(42).unwrap(),
catalog: "mycatalog".to_owned(),
schema: "myschema".to_owned(),
table_name: "mytable".to_owned(),
generation: kanto_meta_format_v1::table_def::TableDefGeneration::new(13).unwrap(),
row_key_kind: kanto_meta_format_v1::table_def::RowKeyKind::PrimaryKeys(
kanto_meta_format_v1::table_def::PrimaryKeysDef {
field_ids: kanto_meta_format_v1::table_def::PrimaryKeys::try_from(vec![
fields[0].field_id,
])
.unwrap(),
},
),
fields: fields
.into_iter()
.map(|field| (field.field_id, field))
.collect(),
};
let owned_archived_table_def = kanto_meta_format_v1::owned_archived(&table_def).unwrap();
let index_def = kanto_meta_format_v1::index_def::IndexDef {
index_id: kanto_meta_format_v1::IndexId::new(42).unwrap(),
table: table_def.table_id,
index_name: "myindex".to_owned(),
columns: vec![
kanto_meta_format_v1::index_def::OrderByExpr {
expr: kanto_meta_format_v1::index_def::Expr::Field {
field_id: table_def.fields[1].field_id,
},
sort_order: kanto_meta_format_v1::index_def::SortOrder::Ascending,
null_order: kanto_meta_format_v1::index_def::NullOrder::NullsLast,
},
kanto_meta_format_v1::index_def::OrderByExpr {
expr: kanto_meta_format_v1::index_def::Expr::Field {
field_id: table_def.fields[0].field_id,
},
sort_order: kanto_meta_format_v1::index_def::SortOrder::Ascending,
null_order: kanto_meta_format_v1::index_def::NullOrder::NullsLast,
},
],
index_kind: kanto_meta_format_v1::index_def::IndexKind::Multi,
include: vec![],
predicate: None,
};
let owned_archived_index_def = kanto_meta_format_v1::owned_archived(&index_def).unwrap();
let index = Index::new(&owned_archived_table_def, &owned_archived_index_def).unwrap();
// TODO `debug_assert_matches!` <https://doc.rust-lang.org/std/assert_matches/macro.debug_assert_matches.html>, <https://github.com/rust-lang/rust/issues/82775> #waiting #ecosystem/rust #severity/low #dev
assert!(matches!(index, Index::Multi(..)));
let multi_index = match index {
Index::Unique(_) => panic!("index was supposed to be multi"),
Index::Multi(multi_index) => multi_index,
};
{
let a: datafusion::arrow::array::ArrayRef =
Arc::new(datafusion::arrow::array::UInt16Array::from(vec![3, 4, 5]));
let c: datafusion::arrow::array::ArrayRef =
Arc::new(datafusion::arrow::array::UInt32Array::from(vec![
11, 10, 12,
]));
let record_batch = RecordBatch::try_from_iter([("a", a.clone()), ("c", c)]).unwrap();
let row_keys =
datafusion::arrow::array::BinaryArray::from_vec(vec![b"foo", b"bar", b"quux"]);
let keys = multi_index
.make_multi_index_keys(&record_batch, &row_keys)
.unwrap();
let got = keys
.into_iter()
// TODO `BinaryArray` non-nullable iteration: the `Option` here is ugly #ref/4b6891jiy4t5a #ecosystem/datafusion #severity/medium #urgency/low
.map(|option| option.unwrap())
.collect::<Vec<_>>();
assert_eq!(
&got,
&[
&[
1u8, 0, 0, 0, 11, // c
1, 0, 3, // a
2, b'f', b'o', b'o', 0, 0, 0, 0, 0, 3 // row key
],
&[
1u8, 0, 0, 0, 10, // c
1, 0, 4, // a
2, b'b', b'a', b'r', 0, 0, 0, 0, 0, 3 // row key
],
&[
1u8, 0, 0, 0, 12, // c
1, 0, 5, // a
2, b'q', b'u', b'u', b'x', 0, 0, 0, 0, 4 // row key
],
],
);
}
{
const FAKE_ROW_KEY: &[u8] = b"fakerowkey";
let bytes = multi_index.index_value(FAKE_ROW_KEY);
let got_index_value =
rkyv::access::<rkyv::Archived<IndexValueMulti<'_>>, rkyv::rancor::Panic>(&bytes)
.unwrap();
assert_eq!(got_index_value.row_key.as_ref(), FAKE_ROW_KEY);
}
}
}

51
crates/kanto/Cargo.toml Normal file
View file

@ -0,0 +1,51 @@
[package]
name = "kanto"
version = "0.1.0"
description = "KantoDB SQL database as a library"
keywords = ["kantodb", "sql", "database"]
categories = ["database-implementations"]
homepage.workspace = true
repository.workspace = true
authors.workspace = true
license.workspace = true
publish = false # TODO publish crate #severity/high #urgency/medium
edition.workspace = true
rust-version.workspace = true
[dependencies]
async-trait = { workspace = true }
datafusion = { workspace = true }
futures-lite = { workspace = true }
kanto-index-format-v1 = { workspace = true }
kanto-meta-format-v1 = { workspace = true }
kanto-record-format-v1 = { workspace = true }
kanto-tunables = { workspace = true }
maybe-tracing = { workspace = true }
parking_lot = { workspace = true }
rkyv = { workspace = true }
smallvec = { workspace = true }
thiserror = { workspace = true }
tracing = { workspace = true }
[dev-dependencies]
hex = { workspace = true }
kanto-backend-rocksdb = { workspace = true }
libtest-mimic = { workspace = true }
regex = { workspace = true }
sqllogictest = { workspace = true }
tokio = { workspace = true }
tracing-subscriber = { workspace = true }
walkdir = { workspace = true }
[lints]
workspace = true
[[test]]
name = "sqllogictest"
path = "tests/sqllogictest.rs"
harness = false
[[test]]
name = "sqllogictest_sqlite"
path = "tests/sqllogictest_sqlite.rs"
harness = false

View file

@ -0,0 +1,80 @@
use datafusion::sql::sqlparser;
use crate::KantoError;
use crate::Transaction;
/// A database backend stores & manages data, implementing almost all things related to the data it stores, such as low-level data formats, transactions, indexes, and so on.
#[async_trait::async_trait]
pub trait Backend: Send + Sync + std::fmt::Debug + DynEq + DynHash + std::any::Any {
/// Return the `Backend` as a [`std::any::Any`].
fn as_any(&self) -> &dyn std::any::Any;
/// Check equality between self and other.
fn is_equal(&self, other: &dyn Backend) -> bool {
self.dyn_eq(other.as_any())
}
/// Start a new transaction.
/// All data manipulation and access happens inside a transaction.
///
/// Transactions do not nest, but savepoints can be used to same effect.
async fn start_transaction(
&self,
access_mode: sqlparser::ast::TransactionAccessMode,
isolation_level: sqlparser::ast::TransactionIsolationLevel,
) -> Result<Box<dyn Transaction>, KantoError>;
/// Kludge to make a `Box<dyn Backend>` cloneable.
fn clone_box(&self) -> Box<dyn Backend>;
}
impl Clone for Box<dyn Backend> {
fn clone(&self) -> Box<dyn Backend> {
self.clone_box()
}
}
/// To implement this helper trait, derive or implement [`Eq`] for your concrete type.
#[expect(missing_docs)]
pub trait DynEq {
fn dyn_eq(&self, other: &dyn std::any::Any) -> bool;
}
impl<T: PartialEq + Eq + 'static> DynEq for T {
fn dyn_eq(&self, other: &dyn std::any::Any) -> bool {
if let Some(other) = other.downcast_ref::<Self>() {
self == other
} else {
false
}
}
}
/// To implement this helper trait, derive or implement [`Hash`] for your concrete type.
#[expect(missing_docs)]
pub trait DynHash {
fn dyn_hash(&self, state: &mut dyn std::hash::Hasher);
}
impl<T: std::hash::Hash> DynHash for T {
fn dyn_hash(&self, mut state: &mut dyn std::hash::Hasher) {
self.hash(&mut state);
}
}
impl std::hash::Hash for dyn Backend {
fn hash<H>(&self, state: &mut H)
where
H: std::hash::Hasher,
{
self.dyn_hash(state);
}
}
impl PartialEq for dyn Backend {
fn eq(&self, other: &Self) -> bool {
self.is_equal(other)
}
}
impl Eq for dyn Backend {}

View file

@ -0,0 +1,7 @@
//! Defaults for things that can be changed in configuration or at runtime.
/// Default SQL catalog name when omitted.
pub const DEFAULT_CATALOG_NAME: &str = "kanto";
/// Default SQL schema name when omitted.
pub const DEFAULT_SCHEMA_NAME: &str = "public";

494
crates/kanto/src/error.rs Normal file
View file

@ -0,0 +1,494 @@
//! Error types for KantoDB.
use datafusion::sql::sqlparser;
use kanto_meta_format_v1::FieldId;
use kanto_meta_format_v1::IndexId;
use kanto_meta_format_v1::SequenceId;
use kanto_meta_format_v1::TableId;
#[derive(Debug)]
#[non_exhaustive]
#[expect(missing_docs)]
pub enum Op {
Init,
CreateTable,
CreateIndex,
IterTables,
LookupName {
name_ref: datafusion::sql::ResolvedTableReference,
},
SaveName {
name_ref: datafusion::sql::ResolvedTableReference,
},
LoadTableDef {
table_id: TableId,
},
IterIndexes {
table_id: TableId,
},
LoadIndexDef {
index_id: IndexId,
},
TableScan {
table_id: TableId,
},
LoadRecord {
table_id: TableId,
row_key: Box<[u8]>,
},
InsertRecords {
table_id: TableId,
},
InsertRecord {
table_id: TableId,
row_key: Box<[u8]>,
},
UpdateRecords {
table_id: TableId,
},
UpdateRecord {
table_id: TableId,
row_key: Box<[u8]>,
},
DeleteRecords {
table_id: TableId,
},
DeleteRecord {
table_id: TableId,
row_key: Box<[u8]>,
},
InsertIndex {
table_id: TableId,
index_id: IndexId,
},
DeleteIndex {
table_id: TableId,
index_id: IndexId,
},
SequenceUpdate {
sequence_id: SequenceId,
},
Rollback,
Commit,
}
impl std::fmt::Display for Op {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Op::Init => {
write!(f, "initializing")
}
Op::IterTables => {
write!(f, "iterate table definitions")
}
Op::LookupName { name_ref } => {
write!(f, "lookup name definition: {name_ref}")
}
Op::SaveName { name_ref } => {
write!(f, "save name definition: {name_ref}")
}
Op::LoadTableDef { table_id } => {
write!(f, "load table definition: {table_id}")
}
Op::IterIndexes { table_id } => {
write!(f, "iterate index definitions: table_id={table_id}")
}
Op::LoadIndexDef { index_id } => {
write!(f, "load index definition: {index_id}")
}
Op::TableScan { table_id } => {
write!(f, "table scan: table_id={table_id}")
}
Op::CreateTable => {
write!(f, "create table")
}
Op::CreateIndex => {
write!(f, "create index")
}
Op::LoadRecord {
table_id,
row_key: _,
} => {
// `row_key` is a lot of noise
write!(f, "load record: table_id={table_id}")
}
Op::InsertRecords { table_id } => {
write!(f, "insert records: table_id={table_id}")
}
Op::InsertRecord {
table_id,
row_key: _,
} => {
// `row_key` is a lot of noise
write!(f, "insert record: table_id={table_id}")
}
Op::UpdateRecords { table_id } => {
write!(f, "update record: table_id={table_id}")
}
Op::UpdateRecord {
table_id,
row_key: _,
} => {
// `row_key` is a lot of noise
write!(f, "update record: table_id={table_id}")
}
Op::DeleteRecords { table_id } => {
write!(f, "delete records: table_id={table_id}")
}
Op::DeleteRecord {
table_id,
row_key: _,
} => {
// `row_key` is a lot of noise
write!(f, "delete record: table_id={table_id}")
}
Op::InsertIndex { table_id, index_id } => {
write!(
f,
"insert into index: table_id={table_id}, index_id={index_id}"
)
}
Op::DeleteIndex { table_id, index_id } => {
write!(
f,
"delete from index: table_id={table_id}, index_id={index_id}"
)
}
Op::SequenceUpdate { sequence_id } => {
write!(f, "update sequence: sequence_id={sequence_id}")
}
Op::Rollback => {
write!(f, "rollback")
}
Op::Commit => {
write!(f, "commit")
}
}
}
}
#[derive(thiserror::Error, Debug)]
#[non_exhaustive]
#[expect(clippy::module_name_repetitions)]
#[expect(missing_docs)]
pub enum CorruptError {
#[error("invalid table id: {table_id}")]
InvalidTableId { table_id: u64 },
#[error("invalid field id: {field_id}\ntable_def={table_def_debug:#?}")]
InvalidFieldId {
field_id: u64,
table_def_debug: String,
},
#[error("primary key field not found: {field_id}\ntable_def={table_def_debug:#?}")]
PrimaryKeyFieldIdNotFound {
field_id: FieldId,
table_def_debug: String,
},
#[error(
"primary key field missing from record batch: {field_id} {field_name:?}\ntable_def={table_def_debug:#?}\nrecord_batch_schema={record_batch_schema_debug}"
)]
PrimaryKeyFieldMissingFromRecordBatch {
field_id: FieldId,
field_name: String,
table_def_debug: String,
record_batch_schema_debug: String,
},
#[error("invalid table definition: {table_id}\ntable_def={table_def_debug:#?}")]
InvalidTableDef {
table_id: TableId,
table_def_debug: String,
},
#[error("invalid index id: index_id={index_id}\nindex_def={index_def_debug:#?}")]
InvalidIndexId {
index_id: u64,
index_def_debug: String,
},
#[error(
"invalid sequence id: sequence_id={sequence_id}\nsequence_def={sequence_def_debug:#?}"
)]
InvalidSequenceId {
sequence_id: u64,
sequence_def_debug: String,
},
#[error("corrupt sequence: sequence_id={sequence_id}")]
SequenceData { sequence_id: SequenceId },
#[error("invalid UTF-8 in {access}: {error}")]
InvalidUtf8 {
access: &'static str,
#[source]
error: std::str::Utf8Error,
},
#[error("invalid record value: {error}")]
InvalidRecord {
#[source]
error: kanto_record_format_v1::error::RecordCorruptError,
},
#[error("index error: {error}")]
Index {
#[source]
error: kanto_index_format_v1::IndexInitError,
},
#[error("table definition validate: {error}")]
Validate {
#[source]
error: kanto_meta_format_v1::table_def::TableDefValidateError,
},
#[error("rkyv error: {error}")]
Rkyv {
#[source]
error: rkyv::rancor::BoxedError,
},
}
#[derive(thiserror::Error, Debug)]
#[non_exhaustive]
#[expect(clippy::module_name_repetitions)]
#[expect(missing_docs)]
pub enum InitError {
#[error("unrecognized database format")]
UnrecognizedDatabaseFormat,
}
#[derive(thiserror::Error, Debug)]
#[non_exhaustive]
#[expect(clippy::module_name_repetitions)]
#[expect(missing_docs)]
pub enum ExecutionError {
#[error("table exists already: {table_ref}")]
TableExists {
table_ref: datafusion::sql::ResolvedTableReference,
},
#[error("sequence exhausted: sequence_id={sequence_id}")]
SequenceExhausted { sequence_id: SequenceId },
#[error("table definition has too many fields")]
TooManyFields,
#[error("conflict on primary key")]
ConflictOnPrimaryKey,
#[error("conflict on unique index")]
ConflictOnUniqueIndex { index_id: IndexId },
#[error("not nullable: column_name={column_name}")]
NotNullable { column_name: String },
}
#[derive(thiserror::Error, Debug)]
#[non_exhaustive]
#[expect(clippy::module_name_repetitions)]
#[expect(missing_docs)]
pub enum PlanError {
#[error("transaction is aborted")]
TransactionAborted,
#[error("not in a transaction")]
NoTransaction,
#[error("savepoint not found: {savepoint_name}")]
SavepointNotFound {
savepoint_name: sqlparser::ast::Ident,
},
#[error("unrecognized sql dialect: {dialect_name}")]
UnrecognizedSqlDialect { dialect_name: String },
#[error("catalog not found: {catalog_name}")]
CatalogNotFound { catalog_name: String },
#[error("table not found: {table_ref}")]
TableNotFound {
table_ref: datafusion::sql::ResolvedTableReference,
},
#[error("column not found: {table_ref}.{column_ident}")]
ColumnNotFound {
table_ref: datafusion::sql::ResolvedTableReference,
column_ident: sqlparser::ast::Ident,
},
#[error("invalid table name: {table_name}")]
InvalidTableName {
table_name: sqlparser::ast::ObjectName,
},
#[error("primary key cannot be nullable: {column_name}")]
PrimaryKeyNullable { column_name: String },
#[error("primary key declared more than once")]
PrimaryKeyTooMany,
#[error("primary key referred to twice in table definition")]
PrimaryKeyDuplicate,
#[error("index must have name")]
IndexNameRequired,
#[error("invalid index name: {index_name}")]
InvalidIndexName {
index_name: sqlparser::ast::ObjectName,
},
#[error("index and table must be in same catalog and schema")]
IndexInDifferentSchema,
#[error("cannot set an unknown variable: {variable_name}")]
SetUnknownVariable {
variable_name: sqlparser::ast::ObjectName,
},
#[error("SET number of values does not match number of variables")]
SetWrongNumberOfValues,
#[error("SET value is wrong type: expected {expected} got {sql}", sql=format!("{}", .value).escape_default())]
SetValueIncorrectType {
expected: &'static str,
value: sqlparser::ast::Expr,
},
#[error("cannot combine ROLLBACK AND [NO] CHAIN and TO SAVEPOINT")]
RollbackAndChainToSavepoint,
#[error("cannot use CREATE TABLE .. ON COMMIT without TEMPORARY")]
CreateTableOnCommitWithoutTemporary,
#[error("cannot use CREATE {{ GLOBAL | LOCAL }} TABLE without TEMPORARY")]
CreateTableGlobalOrLocalWithoutTemporary,
#[error("cannot set nullable value to non-nullable column: column_name={column_name}")]
NotNullable { column_name: String },
}
#[derive(thiserror::Error, Debug)]
#[non_exhaustive]
#[expect(clippy::module_name_repetitions)]
#[expect(missing_docs)]
pub enum KantoError {
#[error("io_error/{code}: {op}: {error}")]
Io {
op: Op,
code: &'static str,
#[source]
error: std::io::Error,
},
#[error("data corruption/{code}: {op}: {error}")]
Corrupt {
op: Op,
code: &'static str,
#[source]
error: Box<CorruptError>,
},
#[error("initialization error/{code}: {error}")]
Init {
code: &'static str,
error: InitError,
},
#[error("sql/{code}: {error}")]
Sql {
code: &'static str,
#[source]
error: sqlparser::parser::ParserError,
},
#[error("plan/{code}: {error}")]
Plan {
code: &'static str,
#[source]
error: Box<PlanError>,
},
#[error("execution/{code}: {op}: {error}")]
Execution {
op: Op,
code: &'static str,
#[source]
error: Box<ExecutionError>,
},
#[error("unimplemented_sql/{code}: {sql_syntax}")]
UnimplementedSql {
code: &'static str,
sql_syntax: &'static str,
},
#[error("unimplemented/{code}: {message}")]
Unimplemented {
code: &'static str,
message: &'static str,
},
#[error("not_supported_sql/{code}: {sql_syntax}")]
NotSupportedSql {
code: &'static str,
sql_syntax: &'static str,
},
#[error("internal_error/{code}: {error}")]
Internal {
code: &'static str,
error: Box<dyn std::error::Error + Send + Sync>,
},
}
impl From<KantoError> for datafusion::error::DataFusionError {
fn from(error: KantoError) -> datafusion::error::DataFusionError {
match error {
KantoError::Io { .. } => {
datafusion::error::DataFusionError::IoError(std::io::Error::other(error))
}
KantoError::Sql { code, error } => {
// TODO ugly way to smuggle extra information
datafusion::error::DataFusionError::SQL(error, Some(format!(" (sql/{code})")))
}
KantoError::Plan { .. } => datafusion::error::DataFusionError::Plan(error.to_string()),
KantoError::Execution { .. } => {
datafusion::error::DataFusionError::Execution(error.to_string())
}
_ => datafusion::error::DataFusionError::External(Box::from(error)),
}
}
}
/// Error type to prefix a message to any error.
#[expect(clippy::exhaustive_structs)]
#[derive(thiserror::Error, Debug)]
#[error("{message}: {error}")]
pub struct Message {
/// Message to show before the wrapped error.
pub message: &'static str,
/// Underlying error.
#[source]
pub error: Box<dyn std::error::Error + Send + Sync>,
}
#[cfg(test)]
mod tests {
use super::*;
#[expect(
clippy::missing_const_for_fn,
reason = "TODO `const fn` to avoid clippy noise <https://github.com/rust-lang/rust-clippy/issues/13938> #waiting #ecosystem/rust"
)]
#[test]
fn impl_error_send_sync_static() {
const fn assert_error<E: std::error::Error + Send + Sync + 'static>() {}
assert_error::<KantoError>();
}
}

27
crates/kanto/src/lib.rs Normal file
View file

@ -0,0 +1,27 @@
//! Common functionality and data types for the Kanto database system.
mod backend;
pub mod defaults;
pub mod error;
mod session;
mod settings;
mod transaction;
mod transaction_context;
pub mod util;
// Re-exports of dependencies exposed by our public APIs.
pub use async_trait;
pub use datafusion;
pub use datafusion::arrow;
pub use datafusion::parquet;
pub use datafusion::sql::sqlparser;
pub use crate::backend::Backend;
pub use crate::backend::DynEq;
pub use crate::backend::DynHash;
pub use crate::error::KantoError;
pub use crate::session::Session;
pub(crate) use crate::settings::Settings;
pub use crate::transaction::Savepoint;
pub use crate::transaction::Transaction;
pub(crate) use crate::transaction_context::TransactionContext;

2250
crates/kanto/src/session.rs Normal file

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,8 @@
use std::sync::Arc;
use datafusion::sql::sqlparser;
#[derive(Debug, Default)]
pub(crate) struct Settings {
pub(crate) sql_dialect: Option<Arc<dyn sqlparser::dialect::Dialect + Send + Sync>>,
}

View file

@ -0,0 +1,102 @@
use std::sync::Arc;
use datafusion::sql::sqlparser;
use futures_lite::Stream;
use crate::KantoError;
/// Savepoint is a marker in a timeline of changes made within one transaction that the transaction can be rolled back to.
#[async_trait::async_trait]
pub trait Savepoint: Send + Sync {
/// Return the `Savepoint` as a [`std::any::Any`].
fn as_any(&self) -> &dyn std::any::Any;
/// Rollback to this savepoint, and release it.
///
/// This has to take a boxed self because dynamic traits cannot consume self in Rust (<https://stackoverflow.com/questions/46620790/how-to-call-a-method-that-consumes-self-on-a-boxed-trait-object>)
async fn rollback(self: Box<Self>) -> Result<(), KantoError>;
}
/// A database transaction is the primary way of manipulating and accessing data stores in a [`kanto::Backend`](crate::Backend).
#[async_trait::async_trait]
pub trait Transaction: Send + Sync {
/// Return the `Transaction` as a [`std::any::Any`].
fn as_any(&self) -> &dyn std::any::Any;
/// Create a [`Savepoint`] at this point in time.
async fn savepoint(&self) -> Result<Box<dyn Savepoint>, KantoError>;
/// Finish the transaction, either committing or abandoning the changes made in it.
async fn finish(
&self,
conclusion: &datafusion::logical_expr::TransactionConclusion,
) -> Result<(), KantoError>;
/// Find a table (that may have been altered in this same transaction).
async fn lookup_table(
&self,
table_ref: &datafusion::sql::ResolvedTableReference,
) -> Result<Option<Arc<dyn datafusion::catalog::TableProvider>>, KantoError>;
/// List all tables.
// TODO i wish Rust had good generators.
// This is really a `[(schema, [table])]`, but we'll keep it flat for simplicity.
#[expect(
clippy::type_complexity,
reason = "TODO would be nice to simplify the iterator item type somehow, but it is a fallible named-items iterator (#ref/i5rfyyqimjyxs) #severity/low #urgency/low"
)]
fn stream_tables(
&self,
) -> Box<
dyn Stream<
Item = Result<
(
datafusion::sql::ResolvedTableReference,
Box<dyn datafusion::catalog::TableProvider>,
),
KantoError,
>,
> + Send,
>;
/// Create a table.
async fn create_table(
&self,
table_ref: datafusion::sql::ResolvedTableReference,
create_table: sqlparser::ast::CreateTable,
) -> Result<(), KantoError>;
/// Create an index.
async fn create_index(
&self,
table_ref: datafusion::sql::ResolvedTableReference,
index_name: Arc<str>,
unique: bool,
nulls_distinct: bool,
columns: Vec<sqlparser::ast::OrderByExpr>,
) -> Result<(), KantoError>;
/// Insert records.
async fn insert(
&self,
session_state: datafusion::execution::SessionState,
table_ref: datafusion::sql::ResolvedTableReference,
input: Arc<datafusion::logical_expr::LogicalPlan>,
) -> Result<(), datafusion::error::DataFusionError>;
/// Delete records.
async fn delete(
&self,
session_state: datafusion::execution::SessionState,
table_ref: datafusion::sql::ResolvedTableReference,
input: Arc<datafusion::logical_expr::LogicalPlan>,
) -> Result<(), datafusion::error::DataFusionError>;
/// Update records.
async fn update(
&self,
session_state: datafusion::execution::SessionState,
table_ref: datafusion::sql::ResolvedTableReference,
input: Arc<datafusion::logical_expr::LogicalPlan>,
) -> Result<(), datafusion::error::DataFusionError>;
}

View file

@ -0,0 +1,158 @@
use std::collections::HashMap;
use std::sync::Arc;
use datafusion::sql::sqlparser;
use futures_lite::StreamExt as _;
use smallvec::SmallVec;
use crate::session::ImplicitOrExplicit;
use crate::util::trigger::Trigger;
use crate::Backend;
use crate::KantoError;
use crate::Settings;
use crate::Transaction;
enum SavepointKind {
NestedTx,
NamedSavepoint { name: sqlparser::ast::Ident },
}
pub(crate) struct SavepointState {
kind: SavepointKind,
savepoints: HashMap<Box<dyn Backend>, Box<dyn crate::Savepoint>>,
settings: Settings,
}
impl SavepointState {
pub(crate) async fn rollback(self) -> Result<(), KantoError> {
for backend_savepoint in self.savepoints.into_values() {
// On error, drop the remaining nested transaction savepoints.
// Some of the backend transactions may be in an unclear state; caller is expected to not use this transaction any more.
backend_savepoint.rollback().await?;
}
Ok(())
}
}
#[derive(Clone)]
pub(crate) struct TransactionContext {
inner: Arc<InnerTransactionContext>,
}
struct InnerTransactionContext {
implicit: ImplicitOrExplicit,
aborted: Trigger,
transactions: HashMap<Box<dyn Backend>, Box<dyn Transaction>>,
/// Blocking mutex, held only briefly for pure data manipulation with no I/O or such.
savepoints:
parking_lot::Mutex<SmallVec<[SavepointState; kanto_tunables::TYPICAL_MAX_TRANSACTIONS]>>,
}
impl TransactionContext {
pub(crate) fn new(
implicit: ImplicitOrExplicit,
transactions: HashMap<Box<dyn Backend>, Box<dyn Transaction>>,
) -> TransactionContext {
let inner = Arc::new(InnerTransactionContext {
implicit,
aborted: Trigger::new(),
transactions,
savepoints: parking_lot::Mutex::new(SmallVec::new()),
});
TransactionContext { inner }
}
pub(crate) fn set_aborted(&self) {
self.inner.aborted.trigger();
}
pub(crate) fn is_aborted(&self) -> bool {
self.inner.aborted.get()
}
pub(crate) fn is_implicit(&self) -> bool {
match self.inner.implicit {
ImplicitOrExplicit::Implicit => true,
ImplicitOrExplicit::Explicit => false,
}
}
pub(crate) fn get_transaction<'a>(
&'a self,
backend: &dyn Backend,
) -> Result<&'a dyn Transaction, KantoError> {
let transaction = self
.inner
.transactions
.get(backend)
.ok_or(KantoError::Internal {
code: "jnrmoeu8rpqow",
error: "backend had no transaction".into(),
})?;
Ok(transaction.as_ref())
}
pub(crate) fn transactions(&self) -> impl Iterator<Item = (&dyn Backend, &dyn Transaction)> {
self.inner
.transactions
.iter()
.map(|(backend, tx)| (backend.as_ref(), tx.as_ref()))
}
pub(crate) async fn savepoint(
&self,
savepoint_name: Option<sqlparser::ast::Ident>,
) -> Result<(), KantoError> {
let savepoints = futures_lite::stream::iter(&self.inner.transactions)
.then(async |(backend, tx)| {
let savepoint = tx.savepoint().await?;
Ok((backend.clone(), savepoint))
})
.try_collect()
.await?;
let kind = if let Some(name) = savepoint_name {
SavepointKind::NamedSavepoint { name }
} else {
SavepointKind::NestedTx
};
let savepoint_state = SavepointState {
kind,
savepoints,
settings: Settings::default(),
};
let mut guard = self.inner.savepoints.lock();
guard.push(savepoint_state);
Ok(())
}
pub(crate) fn pop_to_last_nested_tx(&self) -> Option<SavepointState> {
let mut guard = self.inner.savepoints.lock();
guard
.iter()
.rposition(|savepoint| matches!(savepoint.kind, SavepointKind::NestedTx))
.and_then(|idx| guard.drain(idx..).next())
}
pub(crate) fn pop_to_savepoint(
&self,
savepoint_name: &sqlparser::ast::Ident,
) -> Option<SavepointState> {
let mut guard = self.inner.savepoints.lock();
guard
.iter()
.rposition(|savepoint| {
matches!(&savepoint.kind, SavepointKind::NamedSavepoint { name }
if name == savepoint_name)
})
.and_then(|idx| guard.drain(idx..).next())
}
pub(crate) fn get_setting<T>(&self, getter: impl Fn(&Settings) -> Option<T>) -> Option<T> {
let savepoints = self.inner.savepoints.lock();
let local_var = savepoints
.iter()
.rev()
.find_map(|savepoint_state| getter(&savepoint_state.settings));
local_var
}
}

16
crates/kanto/src/util.rs Normal file
View file

@ -0,0 +1,16 @@
//! Utility functions for implementing [`kanto::Backend`](crate::Backend).
use std::sync::Arc;
mod sql_type_to_field_type;
pub(crate) mod trigger;
pub use self::sql_type_to_field_type::sql_type_to_field_type;
/// Return a stream that has no rows and an empty schema.
#[must_use]
pub fn no_rows_empty_schema() -> datafusion::physical_plan::SendableRecordBatchStream {
let schema = Arc::new(datafusion::arrow::datatypes::Schema::empty());
let data = datafusion::physical_plan::EmptyRecordBatchStream::new(schema);
Box::pin(data)
}

View file

@ -0,0 +1,265 @@
//! Utility functions for implementing [`kanto::Backend`](crate::Backend).
use datafusion::sql::sqlparser;
use kanto_meta_format_v1::field_type::FieldType;
use crate::KantoError;
/// Convert a SQL data type to a KantoDB field type.
#[expect(clippy::too_many_lines, reason = "large match")]
pub fn sql_type_to_field_type(
sql_data_type: &sqlparser::ast::DataType,
) -> Result<FieldType, KantoError> {
// TODO this duplicates unexported `convert_data_type` called by [`SqlToRel::build_schema`](https://docs.rs/datafusion/latest/datafusion/sql/planner/struct.SqlToRel.html#method.build_schema), that one does sql->dftype
use sqlparser::ast::DataType as SqlType;
match sql_data_type {
SqlType::Character(_) | SqlType::Char(_) | SqlType::FixedString(_) => {
Err(KantoError::UnimplementedSql {
code: "hdhf8ygp9zna4",
sql_syntax: "type { CHARACTER | CHAR | FIXEDSTRING }",
})
}
SqlType::CharacterVarying(_)
| SqlType::CharVarying(_)
| SqlType::Varchar(_)
| SqlType::Nvarchar(_)
| SqlType::Text
| SqlType::String(_)
| SqlType::TinyText
| SqlType::MediumText
| SqlType::LongText => Ok(FieldType::String),
SqlType::Uuid => Err(KantoError::UnimplementedSql {
code: "531ax7gb73ce1",
sql_syntax: "type UUID",
}),
SqlType::CharacterLargeObject(_length)
| SqlType::CharLargeObject(_length)
| SqlType::Clob(_length) => Err(KantoError::UnimplementedSql {
code: "ew7uhufhkzj9w",
sql_syntax: "type { CHARACTER LARGE OBJECT | CHAR LARGE OBJECT | CLOB }",
}),
SqlType::Binary(_length) => Err(KantoError::UnimplementedSql {
code: "i16wqwmie17eg",
sql_syntax: "type BINARY",
}),
SqlType::Varbinary(_)
| SqlType::Blob(_)
| SqlType::Bytes(_)
| SqlType::Bytea
| SqlType::TinyBlob
| SqlType::MediumBlob
| SqlType::LongBlob => Ok(FieldType::Binary),
SqlType::Numeric(_) | SqlType::Decimal(_) | SqlType::Dec(_) => {
Err(KantoError::UnimplementedSql {
code: "uoakmzwax448h",
sql_syntax: "type { NUMERIC | DECIMAL | DEC }",
})
}
SqlType::BigNumeric(_) | SqlType::BigDecimal(_) => Err(KantoError::UnimplementedSql {
code: "an57g4d4uzwew",
sql_syntax: "type { BIGNUMERIC | BIGDECIMAL } (BigQuery syntax)",
}),
SqlType::Float(Some(precision)) if (1..=24).contains(precision) => Ok(FieldType::F32),
SqlType::Float(None)
| SqlType::Double(sqlparser::ast::ExactNumberInfo::None)
| SqlType::DoublePrecision
| SqlType::Float64
| SqlType::Float8 => Ok(FieldType::F64),
SqlType::Float(Some(precision)) if (25..=53).contains(precision) => Ok(FieldType::F64),
SqlType::Float(Some(_precision)) => Err(KantoError::UnimplementedSql {
code: "xfwbr486jaw5q",
sql_syntax: "type FLOAT(precision)",
}),
SqlType::TinyInt(None) => Ok(FieldType::I8),
SqlType::TinyInt(Some(_)) => Err(KantoError::UnimplementedSql {
code: "hueii8mjgk8qq",
sql_syntax: "type TINYINT(precision)",
}),
SqlType::UnsignedTinyInt(None) | SqlType::UInt8 => Ok(FieldType::U8),
SqlType::UnsignedTinyInt(Some(_)) => Err(KantoError::UnimplementedSql {
code: "ewt6tneh7ecps",
sql_syntax: "type UNSIGNED TINYINT(precision)",
}),
SqlType::Int2(None) | SqlType::SmallInt(None) | SqlType::Int16 => Ok(FieldType::I16),
SqlType::UnsignedInt2(None) | SqlType::UnsignedSmallInt(None) | SqlType::UInt16 => {
Ok(FieldType::U16)
}
SqlType::Int2(Some(_))
| SqlType::UnsignedInt2(Some(_))
| SqlType::SmallInt(Some(_))
| SqlType::UnsignedSmallInt(Some(_)) => Err(KantoError::UnimplementedSql {
code: "g57gd8499t4hg",
sql_syntax: "type { (UNSIGNED) INT2(precision) | (UNSIGNED) SMALLINT(precision) }",
}),
SqlType::MediumInt(_) | SqlType::UnsignedMediumInt(_) => {
Err(KantoError::UnimplementedSql {
code: "m7snzsrqj68sw",
sql_syntax: "type { MEDIUMINT (UNSIGNED) }",
})
}
SqlType::Int(None) | SqlType::Integer(None) | SqlType::Int4(None) | SqlType::Int32 => {
Ok(FieldType::I32)
}
SqlType::Int(Some(_)) | SqlType::Integer(Some(_)) | SqlType::Int4(Some(_)) => {
Err(KantoError::UnimplementedSql {
code: "tiptbb5d8buxa",
sql_syntax: "type { INTEGER(precision) | INT(precision) | INT4(precision) }",
})
}
SqlType::Int8(None) | SqlType::Int64 | SqlType::BigInt(None) => Ok(FieldType::I64),
SqlType::Int128 => Err(KantoError::UnimplementedSql {
code: "4sg3mftgxu9j4",
sql_syntax: "type INT128 (ClickHouse syntax)",
}),
SqlType::Int256 => Err(KantoError::UnimplementedSql {
code: "ex9mnxxipbqqs",
sql_syntax: "type INT256 (ClickHouse syntax)",
}),
SqlType::UnsignedInt(None)
| SqlType::UnsignedInt4(None)
| SqlType::UnsignedInteger(None)
| SqlType::UInt32 => Ok(FieldType::U32),
SqlType::UnsignedInt(Some(_))
| SqlType::UnsignedInt4(Some(_))
| SqlType::UnsignedInteger(Some(_)) => Err(KantoError::UnimplementedSql {
code: "fzt1q8mfgkr6n",
sql_syntax: "type { UNSIGNED INTEGER(precision) | UNSIGNED INT(precision) | UNSIGNED INT4(precision) }",
}),
SqlType::UInt64 | SqlType::UnsignedBigInt(None) | SqlType::UnsignedInt8(None) => {
Ok(FieldType::U64)
}
SqlType::UInt128 => Err(KantoError::UnimplementedSql {
code: "riob7j6jahpte",
sql_syntax: "type UINT128 (ClickHouse syntax)",
}),
SqlType::UInt256 => Err(KantoError::UnimplementedSql {
code: "q6tzjzd7xf3ga",
sql_syntax: "type UINT256 (ClickHouse syntax)",
}),
SqlType::BigInt(Some(_)) | SqlType::Int8(Some(_)) => Err(KantoError::UnimplementedSql {
code: "88fzntq3mz14g",
sql_syntax: "type { BIGINT(precision) | INT8(precision) }",
}),
SqlType::UnsignedBigInt(Some(_)) | SqlType::UnsignedInt8(Some(_)) => {
Err(KantoError::UnimplementedSql {
code: "sze33j7qzt91s",
sql_syntax: "type { UNSIGNED BIGINT(precision) | UNSIGNED INT8(precision) }",
})
}
SqlType::Float4 | SqlType::Float32 | SqlType::Real => Ok(FieldType::F32),
SqlType::Double(sqlparser::ast::ExactNumberInfo::Precision(precision))
if (25..=53).contains(precision) =>
{
Ok(FieldType::F64)
}
SqlType::Double(sqlparser::ast::ExactNumberInfo::Precision(_precision)) => {
Err(KantoError::UnimplementedSql {
code: "wo43sei9ubpon",
sql_syntax: "type DOUBLE(precision)",
})
}
SqlType::Double(sqlparser::ast::ExactNumberInfo::PrecisionAndScale(_precision, _scale)) => {
Err(KantoError::NotSupportedSql {
code: "zgxdduxydzpoq",
sql_syntax: "type DOUBLE(precision, scale) (MySQL syntax, deprecated)",
})
}
SqlType::Bool | SqlType::Boolean => Ok(FieldType::Boolean),
SqlType::Date => Err(KantoError::UnimplementedSql {
code: "1r5ez1z8j7ryo",
sql_syntax: "type DATE",
}),
SqlType::Date32 => Err(KantoError::UnimplementedSql {
code: "781y5fxih7pnn",
sql_syntax: "type DATE32",
}),
SqlType::Time(_, _) => Err(KantoError::UnimplementedSql {
code: "bppzhg7ck7xhr",
sql_syntax: "type TIME",
}),
SqlType::Datetime(_) => Err(KantoError::UnimplementedSql {
code: "dojz85ngwo5wr",
sql_syntax: "type DATETIME (MySQL syntax)",
}),
SqlType::Datetime64(_, _) => Err(KantoError::UnimplementedSql {
code: "oak3s5g3rnutq",
sql_syntax: "type DATETIME (ClickHouse syntax)",
}),
SqlType::Timestamp(_, _) => Err(KantoError::UnimplementedSql {
code: "njad51sbxzwnn",
sql_syntax: "type TIMESTAMP",
}),
SqlType::Interval => Err(KantoError::UnimplementedSql {
code: "5wxynwyfua69s",
sql_syntax: "type INTERVAL",
}),
SqlType::JSON | SqlType::JSONB => Err(KantoError::UnimplementedSql {
code: "abcw878jr35qc",
sql_syntax: "type { JSON | JSONB }",
}),
SqlType::Regclass => Err(KantoError::UnimplementedSql {
code: "quzjzcsusttgc",
sql_syntax: "type REGCLASS (Postgres syntax)",
}),
SqlType::Bit(_) | SqlType::BitVarying(_) => Err(KantoError::UnimplementedSql {
code: "88uneott9owc1",
sql_syntax: "type BIT (VARYING)",
}),
SqlType::Custom(..) => Err(KantoError::UnimplementedSql {
code: "dt315zutjqpzc",
sql_syntax: "custom type",
}),
SqlType::Array(_) => Err(KantoError::UnimplementedSql {
code: "hnmqj3qbtutn4",
sql_syntax: "type ARRAY",
}),
SqlType::Map(_, _) => Err(KantoError::UnimplementedSql {
code: "7ts8zgnafnp7g",
sql_syntax: "type MAP",
}),
SqlType::Tuple(_) => Err(KantoError::UnimplementedSql {
code: "x7qwkwdhdznxe",
sql_syntax: "type TUPLE",
}),
SqlType::Nested(_) => Err(KantoError::UnimplementedSql {
code: "fa6nofmo1s49g",
sql_syntax: "type NESTED",
}),
SqlType::Enum(_, _) => Err(KantoError::UnimplementedSql {
code: "1geqh5qxdeoko",
sql_syntax: "type ENUM",
}),
SqlType::Set(_) => Err(KantoError::UnimplementedSql {
code: "hm53sq3gco3sw",
sql_syntax: "type SET",
}),
SqlType::Struct(_, _) => Err(KantoError::UnimplementedSql {
code: "nsapujxjjk9hg",
sql_syntax: "type STRUCT",
}),
SqlType::Union(_) => Err(KantoError::UnimplementedSql {
code: "oownfhnfj5b9r",
sql_syntax: "type UNION",
}),
SqlType::Nullable(_) => Err(KantoError::UnimplementedSql {
code: "ra5kspnstauu6",
sql_syntax: "type NULLABLE (ClickHouse syntax)",
}),
SqlType::LowCardinality(_) => Err(KantoError::UnimplementedSql {
code: "78exey8xjz3yk",
sql_syntax: "type LOWCARDINALITY",
}),
SqlType::Unspecified => Err(KantoError::UnimplementedSql {
code: "916n48obgiqh6",
sql_syntax: "type UNSPECIFIED",
}),
SqlType::Trigger => Err(KantoError::UnimplementedSql {
code: "zf5jx1ykoc5s1",
sql_syntax: "type TRIGGER",
}),
SqlType::AnyType => Err(KantoError::NotSupportedSql {
code: "bwjrzsbzg59gy",
sql_syntax: "type ANY TYPE (BigQuery syntax)",
}),
}
}

View file

@ -0,0 +1,29 @@
/// Trigger keeps track of whether something happened.
/// Once triggered, it will always remain triggered.
pub(crate) struct Trigger {
triggered: std::sync::atomic::AtomicBool,
}
impl Trigger {
#[must_use]
pub(crate) const fn new() -> Trigger {
Trigger {
triggered: std::sync::atomic::AtomicBool::new(false),
}
}
pub(crate) fn trigger(&self) {
self.triggered
.store(true, std::sync::atomic::Ordering::Relaxed);
}
pub(crate) fn get(&self) -> bool {
self.triggered.load(std::sync::atomic::Ordering::Relaxed)
}
}
impl Default for Trigger {
fn default() -> Self {
Trigger::new()
}
}

View file

@ -0,0 +1,626 @@
#![expect(
clippy::print_stderr,
clippy::unimplemented,
clippy::unwrap_used,
clippy::expect_used,
reason = "test code"
)]
use std::collections::HashSet;
use std::ffi::OsStr;
use std::path::Path;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Duration;
use futures_lite::stream::StreamExt as _;
use kanto::arrow::array::AsArray as _;
struct TestDatabase(Arc<tokio::sync::RwLock<kanto::Session>>);
impl sqllogictest::MakeConnection for TestDatabase {
type Conn = TestSession;
type MakeFuture =
std::future::Ready<Result<TestSession, <Self::Conn as sqllogictest::AsyncDB>::Error>>;
#[maybe_tracing::instrument(skip(self))]
fn make(&mut self) -> Self::MakeFuture {
std::future::ready(Ok(TestSession(self.0.clone())))
}
}
struct TestSession(Arc<tokio::sync::RwLock<kanto::Session>>);
fn convert_f64_to_sqllogic_string(
f: f64,
expectation_type: &sqllogictest::DefaultColumnType,
) -> String {
match expectation_type {
// Sqlite's tests use implicit conversion to integer by marking result columns as integers.
sqllogictest::DefaultColumnType::Integer => {
let s = format!("{f:.0}");
if s == "-0" {
"0".to_owned()
} else {
s
}
}
sqllogictest::DefaultColumnType::FloatingPoint | sqllogictest::DefaultColumnType::Any => {
// 3 digits of precision is what SQLite sqllogictest seems to expect.
let s = format!("{f:.3}");
if s == "-0.000" {
"0.000".to_owned()
} else {
s
}
}
sqllogictest::DefaultColumnType::Text => {
format!("<unexpected SLT column expectation for Float64: {expectation_type:?}>")
}
}
}
fn convert_binary_to_sqllogic_strings<'arr>(
iter: impl Iterator<Item = Option<&'arr [u8]>> + 'arr,
) -> Box<dyn Iterator<Item = String> + 'arr> {
let iter = iter.map(|opt_bytes| {
opt_bytes
.map(|b| {
let mut hex = String::new();
hex.push_str(r"X'");
hex.push_str(&hex::encode(b));
hex.push('\'');
hex
})
.unwrap_or_else(|| "NULL".to_owned())
});
let as_dyn: Box<dyn Iterator<Item = String>> = Box::new(iter);
as_dyn
}
#[expect(clippy::too_many_lines, reason = "large match")]
fn convert_array_to_sqllogic_strings<'arr>(
array: &'arr Arc<dyn datafusion::arrow::array::Array>,
expectation_type: sqllogictest::DefaultColumnType,
data_type: &datafusion::arrow::datatypes::DataType,
) -> Box<dyn Iterator<Item = String> + 'arr> {
match data_type {
datafusion::arrow::datatypes::DataType::Int8 => {
let array = array
.as_primitive_opt::<kanto::arrow::datatypes::Int8Type>()
.unwrap();
let iter = array.iter().map(|opt_num| {
opt_num
.map(|num| num.to_string())
.unwrap_or_else(|| "NULL".to_owned())
});
let as_dyn: Box<dyn Iterator<Item = String>> = Box::new(iter);
as_dyn
}
datafusion::arrow::datatypes::DataType::UInt8 => {
let array = array
.as_primitive_opt::<kanto::arrow::datatypes::UInt8Type>()
.unwrap();
let iter = array.iter().map(|opt_num| {
opt_num
.map(|num| num.to_string())
.unwrap_or_else(|| "NULL".to_owned())
});
let as_dyn: Box<dyn Iterator<Item = String>> = Box::new(iter);
as_dyn
}
datafusion::arrow::datatypes::DataType::Int32 => {
let array = array
.as_primitive_opt::<kanto::arrow::datatypes::Int32Type>()
.unwrap();
let iter = array.iter().map(|opt_num| {
opt_num
.map(|num| num.to_string())
.unwrap_or_else(|| "NULL".to_owned())
});
let as_dyn: Box<dyn Iterator<Item = String>> = Box::new(iter);
as_dyn
}
datafusion::arrow::datatypes::DataType::UInt32 => {
let array = array
.as_primitive_opt::<kanto::arrow::datatypes::UInt32Type>()
.unwrap();
let iter = array.iter().map(|opt_num| {
opt_num
.map(|num| num.to_string())
.unwrap_or_else(|| "NULL".to_owned())
});
let as_dyn: Box<dyn Iterator<Item = String>> = Box::new(iter);
as_dyn
}
datafusion::arrow::datatypes::DataType::Int64 => {
let array = array
.as_primitive_opt::<kanto::arrow::datatypes::Int64Type>()
.unwrap();
let iter = array.iter().map(|opt_num| {
opt_num
.map(|num| num.to_string())
.unwrap_or_else(|| "NULL".to_owned())
});
let as_dyn: Box<dyn Iterator<Item = String>> = Box::new(iter);
as_dyn
}
datafusion::arrow::datatypes::DataType::UInt64 => {
let array = array
.as_primitive_opt::<kanto::arrow::datatypes::UInt64Type>()
.unwrap();
let iter = array.iter().map(|opt_num| {
opt_num
.map(|num| num.to_string())
.unwrap_or_else(|| "NULL".to_owned())
});
let as_dyn: Box<dyn Iterator<Item = String>> = Box::new(iter);
as_dyn
}
datafusion::arrow::datatypes::DataType::Float32 => {
let array = array
.as_primitive_opt::<kanto::arrow::datatypes::Float32Type>()
.unwrap();
let iter = array.iter().map(move |opt_num| {
opt_num
.map(|f| convert_f64_to_sqllogic_string(f64::from(f), &expectation_type))
.unwrap_or_else(|| "NULL".to_owned())
});
let as_dyn: Box<dyn Iterator<Item = String>> = Box::new(iter);
as_dyn
}
datafusion::arrow::datatypes::DataType::Float64 => {
let array = array
.as_primitive_opt::<kanto::arrow::datatypes::Float64Type>()
.unwrap();
let iter = array.iter().map(move |opt_num| {
opt_num
.map(|f| convert_f64_to_sqllogic_string(f, &expectation_type))
.unwrap_or_else(|| "NULL".to_owned())
});
let as_dyn: Box<dyn Iterator<Item = String>> = Box::new(iter);
as_dyn
}
datafusion::arrow::datatypes::DataType::Utf8 => {
let array = array.as_string_opt::<i32>().unwrap();
let iter = array.iter().map(|opt_str| {
opt_str
.map(|s| {
if s.is_empty() {
"(empty)".to_owned()
} else {
let printable = b' '..b'~';
let t = s
.bytes()
.map(|c| {
if printable.contains(&c) {
c.into()
} else {
'@'
}
})
.collect::<String>();
t
}
})
.unwrap_or_else(|| "NULL".to_owned())
});
let as_dyn: Box<dyn Iterator<Item = String>> = Box::new(iter);
as_dyn
}
datafusion::arrow::datatypes::DataType::Utf8View => {
let array = array.as_string_view_opt().unwrap();
let iter = array.iter().map(|opt_str| {
opt_str
.map(|s| {
if s.is_empty() {
"(empty)".to_owned()
} else {
let printable = b' '..b'~';
let t = s
.bytes()
.map(|c| {
if printable.contains(&c) {
c.into()
} else {
'@'
}
})
.collect::<String>();
t
}
})
.unwrap_or_else(|| "NULL".to_owned())
});
let as_dyn: Box<dyn Iterator<Item = String>> = Box::new(iter);
as_dyn
}
datafusion::arrow::datatypes::DataType::Binary => {
let array = array.as_binary_opt::<i32>().unwrap();
convert_binary_to_sqllogic_strings(array.iter())
}
datafusion::arrow::datatypes::DataType::BinaryView => {
let array = array.as_binary_view_opt().unwrap();
convert_binary_to_sqllogic_strings(array.iter())
}
datafusion::arrow::datatypes::DataType::Boolean => {
let array = array.as_boolean_opt().unwrap();
let iter = array.iter().map(|opt_num| {
opt_num
.map(|num| num.to_string())
.unwrap_or_else(|| "NULL".to_owned())
});
let as_dyn: Box<dyn Iterator<Item = String>> = Box::new(iter);
as_dyn
}
_ => unimplemented!("sqllogictest: data type not implemented: {data_type:?}"),
}
}
fn is_hidden(entry: &walkdir::DirEntry) -> bool {
entry.file_name().to_string_lossy().starts_with('.')
}
#[async_trait::async_trait]
impl sqllogictest::AsyncDB for TestSession {
type Error = datafusion::error::DataFusionError;
type ColumnType = sqllogictest::DefaultColumnType;
#[maybe_tracing::instrument(span_name = "TestSession::run", skip(self), err)]
async fn run(
&mut self,
sql: &str,
) -> Result<sqllogictest::DBOutput<Self::ColumnType>, Self::Error> {
tracing::trace!("acquiring session lock");
let mut guard = self.0.write().await;
tracing::trace!("running sql");
let stream = guard.sql(sql).await?;
tracing::trace!("done with sql");
let types: Vec<Self::ColumnType> = stream
.schema()
.fields()
.iter()
.map(|field| {
// TODO this is wrong, we need the datatype formatting info from the `*.slt` <https://github.com/risinglightdb/sqllogictest-rs/issues/227> #ecosystem/sqllogictest-rs
match field.data_type() {
datafusion::arrow::datatypes::DataType::Int32
| datafusion::arrow::datatypes::DataType::Int64
| datafusion::arrow::datatypes::DataType::UInt32
| datafusion::arrow::datatypes::DataType::UInt64 => {
sqllogictest::column_type::DefaultColumnType::Integer
}
// TODO recognize more data types
_ => sqllogictest::column_type::DefaultColumnType::Any,
}
})
.collect();
let mut rows: Vec<Vec<String>> = Vec::new();
tracing::trace!("collecting");
let batches = stream.try_collect::<_, _, Vec<_>>().await?;
tracing::trace!("collect done");
for batch in batches {
let mut value_per_column = batch
.columns()
.iter()
.enumerate()
.map(|(column_idx, array)| {
debug_assert!(column_idx < types.len());
#[expect(clippy::indexing_slicing, reason = "test code")]
let type_ = types[column_idx].clone();
convert_array_to_sqllogic_strings(
array,
type_,
batch.schema().field(column_idx).data_type(),
)
})
.collect::<Vec<_>>();
loop {
let row: Vec<_> = value_per_column
.iter_mut()
.map_while(Iterator::next)
.collect();
if row.is_empty() {
break;
}
rows.push(row);
}
}
tracing::trace!("prepping output");
let output = sqllogictest::DBOutput::Rows { types, rows };
drop(guard);
Ok(output)
}
fn engine_name(&self) -> &'static str {
"kanto"
}
#[maybe_tracing::instrument()]
async fn sleep(dur: Duration) {
let _wat: () = tokio::time::sleep(dur).await;
}
#[maybe_tracing::instrument(skip(_command), ret, err(level=tracing::Level::WARN))]
async fn run_command(_command: std::process::Command) -> std::io::Result<std::process::Output> {
unimplemented!("security hazard");
}
async fn shutdown(&mut self) -> () {
// nothing
}
}
fn is_datafusion_limitation(error: &dyn std::error::Error) -> Option<&'static str> {
// TODO i'm having serious trouble unpacking this `dyn Error`.
// do ugly string matching instead.
//
// TODO note these as ecosystem limitations, report upstream
let datafusion_limitations = [
(
// TODO DataFusion non-unique column names <https://github.com/apache/arrow-datafusion/issues/6543>, <https://github.com/apache/arrow-datafusion/issues/8379> #ecosystem/datafusion #waiting #severity/high #urgency/medium
"duplicate column names",
regex::Regex::new(r"Projections require unique expression names but the expression .* the same name").unwrap()
),
(
// TODO DataFusion non-unique column names <https://github.com/apache/arrow-datafusion/issues/6543>, <https://github.com/apache/arrow-datafusion/issues/8379> #ecosystem/datafusion #waiting #severity/high #urgency/medium
"duplicate column names (qual vs unqual)",
regex::Regex::new(r"Schema error: Schema contains qualified field name .* and unqualified field name .* which would be ambiguous").unwrap(),
),
(
"convert to null",
regex::Regex::new(r"Error during planning: Cannot automatically convert .* to Null").unwrap(),
),
(
// TODO DataFusion does not support `avg(DISTINCT)`: <https://github.com/apache/arrow-datafusion/issues/2408> #ecosystem/datafusion #waiting #severity/medium #urgency/low
"AVG(DISTINCT)",
regex::Regex::new(r"Execution error: avg\(DISTINCT\) aggregations are not available").unwrap(),
),
(
"non-aggregate values in projection",
regex::Regex::new(r"Error during planning: Projection references non-aggregate values: Expression .* could not be resolved from available columns: .*").unwrap(),
),
(
"AVG(NULL)",
regex::Regex::new(r"Error during planning: No function matches the given name and argument types 'AVG\(Null\)'. You might need to add explicit type casts.").unwrap(),
),
(
"Correlated column is not allowed in predicate",
regex::Regex::new(r"Error during planning: Correlated column is not allowed in predicate:").unwrap(),
),
(
"HAVING references non-aggregate",
regex::Regex::new(r"Error during planning: HAVING clause references non-aggregate values: Expression .* could not be resolved from available columns: .*").unwrap(),
),
(
"scalar subquery",
regex::Regex::new(r"This feature is not implemented: Physical plan does not support logical expression ScalarSubquery\(<subquery>\)").unwrap(),
),
];
let error_message = error.to_string();
datafusion_limitations
.iter()
.find(|(_explanation, re)| re.is_match(&error_message))
.map(|(explanation, _re)| *explanation)
}
#[maybe_tracing::instrument(skip(tester), ret)]
fn execute_sqllogic_test<D: sqllogictest::AsyncDB, M: sqllogictest::MakeConnection<Conn = D>>(
tester: &mut sqllogictest::Runner<D, M>,
path: &Path,
kludge_datafusion_limitations: bool,
) -> Result<(), libtest_mimic::Failed> {
let script = std::fs::read_to_string(path)?;
let script = {
// TODO sqllogictest-rs fails to parse comment at end of line <https://github.com/risinglightdb/sqllogictest-rs/issues/122> #waiting #ecosystem/sqllogictest-rs
let trailing_comment = regex::Regex::new(r"(?mR)#.*$").unwrap();
trailing_comment.replace_all(&script, "")
};
let name = path.as_os_str().to_string_lossy();
tester
.run_script_with_name(&script, Arc::from(name))
.or_else(|error| match error.kind() {
sqllogictest::TestErrorKind::Fail {
sql: _,
err,
kind: sqllogictest::RecordKind::Query,
} => {
if kludge_datafusion_limitations {
if let Some(explanation) = is_datafusion_limitation(&err) {
eprintln!("Ignoring known DataFusion limitation: {explanation}");
return Ok(());
}
}
Err(error)
}
sqllogictest::TestErrorKind::QueryResultMismatch {
sql: _,
expected,
actual,
} => {
{
// TODO hash mismatch likely due to expected column type API is wrong way around: <https://github.com/risinglightdb/sqllogictest-rs/issues/227> #ecosystem/sqllogictest-rs #severity/high #urgency/medium
let re = regex::Regex::new(r"^\d+ values hashing to [0-9a-f]+$").unwrap();
if re.is_match(&expected) && re.is_match(&actual) {
return Ok(());
}
}
Err(error)
}
_ => Err(error),
})?;
Ok(())
}
// kludge around <https://github.com/risinglightdb/sqllogictest-rs/issues/108>
#[maybe_tracing::instrument(skip(_normalizer), ret)]
fn sql_result_validator(
_normalizer: sqllogictest::Normalizer,
actual: &[Vec<String>],
expected: &[String],
) -> bool {
// TODO expected column type API is wrong way around: <https://github.com/risinglightdb/sqllogictest-rs/issues/227>
fn kludge_compare<S: AsRef<str>>(got: &[S], want: &[String]) -> bool {
fn kludge_string_compare<S: AsRef<str>>(got: S, want: &str) -> bool {
let got = got.as_ref();
got == want || got == format!("{want}.000")
}
got.len() == want.len()
&& got
.iter()
.zip(want)
.all(|(got, want)| kludge_string_compare(got, want))
}
let column_wise = || {
let actual: Vec<_> = actual.iter().map(|v| v.join(" ")).collect();
tracing::trace!(?actual, "column-wise joined");
actual
};
let value_wise = || {
// SQLite "value-wise" compatibility.
let actual: Vec<_> = actual.iter().flatten().map(String::as_str).collect();
tracing::trace!(?actual, "value-wise flattened");
actual
};
kludge_compare(&value_wise(), expected) || kludge_compare(&column_wise()[..], expected)
}
fn uses_create_table_primary_key(path: &Path) -> bool {
// TODO `CREATE TABLE with column option { PRIMARY KEY | UNIQUE }` (#ref/unimplemented_sql/hwcjakao83bue) #severity/high #urgency/medium
let re_oneline =
regex::Regex::new(r"(?m)^CREATE TABLE tab\d+\(pk INTEGER PRIMARY KEY,").unwrap();
let re_wrapped =
regex::Regex::new(r"(?m)^CREATE TABLE t\d+\(\n\s+[a-z0-9]+ INTEGER PRIMARY KEY,").unwrap();
let script = std::fs::read_to_string(path).unwrap();
re_oneline.is_match(&script) || re_wrapped.is_match(&script)
}
fn uses_create_view(path: &Path) -> bool {
// TODO `CREATE VIEW` (#ref/unimplemented_sql/b9m5uhu9pnnsw)
let re = regex::Regex::new(r"(?m)^CREATE VIEW ").unwrap();
let script = std::fs::read_to_string(path).unwrap();
re.is_match(&script)
}
pub(crate) fn setup_tests(
slt_root: impl AsRef<Path>,
slt_ext: &str,
kludge_datafusion_limitations: bool,
kludge_sqlite_tests: bool,
) -> Vec<libtest_mimic::Trial> {
let known_broken = HashSet::from([
"rocksdb/evidence/slt_lang_update", // TODO cannot update row_id tables yet (#ref/unimplemented/i61jpodrazscs)
"rocksdb/evidence/slt_lang_droptable", // TODO `DROP TABLE` (#ref/unimplemented_sql/47hnz51gohsx6)
"rocksdb/evidence/slt_lang_dropindex", // TODO `DROP INDEX` (#ref/unimplemented_sql/4kgkg4jhqhfrw)
]);
find_slt_files(&slt_root, slt_ext)
.flat_map(|path| {
[("rocksdb", kanto_backend_rocksdb::create_temp)].map(|(backend_name, backend_fn)| {
let path_pretty = path
.strip_prefix(&slt_root)
.expect("found slt file is not under root")
.with_extension("")
.to_string_lossy()
.into_owned();
let name = format!("{backend_name}/{path_pretty}");
libtest_mimic::Trial::test(&name, {
let path = path.clone();
move || {
let (_dir, backend) = backend_fn().expect("cannot setup backend");
let backend = Box::new(backend);
test_sql_session(
path,
backend_name,
backend,
kludge_datafusion_limitations,
if kludge_sqlite_tests {
sqllogictest::ResultMode::ValueWise
} else {
sqllogictest::ResultMode::RowWise
},
)
}
})
.with_ignored_flag(
kludge_sqlite_tests
&& (known_broken.contains(name.as_str())
|| uses_create_table_primary_key(&path)
|| uses_create_view(&path)),
)
})
})
.collect::<Vec<_>>()
}
pub(crate) fn test_sql_session(
path: impl AsRef<Path>,
backend_name: &str,
backend: Box<dyn kanto::Backend>,
kludge_datafusion_limitations: bool,
result_mode: sqllogictest::ResultMode,
) -> Result<(), libtest_mimic::Failed> {
// `test_log::test` doesn't work because of the function argument, set up tracing manually
{
let subscriber = ::tracing_subscriber::FmtSubscriber::builder()
.with_env_filter(tracing_subscriber::EnvFilter::new(
std::env::var("RUST_LOG")
.unwrap_or_else(|_| format!("kanto=debug,kanto_backend_{backend_name}=debug")),
))
.compact()
.with_test_writer()
.finish();
let _ignore_error = tracing::subscriber::set_global_default(subscriber);
}
let reactor = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap();
// actual test logic
let session = kanto::Session::test_session(backend);
let test_database = TestDatabase(Arc::new(tokio::sync::RwLock::new(session)));
let mut tester = sqllogictest::Runner::new(test_database);
tester.add_label("datafusion");
tester.add_label("kantodb");
tester.with_hash_threshold(8);
tester.with_validator(sql_result_validator);
reactor.block_on(async {
{
let output = tester
.apply_record(sqllogictest::Record::Control(
sqllogictest::Control::ResultMode(result_mode),
))
.await;
debug_assert!(matches!(output, sqllogictest::RecordOutput::Nothing));
}
execute_sqllogic_test(&mut tester, path.as_ref(), kludge_datafusion_limitations)
})?;
drop(tester);
reactor.shutdown_background();
Ok(())
}
pub(crate) fn find_slt_files<P, E>(
input_dir: P,
input_extension: E,
) -> impl Iterator<Item = PathBuf>
where
P: AsRef<Path>,
E: AsRef<OsStr>,
{
let input_extension = input_extension.as_ref().to_owned();
walkdir::WalkDir::new(&input_dir)
.sort_by_file_name()
.into_iter()
.filter_entry(|entry| !is_hidden(entry))
.map(|result| result.expect("read dir"))
.filter(|entry| entry.file_type().is_file())
.map(walkdir::DirEntry::into_path)
.filter(move |path| path.extension().is_some_and(|ext| ext == input_extension))
}

View file

@ -0,0 +1,16 @@
#![expect(
missing_docs,
reason = "rustc lint bug <https://github.com/rust-lang/rust/issues/137561> #ecosystem/rust #waiting"
)]
use std::path::PathBuf;
mod slt_util;
fn main() {
let args = libtest_mimic::Arguments::from_args();
let input_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("../../testdata/sql");
let tests = slt_util::setup_tests(input_dir, "slt", false, false);
libtest_mimic::run(&args, tests).exit();
}

View file

@ -0,0 +1,33 @@
#![expect(
missing_docs,
reason = "rustc lint bug <https://github.com/rust-lang/rust/issues/137561> #ecosystem/rust #waiting"
)]
use std::path::Path;
mod slt_util;
fn main() {
const SQLLOGICTEST_PATH_ENV: &str = "SQLLOGICTEST_PATH";
let args = libtest_mimic::Arguments::from_args();
let tests = if let Some(sqlite_test_path) = std::env::var_os(SQLLOGICTEST_PATH_ENV) {
let input_dir = Path::new(&sqlite_test_path).join("test");
slt_util::setup_tests(input_dir, "test", true, true)
} else {
// We don't know what tests to generate without the env var.
// Generate a dummy test that reports what's missing.
Vec::from([
libtest_mimic::Trial::test("sqlite_tests_not_configured", || {
// TODO replace with our fork that has `skipif` additions
const URL: &str = "https://www.sqlite.org/sqllogictest/";
Err(libtest_mimic::Failed::from(format!(
"set {SQLLOGICTEST_PATH_ENV} to checkout of {URL}"
)))
})
.with_ignored_flag(true),
])
};
libtest_mimic::run(&args, tests).exit();
}

View file

@ -0,0 +1,17 @@
[package]
name = "kanto-key-format-v1"
version = "0.1.0"
description = "Low-level data format helper for KantoDB database"
homepage.workspace = true
repository.workspace = true
authors.workspace = true
license.workspace = true
publish = false # TODO publish crate #severity/high #urgency/medium
edition.workspace = true
rust-version.workspace = true
[dependencies]
datafusion = { workspace = true }
[lints]
workspace = true

View file

@ -0,0 +1,59 @@
//! Internal storage formats for KantoDB.
//!
//! Unless you are implementing a `kanto::Backend`, you should **not** be using these directly.
use std::sync::Arc;
/// Encode rows from `key_source_arrays` into suitable keys.
#[must_use]
pub fn make_keys(
key_source_arrays: &[Arc<dyn datafusion::arrow::array::Array>],
) -> datafusion::arrow::array::BinaryArray {
let sort_fields = key_source_arrays
.iter()
.map(|array| {
datafusion::arrow::row::SortField::new_with_options(
array.data_type().clone(),
datafusion::arrow::compute::SortOptions {
descending: false,
nulls_first: false,
},
)
})
.collect::<Vec<_>>();
let row_converter = datafusion::arrow::row::RowConverter::new(sort_fields)
.expect("internal error #ea7ht1xeexpgo: misuse of RowConverter::new");
let rows = row_converter
.convert_columns(key_source_arrays)
.expect("internal error #z9i4xb8me68gk: misuse of RowConverter::convert_columns");
// TODO this fallibility sucks; do better when we take over from `arrow-row`
rows.try_into_binary()
.expect("internal error #yq9xnqtzhbn1k: RowConverter must produce BinaryArray")
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_simple() {
let array: datafusion::arrow::array::ArrayRef =
Arc::new(datafusion::arrow::array::UInt64Array::from(vec![42, 13]));
let keys = make_keys(&[array]);
let got = format!(
"{}",
datafusion::arrow::util::pretty::pretty_format_columns("key", &[Arc::new(keys)])
.unwrap(),
);
assert_eq!(
got,
"\
+--------------------+\n\
| key |\n\
+--------------------+\n\
| 01000000000000002a |\n\
| 01000000000000000d |\n\
+--------------------+"
);
}
}

View file

@ -0,0 +1,17 @@
[package]
name = "maybe-tracing"
version = "0.1.0"
description = "Workarounds for `tracing` crate not working with Miri and llvm-cov"
homepage.workspace = true
repository.workspace = true
authors.workspace = true
license.workspace = true
publish = false # TODO publish crate #severity/high #urgency/medium
edition.workspace = true
rust-version.workspace = true
[lib]
proc-macro = true
[lints]
workspace = true

View file

@ -0,0 +1,102 @@
//! Workarounds that temporarily disable `tracing` or `tracing_subscriber`.
//!
//! - `tracing::instrument` breaks code coverage.
//! - `test_log::test` breaks Miri isolation mode.
use proc_macro::Delimiter;
use proc_macro::Group;
use proc_macro::Ident;
use proc_macro::Punct;
use proc_macro::Spacing;
use proc_macro::Span;
use proc_macro::TokenStream;
use proc_macro::TokenTree;
/// Instrument the call with tracing, except when recording code coverage.
///
/// This is a workaround for <https://github.com/tokio-rs/tracing/issues/2082>.
// TODO tracing breaks llvm-cov <https://github.com/tokio-rs/tracing/issues/2082> #waiting #ecosystem/tracing
#[proc_macro_attribute]
pub fn instrument(attr: TokenStream, item: TokenStream) -> TokenStream {
let mut tokens = TokenStream::new();
if cfg!(not(coverage)) {
tokens.extend(TokenStream::from_iter([
TokenTree::Punct(Punct::new('#', Spacing::Alone)),
TokenTree::Group(Group::new(
Delimiter::Bracket,
TokenStream::from_iter([
TokenTree::Ident(Ident::new("tracing", Span::call_site())),
TokenTree::Punct(Punct::new(':', Spacing::Joint)),
TokenTree::Punct(Punct::new(':', Spacing::Alone)),
TokenTree::Ident(Ident::new("instrument", Span::call_site())),
TokenTree::Group(Group::new(Delimiter::Parenthesis, attr)),
]),
)),
]));
}
tokens.extend(item);
tokens
}
/// Define a `#[test]` that has `tracing` configured with `env_logger`, except under Miri.
///
/// This is a wrapper for `test_log::test` that disables itself when run under [Miri](https://github.com/rust-lang/miri/).
/// Miri isolation mode does not support logging wall clock times (<https://github.com/rust-lang/miri/issues/3740>).
// TODO tracing breaks miri due to missing wall clock <https://github.com/rust-lang/miri/issues/3740> #waiting #ecosystem/tracing
#[proc_macro_attribute]
pub fn test(attr: TokenStream, item: TokenStream) -> TokenStream {
let mut tokens = TokenStream::new();
// can't `if cfg!(miri)` at proc macro time, it seems that part does not run under the miri interpreter
// output `#[cfg_attr(miri, test)]`
tokens.extend(TokenStream::from_iter([
TokenTree::Punct(Punct::new('#', Spacing::Alone)),
TokenTree::Group(Group::new(
Delimiter::Bracket,
TokenStream::from_iter([
TokenTree::Ident(Ident::new("cfg_attr", Span::call_site())),
TokenTree::Group(Group::new(
Delimiter::Parenthesis,
TokenStream::from_iter([
TokenTree::Ident(Ident::new("miri", Span::call_site())),
TokenTree::Punct(Punct::new(',', Spacing::Alone)),
TokenTree::Ident(Ident::new("test", Span::call_site())),
]),
)),
]),
)),
]));
// output `#[cfg_attr(not(miri), test_log::test)]`
tokens.extend(TokenStream::from_iter([
TokenTree::Punct(Punct::new('#', Spacing::Alone)),
TokenTree::Group(Group::new(
Delimiter::Bracket,
TokenStream::from_iter([
TokenTree::Ident(Ident::new("cfg_attr", Span::call_site())),
TokenTree::Group(Group::new(
Delimiter::Parenthesis,
TokenStream::from_iter([
TokenTree::Ident(Ident::new("not", Span::call_site())),
TokenTree::Group(Group::new(
Delimiter::Parenthesis,
TokenStream::from_iter([TokenTree::Ident(Ident::new(
"miri",
Span::call_site(),
))]),
)),
TokenTree::Punct(Punct::new(',', Spacing::Alone)),
TokenTree::Ident(Ident::new("test_log", Span::call_site())),
TokenTree::Punct(Punct::new(':', Spacing::Joint)),
TokenTree::Punct(Punct::new(':', Spacing::Alone)),
TokenTree::Ident(Ident::new("test", Span::call_site())),
TokenTree::Group(Group::new(Delimiter::Parenthesis, attr)),
]),
)),
]),
)),
]));
tokens.extend(item);
tokens
}

View file

@ -0,0 +1,24 @@
[package]
name = "kanto-meta-format-v1"
version = "0.1.0"
description = "Low-level data format helper for KantoDB database"
homepage.workspace = true
repository.workspace = true
authors.workspace = true
license.workspace = true
publish = false # TODO publish crate #severity/high #urgency/medium
edition.workspace = true
rust-version.workspace = true
[dependencies]
arbitrary = { workspace = true }
datafusion = { workspace = true }
duplicate = { workspace = true }
indexmap = { workspace = true }
paste = { workspace = true }
rkyv = { workspace = true }
rkyv_util = { workspace = true }
thiserror = { workspace = true }
[lints]
workspace = true

View file

@ -0,0 +1,58 @@
//! Data type of a field in a database record.
/// Data type of a field in a database record.
#[derive(rkyv::Archive, rkyv::Serialize, Debug, arbitrary::Arbitrary)]
#[rkyv(derive(Debug))]
#[expect(missing_docs)]
pub enum FieldType {
U64,
U32,
U16,
U8,
I64,
I32,
I16,
I8,
F64,
F32,
String,
Binary,
Boolean,
// TODO more types
// TODO also parameterized, e.g. `VARCHAR(30)`, `FLOAT(23)`
}
impl ArchivedFieldType {
/// If this field is a fixed width primitive Arrow type, what is the width.
#[must_use]
pub fn fixed_arrow_bytes_width(&self) -> Option<usize> {
let data_type = datafusion::arrow::datatypes::DataType::from(self);
data_type.primitive_width()
}
}
#[duplicate::duplicate_item(
T;
[ FieldType ];
[ ArchivedFieldType ];
)]
impl From<&T> for datafusion::arrow::datatypes::DataType {
fn from(field_type: &T) -> Self {
use datafusion::arrow::datatypes::DataType;
match field_type {
T::U64 => DataType::UInt64,
T::U32 => DataType::UInt32,
T::U16 => DataType::UInt16,
T::U8 => DataType::UInt8,
T::I64 => DataType::Int64,
T::I32 => DataType::Int32,
T::I16 => DataType::Int16,
T::I8 => DataType::Int8,
T::F64 => DataType::Float64,
T::F32 => DataType::Float32,
T::String => DataType::Utf8View,
T::Binary => DataType::BinaryView,
T::Boolean => DataType::Boolean,
}
}
}

View file

@ -0,0 +1,92 @@
//! Definition of an index for record lookup and consistency enforcement.
#![expect(
missing_docs,
reason = "waiting <https://github.com/rkyv/rkyv/issues/596> <https://github.com/rkyv/rkyv/issues/597>"
)]
use super::FieldId;
use super::IndexId;
use super::TableId;
/// An expression used in the index key.
#[derive(rkyv::Archive, rkyv::Serialize, Debug)]
#[rkyv(derive(Debug), attr(expect(missing_docs)))]
pub enum Expr {
/// Use a field from the record as index key.
Field {
/// ID of the field to use.
field_id: FieldId,
},
// TODO support more expressions
}
/// Sort order.
#[derive(rkyv::Archive, rkyv::Serialize, Debug)]
#[rkyv(derive(Debug))]
pub enum SortOrder {
/// Ascending order.
Ascending,
/// Descending order.
Descending,
}
/// Sort order for values containing `NULL`s.
#[derive(rkyv::Archive, rkyv::Serialize, Debug)]
#[rkyv(derive(Debug))]
pub enum NullOrder {
/// Nulls come before all other values.
NullsFirst,
/// Nulls come after all other values.
NullsLast,
}
/// Order the index by an expression.
#[derive(rkyv::Archive, rkyv::Serialize, Debug)]
#[rkyv(derive(Debug), attr(expect(missing_docs)))]
pub struct OrderByExpr {
/// Expression evaluating to the value stored.
pub expr: Expr,
/// Sort ascending or descending.
pub sort_order: SortOrder,
/// Do `NULL` values come before or after non-`NULL` values.
pub null_order: NullOrder,
}
/// Unique index specific options.
#[derive(rkyv::Archive, rkyv::Serialize, Debug)]
#[rkyv(derive(Debug), attr(expect(missing_docs)))]
pub struct UniqueIndexOptions {
/// Are `NULL` values considered equal or different.
pub nulls_distinct: bool,
}
/// What kind of an index is this.
#[derive(rkyv::Archive, rkyv::Serialize, Debug)]
#[rkyv(derive(Debug))]
pub enum IndexKind {
Unique(UniqueIndexOptions),
Multi,
}
/// Definition of an index for record lookup and consistency enforcement.
#[derive(rkyv::Archive, rkyv::Serialize, Debug)]
#[rkyv(derive(Debug), attr(expect(missing_docs)))]
pub struct IndexDef {
/// Unique ID of the index.
pub index_id: IndexId,
/// The table this index is for.
pub table: TableId,
/// Name of the index.
///
/// Indexes inherit their catalog and schema from the table they are for.
pub index_name: String,
/// Columns (or expressions) indexed.
pub columns: Vec<OrderByExpr>,
/// Is this a unique index, or can multiple records have the same values for the columns indexed.
pub index_kind: IndexKind,
/// What fields to store in the index record.
pub include: Vec<Expr>,
/// Only include records where this predicate evaluates to true.
pub predicate: Option<Expr>,
}

View file

@ -0,0 +1,62 @@
//! Internal storage formats for KantoDB.
//!
//! Unless you are implementing a `kanto::Backend`, you should **not** be using these directly.
#![allow(
clippy::exhaustive_structs,
clippy::exhaustive_enums,
reason = "exhaustive is ok as the wire format has to be stable anyway"
)]
pub mod field_type;
pub mod index_def;
pub mod name_def;
pub mod sequence_def;
pub mod table_def;
mod util;
pub use datafusion;
pub use datafusion::arrow;
pub use indexmap;
pub use rkyv;
pub use rkyv_util;
use crate::util::idgen::make_id_type;
make_id_type!(
/// ID of a [`FieldDef`](crate::table_def::FieldDef), unique within one [`TableDef`](crate::table_def::TableDef).
pub FieldId
);
make_id_type!(
/// Unique ID of an [`IndexDef`](crate::index_def::IndexDef).
pub IndexId
);
make_id_type!(
/// Unique ID of an [`SequenceDef`](crate::sequence_def::SequenceDef).
pub SequenceId
);
make_id_type!(
/// Unique ID of an [`TableDef`](crate::table_def::TableDef).
pub TableId
);
/// Return an serialized copy of `T` that allows accessing it as `ArchivedT`.
pub fn owned_archived<T>(
value: &T,
) -> Result<rkyv_util::owned::OwnedArchive<T, rkyv::util::AlignedVec>, rkyv::rancor::BoxedError>
where
T: for<'arena> rkyv::Serialize<
rkyv::api::high::HighSerializer<
rkyv::util::AlignedVec,
rkyv::ser::allocator::ArenaHandle<'arena>,
rkyv::rancor::BoxedError,
>,
>,
T::Archived: rkyv::Portable
+ for<'a> rkyv::bytecheck::CheckBytes<
rkyv::api::high::HighValidator<'a, rkyv::rancor::BoxedError>,
>,
{
let bytes = rkyv::to_bytes::<rkyv::rancor::BoxedError>(value)?;
rkyv_util::owned::OwnedArchive::<T, _>::new::<rkyv::rancor::BoxedError>(bytes)
}

View file

@ -0,0 +1,46 @@
//! Named entities.
#![expect(
missing_docs,
reason = "waiting <https://github.com/rkyv/rkyv/issues/596> <https://github.com/rkyv/rkyv/issues/597>"
)]
use super::IndexId;
use super::SequenceId;
use super::TableId;
/// What kind of an entity the name refers to.
#[derive(rkyv::Archive, rkyv::Serialize, Debug, arbitrary::Arbitrary)]
#[rkyv(derive(Debug), attr(expect(missing_docs)))]
pub enum NameKind {
/// Name is for a [`TableDef`]((crate::table_def::TableDef)).
Table {
/// ID of the table.
table_id: TableId,
},
/// Name is for a [`IndexDef`](crate::index_def::IndexDef).
Index {
/// ID of the index.
index_id: IndexId,
},
/// Name is for a [`SequenceDef`](crate::sequence_def::SequenceDef).
Sequence {
/// ID of the sequence.
sequence_id: SequenceId,
},
}
/// SQL defines tables, indexes, sequences etc to all be in one namespace.
/// This is the data type all such names resolve to, and refers to the ultimate destination object.
#[derive(rkyv::Archive, rkyv::Serialize, Debug, arbitrary::Arbitrary)]
#[rkyv(derive(Debug), attr(expect(missing_docs)))]
pub struct NameDef {
/// SQL catalog (top-level namespace) the name is in.
pub catalog: String,
/// SQL schema (second-level namespace) the name is in.
pub schema: String,
/// Name of the object.
pub name: String,
/// What kind of an object it is, and a reference to the object.
pub kind: NameKind,
}

View file

@ -0,0 +1,17 @@
//! Sequence definitions.
use super::SequenceId;
/// Definition of a sequence.
///
/// Not much interesting here.
#[derive(rkyv::Archive, rkyv::Serialize)]
#[expect(missing_docs)]
#[rkyv(attr(expect(missing_docs)))]
pub struct SequenceDef {
pub sequence_id: SequenceId,
pub catalog: String,
pub schema: String,
pub sequence_name: String,
// sequences don't seem to need anything interesting stored
}

View file

@ -0,0 +1,404 @@
//! Definition of a table of records.
use std::collections::HashMap;
use std::fmt::Debug;
use std::sync::Arc;
mod primary_keys;
pub use self::primary_keys::PrimaryKeys;
use crate::field_type::FieldType;
use crate::FieldId;
use crate::SequenceId;
use crate::TableId;
super::util::idgen::make_id_type!(
/// Generation counter of a [`TableDef`].
/// Every edit increments this.
pub TableDefGeneration
);
/// Definition of a *row ID* for the table.
#[derive(rkyv::Archive, rkyv::Serialize, Debug, arbitrary::Arbitrary)]
#[rkyv(derive(Debug), attr(expect(missing_docs)))]
pub struct RowIdDef {
/// Allocate IDs from this sequence.
pub sequence_id: SequenceId,
}
/// Definition of a *primary key* for the table.
#[derive(rkyv::Archive, rkyv::Serialize, Debug, arbitrary::Arbitrary)]
#[rkyv(derive(Debug), attr(expect(missing_docs)))]
pub struct PrimaryKeysDef {
/// Field IDs in `PrimaryKeys` are guaranteed to be present in `TableDef.fields`.
///
/// This is guaranteed to be non-empty and have no duplicates.
pub field_ids: PrimaryKeys,
}
/// What kind of a *row key* this table uses.
#[derive(rkyv::Archive, rkyv::Serialize, Debug, arbitrary::Arbitrary)]
#[rkyv(derive(Debug))]
pub enum RowKeyKind {
/// Row keys are assigned from a sequence.
RowId(RowIdDef),
/// Row keys are composed of field values or expressions using the field values.
PrimaryKeys(PrimaryKeysDef),
}
/// Definition of a field in a [`TableDef`].
#[derive(rkyv::Archive, rkyv::Serialize, Debug)]
#[rkyv(derive(Debug), attr(expect(missing_docs)))]
pub struct FieldDef {
/// ID of the field, unique within the [`TableDef`].
pub field_id: FieldId,
/// In what generation of [`TableDef`] this field was added.
///
/// If a field is deleted and later a field with the same name is added, that's a separate field with a separate [`FieldId`].
pub added_in: TableDefGeneration,
/// In what generation, if any, of [`TableDef`] this field was added.
pub deleted_in: Option<TableDefGeneration>,
/// Unique among live fields (see [`fields_live_in_latest`](ArchivedTableDef::fields_live_in_latest)).
pub field_name: String,
/// Data type of the database column.
pub field_type: FieldType,
/// Can a value of this field be `NULL`.
///
/// Fields that are in `PrimaryKeys` cannot be nullable.
pub nullable: bool,
/// Default value, as Arrow bytes suitable for the data type.
// TODO support setting default values
// TODO maybe only allow setting this field via a validating helper? probably better to do validation at whatever level would manage `generation` etc.
pub default_value_arrow_bytes: Option<Box<[u8]>>,
}
impl ArchivedFieldDef {
#[must_use]
/// Is the field live at the latest [generation](TableDef::generation).
pub fn is_live(&self) -> bool {
self.deleted_in.is_none()
}
}
impl<'a> arbitrary::Arbitrary<'a> for FieldDef {
fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
let field_id = u.arbitrary()?;
let added_in: TableDefGeneration = u.arbitrary()?;
let deleted_in: Option<TableDefGeneration> = if u.ratio(1u8, 10u8)? {
None
} else {
TableDefGeneration::arbitrary_greater_than(u, added_in)?
};
debug_assert!(deleted_in.is_none_or(|del| added_in < del));
let field_name = u.arbitrary()?;
let field_type = u.arbitrary()?;
let default_value_arrow_bytes: Option<_> = u.arbitrary()?;
let nullable = if deleted_in.is_some() && default_value_arrow_bytes.is_none() {
// must be nullable since it's deleted and has no default value
true
} else {
u.arbitrary()?
};
Ok(FieldDef {
field_id,
added_in,
deleted_in,
field_name,
field_type,
nullable,
default_value_arrow_bytes,
})
}
}
/// Definition of a table.
#[derive(rkyv::Archive, rkyv::Serialize, Debug)]
#[expect(missing_docs)]
#[rkyv(derive(Debug), attr(expect(missing_docs)))]
pub struct TableDef {
pub table_id: TableId,
pub catalog: String,
pub schema: String,
pub table_name: String,
/// Every update increases generation.
pub generation: TableDefGeneration,
/// What is the primary key of the table: a (possibly composite) key of field values, or an automatically assigned row ID.
pub row_key_kind: RowKeyKind,
/// The fields stored in this table.
///
/// Guaranteed to contain deleted [`FieldDef`]s for as long as they have records using a [`TableDefGeneration`] where the field was live.
//
// TODO `rkyv::collections::btree_map::ArchivedBTreeMap::get` is impossible to use <https://github.com/rkyv/rkyv/issues/585> #waiting #ecosystem/rkyv
// our inserts happen to in increasing `FieldId` order so insertion order and sorted order are the same; still want a sorted container here to guarantee correctness
pub fields: indexmap::IndexMap<FieldId, FieldDef>,
}
impl<'a> arbitrary::Arbitrary<'a> for TableDef {
#[expect(
clippy::unwrap_used,
clippy::unwrap_in_result,
reason = "not production code #severity/low #urgency/low"
)]
fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
let table_id = u.arbitrary()?;
let catalog = u.arbitrary()?;
let schema = u.arbitrary()?;
let table_name = u.arbitrary()?;
// TODO unique `field_id`
// TODO unique `field_name`
// TODO primary keys cannot be nullable
let fields: indexmap::IndexMap<_, _> = u
.arbitrary_iter::<FieldDef>()?
// awkward callback, don't want to collect to temporary just to map the `Ok` values
.map(|result| result.map(|field| (field.field_id, field)))
.collect::<Result<_, _>>()?;
let generation = {
// choose between generations mentioned in fields and max
let min = fields
.values()
.map(|field| field.added_in)
.chain(fields.values().filter_map(|field| field.deleted_in))
.max()
.map(|generation| generation.get_u64())
.unwrap_or(1);
#[expect(
clippy::range_minus_one,
reason = "Unstructured insists on InclusiveRange"
)]
let num = u.int_in_range(min..=u64::MAX - 1)?;
TableDefGeneration::new(num).unwrap()
};
let row_id = {
let num_primary_keys = u.choose_index(fields.len())?;
if num_primary_keys == 0 {
RowKeyKind::RowId(RowIdDef {
sequence_id: u.arbitrary()?,
})
} else {
// take a random subset in random order
let mut tmp: Vec<_> = fields.keys().collect();
let primary_key_ids = std::iter::from_fn(|| {
if tmp.is_empty() {
None
} else {
Some(u.choose_index(tmp.len()).map(|idx| *tmp.swap_remove(idx)))
}
})
.take(num_primary_keys)
.collect::<Result<Vec<_>, arbitrary::Error>>()?;
let primary_keys = PrimaryKeys::try_from(primary_key_ids).unwrap();
RowKeyKind::PrimaryKeys(PrimaryKeysDef {
field_ids: primary_keys,
})
}
};
let table_def = TableDef {
table_id,
catalog,
schema,
table_name,
generation,
row_key_kind: row_id,
fields,
};
Ok(table_def)
}
}
/// Errors from [`validate`](ArchivedTableDef::validate).
#[expect(clippy::module_name_repetitions)]
#[expect(missing_docs)]
#[derive(thiserror::Error, Debug)]
pub enum TableDefValidateError {
#[error("primary key field id not found: {field_id} not in {table_def_debug}")]
PrimaryKeyFieldIdNotFound {
field_id: FieldId,
table_def_debug: String,
},
#[error("primary key field cannot be deleted: {field_id} in {table_def_debug}")]
PrimaryKeyDeleted {
field_id: FieldId,
table_def_debug: String,
},
#[error("primary key field cannot be nullable: {field_id} in {table_def_debug}")]
PrimaryKeyNullable {
field_id: FieldId,
table_def_debug: String,
},
#[error("field id mismatch: {field_id} in {table_def_debug}")]
FieldIdMismatch {
field_id: FieldId,
table_def_debug: String,
},
#[error("field deleted before creation: {field_id} in {table_def_debug}")]
FieldDeletedBeforeCreation {
field_id: FieldId,
table_def_debug: String,
},
#[error("field name cannot be empty: {field_id} in {table_def_debug}")]
FieldNameEmpty {
field_id: FieldId,
table_def_debug: String,
},
#[error("duplicate field name: {field_id} vs {previous_field_id} in {table_def_debug}")]
FieldNameDuplicate {
field_id: FieldId,
previous_field_id: FieldId,
table_def_debug: String,
},
#[error("fixed size field default value has incorrect size: {field_id} in {table_def_debug}")]
FieldDefaultValueLength {
field_id: FieldId,
table_def_debug: String,
field_type_size: usize,
default_value_length: usize,
},
}
impl ArchivedTableDef {
/// Validate the internal consistency of the table definition.
///
/// Returns only the first error noticed.
pub fn validate(&self) -> Result<(), TableDefValidateError> {
// TODO maybe split into helper functions, one per check?
// TODO maybe field.validate()?;
match &self.row_key_kind {
ArchivedRowKeyKind::RowId(_seq_row_id_def) => {
// nothing
}
ArchivedRowKeyKind::PrimaryKeys(primary_keys_row_id_def) => {
for field_id in primary_keys_row_id_def.field_ids.iter() {
let field = self.fields.get(field_id).ok_or_else(|| {
TableDefValidateError::PrimaryKeyFieldIdNotFound {
field_id: field_id.to_native(),
table_def_debug: format!("{self:?}"),
}
})?;
if !field.is_live() {
let error = TableDefValidateError::PrimaryKeyDeleted {
field_id: field_id.to_native(),
table_def_debug: format!("{self:?}"),
};
return Err(error);
}
if field.nullable {
let error = TableDefValidateError::PrimaryKeyNullable {
field_id: field_id.to_native(),
table_def_debug: format!("{self:?}"),
};
return Err(error);
}
}
}
};
let mut live_field_names = HashMap::new();
for (field_id, field) in self.fields.iter() {
if &field.field_id != field_id {
let error = TableDefValidateError::FieldIdMismatch {
field_id: field_id.to_native(),
table_def_debug: format!("{self:?}"),
};
return Err(error);
}
if let Some(deleted_in) = field.deleted_in.as_ref() {
if deleted_in <= &field.added_in {
let error = TableDefValidateError::FieldDeletedBeforeCreation {
field_id: field_id.to_native(),
table_def_debug: format!("{self:?}"),
};
return Err(error);
}
}
if field.field_name.is_empty() {
let error = TableDefValidateError::FieldNameEmpty {
field_id: field_id.to_native(),
table_def_debug: format!("{self:?}"),
};
return Err(error);
}
if field.is_live() {
let field_name = field.field_name.as_str();
let entry = live_field_names.entry(field_name);
match entry {
std::collections::hash_map::Entry::Occupied(occupied_entry) => {
let error = TableDefValidateError::FieldNameDuplicate {
field_id: field_id.to_native(),
previous_field_id: *occupied_entry.get(),
table_def_debug: format!("{self:?}"),
};
return Err(error);
}
std::collections::hash_map::Entry::Vacant(vacant_entry) => {
let _mut_field_id = vacant_entry.insert(field.field_id.to_native());
}
}
}
if let Some(default_value) = field.default_value_arrow_bytes.as_deref() {
if let Some(field_type_size) = field.field_type.fixed_arrow_bytes_width() {
let default_value_length = default_value.len();
if default_value_length != field_type_size {
let error = TableDefValidateError::FieldDefaultValueLength {
field_id: field_id.to_native(),
table_def_debug: format!("{self:?}"),
field_type_size,
default_value_length,
};
return Err(error);
}
}
}
}
Ok(())
}
/// Which fields are live in the given generation.
pub fn fields_live_in_gen(
&self,
generation: TableDefGeneration,
) -> impl Iterator<Item = &ArchivedFieldDef> {
self.fields.values().filter(move |field| {
// This generation of row stores this field
let is_in_record = generation >= field.added_in
&& field
.deleted_in
.as_ref()
.is_none_or(|del| generation < *del);
is_in_record
})
}
/// Which fields are live in the latest generation.
pub fn fields_live_in_latest(&self) -> impl Iterator<Item = &ArchivedFieldDef> {
self.fields
.values()
.filter(|field| field.deleted_in.as_ref().is_none())
}
}
/// Make an Arrow [`Schema`](datafusion::arrow::datatypes::Schema) for some fields.
///
/// Using a subset of fields is useful since [`TableDef`] can include deleted fields.
/// See [`ArchivedTableDef::fields_live_in_latest`].
#[must_use]
pub fn make_arrow_schema<'defs>(
fields: impl Iterator<Item = &'defs ArchivedFieldDef>,
) -> datafusion::arrow::datatypes::SchemaRef {
let mut builder = {
let (_min, max) = fields.size_hint();
if let Some(capacity) = max {
datafusion::arrow::datatypes::SchemaBuilder::with_capacity(capacity)
} else {
datafusion::arrow::datatypes::SchemaBuilder::new()
}
};
for field in fields {
let arrow_field = datafusion::arrow::datatypes::Field::new(
field.field_name.to_owned(),
datafusion::arrow::datatypes::DataType::from(&field.field_type),
field.nullable,
);
builder.push(arrow_field);
}
Arc::new(builder.finish())
}

View file

@ -0,0 +1,44 @@
use crate::util::unique_nonempty::UniqueNonEmpty;
use crate::FieldId;
/// Slice of primary key [`FieldId`] that is guaranteed to be non-empty and have no duplicates.
#[derive(rkyv::Archive, rkyv::Serialize, arbitrary::Arbitrary)]
pub struct PrimaryKeys(UniqueNonEmpty<FieldId>);
impl TryFrom<Vec<FieldId>> for PrimaryKeys {
// TODO report why it failed #urgency/low #severity/low
type Error = ();
fn try_from(value: Vec<FieldId>) -> Result<Self, Self::Error> {
let une = UniqueNonEmpty::try_from(value)?;
Ok(Self(une))
}
}
impl std::ops::Deref for PrimaryKeys {
type Target = [FieldId];
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl std::ops::Deref for ArchivedPrimaryKeys {
type Target = [<FieldId as rkyv::Archive>::Archived];
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl std::fmt::Debug for PrimaryKeys {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_set().entries(self.0.iter()).finish()
}
}
impl std::fmt::Debug for ArchivedPrimaryKeys {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_set().entries(self.0.iter()).finish()
}
}

View file

@ -0,0 +1,2 @@
pub(crate) mod idgen;
pub(crate) mod unique_nonempty;

View file

@ -0,0 +1,194 @@
macro_rules! make_id_type {
(
$(#[$meta:meta])*
$vis:vis $name:ident
) => {
$(#[$meta])*
#[derive(Copy, Clone, Ord, PartialOrd, Eq, PartialEq, Hash, rkyv::Archive, rkyv::Serialize)]
#[rkyv(derive(Ord, PartialOrd, Eq, PartialEq, Hash, Debug))]
#[rkyv(compare(PartialOrd, PartialEq))]
#[non_exhaustive]
#[rkyv(attr(non_exhaustive))]
$vis struct $name(::std::num::NonZeroU64);
#[allow(clippy::allow_attributes, reason = "generated code")]
// Not all functionality is used by every ID type.
#[allow(dead_code)]
impl $name {
/// Minimum valid value.
$vis const MIN: $name = $name::new(1).expect("constant calculation must be correct");
/// Maximum valid value.
$vis const MAX: $name = $name::new(u64::MAX.saturating_sub(1)).expect("constant calculation must be correct");
/// Construct a new identifier.
/// The values `0` and `u64::MAX` are invalid.
#[must_use]
$vis const fn new(num: u64) -> Option<Self> {
if num == u64::MAX {
None
} else {
// limitation of const functions: no `Option::map`.
match ::std::num::NonZeroU64::new(num) {
Some(nz) => Some(Self(nz)),
None => None,
}
}
}
/// Get the value as a [`u64`].
#[must_use]
$vis const fn get_u64(&self) -> u64 {
self.0.get()
}
/// Get the value as a [`NonZeroU64`](::std::num::NonZeroU64).
#[must_use]
#[allow(dead_code)]
$vis const fn get_non_zero(&self) -> ::std::num::NonZeroU64 {
self.0
}
/// Return the next possible value, or [`None`] if that would be the invalid value [`u64::MAX`].
#[must_use]
$vis const fn next(&self) -> Option<Self> {
// `new` will refuse if we hit `u64::MAX`
Self::new(self.get_u64().saturating_add(1))
}
/// Iterate all valid identifier values.
$vis fn iter_all() -> impl Iterator<Item=$name> {
::std::iter::successors(Some(Self::MIN), |prev: &Self| prev.next())
}
fn arbitrary_greater_than(u: &mut ::arbitrary::Unstructured<'_>, greater_than: Self) -> ::arbitrary::Result<Option<Self>> {
if let Some(min) = greater_than.next() {
let num = u.int_in_range(min.get_u64()..=u64::MAX-1)?;
let id = Self::new(num).unwrap();
Ok(Some(id))
} else {
Ok(None)
}
}
}
impl ::paste::paste!([<Archived $name>]) {
#[doc = concat!(
"Convert the [`Archived",
stringify!($name),
"`] to [`",
stringify!($name),
"`].",
)]
#[must_use]
#[expect(clippy::allow_attributes)]
#[allow(dead_code)]
$vis const fn to_native(&self) -> $name {
$name(self.0.to_native())
}
/// Get the value as a [`u64`].
#[must_use]
#[expect(clippy::allow_attributes)]
#[allow(dead_code)]
$vis const fn get_u64(&self) -> u64 {
self.0.get()
}
/// Get the value as a [`NonZeroU64`](::std::num::NonZeroU64).
#[must_use]
#[expect(clippy::allow_attributes)]
#[allow(dead_code)]
$vis const fn get_non_zero(&self) -> ::std::num::NonZeroU64 {
self.0.to_native()
}
}
#[::duplicate::duplicate_item(
ident;
[$name];
[::paste::paste!([<Archived $name>])];
)]
impl ::std::fmt::Display for ident {
fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::result::Result<(), ::std::fmt::Error> {
::std::write!(f, "{}", self.get_u64())
}
}
impl ::std::fmt::Debug for $name {
fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::result::Result<(), ::std::fmt::Error> {
::std::write!(f, "{}({})", stringify!($name), self.get_u64())
}
}
impl ::std::cmp::PartialEq<u64> for $name {
fn eq(&self, other: &u64) -> bool {
self.0.get() == *other
}
}
impl ::std::cmp::PartialEq<$name> for u64 {
fn eq(&self, other: &$name) -> bool {
*self == other.get_u64()
}
}
impl<'a> ::arbitrary::Arbitrary<'a> for $name {
fn arbitrary(u: &mut ::arbitrary::Unstructured<'a>) -> ::arbitrary::Result<Self> {
let num = u.int_in_range(1..=u64::MAX-1)?;
Ok(Self::new(num).unwrap())
}
}
}
}
pub(crate) use make_id_type;
#[cfg(test)]
mod tests {
use std::num::NonZeroU64;
use super::*;
make_id_type!(FooId);
#[test]
fn roundtrip() {
let foo_id = FooId::new(42).unwrap();
assert_eq!(foo_id.get_u64(), 42u64);
assert_eq!(foo_id.get_non_zero(), NonZeroU64::new(42u64).unwrap());
}
#[test]
// TODO `const fn` to avoid clippy noise <https://github.com/rust-lang/rust-clippy/issues/13938> #waiting #ecosystem/rust
const fn copy() {
let foo_id = FooId::new(42).unwrap();
let one = foo_id;
let two = foo_id;
_ = one;
_ = two;
}
#[test]
fn debug() {
let foo_id = FooId::new(42).unwrap();
assert_eq!(format!("{foo_id:?}"), "FooId(42)");
}
#[test]
fn debug_alternate() {
let foo_id = FooId::new(42).unwrap();
assert_eq!(format!("{foo_id:#?}"), "FooId(42)");
}
#[test]
fn display() {
let foo_id = FooId::new(42).unwrap();
assert_eq!(format!("{foo_id}"), "42");
}
#[test]