Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
232 changes: 116 additions & 116 deletions src/eval_constants.hpp

Large diffs are not rendered by default.

7 changes: 7 additions & 0 deletions src/eval_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,13 @@ class PScore {
return static_cast<Value>((mg() * alpha + eg() * (max - alpha)) / max);
}

// Eg scaling
template<i32 max>
PScore scale_eg(i32 alpha) const {
assert(0 <= alpha && alpha <= max);
return PScore{mg(), static_cast<i16>(eg() * alpha / max)};
}

// complexity_add
PScore complexity_add(Score val) {
const Score e = eg();
Expand Down
13 changes: 13 additions & 0 deletions src/evaluation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,16 @@ PScore apply_winnable(const Position& pos, PScore& score, usize phase) {
return score.complexity_add(winnable);
}

PScore apply_eg_scale(const Position& pos, PScore& eval) {
// Strong pawn scaling
const Color strong_side = eval.eg() > 0 ? Color::White : Color::Black;

const isize strong_pawn_count = pos.ipiece_count(strong_side, PieceType::Pawn);
const isize pcmul = 8 - strong_pawn_count;

return eval.scale_eg<128>(static_cast<i32>(128 - pcmul * pcmul)); // 64 - 128
}

Score evaluate_white_pov(const Position& pos, const PsqtState& psqt_state) {
const Color us = pos.active_color();
usize phase = pos.piece_count(Color::White, PieceType::Knight)
Expand Down Expand Up @@ -542,6 +552,9 @@ Score evaluate_white_pov(const Position& pos, const PsqtState& psqt_state) {
// Winnable
eval = apply_winnable(pos, eval, phase);

// Eg scaling
eval = apply_eg_scale(pos, eval);

return static_cast<Score>(eval.phase<24>(static_cast<i32>(phase)));
};

Expand Down
9 changes: 9 additions & 0 deletions src/tuning/graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,9 @@ PairHandle Graph::record_pair_scalar(OpType op, PairHandle lhs, f64 scalar) {
case OpType::ScalarDivPair:
res = f64x2::scalar_div(scalar, l);
break;
case OpType::ScaleEg:
res = f64x2::make(l.first(), l.second() * scalar);
break;
default:
break;
}
Expand Down Expand Up @@ -474,6 +477,12 @@ void Graph::backward() {
pair_grads[node.lhs()] = f64x2::add(pair_grads[node.lhs()], update);
break;
}
case OpType::ScaleEg: {
const f64x2 grad_out = pair_grads[out_idx];
f64x2 update = f64x2::make(grad_out.first(), grad_out.second() * node.scalar());
pair_grads[node.lhs()] = f64x2::add(pair_grads[node.lhs()], update);
break;
}
case OpType::PairMulValue:
case OpType::ValueMulPair: {
const f64x2 grad_out = pair_grads[out_idx];
Expand Down
1 change: 1 addition & 0 deletions src/tuning/operations.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ enum class OpType : u32 {
PairDivValue,
ValueDivPair,
PairAddClampedSecond, // For complexity
ScaleEg, // For eg scaling

// Pair-Pair Ops
PairMulPair,
Expand Down
4 changes: 4 additions & 0 deletions src/tuning/value.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -331,4 +331,8 @@ PairHandle PairHandle::complexity_add(ValueHandle value) const {
return Graph::get().record_pair_value(OpType::PairAddClampedSecond, *this, value);
}

PairHandle PairHandle::scale_eg_impl(f64 ratio) const {
return Graph::get().record_pair_scalar(OpType::ScaleEg, *this, ratio);
}

} // namespace Clockwork::Autograd
8 changes: 8 additions & 0 deletions src/tuning/value.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,14 @@ struct PairHandle {
PairHandle sigmoid() const;

PairHandle complexity_add(ValueHandle value) const;

template<i32 max>
PairHandle scale_eg(f64 alpha) const {
return scale_eg_impl(alpha / static_cast<f64>(max));
}

private:
PairHandle scale_eg_impl(f64 ratio) const;
};

// Operation decls
Expand Down
Loading