diff --git a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx index 2e8658fc35..118d9da312 100644 --- a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx +++ b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx @@ -25,10 +25,13 @@ import typing from typing import List, Tuple import cython +cimport cpython from cpython.bytes cimport PyBytes_FromStringAndSize from libc.stdint cimport int32_t, int64_t, uint8_t, uint32_t, uintptr_t -from libc.string cimport memset +from libc.string cimport memset, memcpy from libcpp.vector cimport vector as c_vector +from libc.stdlib cimport malloc, free +from libc.errno cimport EIO if typing.TYPE_CHECKING: from typing import Self @@ -40,8 +43,13 @@ cdef extern from "adbc.h" nogil: pass cdef struct CArrowArray"ArrowArray": pass + cdef struct CArrowArrayStream"ArrowArrayStream": - pass + int (*get_schema)(CArrowArrayStream* stream, CArrowSchema* out) nogil noexcept + int (*get_next)(CArrowArrayStream* stream, CArrowArray* out) nogil noexcept + const char* (*get_last_error)(CArrowArrayStream*) nogil noexcept + void (*release)(CArrowArrayStream*) nogil noexcept + void* private_data # ADBC ctypedef uint8_t CAdbcStatusCode"AdbcStatusCode" @@ -460,6 +468,20 @@ cdef class _AdbcHandle: f"with open {self._child_type}") +cdef void pycapsule_stream_deleter(object stream_capsule): + cdef: + CArrowArrayStream* stream + # Do not invoke the deleter on a used/moved capsule + stream = cpython.PyCapsule_GetPointer( + stream_capsule, 'arrowarraystream' + ) + if stream.release != NULL: + print("calling the release callback") + stream.release(stream) + + free(stream) + + cdef class ArrowSchemaHandle: """ A wrapper for an allocated ArrowSchema. @@ -486,6 +508,26 @@ cdef class ArrowArrayHandle: return &self.array +def _create_stream_capsule(): + """ + Create PyCapsule holding a newly allocated (blank) ArrowArrayStream + """ + cdef CArrowArrayStream* stream = malloc( + cython.sizeof(CArrowArrayStream) + ) + memset(stream, 0, cython.sizeof(CArrowArrayStream)) + + return cpython.PyCapsule_New( + stream, 'arrowarraystream', pycapsule_stream_deleter + ) + + +cdef CArrowArrayStream* _get_stream_pointer(stream_capsule): + return cpython.PyCapsule_GetPointer( + stream_capsule, 'arrowarraystream' + ) + + cdef class ArrowArrayStreamHandle: """ A wrapper for an allocated ArrowArrayStream. @@ -878,6 +920,7 @@ cdef class AdbcStatement(_AdbcHandle): cdef: AdbcConnection connection CAdbcStatement statement + bint closed def __init__(self, AdbcConnection connection) -> None: super().__init__("(no child type)") @@ -893,6 +936,7 @@ cdef class AdbcStatement(_AdbcHandle): check_error(status, &c_error) connection._open_child() + self.closed = False def bind(self, data, schema) -> None: """ @@ -960,6 +1004,7 @@ cdef class AdbcStatement(_AdbcHandle): cdef CAdbcError c_error = empty_error() cdef CAdbcStatusCode status self.connection._close_child() + self.closed = True with self._lock: if self.statement.private_data == NULL: return @@ -968,28 +1013,31 @@ cdef class AdbcStatement(_AdbcHandle): status = AdbcStatementRelease(&self.statement, &c_error) check_error(status, &c_error) - def execute_query(self) -> Tuple[ArrowArrayStreamHandle, int]: + def execute_query(self) -> Tuple["PyCapsule", int]: """ Execute the query and get the result set. Returns ------- - ArrowArrayStreamHandle + PyCapsule holding an ArrowArrayStream The result set. int The number of rows if known, else -1. """ cdef CAdbcError c_error = empty_error() - cdef ArrowArrayStreamHandle stream = ArrowArrayStreamHandle() cdef int64_t rows_affected = 0 + + stream_capsule = _create_stream_capsule() + cdef CArrowArrayStream* stream = _get_stream_pointer(stream_capsule) + with nogil: status = AdbcStatementExecuteQuery( &self.statement, - &stream.stream, + stream, &rows_affected, &c_error) check_error(status, &c_error) - return (stream, rows_affected) + return (stream_capsule, rows_affected) def execute_partitions(self) -> Tuple[List[bytes], ArrowSchemaHandle, int]: """ @@ -1132,3 +1180,85 @@ cdef class AdbcStatement(_AdbcHandle): status = AdbcStatementSetSubstraitPlan( &self.statement, c_plan, length, &c_error) check_error(status, &c_error) + + +# Implementation of an ArrowArrayStream that keeps a dependent object valid + + +cdef struct ArrowArrayStreamWrapper: + cpython.PyObject* parent_statement + CArrowArrayStream* parent_array_stream + bint error_set + + +cdef void wrapper_array_stream_release(CArrowArrayStream* array_stream) nogil noexcept: + cdef ArrowArrayStreamWrapper* data + + if array_stream.private_data != NULL: + data = array_stream.private_data + data.parent_array_stream.release(data.parent_array_stream) + + with gil: + cpython.Py_DECREF(data.parent_statement) + + free(array_stream.private_data) + + array_stream.release = NULL + + +cdef const char* wrapper_array_stream_get_last_error(CArrowArrayStream* array_stream) nogil noexcept: + cdef ArrowArrayStreamWrapper* data = array_stream.private_data + if data.error_set: + return "AdbcStatement already closed" + return data.parent_array_stream.get_last_error(data.parent_array_stream) + + +cdef int wrapper_array_stream_get_schema(CArrowArrayStream* array_stream, CArrowSchema* out) nogil noexcept: + cdef ArrowArrayStreamWrapper* data = array_stream.private_data + if (data.parent_statement).closed: + data.error_set = True + return EIO + return data.parent_array_stream.get_schema(data.parent_array_stream, out) + + +cdef int wrapper_array_stream_get_next(CArrowArrayStream* array_stream, CArrowArray* out) nogil noexcept: + cdef ArrowArrayStreamWrapper* data = (array_stream.private_data) + if (data.parent_statement).closed: + data.error_set = True + return EIO + return data.parent_array_stream.get_next(data.parent_array_stream, out) + + +def export_array_stream(object array_stream_capsule, AdbcStatement parent_statement): + """ + Given an ArrowArrayStream PyCapsule, return a new ArrowArrayStream capsule + wrapping the original stream and statement object. + """ + cdef CArrowArrayStream* array_stream = _get_stream_pointer(array_stream_capsule) + + array_stream_capsule_exported = _create_stream_capsule() + cdef CArrowArrayStream* array_stream_exported = _get_stream_pointer( + array_stream_capsule_exported) + + # move input array stream + cdef CArrowArrayStream* array_stream_moved = malloc( + cython.sizeof(CArrowArrayStream)) + memset(array_stream_moved, 0, cython.sizeof(CArrowArrayStream)) + memcpy(array_stream_moved, array_stream, sizeof(CArrowArrayStream)) + array_stream.release = NULL + + array_stream_exported.private_data = NULL + array_stream_exported.get_last_error = &wrapper_array_stream_get_last_error + array_stream_exported.get_schema = &wrapper_array_stream_get_schema + array_stream_exported.get_next = &wrapper_array_stream_get_next + array_stream_exported.release = &wrapper_array_stream_release + + cdef ArrowArrayStreamWrapper* data = malloc( + cython.sizeof(ArrowArrayStreamWrapper)) + data.parent_array_stream = array_stream_moved + data.parent_statement = parent_statement + cpython.Py_INCREF(parent_statement) + data.error_set = False + array_stream_exported.private_data = data + + return array_stream_capsule_exported