musikr: fix iostream lifecycle

Finally pass over ownership of the RsIOStream to the normal
IOStream.
This commit is contained in:
Alexander Capehart 2025-02-17 17:25:16 -07:00
parent da43ebda96
commit 12caac1f80
No known key found for this signature in database
GPG key ID: 37DBE3621FE9AD47
5 changed files with 120 additions and 139 deletions

View file

@ -11,7 +11,7 @@ namespace taglib_shim
class WrappedRsIOStream : public TagLib::IOStream class WrappedRsIOStream : public TagLib::IOStream
{ {
public: public:
explicit WrappedRsIOStream(RsIOStream& stream); explicit WrappedRsIOStream(rust::Box<RsIOStream> stream);
~WrappedRsIOStream() override; ~WrappedRsIOStream() override;
// TagLib::IOStream interface implementation // TagLib::IOStream interface implementation
@ -29,28 +29,28 @@ namespace taglib_shim
bool isOpen() const override; bool isOpen() const override;
private: private:
RsIOStream& rust_stream; rust::Box<RsIOStream> rust_stream;
}; };
WrappedRsIOStream::WrappedRsIOStream(RsIOStream& stream) : rust_stream(stream) {} WrappedRsIOStream::WrappedRsIOStream(rust::Box<RsIOStream> stream) : rust_stream(std::move(stream)) {}
WrappedRsIOStream::~WrappedRsIOStream() = default; WrappedRsIOStream::~WrappedRsIOStream() = default;
TagLib::FileName WrappedRsIOStream::name() const TagLib::FileName WrappedRsIOStream::name() const
{ {
return rust::string(rust_stream.name()).c_str(); return rust::string(rust_stream->name()).c_str();
} }
TagLib::ByteVector WrappedRsIOStream::readBlock(size_t length) TagLib::ByteVector WrappedRsIOStream::readBlock(size_t length)
{ {
std::vector<uint8_t> buffer(length); std::vector<uint8_t> buffer(length);
size_t bytes_read = rust_stream.read(rust::Slice<uint8_t>(buffer.data(), length)); size_t bytes_read = rust_stream->read(rust::Slice<uint8_t>(buffer.data(), length));
return TagLib::ByteVector(reinterpret_cast<char *>(buffer.data()), bytes_read); return TagLib::ByteVector(reinterpret_cast<char *>(buffer.data()), bytes_read);
} }
void WrappedRsIOStream::writeBlock(const TagLib::ByteVector &data) void WrappedRsIOStream::writeBlock(const TagLib::ByteVector &data)
{ {
rust_stream.write(rust::Slice<const uint8_t>( rust_stream->write(rust::Slice<const uint8_t>(
reinterpret_cast<const uint8_t *>(data.data()), data.size())); reinterpret_cast<const uint8_t *>(data.data()), data.size()));
} }
@ -118,7 +118,7 @@ namespace taglib_shim
default: default:
throw std::runtime_error("Invalid seek position"); throw std::runtime_error("Invalid seek position");
} }
rust_stream.seek(offset, whence); rust_stream->seek(offset, whence);
} }
void WrappedRsIOStream::clear() void WrappedRsIOStream::clear()
@ -129,22 +129,22 @@ namespace taglib_shim
void WrappedRsIOStream::truncate(TagLib::offset_t length) void WrappedRsIOStream::truncate(TagLib::offset_t length)
{ {
rust_stream.truncate(length); rust_stream->truncate(length);
} }
TagLib::offset_t WrappedRsIOStream::tell() const TagLib::offset_t WrappedRsIOStream::tell() const
{ {
return rust_stream.tell(); return rust_stream->tell();
} }
TagLib::offset_t WrappedRsIOStream::length() TagLib::offset_t WrappedRsIOStream::length()
{ {
return rust_stream.length(); return rust_stream->length();
} }
bool WrappedRsIOStream::readOnly() const bool WrappedRsIOStream::readOnly() const
{ {
return rust_stream.is_readonly(); return rust_stream->is_readonly();
} }
bool WrappedRsIOStream::isOpen() const bool WrappedRsIOStream::isOpen() const
@ -153,9 +153,9 @@ namespace taglib_shim
} }
// Factory function to create a new RustIOStream // Factory function to create a new RustIOStream
std::unique_ptr<TagLib::IOStream> wrap_RsIOStream(RsIOStream& stream) std::unique_ptr<TagLib::IOStream> wrap_RsIOStream(rust::Box<RsIOStream> stream)
{ {
return std::unique_ptr<TagLib::IOStream>(new WrappedRsIOStream(stream)); return std::unique_ptr<TagLib::IOStream>(new WrappedRsIOStream(std::move(stream)));
} }
} // namespace taglib_shim } // namespace taglib_shim

View file

@ -12,5 +12,5 @@ struct RsIOStream;
namespace taglib_shim namespace taglib_shim
{ {
// Factory functions with external linkage // Factory functions with external linkage
std::unique_ptr<TagLib::IOStream> wrap_RsIOStream(RsIOStream& stream); std::unique_ptr<TagLib::IOStream> wrap_RsIOStream(rust::Box<RsIOStream> stream);
} // namespace taglib_shim } // namespace taglib_shim

View file

@ -15,6 +15,86 @@ impl<'local, 'a> JInputStream<'local> {
} }
impl<'local> IOStream for JInputStream<'local> { impl<'local> IOStream for JInputStream<'local> {
fn read_block(&mut self, buf: &mut [u8]) -> usize {
// Create a direct ByteBuffer from the Rust slice
let byte_buffer = unsafe {
self.env
.borrow_mut()
.new_direct_byte_buffer(buf.as_mut_ptr(), buf.len())
.expect("Failed to create ByteBuffer")
};
// Call readBlock safely
let success = self
.env
.borrow_mut()
.call_method(
&self.input,
"readBlock",
"(Ljava/nio/ByteBuffer;)Z",
&[JValue::Object(&byte_buffer)],
)
.and_then(|result| result.z())
.expect("Failed to call readBlock");
if !success {
return 0;
}
buf.len()
}
fn write_block(&mut self, _data: &[u8]) {
panic!("JInputStream is read-only");
}
fn seek(&mut self, pos: SeekFrom) {
let (method, offset) = match pos {
SeekFrom::Start(offset) => ("seekFromBeginning", offset as i64),
SeekFrom::Current(offset) => ("seekFromCurrent", offset),
SeekFrom::End(offset) => ("seekFromEnd", offset),
};
// Call the appropriate seek method safely
let success = self
.env
.borrow_mut()
.call_method(&self.input, method, "(J)Z", &[JValue::Long(offset)])
.and_then(|result| result.z())
.expect("Failed to seek");
if !success {
panic!("Failed to seek");
}
}
fn truncate(&mut self, _length: i64) {
panic!("JInputStream is read-only");
}
fn tell(&self) -> i64 {
let position = self
.env
.borrow_mut()
.call_method(&self.input, "tell", "()J", &[])
.and_then(|result| result.j())
.expect("Failed to get position");
if position == i64::MIN {
panic!("Failed to get position");
}
position
}
fn length(&self) -> i64 {
self.env
.borrow_mut()
.call_method(&self.input, "length", "()J", &[])
.and_then(|result| result.j())
.expect("Failed to get length")
}
fn name(&self) -> String { fn name(&self) -> String {
// Call the Java name() method safely // Call the Java name() method safely
let name = self let name = self
@ -35,92 +115,3 @@ impl<'local> IOStream for JInputStream<'local> {
true // JInputStream is always read-only true // JInputStream is always read-only
} }
} }
impl<'local> Read for JInputStream<'local> {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
// Create a direct ByteBuffer from the Rust slice
let byte_buffer = unsafe {
self.env
.borrow_mut()
.new_direct_byte_buffer(buf.as_mut_ptr(), buf.len())
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e.to_string()))?
};
// Call readBlock safely
let success = self
.env
.borrow_mut()
.call_method(
&self.input,
"readBlock",
"(Ljava/nio/ByteBuffer;)Z",
&[JValue::Object(&byte_buffer)],
)
.and_then(|result| result.z())
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e.to_string()))?;
if !success {
return Err(std::io::Error::new(
std::io::ErrorKind::Other,
"Failed to read block",
));
}
Ok(buf.len())
}
}
impl<'local> Write for JInputStream<'local> {
fn write(&mut self, _buf: &[u8]) -> std::io::Result<usize> {
Err(std::io::Error::new(
std::io::ErrorKind::PermissionDenied,
"JInputStream is read-only",
))
}
fn flush(&mut self) -> std::io::Result<()> {
Ok(()) // Nothing to flush in a read-only stream
}
}
impl<'local, 'a> Seek for JInputStream<'local> {
fn seek(&mut self, pos: SeekFrom) -> std::io::Result<u64> {
let (method, offset) = match pos {
SeekFrom::Start(offset) => ("seekFromBeginning", offset as i64),
SeekFrom::Current(offset) => ("seekFromCurrent", offset),
SeekFrom::End(offset) => ("seekFromEnd", offset),
};
// Call the appropriate seek method safely
let success = self
.env
.borrow_mut()
.call_method(&self.input, method, "(J)Z", &[JValue::Long(offset)])
.and_then(|result| result.z())
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e.to_string()))?;
if !success {
return Err(std::io::Error::new(
std::io::ErrorKind::Other,
"Failed to seek",
));
}
// Return current position safely
let position = self
.env
.borrow_mut()
.call_method(&self.input, "tell", "()J", &[])
.and_then(|result| result.j())
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e.to_string()))?;
if position == i64::MIN {
return Err(std::io::Error::new(
std::io::ErrorKind::Other,
"Failed to get position",
));
}
Ok(position as u64)
}
}

View file

@ -5,16 +5,16 @@ mod bridge_impl {
// Expose Rust IOStream to C++ // Expose Rust IOStream to C++
extern "Rust" { extern "Rust" {
#[cxx_name = "RsIOStream"] #[cxx_name = "RsIOStream"]
type DynIOStream<'a>; type DynIOStream<'io_stream>;
fn name(self: &mut DynIOStream<'_>) -> String; fn name(self: &DynIOStream<'_>) -> String;
fn read(self: &mut DynIOStream<'_>, buffer: &mut [u8]) -> usize; fn read(self: &mut DynIOStream<'_>, buffer: &mut [u8]) -> usize;
fn write(self: &mut DynIOStream<'_>, data: &[u8]); fn write(self: &mut DynIOStream<'_>, data: &[u8]);
fn seek(self: &mut DynIOStream<'_>, offset: i64, whence: i32); fn seek(self: &mut DynIOStream<'_>, offset: i64, whence: i32);
fn truncate(self: &mut DynIOStream<'_>, length: i64); fn truncate(self: &mut DynIOStream<'_>, length: i64);
fn tell(self: &mut DynIOStream<'_>) -> i64; fn tell(self: &DynIOStream<'_>) -> i64;
fn length(self: &mut DynIOStream<'_>) -> i64; fn length(self: &DynIOStream<'_>) -> i64;
fn is_readonly(self: &mut DynIOStream<'_>) -> bool; fn is_readonly(self: &DynIOStream<'_>) -> bool;
} }
#[namespace = "taglib_shim"] #[namespace = "taglib_shim"]
@ -42,8 +42,8 @@ mod bridge_impl {
#[namespace = "TagLib"] #[namespace = "TagLib"]
#[cxx_name = "IOStream"] #[cxx_name = "IOStream"]
type CPPIOStream; type CPPIOStream<'io_stream>;
fn wrap_RsIOStream(stream: Pin<&mut DynIOStream>) -> UniquePtr<CPPIOStream>; fn wrap_RsIOStream<'io_stream>(stream: Box<DynIOStream<'io_stream>>) -> UniquePtr<CPPIOStream<'io_stream>>;
#[namespace = "TagLib"] #[namespace = "TagLib"]
#[cxx_name = "FileRef"] #[cxx_name = "FileRef"]

View file

@ -3,22 +3,26 @@ use cxx::UniquePtr;
use std::io::{Read, Seek, SeekFrom, Write}; use std::io::{Read, Seek, SeekFrom, Write};
use std::pin::Pin; use std::pin::Pin;
pub trait IOStream: Read + Write + Seek { pub trait IOStream {
fn read_block(&mut self, buffer: &mut [u8]) -> usize;
fn write_block(&mut self, data: &[u8]);
fn seek(&mut self, pos: SeekFrom);
fn truncate(&mut self, length: i64);
fn tell(&self) -> i64;
fn length(&self) -> i64;
fn name(&self) -> String; fn name(&self) -> String;
fn is_readonly(&self) -> bool; fn is_readonly(&self) -> bool;
} }
pub(super) struct BridgedIOStream<'io_stream> { pub(super) struct BridgedIOStream<'io_stream> {
rs_stream: Pin<Box<DynIOStream<'io_stream>>>, cpp_stream: UniquePtr<CPPIOStream<'io_stream>>,
cpp_stream: UniquePtr<CPPIOStream>,
} }
impl<'io_stream> BridgedIOStream<'io_stream> { impl<'io_stream> BridgedIOStream<'io_stream> {
pub fn new<T: IOStream + 'io_stream>(stream: T) -> Self { pub fn new<T: IOStream + 'io_stream>(stream: T) -> Self {
let mut rs_stream = Box::pin(DynIOStream(Box::new(stream))); let rs_stream: Box<DynIOStream<'io_stream>> = Box::new(DynIOStream(Box::new(stream)));
let cpp_stream = bridge::wrap_RsIOStream(rs_stream.as_mut()); let cpp_stream: UniquePtr<CPPIOStream<'io_stream>> = bridge::wrap_RsIOStream(rs_stream);
BridgedIOStream { BridgedIOStream {
rs_stream,
cpp_stream, cpp_stream,
} }
} }
@ -28,31 +32,21 @@ impl<'io_stream> BridgedIOStream<'io_stream> {
} }
} }
impl<'io_stream> Drop for BridgedIOStream<'io_stream> {
fn drop(&mut self) {
unsafe {
// CPP stream references the rust stream, so it must be dropped first
std::ptr::drop_in_place(&mut self.cpp_stream);
std::ptr::drop_in_place(&mut self.rs_stream);
};
}
}
#[repr(C)] #[repr(C)]
pub(super) struct DynIOStream<'io_stream>(Box<dyn IOStream + 'io_stream>); pub(super) struct DynIOStream<'io_stream>(Box<dyn IOStream + 'io_stream>);
impl<'io_stream> DynIOStream<'io_stream> { impl<'io_stream> DynIOStream<'io_stream> {
// Implement the exposed functions for cxx bridge // Implement the exposed functions for cxx bridge
pub fn name(&mut self) -> String { pub fn name(&self) -> String {
self.0.name() self.0.name()
} }
pub fn read(&mut self, buffer: &mut [u8]) -> usize { pub fn read(&mut self, buffer: &mut [u8]) -> usize {
self.0.read(buffer).unwrap_or(0) self.0.read_block(buffer)
} }
pub fn write(&mut self, data: &[u8]) { pub fn write(&mut self, data: &[u8]) {
self.0.write_all(data).unwrap(); self.0.write_block(data);
} }
pub fn seek(&mut self, offset: i64, whence: i32) { pub fn seek(&mut self, offset: i64, whence: i32) {
@ -62,23 +56,19 @@ impl<'io_stream> DynIOStream<'io_stream> {
2 => SeekFrom::End(offset), 2 => SeekFrom::End(offset),
_ => panic!("Invalid seek whence"), _ => panic!("Invalid seek whence"),
}; };
self.0.seek(pos).unwrap(); self.0.seek(pos);
} }
pub fn truncate(&mut self, length: i64) { pub fn truncate(&mut self, length: i64) {
self.0.seek(SeekFrom::Start(length as u64)).unwrap(); self.0.truncate(length);
// TODO: Actually implement truncate once we have a better trait bound
} }
pub fn tell(&mut self) -> i64 { pub fn tell(&self) -> i64 {
self.0.seek(SeekFrom::Current(0)).unwrap() as i64 self.0.tell()
} }
pub fn length(&mut self) -> i64 { pub fn length(&self) -> i64 {
let current = self.0.seek(SeekFrom::Current(0)).unwrap(); self.0.length()
let end = self.0.seek(SeekFrom::End(0)).unwrap();
self.0.seek(SeekFrom::Start(current)).unwrap();
end as i64
} }
pub fn is_readonly(&self) -> bool { pub fn is_readonly(&self) -> bool {