Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[QST] Adding new parameter to Conv2dFprop in Python #2166

Open
IzanCatalan opened this issue Mar 12, 2025 · 13 comments
Open

[QST] Adding new parameter to Conv2dFprop in Python #2166

IzanCatalan opened this issue Mar 12, 2025 · 13 comments

Comments

@IzanCatalan
Copy link

IzanCatalan commented Mar 12, 2025

What is your question?
Hi, I have implemented a modification to the conv2d kernel in Cutlass. I have added three new attributes (two tensors, mask0s and mask1s, with the same datatype as the filters) and a flag (integer) that can be used in the convolutional kernel:

void operator()(Params const &params, SharedStorage &shared_storage) {
.

To do so, I modified the cutlass code into three parts:

  1. Inside Host Tensot:
  /// Host-side memory allocation
  std::vector<StorageUnit> host_;
  std::vector<StorageUnit> host_mask1s;
  std::vector<StorageUnit> host_mask0s;
  int* host_flag;

  /// Device-side memory
  device_memory::allocation<StorageUnit> device_;
  device_memory::allocation<StorageUnit> device_masks1s;
  device_memory::allocation<StorageUnit> device_masks0s;
  device_memory::allocation<int> device_flag;

 /// Accesses the tensor reference pointing to data
  TensorRef host_ref_complete(LongIndex ptr_element_offset=0) { return TensorRef(host_data_ptr_offset(ptr_element_offset), 
                                                                        layout_, 
                                                                        host_data_ptr_offset_mask0s(ptr_element_offset),
                                                                        host_data_ptr_offset_mask1s(ptr_element_offset),
                                                                        host_flag); }

  /// Accesses the tensor reference pointing to data
  ConstTensorRef device_ref_complete(LongIndex ptr_element_offset=0) const {
    return TensorRef(device_data_ptr_offset(ptr_element_offset), 
    layout_, device_data_ptr_offset_mask0s(ptr_element_offset), 
    device_data_ptr_offset_mask1s(ptr_element_offset), 
    device_data_flag());
  }


  /// Copies data from device to host
  void sync_host() {
    if (device_backed()) {
      device_memory::copy_to_host(
          host_.data(), device_.get(), device_.size());

      device_memory::copy_to_host(
          host_mask1s.data(), device_masks1s.get(), device_masks1s.size());
      
      device_memory::copy_to_host(
          host_mask0s.data(), device_masks0s.get(), device_masks0s.size());

      device_memory::copy_to_host(
          host_flag, device_flag.get(), device_flag.size());
      
    }
  }

  /// Copies data from host to device
  void sync_device() {
    if (device_backed()) {
      device_memory::copy_to_device(
          device_.get(), host_.data(), host_.size());
      
      device_memory::copy_to_device(
          device_masks1s.get(), host_mask1s.data(), host_mask1s.size());
      
      device_memory::copy_to_device(
          device_masks0s.get(), host_mask0s.data(), host_mask0s.size());

      device_memory::copy_to_device(
          device_flag.get(), host_flag, 1);
          
    }
  }


I have shown only the most critical functions globally, but inside the class, I modified constructors and reset methods to allocate memory from the host and later with copy_to_devidce() in the GPU.

  1. In Tensor_ref:
/// Pointer
Element* ptr_;
Element* ptr_mask0s;
Element* ptr_mask1s;
int *check;

CUTLASS_HOST_DEVICE
TensorRef(
  Element *ptr,                   ///< pointer to start of 
  Layout const &layout,            ///< layout object containing stride and mapping function
  Element *ptr_mask0s,            ///< pointer to start of tensor mask0s
  Element *ptr_mask1s,            ///< pointer to start of  tensor mask1s
  int *device_flag              ///< pointer to check flag
):
  ptr_(ptr), layout_(layout), check(device_flag), ptr_mask0s(ptr_mask0s), ptr_mask1s(ptr_mask1s){ 
  }

/// Returns the pointer to referenced data
CUTLASS_HOST_DEVICE
Element * data() const { return ptr_; }

/// Returns the pointer to referenced data
CUTLASS_HOST_DEVICE
Element * data_mask0s() const { return ptr_mask0s; }

/// Returns the pointer to referenced data
CUTLASS_HOST_DEVICE
Element * data_mask1s() const { return ptr_mask1s; }

int * check_flag() const { return check; }
  1. In the same conv kernel:

/// Parameters structure
  struct Params {
    ConvProblemSize problem_size;
    cutlass::gemm::GemmCoord grid_tiled_shape;
    gemm::GemmCoord implicit_gemm_problem_size;
    int swizzle_log_tile;
    int gemm_k_iterations;
    int gemm_k_iterations_per_channel;
    int *first_call;
    typename Mma::IteratorA::Params iterator_A;
    typename Mma::IteratorA::Element const *ptr_A;
    typename Mma::IteratorB::Params iterator_B;
    typename Mma::IteratorB::Element *ptr_B;
    typename Mma::IteratorB::Element *ptr_B_mask0s;
    typename Mma::IteratorB::Element *ptr_B_mask1s;
    typename Epilogue::OutputTileIterator::Params iterator_C;
    typename Epilogue::OutputTileIterator::Element *ptr_C;
    typename Epilogue::OutputTileIterator::Params iterator_D;
    typename Epilogue::OutputTileIterator::Element *ptr_D;
    typename EpilogueOutputOp::Params output_op;
    int *semaphore;
    SplitKMode split_k_mode;

    //
    // Methods
    //

    CUTLASS_HOST_DEVICE
    Params(): swizzle_log_tile(0), gemm_k_iterations(0) { }

    /// 
    CUTLASS_HOST_DEVICE
    Params(
      Arguments const &args,
      int *semaphore = nullptr
    ):
      problem_size(args.problem_size),
      implicit_gemm_problem_size(cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size)),
      iterator_A(Mma::IteratorA::getParams(args.problem_size, args.ref_A.layout())),
      ptr_A(args.ref_A.data()),
      iterator_B(args.problem_size, args.ref_B.layout()),
      ptr_B(args.ref_B.data()),
      iterator_C(ConvOutputIteratorParameter::layout(args.ref_C), implicit_gemm_tensor_c_extent(kConvolutionalOperator, args.problem_size)),
      ptr_C(args.ref_C.data()),
      iterator_D(ConvOutputIteratorParameter::layout(args.ref_D), implicit_gemm_tensor_c_extent(kConvolutionalOperator, args.problem_size)),
      ptr_D(args.ref_D.data()),
      output_op(args.output_op),
      first_call(args.ref_B.check_flag()),
      ptr_B_mask0s(args.ref_B.data_mask0s()),
      ptr_B_mask1s(args.ref_B.data_mask1s()),
      semaphore(semaphore),
      split_k_mode(args.split_k_mode)
    {
      gemm_k_iterations = implicit_gemm_k_iterations(
        kConvolutionalOperator,
        ThreadblockShape::kK,
        args.problem_size,
        kIteratorAlgorithm,
        kGroupMode,
        ThreadblockShape::kN);

      gemm_k_iterations_per_channel = implicit_gemm_k_iterations_per_channel(
          kConvolutionalOperator, args.problem_size, kIteratorAlgorithm);

      ThreadblockSwizzle threadblock_swizzle;
      printf("construct params on kernel/implicit Gemem lin256...\n");

      grid_tiled_shape = threadblock_swizzle.get_tiled_shape(
        implicit_gemm_problem_size,
        {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK},
        args.problem_size.split_k_slices);

      swizzle_log_tile = threadblock_swizzle.get_log_tile(grid_tiled_shape);
    }
  };

All of this code is currently working, and it does ok without problems or memory errors while executing example 16.

My question now is how to adapt the host code to Python. The device code works well and takes the same arguments as before; the only difference is that now the tensor_ref class can have three more attributes. However, when defining the convolution on python there is a part in which it is specified the arguments in

I tried to modify this part of the code by adding two ElementB pointers and an integer and using the new constructor of tenser_ref, which I have previously added:

struct ${operation_name}_TemporaryArgs {
    int conv_kind;
    cutlass::conv::Conv2dProblemSize problem_size;
    ElementA* ptr_A;
    ElementB* ptr_B;
    ElementB* ptr_B_0;
    ElementB* ptr_B_1;
    ElementC* ptr_C;
    ElementC* ptr_D;
    int tensor_c_numel;
    typename EpilogueOutputOp::Params epilogue_params;
    int split_k_mode;
  };

  typename ${operation_name}${operation_suffix}::Arguments
  construct_arguments(${operation_name}_TemporaryArgs args) {
    cutlass::conv::Operator conv_operator = static_cast<cutlass::conv::Operator>(args.conv_kind);
    auto tc_A = cutlass::conv::implicit_gemm_tensor_a_extent(conv_operator, args.problem_size);
    auto tc_B = cutlass::conv::implicit_gemm_tensor_b_extent(conv_operator, args.problem_size);
    auto tc_C = cutlass::conv::implicit_gemm_tensor_c_extent(conv_operator, args.problem_size);
    auto tc_D = cutlass::conv::implicit_gemm_tensor_c_extent(conv_operator, args.problem_size);

    auto size_C = tc_C.at(0) * tc_C.at(1) * tc_C.at(2) * tc_C.at(3);
    if (args.tensor_c_numel >= 0 && args.tensor_c_numel == tc_C.at(3) && args.tensor_c_numel < size_C) {
      // C is interpreted as bias
      tc_C = {0, 0, 0, 0};
    }


    int device_flag = 0;
    cutlass::TensorRef<ElementA, LayoutA> tref_A(args.ptr_A, LayoutA::packed(tc_A));
    cutlass::TensorRef<ElementB, LayoutA> tref_B(args.ptr_B, LayoutB::packed(tc_B), args.ptr_B_0, args.ptr_B_1, &device_flag);
    cutlass::TensorRef<ElementC, LayoutA> tref_C(args.ptr_C, LayoutC::packed(tc_C));
    cutlass::TensorRef<ElementC, LayoutA> tref_D(args.ptr_D, LayoutC::packed(tc_D));

    return {
      args.problem_size,
      tref_A,
      tref_B,
      tref_C,
      tref_D,
      args.epilogue_params,
      static_cast<cutlass::conv::SplitKMode>(args.split_k_mode)
    };
  }

But of course, it gives an error of cuda illegal address when executing sync() to the parameters because I have added two pointers and an integer that has not correctly copied to GPU memory:


Traceback (most recent call last):
  File "python/prova.py", line 81, in <module>
    plan.run(input, weight, tensor_C, output, stride, padding, dilation, alpha, beta, print_module=print_module)
  File "/mnt/beegfs/gap/izcagal@upvnet.upv.es/cutlass/python/cutlass/op/conv.py", line 939, in run
    return super().run(
  File "/mnt/beegfs/gap/izcagal@upvnet.upv.es/cutlass/python/cutlass/op/conv.py", line 890, in run
    arguments.sync()
  File "/mnt/beegfs/gap/izcagal@upvnet.upv.es/cutlass/python/cutlass/backend/conv2d_operation.py", line 191, in sync
    return super().sync()
  File "/mnt/beegfs/gap/izcagal@upvnet.upv.es/cutlass/python/cutlass/backend/arguments.py", line 107, in sync
    raise RuntimeError("CUDA Error %s" % str(err))
RuntimeError: CUDA Error cudaError_t.cudaErrorIllegalAddress

So, the question is, where do I need to add these pointers and the integer to properly initialize them on the host CPU memory and copy them to the GPU??

In the example 16 it is used host_tensor class, but in python, there is not such a class, so I am a little bit lost on this.

Any feedback would be appreciated.

Thanks.

@IzanCatalan IzanCatalan changed the title [QST] Adding new parameter to Conv2dFprop in PyThon [QST] Adding new parameter to Conv2dFprop in Python Mar 12, 2025
@jackkosaian
Copy link
Contributor

Thanks for the detailed messages. I think you'll need to modify the c_types structure for TensorRef in order to accommodate the extra members you added to TensorRef in C++:

class TensorRef_(ctypes.Structure):
_fields_ = [
("ptr", ctypes.c_void_p),
("layout", Layout4D)
]
def __init__(self, tensor_ref):
setattr(self, "ptr", tensor_ref.data())
setattr(self, "layout", Layout4D(tensor_ref.layout()))
class TensorRef2D_(ctypes.Structure):
_fields_ = [
("ptr", ctypes.c_void_p),
("stride", ctypes.c_int)
]

@IzanCatalan
Copy link
Author

@jackkosaian I modified the part of the code you told me in

class TensorRef_(ctypes.Structure):
_fields_ = [
("ptr", ctypes.c_void_p),
("layout", Layout4D)
]
def __init__(self, tensor_ref):
setattr(self, "ptr", tensor_ref.data())
setattr(self, "layout", Layout4D(tensor_ref.layout()))
class TensorRef2D_(ctypes.Structure):
_fields_ = [
("ptr", ctypes.c_void_p),
("stride", ctypes.c_int)
]

class TensorRef_(ctypes.Structure):
    _fields_ = [
        ("ptr", ctypes.c_void_p),
        ("layout", Layout4D),
        ("ptr_mask0s", ctypes.c_void_p),
        ("ptr_mask1s", ctypes.c_void_p),
        ("device_flag", ctypes.c_int)
    ]

    def __init__(self, tensor_ref):
        setattr(self, "ptr", tensor_ref.data())
        setattr(self, "layout", Layout4D(tensor_ref.layout()))
        setattr(self, "ptr_mask0s", tensor_ref.data_mask0s())
        setattr(self, "ptr_mask1s", tensor_ref.data_mask1s())
        setattr(self, "device_flag", tensor_ref.check_flag())


class TensorRef2D_(ctypes.Structure):
    _fields_ = [
        ("ptr", ctypes.c_void_p),
        ("layout", Layout4D),
        ("ptr_mask0s", ctypes.c_void_p),
        ("ptr_mask1s", ctypes.c_void_p),
        ("device_flag", ctypes.c_int)
    ]


def get_conv2d_arguments(epilogue_functor):
    _EpilogueOutputOpParams = epilogue_functor.epilogue_type

    class _Conv2dArguments(ctypes.Structure):
        print("inside _Conv2dArguments get_conv2d_arguments")
        _fields_ = [
            ("conv_kind", ctypes.c_int),
            ("problem_size", Conv2DProblemSize_),
            ("ptr_A", ctypes.c_void_p),
            ("ptr_B", ctypes.c_void_p),
            ("ptr_C", ctypes.c_void_p),
            ("ptr_D", ctypes.c_void_p),
            ("ptr_mask0s", ctypes.c_void_p),
            ("ptr_mask1s", ctypes.c_void_p),
            ("device_flag", ctypes.c_int),
            ("tensor_C_numel", ctypes.c_int),
            ("output_op", _EpilogueOutputOpParams),
            ("split_k_mode", ctypes.c_int)
        ]

    return _Conv2dArguments, _EpilogueOutputOpParams

However I still got the same error. I believe this is because the pointers ptr_mask0s, ptr_maks1s and the flag are not correctly initialised in the host code, for instance, as I did in Host Tensot code:

/// Resizes internal memory allocations without affecting layout or extent
  void reserve(
    size_t count,                                        ///< size of tensor in elements
    bool device_backed_ = true) {                        ///< if true, device memory is also allocated
#if (CUTLASS_DEBUG_TRACE_LEVEL > 1)
    CUTLASS_TRACE_HOST("cutlass::HostTensor::reserve(count=" << count << ", device_backed_=" << (device_backed_ ? "true" : "false") << ")");
#endif

    device_.reset();
    device_masks1s.reset();
    device_masks0s.reset();
    device_flag.reset();
    host_.clear();
    host_mask1s.clear();
    host_mask0s.clear();
    delete host_flag;
    host_flag = nullptr;

    size_t count_container = count_to_container_storage_unit_count(count);
#if (CUTLASS_DEBUG_TRACE_LEVEL > 1)
    CUTLASS_TRACE_HOST("cutlass::HostTensor::reserve: host_.resize(" << count_container << ")");
#endif    
    host_.resize(count_container);
    host_mask0s.resize(count_container);
    host_mask1s.resize(count_container);
    host_flag = new int(0);

    // Allocate memory
    StorageUnit* device_memory = nullptr;
    StorageUnit* device_memory_mask0s = nullptr;
    StorageUnit* device_memory_mask1s = nullptr;
    int* device_memory_flag = nullptr;
    if (device_backed_) {
#if (CUTLASS_DEBUG_TRACE_LEVEL > 1)
      CUTLASS_TRACE_HOST("cutlass::HostTensor::reserve: device_memory::allocate(" << count_container << ")");
#endif
      device_memory = device_memory::allocate<StorageUnit>(count_container);
      device_memory_mask0s = device_memory::allocate<StorageUnit>(count_container);
      device_memory_mask1s = device_memory::allocate<StorageUnit>(count_container);
      device_memory_flag = device_memory::allocate<int>(1);
    }
    device_.reset(device_memory, device_backed_ ? count_container : 0);
    device_masks1s.reset(device_memory_mask1s, device_backed_ ? count_container : 0);
    device_masks0s.reset(device_memory_mask0s, device_backed_ ? count_container : 0);
    device_flag.reset(device_memory_flag, device_backed_ ? 1 : 0);
  }

  /// Updates the extent and layout of the HostTensor. Allocates memory according to the new
  /// extent and layout.
  void reset(
    TensorCoord const &extent,                           ///< extent of logical tensor
    Layout const &layout,                                ///< layout object of tensor
    bool device_backed_ = true) {                        ///< if true, device memory is also allocated. 
    extent_ = extent;
    layout_ = layout;

    host_.clear();
    host_mask1s.clear();
    host_mask0s.clear();
    host_flag = nullptr;                                  //not delete because host_flag was never initialized to this point


    device_.reset();
    device_masks1s.reset();
    device_masks0s.reset();
    device_flag.reset();

    reserve(size_t(layout_.capacity(extent_)), device_backed_);
  }

/// Changes the size of the logical tensor. Only allocates memory if new capacity exceeds reserved capacity.
  /// To force allocation, call reset().
  void resize(
    TensorCoord const &extent,                           ///< extent of logical tensor
    Layout const &layout,                                ///< layout object of tensor
    bool device_backed_ = true) {                        ///< if true, device memory is also allocated. 

    extent_ = extent;
    layout_ = layout;

    LongIndex new_size = size_t(layout_.capacity(extent_));
    LongIndex new_size_container = count_to_container_storage_unit_count((layout_.capacity(extent_)));

    if (static_cast<decltype(host_.size())>(new_size_container) > host_.size()) {
      reserve(new_size, device_backed_);
    }
  }

I think the solution is to create two tensors and a flag in the CPU host code in Python, initialise them and then sync() with the GPU memory. To do so, I decided to modify some parts of the code from different files:

In argumments.py:

def __init__(
        self,
        A: "Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor, cp.ndarray]",
        B: "Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor, cp.ndarray]",
        C: "Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor, cp.ndarray]",
        D: "Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor, cp.ndarray]",
        ptr_mask0s: "Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor, cp.ndarray]",
        ptr_mask1s: "Union[cuda.CUdeviceptr, np.ndarray, torch.Tensor, cp.ndarray]",
        **kwargs,
    ) -> None:
        # tensor_C can be interpreted as the bias with bias=True in keyword args
        self.bias = kwargs.get("bias", False)
        print("Inside ArgumentBase")
        self.stream = kwargs.get("stream", cuda.CUstream(0))

        # RMM buffers used to track tensor lifetime
        self.buffers = {}
        # Host tensor to copy the computed result back
        self.host_tensors = {}

        self.ptr_A = self.tensor_to_ptr(A, "A")
        self.ptr_B = self.tensor_to_ptr(B, "B")
        self.ptr_C = self.tensor_to_ptr(C, "C")
        self.ptr_D = self.tensor_to_ptr(D, "D", is_output=True)
        self.ptr_mask0s = self.tensor_to_ptr(ptr_mask0s, "ptr_mask0s", is_output=True)
        self.ptr_mask1s = self.tensor_to_ptr(ptr_mask1s, "ptr_mask1s", is_output=True)
        if C is not None:
            if not isinstance(C, cuda.CUdeviceptr):
                self.tensor_c_numel = prod(C.shape)

In ctypes.py:

def get_conv2d_arguments(epilogue_functor):
    _EpilogueOutputOpParams = epilogue_functor.epilogue_type

    class _Conv2dArguments(ctypes.Structure):
        print("inside _Conv2dArguments get_conv2d_arguments")
        _fields_ = [
            ("conv_kind", ctypes.c_int),
            ("problem_size", Conv2DProblemSize_),
            ("ptr_A", ctypes.c_void_p),
            ("ptr_B", ctypes.c_void_p),
            ("ptr_C", ctypes.c_void_p),
            ("ptr_D", ctypes.c_void_p),
            ("ptr_mask0s", ctypes.c_void_p),
            ("ptr_mask1s", ctypes.c_void_p),
            ("device_flag", ctypes.c_int),
            ("tensor_C_numel", ctypes.c_int),
            ("output_op", _EpilogueOutputOpParams),
            ("split_k_mode", ctypes.c_int)
        ]

    return _Conv2dArguments, _EpilogueOutputOpParams

In conv2d_operation.py:

class Conv2dArguments(ArgumentBase):

    def __init__(self, operation, problem_size, A, B, C, D, ptr_mask0s, ptr_mask1s, device_flag,
        split_k_mode=SplitKMode.Serial, **kwargs, ) -> None:
        self.operation = operation
        self.conv_kind = operation.conv_kind
        self.layout_A = operation.A.layout
        self.layout_B = operation.B.layout
        self.layout_C = operation.C.layout

        self.element_A = operation.A.element
        self.element_B = operation.B.element
        self.element_C = operation.C.element

        self.device_flag = device_flag

        if self.layout_C == LayoutType.TensorNC32HW32:
            raise Exception("Layout type TensorNC32HW32 is not currently supported")

        super().__init__(A, B, C, D, ptr_mask0s, ptr_mask1s, **kwargs)

        if "split_k_slices" in kwargs.keys() and kwargs["split_k_slices"] > 1:
            self.split_k_mode = split_k_mode
            self.split_k_slices = kwargs["split_k_slices"]
        else:
            self.split_k_mode = SplitKMode.Serial
            self.split_k_slices = 1

        if "output_op" in kwargs.keys() and self.split_k_mode != SplitKMode.Parallel:
            self.output_op = kwargs["output_op"]
        else:
            self.output_op = self.operation.epilogue_type(1.0, 0.0)

        self.problem_size = problem_size
        self.problem_size.split_k_slices = self.split_k_slices
        print("Inside Conv2dArguments init", self.problem_size, problem_size)
        self.initialize()

    def get_arguments(self):
        print("Inside Conv2dArguments get_arguments", self.ptr_B)
        tc_numel = -1
        if hasattr(self, "tensor_c_numel"):
            tc_numel = self.tensor_c_numel

        self.c_arguments = self.operation.argument_type(
            int(self.conv_kind),
            self.problem_size.ctype,
            int(to_device_ptr(self.ptr_A)),
            int(to_device_ptr(self.ptr_B)),
            int(to_device_ptr(self.ptr_C)),
            int(to_device_ptr(self.ptr_D)),
            int(to_device_ptr(self.ptr_mask0s)),
            int(to_device_ptr(self.ptr_mask1s)),
            int(to_device_ptr(self.device_flag)),
            tc_numel,
            self.output_op,
            int(self.split_k_mode)
        )

And last, in conv.py:

class Conv2dFprop(Conv2d):
    def __init__(
        self,
        input=None, weight=None, C=None, output=None, ptr_mask0s=None, ptr_mask1s=None, device_flag=0 ,alpha=1, beta=0,
        element=None,
        element_input=None, element_weight=None, element_C=None, element_output=None,
        element_accumulator=None,
        cc: int = None, kernel_cc: int = None):
        print("🔍 Rastreo de llamadas dentro de Conv2dFprop:")
        A, B, D = input, weight, output
        element_A, element_B, element_D = element_input, element_weight, element_output
        print("weight:", weight, "element_weight:", element_weight, "B: ", B)
        super().__init__(
            "fprop", A, B, C, D, ptr_mask0s, ptr_mask1s, device_flag, alpha, beta, element,
            element_A, element_B, element_C, element_D,
            element_accumulator, cc, kernel_cc)

    def run(
        self, input=None, weight=None, C=None, output=None, ptr_mask0s=None, ptr_mask1s=None, device_flag=0, alpha=None, beta=None,
        stride=(1, 1), padding=(0, 0), dilation=(1, 1), split_k=("serial", 1),
        sync: bool = True, print_module: bool = False,
        stream: cuda.CUstream = cuda.CUstream(0)) -> Conv2dArguments:
        print("🔍 Rastreo de llamadas dentro de Conv2dFprop.run:")
        A, B, D = input, weight, output
        return super().run(
            A, B, C, D, ptr_mask0s, ptr_mask1s, device_flag, alpha, beta, stride, padding, dilation, split_k, sync, print_module, stream)

class Conv2d(OperationBase):
    def __init__(
        self, kind="fprop",
        A=None, B=None, C=None, D=None, ptr_mask0s=None, ptr_mask1s=None, device_flag=0, alpha=1.0, beta=0.0,
        element=None,
        element_A=None, element_B=None, element_C=None, element_D=None,
        element_accumulator=None,
        cc: int = None, kernel_cc: int = None
    ):
        super().__init__(cc=cc, kernel_cc=kernel_cc, operation_kind=OperationKind.Conv2d)
        # Verify the kernel cc
        if self.current_cc == 90:
            # The Conv2d kernel on Hopper (SM90) is currently unsupported
            # Revert to use SM80-tagged kernels
            cutlass.logger.warning("Reverting to using SM80-tagged kernel. Opclass may change.")
            self.specified_kernel_cc = 80
            self._reset_options(80)

        # The arch is used in testing
        self.arch = self.current_cc
        self.name = "conv2d" + kind

        # The convolution kind. (concept: cutlass_library.library.ConvKind)
        self.conv_kind = datatypes.getattr_enum(ConvKind, kind)

        # The element types (concept: cutlass library types) of A, B, C, and D
        elements = []
        layouts = []

        # Complete the data types based on user-provided arguments
        for elt, tens, name in zip([element_A, element_B, element_C, element_D],
                                   [A, B, C, D],
                                   ["A", "B", "C", "D"]):
            if elt is not None and tens is not None:
                raise Exception(f'Must not specify both element_{name} and tensor {name}')
            if elt is None and tens is None and element is None:
                raise Exception(f'Must specify one of element_{name}, tensor {name}, or generic element.')

            elt_to_set = None
            lay_to_set = None

            if tens is not None:
                elt_to_set, _ = datatypes.get_datatype_and_layout(tens)
            else:
                elt_to_set = elt if elt is not None else element

            assert elt_to_set is not None

            # Currently we only support layout TensorNHWC
            lay_to_set = cutlass.LayoutType.TensorNHWC
            elements.append(datatypes.library_type(elt_to_set))
            layouts.append(lay_to_set)

        self._element_a, self._element_b, self._element_c, self._element_d = elements
        self._layout_a, self._layout_b, self._layout_c, self._layout_d = layouts

        self.A, self.B, self.C, self.D, self.ptr_mask0s, self.ptr_mask1s, self.device_flag, self.alpha, self.beta= A, B, C, D, ptr_mask0s, ptr_mask1s, device_flag, alpha, beta

        if element_accumulator is None:
            self._element_accumulator = self._element_c
        else:
            self._element_accumulator = datatypes.library_type(element_accumulator)

        # Default inputs if none is supplied in run()
        self.A = A
        self.B = B
        self.C = C
        self.D = D

        self.alpha = alpha
        self.beta = beta

        # We only specify the stride of the swizzling functor here
        # The actual swizzling functor is determined in run based on conv_kind and stride
        self._swizzling_stride = 1

        # Arguments that will be set to default value in _reset_operations
        # The default tile_description and op_class are fetched from manifest of cutlass library
        self._tile_description = None
        self.op_class = None
        # The default identity epilogue will be created
        self.epilogue_functor = None

        self._reset_operations()

        # Arguments that will be determined online based on arguments of "run"
        # based on stride, input/output channels, alignment, and conv_kind
        self._iterator_algorithm = None
        self._stride_support = None
        print(f"End Conv.py init: self.B = {self.B}, self.elemtnB = {self._element_b}")
        traceback.print_stack()

def run(self, A=None, B=None, C=None, D=None, ptr_mask0s=None, ptr_mask1s=None, device_flag=0,
            stride=(1, 1), padding=(0, 0), dilation=(1, 1),
            alpha=None, beta=None,
            split_k=("serial", 1), sync: bool = True,
            print_module: bool = False,
            stream: cuda.CUstream = cuda.CUstream(0)) -> Conv2dArguments:
        
        super().run_setup()

        A = self._verify_tensor(A, self.A, self._element_a, self._layout_a, "A")
        B = self._verify_tensor(B, self.B, self._element_b, self._layout_b, "B")
        C = self._verify_tensor(C, self.C, self._element_c, self._layout_c, "C")
        D = self._verify_tensor(D, self.D, self._element_d, self._layout_d, "D")
        ptr_mask0s = self._verify_tensor(ptr_mask0s, self.ptr_mask0s, self._element_b, self._layout_b, "ptr_mask0s")
        ptr_mask1s = self._verify_tensor(ptr_mask1s, self.ptr_mask1s, self._element_b, self._layout_b, "ptr_mask1s")
        alpha = self._verify_scalar(alpha, self.alpha, self._element_c, "alpha")
        beta = self._verify_scalar(beta, self.beta, self._element_c, "beta")

        # handle the case when there is no C
        if C is None:
            if beta != 0:
                raise Exception(f"With beta {beta} != 0, C has to be provided.")
            else:
                C = D

        # Construct problem size based on input
        # It also verifies whether the A, B, C, D, stride, padding, and dilation are matching
        problem_size = self._get_and_verify_conv_problem_size(A, B, C, stride, padding, dilation)

        # Propose stride support based on input
        stride_support = self._propose_stride_support(stride)

        # Propose swizzling functor
        swizzling_functor = self._propose_swizzling_functor(stride)

        shape_a = datatypes.get_tensor_shape(A, op="CONV")
        shape_b = datatypes.get_tensor_shape(B, op="CONV")
        shape_c = datatypes.get_tensor_shape(C, op="CONV")

        # Get the alignment
        alignment_a = self.possible_operations.find_alignment(shape_a, self._layout_a, operand="A")
        alignment_b = self.possible_operations.find_alignment(shape_b, self._layout_b, operand="B")
        alignment_c = self.possible_operations.find_alignment(shape_c, self._layout_c, operand="C")

        alignment_a = check.update_alignment(alignment_a, self.alignment_pref_A)
        alignment_b = check.update_alignment(alignment_b, self.alignment_pref_B)
        alignment_c = check.update_alignment(alignment_c, self.alignment_pref_C)

        # Propose iterator algorithm based on input
        if self._iterator_algorithm is None:
            # Propose a default iterator algorithm based on the problem size
            iterator_algorithm = self._propose_iterator_algorithm(problem_size, alignment_a, alignment_b)
        else:
            if (self._validate_iterator_algorithm(self._iterator_algorithm, problem_size, alignment_a, alignment_b)):
                iterator_algorithm = self._iterator_algorithm
            else:
                raise Exception(f"Iterator algorithm {self._iterator_algorithm} is invalid for current problem.")

        epilogue_args = [alpha, beta]

        if hasattr(self, "_activation_args"):
            if isinstance(self._activation_args, list):
                epilogue_args += self._activation_args
            else:
                epilogue_args.append(self._activation_args)

        if split_k[0] == "parallel" and split_k[1] > 1:
            epilogue_functor = self._create_epilogue_functor_activation(epilogue.identity)
        else:
            epilogue_functor = self.epilogue_functor

        # The alignment is determined by the iterator function (I believe)
        self.compile(tile_description=self.tile_description, alignment_A=alignment_a, alignment_B=alignment_b,
                     alignment_C=alignment_c, iterator_algorithm=iterator_algorithm, stride_support=stride_support,
                     swizzling_functor=swizzling_functor, epilogue_functor=epilogue_functor, print_module=print_module)

        # Create reduction operation for parallel split-k
        if split_k[0] == "parallel" and split_k[1] > 1:
            epilogue_functor_reduction = self._reset_epilogue_functor_alignment(alignment_c, self.epilogue_functor)
            self.reduction_operation = ReductionOperation(
                shape=MatrixCoord(4, 32 * alignment_c), C=self.operation.C,
                element_accumulator=self._element_accumulator,
                element_compute=self._element_accumulator,
                epilogue_functor=epilogue_functor_reduction,
                count=alignment_c
            )
            if print_module:
                print(self.reduction_operation.rt_module.emit())
            compiler.add_module([self.reduction_operation,])

        arguments = Conv2dArguments(
            operation=self.operation, problem_size=problem_size,
            A=A, B=B, C=C, D=D, ptr_mask0s=ptr_mask0s, ptr_mask1s=ptr_mask1s, device_flag=device_flag,
            output_op=self.operation.epilogue_type(*epilogue_args),
            split_k_mode=datatypes.getattr_enum(SplitKMode, split_k[0]),
            split_k_slices=split_k[1],
            stream=stream
        )

        self.operation.run(arguments)

        if split_k[0] == "parallel" and split_k[1] > 1:
            implicit_gemm_size = arguments.problem_size.implicit_gemm_size(self.conv_kind)
            reduction_arguments = ReductionArguments(
                self.reduction_operation,
                problem_size=[implicit_gemm_size.m, implicit_gemm_size.n],
                partitions=split_k[1],
                workspace=arguments.ptr_D,
                destination=D,
                source=C,
                output_op=self.reduction_operation.epilogue_type(*epilogue_args),
                stream=stream
            )
            self.reduction_operation.run(reduction_arguments)

        if sync:
            if split_k[0] == "parallel" and split_k[1] > 1:
                reduction_arguments.sync()

                # Free memory allocated by args because we are not
                # calling `arguments.sync()` in this case (which will free memory)
                arguments.free()
            else:
                arguments.sync()

        return arguments

I executed the code again with all of these modifications, and now I received another different error related to verify_tensor() function:

Traceback (most recent call last):
  File "python/prova.py", line 81, in <module>
    plan.run(input, weight, tensor_C, output, stride, padding, dilation, alpha, beta, print_module=print_module)
  File "/mnt/beegfs/gap/izcagal@upvnet.upv.es/cutlass/python/cutlass/op/conv.py", line 941, in run
    return super().run(
  File "/mnt/beegfs/gap/izcagal@upvnet.upv.es/cutlass/python/cutlass/op/conv.py", line 782, in run
    ptr_mask0s = self._verify_tensor(ptr_mask0s, self.ptr_mask0s, self._element_b, self._layout_b, "ptr_mask0s")
  File "/mnt/beegfs/gap/izcagal@upvnet.upv.es/cutlass/python/cutlass/op/op.py", line 204, in _verify_tensor
    self._verify_type_and_layout(tensor, ref_dtype, ref_layout, name)
  File "/mnt/beegfs/gap/izcagal@upvnet.upv.es/cutlass/python/cutlass/op/conv.py", line 693, in _verify_type_and_layout
    dtype, _ = datatypes.get_datatype_and_layout(tensor)
  File "/mnt/beegfs/gap/izcagal@upvnet.upv.es/cutlass/python/cutlass/utils/datatypes.py", line 278, in get_datatype_and_layout
    raise Exception(f"Unable to convert tensor of type {type(tensor)} to Python-bound CUTLASS datatype and layout.")
Exception: Unable to convert tensor of type <class 'tuple'> to Python-bound CUTLASS datatype and layout.

Could you please tell me how to proceed and which parts of the Python host code I should modify that I have not mentioned yet? Is every step I have shown correct to achieve my purpose?

@IzanCatalan
Copy link
Author

@jackkosaian, sorry, but do you have any update about it, please? I have a little urgency :(

@jackkosaian
Copy link
Contributor

It's a bit hard for me to tell what differences were made given the pastes of large chunks of code alone. Do you have a fork of CUTLASS containing the changes you've made that you could link here? That would make it easier to help.

@IzanCatalan
Copy link
Author

@jackkosaian Sorry, I didn't fork CUTLASS. I am modifying the code locally. I can submit though the four files I modify (conv.py, conv2d_operation.py, arguments.py and ctypes.py) or the diff file between my files and the original ones if this serves to clarify things.

@jackkosaian
Copy link
Contributor

Yes, seeing the diff of your whole repo (both C++ and Python changes) would be helpful.

@IzanCatalan
Copy link
Author

IzanCatalan commented Mar 20, 2025

@jackkosaian I paste here a ZIP file including the diff files created from the following files:

c++ files (current modifications working): Host Tensot , Tensor_ref , conv kernel and example 16.

python files: conv.py, conv2d_operation.py, ctypes.py and argumments.py.

diffFiles.zip

Let me know your thoughts when you check the code.

@jackkosaian
Copy link
Contributor

Thanks. What is the Python code that you are running that calls into your modified CUTLASS Python/C++?

@jackkosaian
Copy link
Contributor

jackkosaian commented Mar 20, 2025

I got the same error you are mentioning when I ran the following code:

plan = cutlass.Conv2dFprop(element=dtype, element_accumulator=torch.float32)
plan.run(input, weight, tensor_C, output, stride, padding, dilation, alpha, beta, print_module=print_module)

The issue is that the ptr_mask and device_flag variables were not being passed to plan.run().

Are you calling plan.run() like what follows?

plan = cutlass.Conv2dFprop(element=dtype, element_accumulator=torch.float32)
plan.run(input, weight, tensor_C, output, mask0, mask1, device_flag, stride, padding, dilation, alpha, beta, print_module=print_module)

The Python issue went away after I changed the code to this. There is a compilation issue after that, which seems to be related to logic added in your modifications.

@IzanCatalan
Copy link
Author

@jackkosaian, thank you for your help, this is my host Python code:

# Compute the output size [N, P, Q, K]
N, P, Q, K = cutlass.Conv2d.output_size((N, H, W, C), (K, R, S, C), padding, stride, dilation)

dtype = torch.int8
type_A = torch.int8
type_B = torch.int8
type_C = torch.int8
type_D = torch.int8

torch.manual_seed(1234)
print("HOST: create tensors")

input = torch.randint(
    low=-128, high=127, size=(N, C, H, W), dtype=type_A, device="cuda"
).to(memory_format=torch.channels_last)

weight = torch.randint(
    low=-128, high=127, size=(K, C, R, S), dtype=type_B, device="cuda"
).to(memory_format=torch.channels_last)

tensor_C = torch.randint(
    low=-128, high=127, size=(N, K, P, Q), dtype=type_C, device="cuda"
).to(memory_format=torch.channels_last)

output = torch.zeros_like(tensor_C)

tensor_D = torch.randint(
    low=-128, high=127, size=(N, C, H, W), dtype=type_D, device="cuda"
).to(memory_format=torch.channels_last)

alpha = 1
beta = 0

plan = cutlass.Conv2dFprop(element=dtype, element_accumulator=torch.int32)
plan.run(input, weight, tensor_C, output, stride, padding, dilation, alpha, beta, print_module=print_module)

As you said, I did not add "masks, masks and device_flag" parameters. However, my goal was to avoid adding this parameter in the user code here.

Can these parameters be auto-initialized or similar from the inside conv.py file or run() function?

It would be better for what I intend to achieve if the user continues to call the constructor function cutlass.Conv2dFprop and run() function as usual, just to not add an additional workload.

@jackkosaian
Copy link
Contributor

Yes, you should be able to do this. You will need to remove these arguments from the constructor and run().

You could then use torch or CUDA Python to allocate and initialize the memory needed for these extra parameters from within run().

@IzanCatalan
Copy link
Author

IzanCatalan commented Mar 21, 2025

@jackkosaian, which part inside run() is the best for initializing these parameters, and which is the proper way, in your opinion, to do it?

I mean, before

arguments = Conv2dArguments(

Because inside Conv2dArguments, there is already get_arguments() function which performs int(to_device_ptr(self.ptr_mask0s)). This, in practice, is returning a cuda pointer, so the pointers, I think, must be initialized before.
What do you think?

@jackkosaian
Copy link
Contributor

Anywhere above there should be fine, but it might make most sense to put it here, where we get the tensors for A, B, etc.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants