AI-MXNetCAPI
view release on metacpan or search on metacpan
* \param is_train bool value to indicate whether the forward pass is for evaluation
* \return 0 when success, -1 when failure happens
*/
int MXExecutorForward(ExecutorHandle handle, int is_train);
/*!
* \brief Excecutor run backward
*
* \param handle execute handle
* \param len lenth
* \param head_grads NDArray handle for heads' gradient
*
* \return 0 when success, -1 when failure happens
*/
int MXExecutorBackward(ExecutorHandle handle,
mx_uint len,
NDArrayHandle *in);
/*!
* \brief Get executor's head NDArray
*
* \param handle executor handle
* \param out_size output ndarray vector size
* \param out out put ndarray handles
* \return 0 when success, -1 when failure happens
*/
int MXExecutorOutputs(ExecutorHandle handle,
mx_uint *out_size,
NDArrayHandle **out_array);
/*!
* \brief Generate Executor from symbol
*
* \param symbol_handle symbol handle
* \param dev_type device type
* \param dev_id device id
* \param len length
* \param in_args in args array
* \param arg_grad_store arg grads handle array
* \param grad_req_type grad req array
* \param aux_states_len length of auxiliary states
* \param aux_states auxiliary states array
* \param out output executor handle
* \return 0 when success, -1 when failure happens
*/
int MXExecutorBind(SymbolHandle symbol_handle,
int dev_type,
int dev_id,
mx_uint len,
NDArrayHandle *in,
NDArrayHandle *in,
mx_uint *in,
mx_uint aux_states_len,
NDArrayHandle *in,
ExecutorHandle *out);
/*!
* \brief Generate Executor from symbol,
* This is advanced function, allow specify group2ctx map.
* The user can annotate "ctx_group" attribute to name each group.
*
* \param symbol_handle symbol handle
* \param dev_type device type of default context
* \param dev_id device id of default context
* \param num_map_keys size of group2ctx map
* \param map_keys keys of group2ctx map
* \param map_dev_types device type of group2ctx map
* \param map_dev_ids device id of group2ctx map
* \param len length
* \param in_args in args array
* \param arg_grad_store arg grads handle array
* \param grad_req_type grad req array
* \param aux_states_len length of auxiliary states
* \param aux_states auxiliary states array
* \param out output executor handle
* \return 0 when success, -1 when failure happens
*/
int MXExecutorBindX(SymbolHandle symbol_handle,
int dev_type,
int dev_id,
mx_uint num_map_keys,
const char** in,
const int* in,
const int* in,
mx_uint len,
NDArrayHandle *in,
NDArrayHandle *in,
mx_uint *in,
mx_uint aux_states_len,
NDArrayHandle *in,
ExecutorHandle *out);
/*!
* \brief Generate Executor from symbol,
* This is advanced function, allow specify group2ctx map.
* The user can annotate "ctx_group" attribute to name each group.
*
* \param symbol_handle symbol handle
* \param dev_type device type of default context
* \param dev_id device id of default context
* \param num_map_keys size of group2ctx map
* \param map_keys keys of group2ctx map
* \param map_dev_types device type of group2ctx map
* \param map_dev_ids device id of group2ctx map
* \param len length
* \param in_args in args array
* \param arg_grad_store arg grads handle array
* \param grad_req_type grad req array
* \param aux_states_len length of auxiliary states
* \param aux_states auxiliary states array
* \param shared_exec input executor handle for memory sharing
* \param out output executor handle
* \return 0 when success, -1 when failure happens
*/
int MXExecutorBindEX(SymbolHandle symbol_handle,
int dev_type,
int dev_id,
mx_uint num_map_keys,
const char** in,
const int* in,
const int* in,
mx_uint len,
NDArrayHandle *in,
NDArrayHandle *in,
mx_uint *in,
mx_uint aux_states_len,
NDArrayHandle *in,
ExecutorHandle shared_exec,
ExecutorHandle *out);
int MXExecutorSimpleBind(SymbolHandle symbol_handle,
int dev_type,
int dev_id,
const mx_uint num_g2c_keys,
const char** in, // g2c_keys,
const int* in, // g2c_dev_types,
const int* in, // g2c_dev_ids,
const mx_uint provided_grad_req_list_len,
const char** in, // provided_grad_req_names,
const char** in, // provided_grad_req_types,
const mx_uint num_provided_arg_shapes,
const char** in, // provided_arg_shape_names,
const mx_uint* in, // provided_arg_shape_data,
const mx_uint* in, // provided_arg_shape_idx,
const mx_uint num_provided_arg_dtypes,
const char** in, // provided_arg_dtype_names,
const int* in, // provided_arg_dtypes,
const mx_uint num_shared_arg_names,
const char** in, // shared_arg_name_list,
//------------
int* shared_buffer_len,
const char** shared_buffer_name_list,
NDArrayHandle* shared_buffer_handle_list,
const char*** updated_shared_buffer_name_list,
NDArrayHandle** updated_shared_buffer_handle_list,
//------------------
mx_uint* num_in_args,
NDArrayHandle** in_args,
NDArrayHandle** arg_grads,
( run in 2.021 seconds using v1.01-cache-2.11-cpan-140bd7fdf52 )