kyopro_library

This documentation is automatically generated by online-judge-tools/verification-helper

View the Project on GitHub dyktr06/kyopro_library

:heavy_check_mark: test/library_checker/data_structure/unionfind_with_potential_non_commutative_group.test.cpp

Depends on

Code

#define PROBLEM "https://judge.yosupo.jp/problem/unionfind_with_potential_non_commutative_group"
#include <bits/stdc++.h>
using namespace std;

#include "../../../lib/data_structure/weighted_union_find.hpp"
#include "../../../lib/math/modint.hpp"

using mint = ModInt<998244353>;

struct Potential{
    mint A[2][2];
};
Potential operator+(Potential r, Potential l){
    Potential res;
    for(int i = 0; i < 2; i++){
        for(int j = 0; j < 2; j++){
            for(int k = 0; k < 2; k++){
                res.A[i][k] += l.A[i][j] * r.A[j][k];
            }
        }
    }
    return res;
}
Potential operator-(Potential r){
    swap(r.A[0][0], r.A[1][1]);
    r.A[0][1] = -r.A[0][1];
    r.A[1][0] = -r.A[1][0];
    return r;
}
Potential operator-(Potential l, Potential r){
    return l + (-r);
}

int main(){
    ios::sync_with_stdio(false);
    cin.tie(nullptr);

    int n, q; cin >> n >> q;
    WeightedUnionFind<Potential> uf(n, {{{1,0},{0,1}}});
    while(q--){
        int t; cin >> t;
        if(t == 0){
            int u, v; cin >> u >> v;
            Potential x;
            for(int i = 0; i < 2; i++){
                for(int j = 0; j < 2; j++){
                    cin >> x.A[i][j];
                }
            }
            if(uf.same(u, v)){
                Potential y = uf.diff(u, v);
                int is_ok = 1;
                for(int i = 0; i < 2; i++){
                    for(int j = 0; j < 2; j++){
                        if(x.A[i][j] != y.A[i][j]){
                            is_ok = 0;
                        }
                    }
                }
                cout << is_ok << "\n";
            } else{
                uf.unite(u, v, x);
                cout << 1 << "\n";
            }
        } else{
            int u, v; cin >> u >> v;
            if(uf.same(u, v)){
                Potential x = uf.diff(u, v);
                cout << x.A[0][0] << " " << x.A[0][1] << " " << x.A[1][0] << " " << x.A[1][1] << "\n";
            } else{
                cout << -1 << "\n";
            }
        }
    }
}
#line 1 "test/library_checker/data_structure/unionfind_with_potential_non_commutative_group.test.cpp"
#define PROBLEM "https://judge.yosupo.jp/problem/unionfind_with_potential_non_commutative_group"
#include <bits/stdc++.h>
using namespace std;

#line 2 "lib/data_structure/weighted_union_find.hpp"

/**
 * @brief Weighted Union-Find
 * @docs docs/data_structure/weighted_union_find.md
 */

#line 10 "lib/data_structure/weighted_union_find.hpp"

template <typename T>
struct WeightedUnionFind{
    int V;
    std::vector<int> par;
    std::vector<T> diff_weight;
    T ex;

    WeightedUnionFind(const int N, const T &e = 0) : V(N), par(N), diff_weight(N), ex(e){
        for(int i = 0; i < N; ++i){
            par[i] = -1;
            diff_weight[i] = e;
        }
    }

    int root(const int x){
        assert(0 <= x && x < V);
        if(par[x] < 0){
            return x;
        }
        int px = par[x];
        int rx = root(par[x]);
        diff_weight[x] = diff_weight[x] + diff_weight[px];
        return par[x] = rx;
    }

    T weight(const int x){
        root(x);
        return diff_weight[x];
    }

    T diff(const int x, const int y){
        return weight(x) - weight(y);
    }

    void unite(const int x, const int y, const T &w){
        int tx = x, ty = y;
        T tw = w;
        int rx = root(x), ry = root(y);
        if(rx == ry) return;

        if(par[rx] < par[ry]){
            std::swap(rx, ry);
            std::swap(tx, ty);
            tw = ex - tw;
        }

        par[ry] = par[rx] + par[ry];
        par[rx] = ry;
        diff_weight[rx] = ex - diff_weight[tx] + tw + diff_weight[ty];
    }

    bool same(const int x, const int y){
        return root(x) == root(y);
    }

    int size(const int x){
        return -par[root(x)];
    }
};
#line 2 "lib/math/modint.hpp"

#line 5 "lib/math/modint.hpp"

/**
 * @brief ModInt
 * @docs docs/math/modint.md
 */

template <long long Modulus>
struct ModInt{
    long long val;
    static constexpr int mod() { return Modulus; }
    constexpr ModInt(const long long _val = 0) noexcept : val(_val) {
        normalize();
    }
    void normalize(){
        val = (val % Modulus + Modulus) % Modulus;
    }
    inline ModInt &operator+=(const ModInt &rhs) noexcept {
        if(val += rhs.val, val >= Modulus) val -= Modulus;
        return *this;
    }
    inline ModInt &operator-=(const ModInt &rhs) noexcept {
        if(val -= rhs.val, val < 0) val += Modulus;
        return *this;
    }
    inline ModInt &operator*=(const ModInt &rhs) noexcept {
        val = val * rhs.val % Modulus;
        return *this;
    }
    inline ModInt &operator/=(const ModInt &rhs) noexcept {
        val = val * inv(rhs.val).val % Modulus;
        return *this;
    }
    inline ModInt &operator++() noexcept {
        if(++val >= Modulus) val -= Modulus;
        return *this;
    }
    inline ModInt operator++(int) noexcept {
        ModInt t = val;
        if(++val >= Modulus) val -= Modulus;
        return t;
    }
    inline ModInt &operator--() noexcept {
        if(--val < 0) val += Modulus;
        return *this;
    }
    inline ModInt operator--(int) noexcept {
        ModInt t = val;
        if(--val < 0) val += Modulus;
        return t;
    }
    inline ModInt operator-() const noexcept { return (Modulus - val) % Modulus; }
    inline ModInt inv(void) const { return inv(val); }
    ModInt pow(long long n) const {
        assert(0 <= n);
        ModInt x = *this, r = 1;
        while(n){
            if(n & 1) r *= x;
            x *= x;
            n >>= 1;
        }
        return r;
    }
    ModInt inv(const long long n) const {
        long long a = n, b = Modulus, u = 1, v = 0;
        while(b){
            long long t = a / b;
            a -= t * b; std::swap(a, b);
            u -= t * v; std::swap(u, v);
        }
        u %= Modulus;
        if(u < 0) u += Modulus;
        return u;
    }
    friend inline ModInt operator+(const ModInt &lhs, const ModInt &rhs) noexcept { return ModInt(lhs) += rhs; }
    friend inline ModInt operator-(const ModInt &lhs, const ModInt &rhs) noexcept { return ModInt(lhs) -= rhs; }
    friend inline ModInt operator*(const ModInt &lhs, const ModInt &rhs) noexcept { return ModInt(lhs) *= rhs; }
    friend inline ModInt operator/(const ModInt &lhs, const ModInt &rhs) noexcept { return ModInt(lhs) /= rhs; }
    friend inline bool operator==(const ModInt &lhs, const ModInt &rhs) noexcept { return lhs.val == rhs.val; }
    friend inline bool operator!=(const ModInt &lhs, const ModInt &rhs) noexcept { return lhs.val != rhs.val; }
    friend inline std::istream &operator>>(std::istream &is, ModInt &x) noexcept {
        is >> x.val;
        x.normalize();
        return is;
    }
    friend inline std::ostream &operator<<(std::ostream &os, const ModInt &x) noexcept { return os << x.val; }
};
#line 7 "test/library_checker/data_structure/unionfind_with_potential_non_commutative_group.test.cpp"

using mint = ModInt<998244353>;

struct Potential{
    mint A[2][2];
};
Potential operator+(Potential r, Potential l){
    Potential res;
    for(int i = 0; i < 2; i++){
        for(int j = 0; j < 2; j++){
            for(int k = 0; k < 2; k++){
                res.A[i][k] += l.A[i][j] * r.A[j][k];
            }
        }
    }
    return res;
}
Potential operator-(Potential r){
    swap(r.A[0][0], r.A[1][1]);
    r.A[0][1] = -r.A[0][1];
    r.A[1][0] = -r.A[1][0];
    return r;
}
Potential operator-(Potential l, Potential r){
    return l + (-r);
}

int main(){
    ios::sync_with_stdio(false);
    cin.tie(nullptr);

    int n, q; cin >> n >> q;
    WeightedUnionFind<Potential> uf(n, {{{1,0},{0,1}}});
    while(q--){
        int t; cin >> t;
        if(t == 0){
            int u, v; cin >> u >> v;
            Potential x;
            for(int i = 0; i < 2; i++){
                for(int j = 0; j < 2; j++){
                    cin >> x.A[i][j];
                }
            }
            if(uf.same(u, v)){
                Potential y = uf.diff(u, v);
                int is_ok = 1;
                for(int i = 0; i < 2; i++){
                    for(int j = 0; j < 2; j++){
                        if(x.A[i][j] != y.A[i][j]){
                            is_ok = 0;
                        }
                    }
                }
                cout << is_ok << "\n";
            } else{
                uf.unite(u, v, x);
                cout << 1 << "\n";
            }
        } else{
            int u, v; cin >> u >> v;
            if(uf.same(u, v)){
                Potential x = uf.diff(u, v);
                cout << x.A[0][0] << " " << x.A[0][1] << " " << x.A[1][0] << " " << x.A[1][1] << "\n";
            } else{
                cout << -1 << "\n";
            }
        }
    }
}
Back to top page