Skip to content

Commit 4d10ff2

Browse files
author
Stephan Schiffels
committed
added new options to bound coalescence rates and to unbound the cross coalescence rate
1 parent 19d47d9 commit 4d10ff2

File tree

4 files changed

+149
-66
lines changed

4 files changed

+149
-66
lines changed

Makefile

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,6 @@ build/maximize : model/*.d powell.d brent.d maximization_step.d logger.d maximiz
1111
build/test/msmc : model/*.d powell.d brent.d maximization_step.d expectation_step.d msmc.d branchlength.d logger.d
1212
dmd -O ${GSL} -odbuild/test -ofbuild/test/msmc $^
1313

14-
build/test : model/*.d test.d
15-
dmd ${GSL} -odbuild -ofbuild/test $^
16-
1714
build/decode : model/*.d decode.d branchlength.d
1815
dmd ${GSL} -odbuild -ofbuild/decode $^
1916

@@ -29,7 +26,7 @@ testcoverage : model/*.d unittest.d powell.d brent.d maximization_step.d expecta
2926
mv *.lst code_coverage/
3027

3128
unittest : model/*.d unittest.d powell.d brent.d maximization_step.d expectation_step.d logger.d branchlength.d
32-
dmd -unittest ${GSL} -odbuild -ofbuild/unittest $^
29+
dmd -debug -unittest ${GSL} -odbuild -ofbuild/unittest $^
3330
build/unittest
3431

3532
clean :

maximization_step.d

Lines changed: 98 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,12 @@ import model.triple_index_marginal;
2929
import powell;
3030
import logger;
3131

32-
MSMCmodel getMaximization(double[] eVec, double[][] eMat, MSMCmodel params, in size_t[] timeSegmentPattern,
33-
bool fixedPopSize, bool fixedRecombination)
34-
{
35-
auto minFunc = new MinFunc(eVec, eMat, params, timeSegmentPattern, fixedPopSize, fixedRecombination);
32+
MSMCmodel getMaximization(double[] eVec, double[][] eMat, MSMCmodel params,
33+
in size_t[] timeSegmentPattern, bool fixedPopSize, bool fixedRecombination, bool boundCrossCoal,
34+
double loBoundLambda, double hiBoundLambda) {
35+
36+
auto minFunc = new MinFunc(eVec, eMat, params, timeSegmentPattern, fixedPopSize,
37+
fixedRecombination, boundCrossCoal, loBoundLambda, hiBoundLambda);
3638

3739
auto powell = new Powell!MinFunc(minFunc);
3840
auto x = minFunc.initialValues();
@@ -50,17 +52,22 @@ class MinFunc {
5052
size_t nrSubpopPairs, nrParams;
5153
const double[] expectationResultVec;
5254
const double[][] expectationResultMat;
53-
bool fixedPopSize, fixedRecombination;
55+
bool fixedPopSize, fixedRecombination, boundCrossCoal;
56+
double loBoundLambda, hiBoundLambda;
5457

55-
this(in double[] expectationResultVec, in double[][] expectationResultMat, MSMCmodel initialParams,
56-
in size_t[] timeSegmentPattern, bool fixedPopSize, bool fixedRecombination)
57-
{
58+
this(in double[] expectationResultVec, in double[][] expectationResultMat, MSMCmodel
59+
initialParams, in size_t[] timeSegmentPattern, bool fixedPopSize, bool fixedRecombination,
60+
bool boundCrossCoal, double loBoundLambda, double hiBoundLambda) {
5861
this.initialParams = initialParams;
5962
this.timeSegmentPattern = timeSegmentPattern;
6063
this.expectationResultVec = expectationResultVec;
6164
this.expectationResultMat = expectationResultMat;
6265
this.fixedPopSize = fixedPopSize;
6366
this.fixedRecombination = fixedRecombination;
67+
this.boundCrossCoal = boundCrossCoal;
68+
this.loBoundLambda = loBoundLambda;
69+
this.hiBoundLambda = hiBoundLambda;
70+
6471
nrSubpopPairs = initialParams.nrSubpopulations * (initialParams.nrSubpopulations + 1) / 2;
6572
nrParams = nrSubpopPairs * cast(size_t)timeSegmentPattern.length;
6673
if(!fixedRecombination)
@@ -81,14 +88,23 @@ class MinFunc {
8188
body {
8289
auto x = getXfromLambdaVec(initialParams.lambdaVec);
8390
if(!fixedRecombination)
84-
x ~= log(initialParams.recombinationRate);
91+
x ~= toScaledRecombination(initialParams.recombinationRate);
8592
return x;
8693
}
8794

95+
double toScaledRecombination(double rec) {
96+
return log(rec);
97+
}
98+
99+
double fromScaledRecombination(double scaledRec) {
100+
return exp(scaledRec);
101+
}
102+
88103
double[] getXfromLambdaVec(double[] lambdaVec)
89104
out(x) {
90105
if(fixedPopSize)
91-
assert(x.length == timeSegmentPattern.length * (nrSubpopPairs - initialParams.nrSubpopulations));
106+
assert(x.length == timeSegmentPattern.length *
107+
(nrSubpopPairs - initialParams.nrSubpopulations));
92108
else
93109
assert(x.length == timeSegmentPattern.length * nrSubpopPairs);
94110
}
@@ -104,30 +120,76 @@ class MinFunc {
104120
auto p2 = initialParams.subpopLabels[triple.ind2];
105121
if(p1 == p2) {
106122
if(!fixedPopSize) {
107-
ret ~= log(lambdaVec[lIndex]);
123+
auto l = lambdaVec[lIndex];
124+
if(l < loBoundLambda)
125+
l = loBoundLambda + 0.000000001;
126+
if(l > hiBoundLambda)
127+
l = hiBoundLambda - 0.000000001;
128+
ret ~= toScaledLambda(l);
108129
}
109130
}
110131
else {
111-
auto marginalIndex1 = initialParams.marginalIndex.subpopulationTripleToMarginalIndexMap[triple.time][p1][p1];
112-
auto marginalIndex2 = initialParams.marginalIndex.subpopulationTripleToMarginalIndexMap[triple.time][p2][p2];
132+
auto marginalIndex1 =
133+
initialParams.marginalIndex.subpopulationTripleToMarginalIndexMap[triple.time][p1][p1];
134+
auto marginalIndex2 =
135+
initialParams.marginalIndex.subpopulationTripleToMarginalIndexMap[triple.time][p2][p2];
113136
auto lambda1 = lambdaVec[marginalIndex1];
114137
auto lambda2 = lambdaVec[marginalIndex2];
115138
auto lambda12 = lambdaVec[lIndex];
116139
if(lambda12 >= 0.5 * (lambda1 + lambda2))
117140
lambda12 = 0.4999999999 * (lambda1 + lambda2);
118-
auto ratio = 2.0 * lambda12 / (lambda1 + lambda2);
119-
ret ~= tan(ratio * PI - PI_2);
141+
ret ~= toScaledCrossLambda(lambda12, lambda1, lambda2);
120142
}
121143
}
122144
count += nrIntervalsInSegment;
123145
}
124146
return ret;
125147
}
126148

149+
double toScaledLambda(double lambda) {
150+
if(hiBoundLambda < double.infinity) {
151+
auto frac = (lambda - loBoundLambda) / (hiBoundLambda - loBoundLambda);
152+
return tan(frac * PI - PI_2);
153+
}
154+
else
155+
return log(lambda - loBoundLambda);
156+
}
157+
158+
double fromScaledLambda(double scaledLambda) {
159+
if(hiBoundLambda < double.infinity) {
160+
auto scaledFrac = (atan(scaledLambda) + PI_2) / PI;
161+
return scaledFrac * (hiBoundLambda - loBoundLambda) + loBoundLambda;
162+
}
163+
else {
164+
return exp(scaledLambda) + loBoundLambda;
165+
}
166+
}
167+
168+
double toScaledCrossLambda(double crossLambda, double lambda1, double lambda2) {
169+
if(boundCrossCoal) {
170+
auto ratio = 2.0 * crossLambda / (lambda1 + lambda2);
171+
return tan(ratio * PI - PI_2);
172+
}
173+
else
174+
return toScaledLambda(crossLambda);
175+
}
176+
177+
double fromScaledCrossLambda(double scaledCrossLambda, double lambda1, double lambda2) {
178+
if(boundCrossCoal) {
179+
auto ratio = (atan(scaledCrossLambda) + PI_2) / PI;
180+
return ratio * 0.5 * (lambda1 + lambda2);
181+
}
182+
else
183+
return fromScaledLambda(scaledCrossLambda);
184+
}
185+
127186
MSMCmodel makeParamsFromVec(in double[] x) {
128187
auto lambdaVec = fixedPopSize ? getLambdaVecFromXfixedPop(x) : getLambdaVecFromX(x);
129-
auto recombinationRate = fixedRecombination ? initialParams.recombinationRate : getRecombinationRateFromX(x);
130-
return new MSMCmodel(initialParams.mutationRate, recombinationRate, initialParams.subpopLabels, lambdaVec, initialParams.nrTimeIntervals, initialParams.nrTtotIntervals, initialParams.emissionRate.directedEmissions);
188+
auto recombinationRate =
189+
fixedRecombination ? initialParams.recombinationRate : getRecombinationRateFromX(x);
190+
return new MSMCmodel(initialParams.mutationRate, recombinationRate, initialParams.subpopLabels,
191+
lambdaVec, initialParams.nrTimeIntervals, initialParams.nrTtotIntervals,
192+
initialParams.emissionRate.directedEmissions);
131193
}
132194

133195
double[] getLambdaVecFromXfixedPop(in double[] x)
@@ -150,13 +212,13 @@ class MinFunc {
150212

151213
if(p1 != p2) {
152214
auto marginalIndex1 =
153-
initialParams.marginalIndex.subpopulationTripleToMarginalIndexMap[triple.time][p1][p1];
154-
auto marginalIndex2 =
155-
initialParams.marginalIndex.subpopulationTripleToMarginalIndexMap[triple.time][p2][p2];
215+
initialParams.marginalIndex.subpopulationTripleToMarginalIndexMap[triple.time][p1][p1];
216+
auto marginalIndex2 =
217+
initialParams.marginalIndex.subpopulationTripleToMarginalIndexMap[triple.time][p2][p2];
156218
auto lambda1 = lambdaVec[marginalIndex1];
157219
auto lambda2 = lambdaVec[marginalIndex2];
158-
auto ratio = (atan(x[segmentIndex * valuesPerTime + xIndex]) + PI_2) / PI;
159-
lambdaVec[lIndex] = ratio * 0.5 * (lambda1 + lambda2);
220+
auto scaledCrossLambda = x[segmentIndex * valuesPerTime + xIndex];
221+
lambdaVec[lIndex] = fromScaledCrossLambda(scaledCrossLambda, lambda1, lambda2);
160222
xIndex += 1;
161223
}
162224
}
@@ -178,7 +240,7 @@ class MinFunc {
178240
foreach(subpopPairIndex; 0 .. nrSubpopPairs) {
179241
auto lIndex = timeIndex * nrSubpopPairs + subpopPairIndex;
180242
auto xIndex = segmentIndex * nrSubpopPairs + subpopPairIndex;
181-
lambdaVec[lIndex] = exp(x[xIndex]);
243+
lambdaVec[lIndex] = fromScaledLambda(x[xIndex]);
182244
}
183245
foreach(subpopPairIndex; 0 .. nrSubpopPairs) {
184246
auto lIndex = timeIndex * nrSubpopPairs + subpopPairIndex;
@@ -190,13 +252,12 @@ class MinFunc {
190252
if(p1 != p2) {
191253
auto xIndex = segmentIndex * nrSubpopPairs + subpopPairIndex;
192254
auto marginalIndex1 =
193-
initialParams.marginalIndex.subpopulationTripleToMarginalIndexMap[triple.time][p1][p1];
255+
initialParams.marginalIndex.subpopulationTripleToMarginalIndexMap[triple.time][p1][p1];
194256
auto marginalIndex2 =
195-
initialParams.marginalIndex.subpopulationTripleToMarginalIndexMap[triple.time][p2][p2];
257+
initialParams.marginalIndex.subpopulationTripleToMarginalIndexMap[triple.time][p2][p2];
196258
auto lambda1 = lambdaVec[marginalIndex1];
197259
auto lambda2 = lambdaVec[marginalIndex2];
198-
auto ratio = (atan(x[xIndex]) + PI_2) / PI;
199-
lambdaVec[lIndex] = ratio * 0.5 * (lambda1 + lambda2);
260+
lambdaVec[lIndex] = fromScaledCrossLambda(x[xIndex], lambda1, lambda2);
200261
}
201262
}
202263
timeIndex += 1;
@@ -211,17 +272,19 @@ class MinFunc {
211272
assert(!fixedRecombination);
212273
}
213274
body {
214-
return exp(x[$ - 1]);
275+
return fromScaledRecombination(x[$ - 1]);
215276
}
216277

217278
double logLikelihood(MSMCmodel params) {
218279
double ret = 0.0;
219280
foreach(au; 0 .. initialParams.nrMarginals) {
220281
foreach(bv; 0 .. initialParams.nrMarginals) {
221-
ret += expectationResultMat[au][bv] * log(params.transitionRate.transitionProbabilityQ2(au, bv));
282+
ret +=
283+
expectationResultMat[au][bv] * log(params.transitionRate.transitionProbabilityQ2(au, bv));
222284
}
223285
ret += expectationResultVec[au] * log(
224-
params.transitionRate.transitionProbabilityQ1(au) + params.transitionRate.transitionProbabilityQ2(au, au)
286+
params.transitionRate.transitionProbabilityQ1(au) +
287+
params.transitionRate.transitionProbabilityQ2(au, au)
225288
);
226289
}
227290
return ret;
@@ -239,26 +302,26 @@ unittest {
239302
auto expectationResultMat = new double[][](params.nrMarginals, params.nrMarginals);
240303
auto timeSegmentPattern = [2UL, 2];
241304

242-
auto minFunc = new MinFunc(expectationResultVec, expectationResultMat, params, timeSegmentPattern, false, false);
305+
auto minFunc = new MinFunc(expectationResultVec, expectationResultMat, params, timeSegmentPattern, false, false, false, 0, double.infinity);
243306
auto rho = 0.001;
244307
auto x = minFunc.getXfromLambdaVec(lambdaVec);
245-
x ~= log(rho);
308+
x ~= minFunc.toScaledRecombination(rho);
246309
auto lambdaFromX = minFunc.getLambdaVecFromX(x);
247310
auto rhoFromX = minFunc.getRecombinationRateFromX(x);
248311
foreach(i; 0 .. lambdaVec.length)
249312
assert(approxEqual(lambdaFromX[i], lambdaVec[i], 1.0e-8, 0.0), text(lambdaFromX[i], " ", lambdaVec[i]));
250313
assert(approxEqual(rhoFromX, rho, 1.0e-8, 0.0));
251314

252-
minFunc = new MinFunc(expectationResultVec, expectationResultMat, params, timeSegmentPattern, true, false);
315+
minFunc = new MinFunc(expectationResultVec, expectationResultMat, params, timeSegmentPattern, true, false, false, 0, double.infinity);
253316
x = minFunc.getXfromLambdaVec(lambdaVec);
254-
x ~= log(rho);
317+
x ~= minFunc.toScaledRecombination(rho);
255318
lambdaFromX = minFunc.getLambdaVecFromXfixedPop(x);
256319
rhoFromX = minFunc.getRecombinationRateFromX(x);
257320
foreach(i; 0 .. lambdaVec.length)
258321
assert(approxEqual(lambdaFromX[i], lambdaVec[i], 1.0e-8, 0.0), text(lambdaFromX[i], " ", lambdaVec[i]));
259322
assert(approxEqual(rhoFromX, rho, 1.0e-8, 0.0));
260323

261-
minFunc = new MinFunc(expectationResultVec, expectationResultMat, params, timeSegmentPattern, false, true);
324+
minFunc = new MinFunc(expectationResultVec, expectationResultMat, params, timeSegmentPattern, false, true, false, 0, double.infinity);
262325
x = minFunc.getXfromLambdaVec(lambdaVec);
263326
lambdaFromX = minFunc.getLambdaVecFromX(x);
264327
foreach(i; 0 .. lambdaVec.length)

msmc.d

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,10 @@ auto memory = false;
5252
auto naiveImplementation = false;
5353
auto fixedPopSize = false;
5454
auto fixedRecombination = false;
55+
auto unboundedCrossCoal = false;
56+
double loBoundLambda = 0.0;
57+
double hiBoundLambda = double.infinity;
58+
5559
bool directedEmissions = false;
5660
bool skipAmbiguous = false;
5761
string[] inputFileNames, treeFileNames;
@@ -152,37 +156,52 @@ void parseCommandLine(string[] args) {
152156
"fixedPopSize", &fixedPopSize,
153157
"fixedRecombination|R", &fixedRecombination,
154158
"initialLambdaVec", &handleLambdaVecString,
155-
"treeFileNames", &handleTreeFileNames
159+
"treeFileNames", &handleTreeFileNames,
160+
"unboundedCrossCoal", &unboundedCrossCoal,
161+
"loBoundLambda", &loBoundLambda,
162+
"hiBoundLambda", &hiBoundLambda
156163
);
157164
if(nrThreads)
158165
std.parallelism.defaultPoolThreads(nrThreads);
159166
enforce(args.length > 1, "need at least one input file");
160167
enforce(hmmStrideWidth > 0, "hmmStrideWidth must be positive");
161168
inputFileNames = args[1 .. $];
169+
162170
if(indices.length == 0)
163171
inferDefaultIndices();
172+
164173
if(subpopLabels.length == 0)
165174
inferDefaultSubpopLabels(indices.length);
175+
166176
enforce(indices.length == subpopLabels.length, "nr haplotypes in subpopLabels and indices must be equal");
177+
167178
inputData = readDataFromFiles(inputFileNames, directedEmissions, indices, skipAmbiguous);
179+
168180
if(isNaN(mutationRate)) {
169181
stderr.write("estimating scaled mutation rate: ");
170182
mutationRate = getTheta(inputData, indices.length) / 2.0;
171183
stderr.writeln(mutationRate);
172184
}
173185
recombinationRate = mutationRate * rhoOverMu;
174186
nrTimeSegments = timeSegmentPattern.reduce!"a+b"();
187+
175188
if(nrTtotSegments == 0)
176189
nrTtotSegments = nrTimeSegments;
190+
177191
auto nrSubpops = MarginalTripleIndex.computeNrSubpops(subpopLabels);
192+
178193
auto nrMarginals = nrTimeSegments * nrSubpops * (nrSubpops + 1) / 2;
194+
179195
if(lambdaVec.length > 0) {
180196
// this is necessary because we read in a scaled lambdaVec.
181197
lambdaVec[] *= mutationRate;
182198
enforce(lambdaVec.length == nrMarginals, "initialLambdaVec must have correct length");
183199
}
184200
enforce(treeFileNames.length == 0 || treeFileNames.length == inputFileNames.length);
185201

202+
enforce(hiBoundLambda > loBoundLambda,
203+
"higher bound of lambda needs to be higher than the lower bounda");
204+
186205
logFileName = outFilePrefix ~ ".log";
187206
loopFileName = outFilePrefix ~ ".loop.txt";
188207
finalFileName = outFilePrefix ~ ".final.txt";
@@ -207,7 +226,10 @@ void printGlobalParams() {
207226
logInfo(format("hmmStrideWidth: %s\n", hmmStrideWidth));
208227
logInfo(format("fixedPopSize: %s\n", fixedPopSize));
209228
logInfo(format("fixedRecombination: %s\n", fixedRecombination));
229+
logInfo(format("unboundedCrossCoal: %s\n", unboundedCrossCoal));
210230
logInfo(format("initialLambdaVec: %s\n", lambdaVec));
231+
logInfo(format("loBoundLambda: %s\n", loBoundLambda));
232+
logInfo(format("hiBoundLambda: %s\n", hiBoundLambda));
211233
logInfo(format("directedEmissions: %s\n", directedEmissions));
212234
logInfo(format("skipAmbiguous: %s\n", skipAmbiguous));
213235
logInfo(format("indices: %s\n", indices));
@@ -232,11 +254,11 @@ void inferDefaultSubpopLabels(size_t nrHaplotypes) {
232254
void run() {
233255
MSMCmodel params;
234256
if(lambdaVec.length > 0)
235-
params = new MSMCmodel(mutationRate, recombinationRate, subpopLabels, lambdaVec, nrTimeSegments, nrTtotSegments,
236-
directedEmissions);
257+
params = new MSMCmodel(mutationRate, recombinationRate, subpopLabels, lambdaVec,
258+
nrTimeSegments, nrTtotSegments, directedEmissions);
237259
else
238-
params = MSMCmodel.withTrivialLambda(mutationRate, recombinationRate, subpopLabels, nrTimeSegments, nrTtotSegments,
239-
directedEmissions);
260+
params = MSMCmodel.withTrivialLambda(mutationRate, recombinationRate, subpopLabels,
261+
nrTimeSegments, nrTtotSegments, directedEmissions);
240262

241263
auto nrFiles = inputData.length;
242264
if(params.nrHaplotypes > 2) {
@@ -271,7 +293,8 @@ void run() {
271293
auto filename = outFilePrefix ~ format(".loop_%s.expectationMatrix.txt", iteration);
272294
printMatrix(filename, eVec, eMat);
273295
}
274-
auto newParams = getMaximization(eVec, eMat, params, timeSegmentPattern, fixedPopSize, fixedRecombination);
296+
auto newParams = getMaximization(eVec, eMat, params, timeSegmentPattern, fixedPopSize,
297+
fixedRecombination, !unboundedCrossCoal, loBoundLambda, hiBoundLambda);
275298
params = newParams;
276299
}
277300

0 commit comments

Comments
 (0)