Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions src/base/flamec/supermatrix/hip/include/FLASH_Queue_hip.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,12 @@ void FLASH_Queue_set_hip_num_blocks( dim_t n_blocks );
dim_t FLASH_Queue_get_hip_num_blocks( void );

FLA_Error FLASH_Queue_bind_hip( int thread );
FLA_Error FLASH_Queue_alloc_hip( dim_t size, FLA_Datatype datatype, void** buffer_hip );
FLA_Error FLASH_Queue_free_async_hip( void* buffer_hip );
FLA_Error FLASH_Queue_write_hip( FLA_Obj obj, void* buffer_hip );
FLA_Error FLASH_Queue_alloc_async_hip( int thread, dim_t size, FLA_Datatype datatype, void** buffer_hip );
FLA_Error FLASH_Queue_free_async_hip( int thread, void* buffer_hip );
FLA_Error FLASH_Queue_write_async_hip( int thread, FLA_Obj obj, void* buffer_hip );
FLA_Error FLASH_Queue_read_hip( int thread, FLA_Obj obj, void* buffer_hip );
FLA_Error FLASH_Queue_read_async_hip( int thread, FLA_Obj obj, void* buffer_hip );
FLA_Error FLASH_Queue_sync_stream_hip( int thread );
FLA_Error FLASH_Queue_sync_device_hip( int device );
FLA_Error FLASH_Queue_sync_hip( );

Expand Down
98 changes: 55 additions & 43 deletions src/base/flamec/supermatrix/hip/main/FLASH_Queue_hip.c
Original file line number Diff line number Diff line change
Expand Up @@ -250,17 +250,23 @@ FLA_Error FLASH_Queue_bind_hip( int thread )
{

// Bind a HIP device to this thread.
hipSetDevice( thread );
if ( hipSetDevice( thread ) != hipSuccess ) return FLA_FAILURE;

// initialize its rocBLAS handle
if ( handles[thread] == NULL )
rocblas_create_handle( &(handles[thread]) );
{
//hipStream_t stream;
//if ( hipStreamCreate(&stream) != hipSuccess ) return FLA_FAILURE;
if ( rocblas_create_handle( &(handles[thread]) ) != rocblas_status_success ) return FLA_FAILURE;
//if ( rocblas_set_stream( handles[thread], stream ) != rocblas_status_success ) return FLA_FAILURE;
}

return FLA_SUCCESS;
}


FLA_Error FLASH_Queue_alloc_hip( dim_t size,
FLA_Error FLASH_Queue_alloc_async_hip( int thread,
dim_t size,
FLA_Datatype datatype,
void** buffer_hip )
/*----------------------------------------------------------------------------
Expand All @@ -269,11 +275,14 @@ FLA_Error FLASH_Queue_alloc_hip( dim_t size,

----------------------------------------------------------------------------*/
{
hipError_t status;

// Allocate memory for a block on HIP.
status = hipMalloc( buffer_hip,
size * FLA_Obj_datatype_size( datatype ) );
hipStream_t stream;
rocblas_get_stream( handles[thread], &stream );
//fprintf(stdout, "Trying to allocate %ld bytes on device %d\n", size, thread);
hipError_t status = hipMallocAsync( buffer_hip,
size * FLA_Obj_datatype_size( datatype ),
stream );

// Check to see if the allocation was successful.
if ( status != hipSuccess )
Expand All @@ -284,28 +293,34 @@ FLA_Error FLASH_Queue_alloc_hip( dim_t size,
FLA_Check_error_code( FLA_MALLOC_GPU_RETURNED_NULL_POINTER );
}

//fprintf( stdout, "allocating on thread %d for pointer %p\n", thread, *buffer_hip );

return FLA_SUCCESS;
}


FLA_Error FLASH_Queue_free_async_hip( void* buffer_hip )
FLA_Error FLASH_Queue_free_async_hip( int thread, void* buffer_hip )
/*----------------------------------------------------------------------------

FLASH_Queue_free_async_hip

----------------------------------------------------------------------------*/
{
// Free memory for a block on HIP.
hipFreeAsync( (hipStream_t) 0, buffer_hip );
hipStream_t stream;
rocblas_get_stream( handles[thread], &stream );
hipFreeAsync( stream, buffer_hip );

return FLA_SUCCESS;
}


FLA_Error FLASH_Queue_write_hip( FLA_Obj obj, void* buffer_hip )
FLA_Error FLASH_Queue_write_async_hip( int thread,
FLA_Obj obj,
void* buffer_hip )
/*----------------------------------------------------------------------------

FLASH_Queue_write_hip
FLASH_Queue_write_async_hip

----------------------------------------------------------------------------*/
{
Expand All @@ -317,11 +332,13 @@ FLA_Error FLASH_Queue_write_hip( FLA_Obj obj, void* buffer_hip )
const size_t count = FLA_Obj_elem_size( obj )
* FLA_Obj_col_stride( obj )
* FLA_Obj_width( obj );
hipStream_t stream;
rocblas_get_stream( handles[thread], &stream );
const hipError_t err = hipMemcpyAsync( buffer_hip,
FLA_Obj_buffer_at_view( obj ),
count,
hipMemcpyHostToDevice,
(hipStream_t) 0 );
stream );

if ( err != hipSuccess )
{
Expand All @@ -342,39 +359,11 @@ FLA_Error FLASH_Queue_read_hip( int thread, FLA_Obj obj, void* buffer_hip )

----------------------------------------------------------------------------*/
{
if ( flash_malloc_managed_hip )
{
// inject a stream sync on the rocBLAS stream to ensure completion
hipError_t err = hipStreamSynchronize( (hipStream_t) 0 );
if ( err != hipSuccess )
{
fprintf( stderr,
"Failure to synchronize on HIP stream. err=%d\n",
err );
return FLA_FAILURE;
}
return FLA_SUCCESS;
}

// Read the memory of a block on HIP to main memory.
hipSetDevice( thread );
const size_t count = FLA_Obj_elem_size( obj )
* FLA_Obj_col_stride( obj )
* FLA_Obj_width( obj );
const hipError_t err = hipMemcpy( FLA_Obj_buffer_at_view( obj ),
buffer_hip,
count,
hipMemcpyDeviceToHost );
FLA_Error err1 = FLASH_Queue_read_async_hip( thread, obj, buffer_hip );
if ( err1 != FLA_SUCCESS ) return err1;

if ( err != hipSuccess )
{
fprintf( stderr,
"Failure to read block from HIP device. Size=%ld, err=%d\n",
count, err );
return FLA_FAILURE;
}

return FLA_SUCCESS;
return FLASH_Queue_sync_stream_hip( thread );
}

FLA_Error FLASH_Queue_read_async_hip( int thread, FLA_Obj obj, void* buffer_hip )
Expand All @@ -394,11 +383,13 @@ FLA_Error FLASH_Queue_read_async_hip( int thread, FLA_Obj obj, void* buffer_hip
const size_t count = FLA_Obj_elem_size( obj )
* FLA_Obj_col_stride( obj )
* FLA_Obj_width( obj );
hipStream_t stream;
rocblas_get_stream( handles[thread], &stream );
const hipError_t err = hipMemcpyAsync( FLA_Obj_buffer_at_view( obj ),
buffer_hip,
count,
hipMemcpyDeviceToHost,
(hipStream_t) 0 );
stream );

if ( err != hipSuccess )
{
Expand Down Expand Up @@ -431,6 +422,27 @@ FLA_Error FLASH_Queue_sync_device_hip( int device )
return FLA_SUCCESS;
}

FLA_Error FLASH_Queue_sync_stream_hip( int thread )
/*----------------------------------------------------------------------------

FLASH_Queue_sync_stream_hip

----------------------------------------------------------------------------*/
{
hipStream_t stream;
rocblas_get_stream( handles[thread], &stream );
const hipError_t err = hipStreamSynchronize( stream );
if ( err != hipSuccess )
{
fprintf( stderr,
"Failure to sync HIP stream. Thread=%d, err=%d\n",
thread, err );
return FLA_FAILURE;
}

return FLA_SUCCESS;
}


FLA_Error FLASH_Queue_sync_hip( )
/*----------------------------------------------------------------------------
Expand Down
15 changes: 12 additions & 3 deletions src/base/flamec/supermatrix/main/FLASH_Queue_exec.c
Original file line number Diff line number Diff line change
Expand Up @@ -2243,7 +2243,16 @@ void FLASH_Queue_create_hip( int thread, void *arg )
{
// Allocate the memory on the HIP device for all the blocks a priori.
for ( i = 0; i < hip_n_blocks; i++ )
FLASH_Queue_alloc_hip( block_size, datatype, &(args->hip[thread * hip_n_blocks + i].buffer_hip) );
FLASH_Queue_alloc_async_hip( thread,
block_size,
datatype,
&(args->hip[thread * hip_n_blocks + i].buffer_hip) );
}
else
{
// write something into the buffer_hip pointer to make it unique for tracking
for ( i = 0; i < hip_n_blocks; i++ )
args->hip[thread * hip_n_blocks + i].buffer_hip = (void*) (thread * hip_n_blocks + i);
}

return;
Expand Down Expand Up @@ -2279,7 +2288,7 @@ void FLASH_Queue_destroy_hip( int thread, void *arg )
if ( hip_obj.obj.base != NULL && !hip_obj.clean )
FLASH_Queue_read_async_hip( thread, hip_obj.obj, hip_obj.buffer_hip );
// Free the memory on the HIP for all the blocks.
FLASH_Queue_free_async_hip( hip_obj.buffer_hip );
FLASH_Queue_free_async_hip( thread, hip_obj.buffer_hip );
}

return;
Expand Down Expand Up @@ -2786,7 +2795,7 @@ void FLASH_Queue_update_block_hip( FLA_Obj obj,

// Move the block to the HIP device.
if ( transfer )
FLASH_Queue_write_hip( hip_obj.obj, hip_obj.buffer_hip );
FLASH_Queue_write_async_hip( thread, hip_obj.obj, hip_obj.buffer_hip );

return;
}
Expand Down