competitive-programing

This documentation is automatically generated by online-judge-tools/verification-helper

View the Project on GitHub Astral-23/competitive-programing

:heavy_check_mark: verify/treedp.test.cpp

Depends on

Code

#define PROBLEM "https://judge.yosupo.jp/problem/tree_path_composite_sum"
#include "../Utility/template.hpp"
#include "../Utility/modint.hpp"
#include "../Algorithm/treedp.hpp"
vec<ll> A;

using mint = modint998244353;

struct S {
    mint val; int siz;
    S(){}
    S(mint v, int s) : val(v), siz(s){}
};

S op(S l, S r) {
    l.val += r.val;
    l.siz += r.siz;
    return l;
}

S e() {
    return S(0, 0);
}

S addroot(S res, int v) {
    res.val += A[v];
    res.siz += 1;
    return res;
}

struct F{
    mint b, c;
    F(){}
    F(mint t, mint y) : b(t), c(y) {}
};

S mp(F f, S s) {
    s.val = s.val * f.b + f.c * s.siz;
    return s;
}

int main() {
    int n;
    cin >> n;
    A = vec<ll>(n, 0);
    rep(i, 0, n) cin >> A[i];

    TDP<S, op, e, addroot, F, mp> tdp(n);
    rep(i, 0, n-1) {
        int u, v, b, c;
        cin >> u >> v >> b >> c;
        tdp.add_edge(u, v, F(b, c), F(b, c));
    }

    auto ans = tdp.exe();
    rep(i, 0, n) cout << ans[i].val.x << '\n';

}
#line 1 "verify/treedp.test.cpp"
#define PROBLEM "https://judge.yosupo.jp/problem/tree_path_composite_sum"
#line 1 "Utility/template.hpp"
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
#define rep(i, s, t) for (ll i = s; i < (ll)(t); i++)
#define rrep(i, s, t) for (ll i = (ll)(t) - 1; i >= (ll)(s); i--)
#define all(x) begin(x), end(x)

#define TT template <typename T>
TT using vec = vector<T>;
template <class T1, class T2> bool chmin(T1 &x, T2 y) {
    return x > y ? (x = y, true) : false;
}
template <class T1, class T2> bool chmax(T1 &x, T2 y) {
    return x < y ? (x = y, true) : false;
}
struct io_setup {
    io_setup() {
        ios::sync_with_stdio(false);
        std::cin.tie(nullptr);
        cout << fixed << setprecision(15);
    }
} io_setup;

/*
@brief verify用テンプレート
*/
#line 1 "Utility/modint.hpp"

// 動的mod : template<int mod> を消して、上の方で変数modを宣言
template <uint32_t mod> struct modint {
    using mm = modint;
    uint32_t x;
    modint() : x(0) {}
    TT modint(T a = 0) : x((ll(a) % mod + mod)) {
        if (x >= mod) x -= mod;
    }

    friend mm operator+(mm a, mm b) {
        a.x += b.x;
        if (a.x >= mod) a.x -= mod;
        return a;
    }
    friend mm operator-(mm a, mm b) {
        a.x -= b.x;
        if (a.x >= mod) a.x += mod;
        return a;
    }

    mm operator-() const { return mod - x; }

    //+と-だけで十分な場合、以下は省略して良いです。

    friend mm operator*(mm a, mm b) { return (uint64_t)(a.x) * b.x; }
    friend mm operator/(mm a, mm b) { return a * b.inv(); }
    friend mm &operator+=(mm &a, mm b) { return a = a + b; }
    friend mm &operator-=(mm &a, mm b) { return a = a - b; }
    friend mm &operator*=(mm &a, mm b) { return a = a * b; }
    friend mm &operator/=(mm &a, mm b) { return a = a * b.inv(); }

    mm inv() const {
        assert(x != 0);
        return pow(mod - 2);
    }
    mm pow(ll y) const {
        mm res = 1;
        mm v = *this;
        while (y) {
            if (y & 1) res *= v;
            v *= v;
            y /= 2;
        }
        return res;
    }

    friend istream &operator>>(istream &is, mm &a) {
        ll t;
        cin >> t;
        a = mm(t);
        return is;
    }

    friend ostream &operator<<(ostream &os, mm a) { return os << a.x; }

    bool operator==(mm a) { return x == a.x; }
    bool operator!=(mm a) { return x != a.x; }

    bool operator<(const mm &a) const { return x < a.x; }
};
using modint998244353 = modint<998244353>;
using modint1000000007 = modint<1'000'000'007>;
/*
@brief modint
*/
#line 1 "Algorithm/treedp.hpp"
template <class S,
          S (*op)(S, S),
          S (*e)(),
          S (*addroot)(S, int),
          class F,
          S (*mp)(F, S)>
struct TDP {
    using pif = pair<int, F>;
    using vs = vec<S>;
    using vvs = vec<vs>;

    int n;
    vec<vec<pif>> g;
    vvs dp;
    // dp[v][i] :=  (v → g[v][i])の辺について、
    // g[v][i]を根とする部分木の結果
    vs ans;

    TDP(int n) : n(n) {
        g.resize(n);
        dp = vvs(n);
        ans = vs(n, e());
    }

  private:
    S dfs(int v, int p) {
        S res = e();
        int d = g[v].size();
        dp[v].resize(d);
        rep(i, 0, d) {
            int to = g[v][i].first;
            if (to == p) continue;
            dp[v][i] = dfs(to, v);
            res = op(res, mp(g[v][i].second, dp[v][i]));
            // 部分木の結果を集約。
            // 本実装では辺を加味 ->
            // 部分木集約の順を徹底している(辻褄が合うならいつでも良い)
        }
        // 辺・頂点をaddした影響を反映したものを返す。
        return addroot(res, v);
    }

    void bfs(int v, S par, int p) {
        int d = g[v].size();
        rep(i, 0, d) if (g[v][i].first == p) dp[v][i] = par;
        // 親の結果を渡しておく。

        vs L(d + 1, e());
        vs R(d + 1, e());

        rep(i, 0, d) L[i + 1] = op(L[i], mp(g[v][i].second, dp[v][i]));
        rrep(i, 0, d) R[i] = op(mp(g[v][i].second, dp[v][i]), R[i + 1]);
        // 本実装では辺を加味 ->
        // 部分木集約の順を徹底している(辻褄が合うならいつでも良い)

        ans[v] = addroot(L[d], v);
        // 辺・頂点をaddした影響を反映したものを返す。ansに格納する時だけ何か弄りたいならここを弄る。

        rep(i, 0, d) {
            int to = g[v][i].first;
            if (to == p) continue;
            S nx = op(L[i], R[i + 1]);
            // 本実装では辺を加味 ->
            // 部分木集約の順を徹底している(辻褄が合うならいつでも良い)

            bfs(to, addroot(nx, v), v);
            // to -> vの向きに辺・頂点をaddした影響を反映したものを返す。
        }
    }

  public:
    // s -> t に重みfの辺、 t -> sに重みhの辺
    void add_edge(int s, int t, F f, F h) {
        g[s].emplace_back(t, f);
        g[t].emplace_back(s, h);
    }

    vec<S> exe() {
        dfs(0, -1);
        bfs(0, e(), -1);
        return ans;
    }
};

/*
@brief 全方位木DP
@docs doc/treedp.md
*/
#line 5 "verify/treedp.test.cpp"
vec<ll> A;

using mint = modint998244353;

struct S {
    mint val; int siz;
    S(){}
    S(mint v, int s) : val(v), siz(s){}
};

S op(S l, S r) {
    l.val += r.val;
    l.siz += r.siz;
    return l;
}

S e() {
    return S(0, 0);
}

S addroot(S res, int v) {
    res.val += A[v];
    res.siz += 1;
    return res;
}

struct F{
    mint b, c;
    F(){}
    F(mint t, mint y) : b(t), c(y) {}
};

S mp(F f, S s) {
    s.val = s.val * f.b + f.c * s.siz;
    return s;
}

int main() {
    int n;
    cin >> n;
    A = vec<ll>(n, 0);
    rep(i, 0, n) cin >> A[i];

    TDP<S, op, e, addroot, F, mp> tdp(n);
    rep(i, 0, n-1) {
        int u, v, b, c;
        cin >> u >> v >> b >> c;
        tdp.add_edge(u, v, F(b, c), F(b, c));
    }

    auto ans = tdp.exe();
    rep(i, 0, n) cout << ans[i].val.x << '\n';

}
Back to top page