musikr: fix rustiostream ownership problem

Sadly this requires me to more or less give an &mut off to TagLib to
own while also owning the core type to drop later, but since stream
is only owned to be dropped it's fine.
This commit is contained in:
Alexander Capehart 2025-02-14 14:22:20 -07:00
parent 3ecdbf289b
commit 74edd1dbdf
No known key found for this signature in database
GPG key ID: 37DBE3621FE9AD47
8 changed files with 124 additions and 107 deletions

View file

@ -120,7 +120,6 @@ fn main() {
.file("shim/file_shim.cpp") .file("shim/file_shim.cpp")
.file("shim/tk_shim.cpp") .file("shim/tk_shim.cpp")
.include(format!("taglib/pkg/{}/include", target)) .include(format!("taglib/pkg/{}/include", target))
.include("shim")
.include(".") // Add the current directory to include path .include(".") // Add the current directory to include path
.flag_if_supported("-std=c++14"); .flag_if_supported("-std=c++14");

View file

@ -2,6 +2,7 @@
#include <stdexcept> #include <stdexcept>
#include <rust/cxx.h> #include <rust/cxx.h>
#include <vector> #include <vector>
#include "metadatajni/src/taglib/bridge.rs.h"
// These are the functions we'll define in Rust // These are the functions we'll define in Rust
extern "C" extern "C"
@ -20,7 +21,7 @@ namespace taglib_shim
{ {
// Factory function to create a new RustIOStream // Factory function to create a new RustIOStream
std::unique_ptr<RustIOStream> new_rust_iostream(RustStream *stream) std::unique_ptr<RustIOStream> new_RustIOStream(BridgeStream& stream)
{ {
return std::unique_ptr<RustIOStream>(new RustIOStream(stream)); return std::unique_ptr<RustIOStream>(new RustIOStream(stream));
} }
@ -31,27 +32,26 @@ namespace taglib_shim
return std::make_unique<TagLib::FileRef>(stream.release(), true); return std::make_unique<TagLib::FileRef>(stream.release(), true);
} }
RustIOStream::RustIOStream(RustStream *stream) : rust_stream(stream) {} RustIOStream::RustIOStream(BridgeStream& stream) : rust_stream(stream) {}
RustIOStream::~RustIOStream() = default; RustIOStream::~RustIOStream() = default;
TagLib::FileName RustIOStream::name() const TagLib::FileName RustIOStream::name() const
{ {
return rust_stream_name(rust_stream); return rust::string(rust_stream.name()).c_str();
} }
TagLib::ByteVector RustIOStream::readBlock(size_t length) TagLib::ByteVector RustIOStream::readBlock(size_t length)
{ {
std::vector<uint8_t> buffer(length); std::vector<uint8_t> buffer(length);
size_t bytes_read = rust_stream_read(rust_stream, buffer.data(), length); size_t bytes_read = rust_stream.read(rust::Slice<uint8_t>(buffer.data(), length));
return TagLib::ByteVector(reinterpret_cast<char *>(buffer.data()), bytes_read); return TagLib::ByteVector(reinterpret_cast<char *>(buffer.data()), bytes_read);
} }
void RustIOStream::writeBlock(const TagLib::ByteVector &data) void RustIOStream::writeBlock(const TagLib::ByteVector &data)
{ {
rust_stream_write(rust_stream, rust_stream.write(rust::Slice<const uint8_t>(
reinterpret_cast<const uint8_t *>(data.data()), reinterpret_cast<const uint8_t *>(data.data()), data.size()));
data.size());
} }
void RustIOStream::insert(const TagLib::ByteVector &data, TagLib::offset_t start, size_t replace) void RustIOStream::insert(const TagLib::ByteVector &data, TagLib::offset_t start, size_t replace)
@ -118,7 +118,7 @@ namespace taglib_shim
default: default:
throw std::runtime_error("Invalid seek position"); throw std::runtime_error("Invalid seek position");
} }
rust_stream_seek(rust_stream, offset, whence); rust_stream.seek(offset, whence);
} }
void RustIOStream::clear() void RustIOStream::clear()
@ -129,22 +129,22 @@ namespace taglib_shim
void RustIOStream::truncate(TagLib::offset_t length) void RustIOStream::truncate(TagLib::offset_t length)
{ {
rust_stream_truncate(rust_stream, length); rust_stream.truncate(length);
} }
TagLib::offset_t RustIOStream::tell() const TagLib::offset_t RustIOStream::tell() const
{ {
return rust_stream_tell(rust_stream); return rust_stream.tell();
} }
TagLib::offset_t RustIOStream::length() TagLib::offset_t RustIOStream::length()
{ {
return rust_stream_length(rust_stream); return rust_stream.length();
} }
bool RustIOStream::readOnly() const bool RustIOStream::readOnly() const
{ {
return rust_stream_is_readonly(rust_stream); return rust_stream.is_readonly();
} }
bool RustIOStream::isOpen() const bool RustIOStream::isOpen() const

View file

@ -4,18 +4,19 @@
#include <string> #include <string>
#include <taglib/tiostream.h> #include <taglib/tiostream.h>
#include <taglib/fileref.h> #include <taglib/fileref.h>
#include "rust/cxx.h"
// Forward declare the bridge type
struct BridgeStream;
namespace taglib_shim namespace taglib_shim
{ {
// Forward declaration of the Rust-side stream
struct RustStream;
// C++ implementation of TagLib::IOStream that delegates to Rust // C++ implementation of TagLib::IOStream that delegates to Rust
class RustIOStream : public TagLib::IOStream class RustIOStream : public TagLib::IOStream
{ {
public: public:
explicit RustIOStream(RustStream *stream); explicit RustIOStream(BridgeStream& stream);
~RustIOStream() override; ~RustIOStream() override;
// TagLib::IOStream interface implementation // TagLib::IOStream interface implementation
@ -33,11 +34,12 @@ namespace taglib_shim
bool isOpen() const override; bool isOpen() const override;
private: private:
RustStream *rust_stream; BridgeStream& rust_stream;
}; };
// Factory functions // Factory functions with external linkage
std::unique_ptr<RustIOStream> new_rust_iostream(RustStream *stream); std::unique_ptr<RustIOStream> new_RustIOStream(BridgeStream& stream);
std::unique_ptr<TagLib::FileRef> new_FileRef_from_stream(std::unique_ptr<RustIOStream> stream); std::unique_ptr<TagLib::FileRef> new_FileRef_from_stream(std::unique_ptr<RustIOStream> stream);
} // namespace taglib_shim } // namespace taglib_shim

View file

@ -1,32 +1,34 @@
use crate::taglib::stream::IOStream; use crate::taglib::stream::IOStream;
use jni::objects::{JObject, JValue}; use jni::objects::{JObject, JValue};
use jni::JNIEnv;
use std::io::{Read, Seek, SeekFrom, Write}; use std::io::{Read, Seek, SeekFrom, Write};
use crate::SharedEnv;
pub struct JInputStream<'local, 'a> { pub struct JInputStream<'local> {
env: &'a mut JNIEnv<'local>, env: SharedEnv<'local>,
input: JObject<'local>, input: JObject<'local>,
} }
impl<'local, 'a> JInputStream<'local, 'a> { impl<'local, 'a> JInputStream<'local> {
pub fn new( pub fn new(
env: &'a mut JNIEnv<'local>, env: SharedEnv<'local>,
input: JObject<'local>, input: JObject<'local>,
) -> Self { ) -> Self {
Self { env, input } Self { env, input }
} }
} }
impl<'local, 'a> IOStream for JInputStream<'local, 'a> { impl<'local> IOStream for JInputStream<'local> {
fn name(&mut self) -> String { fn name(&mut self) -> String {
// Call the Java name() method safely // Call the Java name() method safely
let name = self let name = self
.env .env
.borrow_mut()
.call_method(&self.input, "name", "()Ljava/lang/String;", &[]) .call_method(&self.input, "name", "()Ljava/lang/String;", &[])
.and_then(|result| result.l()) .and_then(|result| result.l())
.expect("Failed to call name() method"); .expect("Failed to call name() method");
self.env self.env
.borrow_mut()
.get_string(&name.into()) .get_string(&name.into())
.expect("Failed to convert Java string") .expect("Failed to convert Java string")
.into() .into()
@ -37,11 +39,12 @@ impl<'local, 'a> IOStream for JInputStream<'local, 'a> {
} }
} }
impl<'local, 'a> Read for JInputStream<'local, 'a> { impl<'local> Read for JInputStream<'local> {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> { fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
// Create a direct ByteBuffer from the Rust slice // Create a direct ByteBuffer from the Rust slice
let byte_buffer = unsafe { let byte_buffer = unsafe {
self.env self.env
.borrow_mut()
.new_direct_byte_buffer(buf.as_mut_ptr(), buf.len()) .new_direct_byte_buffer(buf.as_mut_ptr(), buf.len())
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e.to_string()))? .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e.to_string()))?
}; };
@ -49,6 +52,7 @@ impl<'local, 'a> Read for JInputStream<'local, 'a> {
// Call readBlock safely // Call readBlock safely
let success = self let success = self
.env .env
.borrow_mut()
.call_method( .call_method(
&self.input, &self.input,
"readBlock", "readBlock",
@ -69,7 +73,7 @@ impl<'local, 'a> Read for JInputStream<'local, 'a> {
} }
} }
impl<'local, 'a> Write for JInputStream<'local, 'a> { impl<'local> Write for JInputStream<'local> {
fn write(&mut self, _buf: &[u8]) -> std::io::Result<usize> { fn write(&mut self, _buf: &[u8]) -> std::io::Result<usize> {
Err(std::io::Error::new( Err(std::io::Error::new(
std::io::ErrorKind::PermissionDenied, std::io::ErrorKind::PermissionDenied,
@ -82,7 +86,7 @@ impl<'local, 'a> Write for JInputStream<'local, 'a> {
} }
} }
impl<'local, 'a> Seek for JInputStream<'local, 'a> { impl<'local, 'a> Seek for JInputStream<'local> {
fn seek(&mut self, pos: SeekFrom) -> std::io::Result<u64> { fn seek(&mut self, pos: SeekFrom) -> std::io::Result<u64> {
let (method, offset) = match pos { let (method, offset) = match pos {
SeekFrom::Start(offset) => ("seekFromBeginning", offset as i64), SeekFrom::Start(offset) => ("seekFromBeginning", offset as i64),
@ -93,6 +97,7 @@ impl<'local, 'a> Seek for JInputStream<'local, 'a> {
// Call the appropriate seek method safely // Call the appropriate seek method safely
let success = self let success = self
.env .env
.borrow_mut()
.call_method(&self.input, method, "(J)Z", &[JValue::Long(offset)]) .call_method(&self.input, method, "(J)Z", &[JValue::Long(offset)])
.and_then(|result| result.z()) .and_then(|result| result.z())
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e.to_string()))?; .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e.to_string()))?;
@ -107,6 +112,7 @@ impl<'local, 'a> Seek for JInputStream<'local, 'a> {
// Return current position safely // Return current position safely
let position = self let position = self
.env .env
.borrow_mut()
.call_method(&self.input, "tell", "()J", &[]) .call_method(&self.input, "tell", "()J", &[])
.and_then(|result| result.j()) .and_then(|result| result.j())
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e.to_string()))?; .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e.to_string()))?;

View file

@ -1,3 +1,5 @@
use std::cell::RefCell;
use std::rc::Rc;
use jni::objects::{JClass, JObject}; use jni::objects::{JClass, JObject};
use jni::sys::jstring; use jni::sys::jstring;
use jni::JNIEnv; use jni::JNIEnv;
@ -8,7 +10,8 @@ mod jstream;
use taglib::file::FileRef; use taglib::file::FileRef;
use jstream::JInputStream; use jstream::JInputStream;
pub use taglib::*;
type SharedEnv<'local> = Rc<RefCell<JNIEnv<'local>>>;
#[no_mangle] #[no_mangle]
pub extern "C" fn Java_org_oxycblt_musikr_metadata_MetadataJNI_openFile<'local>( pub extern "C" fn Java_org_oxycblt_musikr_metadata_MetadataJNI_openFile<'local>(
@ -17,7 +20,8 @@ pub extern "C" fn Java_org_oxycblt_musikr_metadata_MetadataJNI_openFile<'local>(
input: JObject<'local>, input: JObject<'local>,
) -> jstring { ) -> jstring {
// Create JInputStream from the Java input stream // Create JInputStream from the Java input stream
let stream = JInputStream::new(&mut env, input); let shared_env = Rc::new(RefCell::new(env));
let mut stream = JInputStream::new(shared_env.clone(), input);
let file_ref = FileRef::new(stream); let file_ref = FileRef::new(stream);
// file_ref.file().and_then(|file| { // file_ref.file().and_then(|file| {
// let audio_properties = file.audio_properties().map(|props| AudioProperties { // let audio_properties = file.audio_properties().map(|props| AudioProperties {
@ -55,6 +59,6 @@ pub extern "C" fn Java_org_oxycblt_musikr_metadata_MetadataJNI_openFile<'local>(
// }); // });
// Return the title // Return the title
let output = env.new_string("title").expect("Couldn't create string!"); let output = shared_env.borrow_mut().new_string("title").expect("Couldn't create string!");
output.into_raw() output.into_raw()
} }

View file

@ -1,5 +1,23 @@
use super::stream::BridgeStream;
#[cxx::bridge] #[cxx::bridge]
mod bridge_impl { mod bridge_impl {
// Expose Rust IOStream to C++
extern "Rust" {
#[cxx_name = "BridgeStream"]
type BridgeStream<'a>;
fn name(self: &mut BridgeStream<'_>) -> String;
fn read(self: &mut BridgeStream<'_>, buffer: &mut [u8]) -> usize;
fn write(self: &mut BridgeStream<'_>, data: &[u8]);
fn seek(self: &mut BridgeStream<'_>, offset: i64, whence: i32);
fn truncate(self: &mut BridgeStream<'_>, length: i64);
fn tell(self: &mut BridgeStream<'_>) -> i64;
fn length(self: &mut BridgeStream<'_>) -> i64;
fn is_readonly(self: &BridgeStream<'_>) -> bool;
}
#[namespace = "taglib_shim"]
unsafe extern "C++" { unsafe extern "C++" {
include!("taglib/taglib.h"); include!("taglib/taglib.h");
include!("taglib/tstring.h"); include!("taglib/tstring.h");
@ -10,6 +28,9 @@ mod bridge_impl {
include!("shim/file_shim.hpp"); include!("shim/file_shim.hpp");
include!("shim/tk_shim.hpp"); include!("shim/tk_shim.hpp");
#[cxx_name = "RustIOStream"]
type RustIOStream;
#[namespace = "TagLib"] #[namespace = "TagLib"]
#[cxx_name = "FileRef"] #[cxx_name = "FileRef"]
type TFileRef; type TFileRef;
@ -18,15 +39,9 @@ mod bridge_impl {
#[cxx_name = "file"] #[cxx_name = "file"]
fn thisFile(self: Pin<&TFileRef>) -> *mut BaseFile; fn thisFile(self: Pin<&TFileRef>) -> *mut BaseFile;
#[namespace = "taglib_shim"] // Create a RustIOStream from a BridgeStream
type RustIOStream; unsafe fn new_RustIOStream(stream: Pin<&mut BridgeStream>) -> UniquePtr<RustIOStream>;
// Create a FileRef from an iostream // Create a FileRef from an iostream
#[namespace = "taglib_shim"]
unsafe fn new_rust_iostream(stream: *mut RustStream) -> UniquePtr<RustIOStream>;
#[namespace = "taglib_shim"]
type RustStream;
#[namespace = "taglib_shim"]
fn new_FileRef_from_stream(stream: UniquePtr<RustIOStream>) -> UniquePtr<TFileRef>; fn new_FileRef_from_stream(stream: UniquePtr<RustIOStream>) -> UniquePtr<TFileRef>;
#[namespace = "TagLib"] #[namespace = "TagLib"]

View file

@ -7,20 +7,19 @@ use super::stream::IOStream;
use std::pin::Pin; use std::pin::Pin;
use std::marker::PhantomData; use std::marker::PhantomData;
pub struct FileRef<'a, T: IOStream + 'a> { pub struct FileRef<'a> {
data: PhantomData<&'a T>, stream: BridgeStream<'a>,
ptr: UniquePtr<TFileRef> file_ref: UniquePtr<TFileRef>
} }
impl <'a, T: IOStream + 'a> FileRef<'a, T> { impl <'a> FileRef<'a> {
pub fn new(stream: T) -> FileRef<'a, T> { pub fn new<T : IOStream + 'a>(stream: T) -> FileRef<'a> {
let bridge_stream = BridgeStream::new(stream); let mut bridge_stream = BridgeStream::new(stream);
let raw_stream = Box::into_raw(Box::new(bridge_stream)) as *mut bridge::RustStream; let iostream = unsafe { bridge::new_RustIOStream(Pin::new(&mut bridge_stream)) };
let iostream = unsafe { bridge::new_rust_iostream(raw_stream) };
let file_ref = bridge::new_FileRef_from_stream(iostream); let file_ref = bridge::new_FileRef_from_stream(iostream);
FileRef { FileRef {
data: PhantomData::<&'a T>, stream: bridge_stream,
ptr: file_ref file_ref
} }
} }
@ -32,7 +31,7 @@ impl <'a, T: IOStream + 'a> FileRef<'a, T> {
// not change address by C++ semantics. // not change address by C++ semantics.
// - The file data is a pointer that does not depend on the // - The file data is a pointer that does not depend on the
// address of self. // address of self.
let this = Pin::new_unchecked(&*self.ptr); let this = Pin::new_unchecked(&*self.file_ref);
// Note: This is not the rust ptr "is_null", but a taglib isNull method // Note: This is not the rust ptr "is_null", but a taglib isNull method
// that checks for file validity. Without this check, we can get corrupted // that checks for file validity. Without this check, we can get corrupted
// file ptrs. // file ptrs.
@ -187,3 +186,14 @@ impl AudioProperties {
} }
} }
} }
impl <'a> Drop for FileRef<'a> {
fn drop(&mut self) {
// First drop the file, since it has a pointer to the stream.
// Then drop the stream
unsafe {
std::ptr::drop_in_place(&mut self.file_ref);
std::ptr::drop_in_place(&mut self.stream);
}
}
}

View file

@ -1,6 +1,7 @@
use std::ffi::{c_void, CString}; use std::ffi::{c_void, CString};
use std::io::{Read, Seek, SeekFrom, Write}; use std::io::{Read, Seek, SeekFrom, Write};
use std::os::raw::c_char; use std::os::raw::c_char;
use cxx::CxxString;
pub trait IOStream: Read + Write + Seek { pub trait IOStream: Read + Write + Seek {
fn name(&mut self) -> String; fn name(&mut self) -> String;
@ -14,67 +15,47 @@ impl<'a> BridgeStream<'a> {
pub fn new<T: IOStream + 'a>(stream: T) -> Self { pub fn new<T: IOStream + 'a>(stream: T) -> Self {
BridgeStream(Box::new(stream)) BridgeStream(Box::new(stream))
} }
}
#[no_mangle] // Implement the exposed functions for cxx bridge
pub extern "C" fn rust_stream_name(stream: *mut c_void) -> *const c_char { pub fn name(&mut self) -> String {
let stream = unsafe { &mut *(stream as *mut BridgeStream<'_>) }; self.0.name()
let name = stream.0.name(); }
// Note: This leaks memory, but TagLib only calls this once during construction
// and keeps the pointer, so it's fine
CString::new(name).unwrap().into_raw()
}
#[no_mangle] pub fn read(&mut self, buffer: &mut [u8]) -> usize {
pub extern "C" fn rust_stream_read(stream: *mut c_void, buffer: *mut u8, length: usize) -> usize { self.0.read(buffer).unwrap_or(0)
let stream = unsafe { &mut *(stream as *mut BridgeStream<'_>) }; }
let buffer = unsafe { std::slice::from_raw_parts_mut(buffer, length) };
stream.0.read(buffer).unwrap_or(0)
}
#[no_mangle] pub fn write(&mut self, data: &[u8]) {
pub extern "C" fn rust_stream_write(stream: *mut c_void, data: *const u8, length: usize) { self.0.write_all(data).unwrap();
let stream = unsafe { &mut *(stream as *mut BridgeStream<'_>) }; }
let data = unsafe { std::slice::from_raw_parts(data, length) };
stream.0.write_all(data).unwrap();
}
#[no_mangle] pub fn seek(&mut self, offset: i64, whence: i32) {
pub extern "C" fn rust_stream_seek(stream: *mut c_void, offset: i64, whence: i32) { let pos = match whence {
let stream = unsafe { &mut *(stream as *mut BridgeStream<'_>) }; 0 => SeekFrom::Start(offset as u64),
let pos = match whence { 1 => SeekFrom::Current(offset),
0 => SeekFrom::Start(offset as u64), 2 => SeekFrom::End(offset),
1 => SeekFrom::Current(offset), _ => panic!("Invalid seek whence"),
2 => SeekFrom::End(offset), };
_ => panic!("Invalid seek whence"), self.0.seek(pos).unwrap();
}; }
stream.0.seek(pos).unwrap();
}
#[no_mangle] pub fn truncate(&mut self, length: i64) {
pub extern "C" fn rust_stream_truncate(stream: *mut c_void, length: i64) { self.0.seek(SeekFrom::Start(length as u64)).unwrap();
let stream = unsafe { &mut *(stream as *mut BridgeStream<'_>) }; // TODO: Actually implement truncate once we have a better trait bound
stream.0.seek(SeekFrom::Start(length as u64)).unwrap(); }
// TODO: Actually implement truncate once we have a better trait bound
}
#[no_mangle] pub fn tell(&mut self) -> i64 {
pub extern "C" fn rust_stream_tell(stream: *mut c_void) -> i64 { self.0.seek(SeekFrom::Current(0)).unwrap() as i64
let stream = unsafe { &mut *(stream as *mut BridgeStream<'_>) }; }
stream.0.seek(SeekFrom::Current(0)).unwrap() as i64
}
#[no_mangle] pub fn length(&mut self) -> i64 {
pub extern "C" fn rust_stream_length(stream: *mut c_void) -> i64 { let current = self.0.seek(SeekFrom::Current(0)).unwrap();
let stream = unsafe { &mut *(stream as *mut BridgeStream<'_>) }; let end = self.0.seek(SeekFrom::End(0)).unwrap();
let current = stream.0.seek(SeekFrom::Current(0)).unwrap(); self.0.seek(SeekFrom::Start(current)).unwrap();
let end = stream.0.seek(SeekFrom::End(0)).unwrap(); end as i64
stream.0.seek(SeekFrom::Start(current)).unwrap(); }
end as i64
}
#[no_mangle] pub fn is_readonly(&self) -> bool {
pub extern "C" fn rust_stream_is_readonly(stream: *const c_void) -> bool { self.0.is_readonly()
let stream = unsafe { &*(stream as *const BridgeStream<'_>) }; }
stream.0.is_readonly()
} }