2021/06/18

实现一个任意1-4维的矩阵类 Tensor

请先阅读主程序,然后实现打印矩阵的函数:

void Tensor_print(int dimensions, const int sizes[], const double data[])

注意

由于只能打印出1维和2维的矩阵,当矩阵大于2维时需要按顺序打出高维矩阵中的各个2维矩阵
EXAMPLE INPUT

1.3

5.1

2.8

6.3

EXAMPLE OUTPUT

Tensor of 5

1.3

1.3

1.3

1.3

1.3



Tensor of 3x4

    5.1    5.1    5.1    5.1

    5.1    5.1    5.1    5.1

    5.1    5.1    5.1    5.1



Tensor of 3x4x5

data[0]

    2.8    2.8    2.8    2.8    2.8

    2.8    2.8    2.8    2.8    2.8

    2.8    2.8    2.8    2.8    2.8

    2.8    2.8    2.8    2.8    2.8

data[1]

    2.8    2.8    2.8    2.8    2.8

    2.8    2.8    2.8    2.8    2.8

    2.8    2.8    2.8    2.8    2.8

    2.8    2.8    2.8    2.8    2.8

data[2]

    2.8    2.8    2.8    2.8    2.8

    2.8    2.8    2.8    2.8    2.8

    2.8    2.8    2.8    2.8    2.8

    2.8    2.8    2.8    2.8    2.8


Tensor of 2x3x4x5

data[0][0]

    6.3    6.3    6.3    6.3    6.3

    6.3    6.3    6.3    6.3    6.3

    6.3    6.3    6.3    6.3    6.3

    6.3    6.3    6.3    6.3    6.3

data[0][1]

    6.3    6.3    6.3    6.3    6.3

    6.3    6.3    6.3    6.3    6.3

    6.3    6.3    6.3    6.3    6.3

    6.3    6.3    6.3    6.3    6.3

data[0][2]

    6.3    6.3    6.3    6.3    6.3

    6.3    6.3    6.3    6.3    6.3

    6.3    6.3    6.3    6.3    6.3

    6.3    6.3    6.3    6.3    6.3

data[1][0]

    6.3    6.3    6.3    6.3    6.3

    6.3    6.3    6.3    6.3    6.3

    6.3    6.3    6.3    6.3    6.3

    6.3    6.3    6.3    6.3    6.3

data[1][1]

    6.3    6.3    6.3    6.3    6.3

    6.3    6.3    6.3    6.3    6.3

    6.3    6.3    6.3    6.3    6.3

    6.3    6.3    6.3    6.3    6.3

data[1][2]

    6.3    6.3    6.3    6.3    6.3

    6.3    6.3    6.3    6.3    6.3

    6.3    6.3    6.3    6.3    6.3

    6.3    6.3    6.3    6.3    6.3

主程序 (不能修改)

#include "source.cpp"

#include <iostream>
using namespace std;

void _assert(bool valid, const char err_msg[]) {
    if (valid) return;
    cout << err_msg << endl;
    exit(1);
}

class Tensor
{
private:
    double * data;
    int sizes[4];
    int dimensions;

public:
    Tensor(int size0, int size1=-1, int size2=-1, int size3=-1) {
        _assert(size0 > 0, "第0维大小必须大于0");
        if (size1 != -1) _assert(size1 > 0, "第1维大小必须大于0");
        if (size2 != -1) _assert(size2 > 0, "第2维大小必须大于0");
        if (size3 != -1) _assert(size3 > 0, "第3维大小必须大于0");

        this->dimensions = 1;
        this->sizes[0] = size0;
        this->sizes[1] = this->sizes[2] = this->sizes[3] = 1;
        if (size1 != -1) {
            this->dimensions = 2;
            this->sizes[1] = size1;
        }
        if (size2 != -1) {
            this->dimensions = 3;
            this->sizes[2] = size2;
        }
        if (size3 != -1) {
            this->dimensions = 4;
            this->sizes[3] = size3;
        }

        int totel_size = this->numel();
        this->data = new double[totel_size];
        for (int i = 0; i < totel_size; ++ i)
            this->data[i] = 0;
    }

    int numel() { // number of elements
        return this->sizes[0] * this->sizes[1] * this->sizes[2] * this->sizes[3];
    }

    void fill(double value) {
        for (int i = 0; i < this->numel(); ++ i)
            this->data[i] = value;
    }

    ~Tensor() {
        delete [] this->data;
    }

    double & get(int x0, int x1=-1, int x2=-1, int x3=-1) {
        // 检查有否越界
        _assert(x0 >= 0 && x0 < this->sizes[0], "第0维越界");
        _assert((this->dimensions < 2 && x1 == -1) || (x1 >= 0 && x1 < this->sizes[1]), "第1维越界");
        _assert((this->dimensions < 3 && x2 == -1) || (x2 >= 0 && x2 < this->sizes[2]), "第2维越界");
        _assert((this->dimensions < 4 && x3 == -1) || (x3 >= 0 && x3 < this->sizes[3]), "第3维越界");

        int index = x0;
        if (this->dimensions > 1) index = index * this->sizes[1] + x1;
        if (this->dimensions > 2) index = index * this->sizes[2] + x2;
        if (this->dimensions > 3) index = index * this->sizes[3] + x3;
        return this->data[index];
    }

    void print() const {
        return Tensor_print(this->dimensions, this->sizes, this->data);        
    }

};

int main() {
    double value;
    // 测试1
    Tensor t1(5);
    cin >> value;
    t1.fill(value);
    t1.print();
    cout << endl;
    
    // 测试2
    Tensor t2(3, 4);
    cin >> value;
    t2.fill(value);
    t2.print();
    cout << endl;

    // 测试3
    Tensor t3(3, 4, 5);
    cin >> value;
    t3.fill(value);
    t3.print();
    cout << endl;

    // 测试4
    Tensor t4(2, 3, 4, 5);
    cin >> value;
    t4.fill(value);
    t4.print();
}

参考答案

#include <iostream>
using namespace std;
void Tensor_print(int dimensions, const int sizes[], const double data[]){
    if(dimensions==1){
        cout << "Tensor of "<<sizes[0]<< endl;
        for (int i = 0; i <sizes[0]; ++ i){
            cout << data[i] <<endl;
        } 
    }
    else if(dimensions==2) {
        cout << "Tensor of "<<sizes[0]<<"x"<<sizes[1]<< endl;
        for (int i = 0; i <sizes[0]; ++ i){
            for (int j = 0; j <sizes[1]; ++ j) {
                cout << "    " << data[i*sizes[1]+j];
            }
            cout<<endl;
        }
    }
    else if(dimensions==3) {
        cout << "Tensor of "<<sizes[0]<<"x"<<sizes[1]<<"x"<<sizes[2]<< endl;
        for (int i = 0; i <sizes[0]; ++ i){
            cout<<"data["<<i<<"]"<<endl;
            for (int j = 0; j <sizes[1]; ++ j) {
                for (int k = 0; k <sizes[2]; ++ k) {
                    cout << "    " << data[i*sizes[1]*sizes[2]+j*sizes[2]+k];
                }
                cout<<endl;
            }
        }
    }
    else if(dimensions==4) {
        cout << "Tensor of "<<sizes[0]<<"x"<<sizes[1]<<"x"<<sizes[2]<<"x"<<sizes[3]<< endl;
        for (int i = 0; i <sizes[0]; ++ i){
            for (int j = 0; j <sizes[1]; ++ j) {
                cout<<"data["<<i<<"]["<<j<<"]"<<endl;
                for (int k = 0; k <sizes[2]; ++ k) {
                    for (int l = 0; l <sizes[3]; ++ l) {
                        cout << "    " << data[i*sizes[1]*sizes[2]*sizes[3]+j*sizes[2]*sizes[3]+k*sizes[3]+l];
                    }
                    cout<<endl;
                }
            }
        }
    }
}
//Provided by Yizuodi
上一篇: [离谱答案]C.3 下一篇: C.2
支持 makedown语法