File size: 775 Bytes
92455fe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 |
#pragma once
#include <iostream>
#include <vector>
#include "torch/torch.h"
#include "metrics.h"
enum class CheckerMode {
kElementWise,
kRowIndex,
kJustDump,
};
struct Checkee {
torch::Tensor *tensor;
CheckerMode mode;
std::string name;
};
void case_initialize();
int get_params_count();
void *case_get_input(int index);
std::vector<Checkee> case_run_kernel(void *input, PerfMetrics* metrics);
std::vector<Checkee> case_run_ref_kernel(void *input);
const char *case_get_name();
void get_error_tolerance(float *rtol, float *atol);
void case_destroy(void *input);
CheckerMode get_checker_mode();
// using OutputData = torch::Tensor;
// void ref_kernel(const BlockwiseMatmulInputs &data);
// BlockwiseMatmulInputs generate_input(int m, int n, int k, int seed); |